Source code for scimba_torch.utils.typing_protocols
"""Common typing protocols to replace mypy_extensions usage.This module provides type-safe alternatives to mypy_extensions constructsusing the standard library's typing.Protocol."""fromtypingimportAny,Protocolimporttorch
[docs]classVarArgCallable(Protocol):"""Protocol for callable with variable torch.Tensor arguments. Replaces: Callable[[VarArg(torch.Tensor)], torch.Tensor] """def__call__(self,*args:torch.Tensor)->torch.Tensor:...# noqa: D102
[docs]classVarArgAnyCallable(Protocol):"""Protocol for callable with variable arguments of any type. Replaces: Callable[[VarArg(Any)], torch.Tensor] """def__call__(self,*args:Any)->torch.Tensor:...# noqa: D102
[docs]classFuncTypeCallable(Protocol):"""Protocol for functions taking a tensor x and keyword arguments. Replaces: Callable[[Arg(torch.Tensor, "x"), KwArg(Any)], torch.Tensor] """def__call__(self,x:torch.Tensor,**kwargs:Any)->torch.Tensor:# noqa: D102...
[docs]classFuncFuncArgsCallable(Protocol):"""Protocol for higher-order functions. This function takes another function and additional args. Replaces: Callable[[TYPE_FUNC_ARGS, VarArg(TYPE_ARGS)], torch.Tensor] """def__call__(self,func:VarArgCallable,*args:torch.Tensor)->torch.Tensor:...# noqa: D102
# Convenient type aliases for backward compatibilityFUNC_TYPE=FuncTypeCallableTYPE_FUNC_ARGS=VarArgCallableTYPE_FUNC_FUNC_ARGS=FuncFuncArgsCallable