Source code for scimba_torch.neural_nets.coordinates_based_nets.scimba_module
"""Base module for Scimba neural networks."""importtorchimporttorch.nnasnn
[docs]classScimbaModule(nn.Module):"""Abstract class for Scimba neural networks. Args: in_size: Input dimension. out_size: Output dimension. **kwargs: Additional keyword arguments. """def__init__(self,in_size:int,out_size:int,**kwargs):super().__init__()self.in_size=in_size#: Input dimension.self.out_size=out_size#: Output dimension.self.output_layer=None#: Output layer module (to be set by subclasses).
[docs]defparameters(self,flag_scope:str="all",flag_format:str="list"):"""Get parameters of the neural net. Args: flag_scope: Specifies which parameters to return. Options: 'all', 'last_layer', 'except_last_layer'. flag_format: Specifies the format Options: 'list', 'tensor'. Returns: list[nn.Parameter] or torch.Tensor Raises: ValueError: If flag_scope is not one of the supported options. """ifflag_scope=="all":param_iter=super().parameters()elifflag_scope=="last_layer":param_iter=self.output_layer.parameters()elifflag_scope=="except_last_layer":param_iter=(paramforname,paraminself.named_parameters()ifnotname.startswith("output_layer"))else:raiseValueError(f"Unknown flag_scope: {flag_scope}")ifflag_format=="list":returnlist(param_iter)elifflag_format=="tensor":returntorch.nn.utils.parameters_to_vector(param_iter)else:raiseValueError(f"Unknown flag_format: {flag_format}")
[docs]defset_parameters(self,new_params:torch.Tensor,flag_scope:str="all"):"""Set parameters. Args: new_params: new parameters. flag_scope: 'all', 'last_layer', 'except_last_layer' Raises: ValueError: If flag_scope is not one of the supported options. """ifflag_scope=="all":param_iter=super().parameters()elifflag_scope=="last_layer":param_iter=self.output_layer.parameters()elifflag_scope=="except_last_layer":param_iter=(paramforname,paraminself.named_parameters()ifnotname.startswith("output_layer"))else:raiseValueError(f"Unknown flag_scope: {flag_scope}")torch.nn.utils.vector_to_parameters(new_params,param_iter)