Source code for scimba_torch.approximation_space.abstract_space
"""Defines an abstract class for an approximation space."""fromabcimportABC,abstractmethodfromcopyimportdeepcopyfromtypingimportGeneratorfromwarningsimportwarnimporttorchfromscimba_torch.integration.monte_carloimportTensorizedSamplerfromscimba_torch.utils.scimba_tensorsimportLabelTensor,MultiLabelTensor### Describe a numerical model with where we gives a transition between 2 times### We assume that the formulation between two time can be non explicit
[docs]classAbstractApproxSpace(ABC):"""Abstract class for an approximation space. This class provides the base structure for approximation spaces, including methods for gradient computation, evaluation, and handling degrees of freedom. Args: nb_unknowns: Number of unknowns in the approximation space. **kwargs: Additional keyword arguments. """type_space:str#: Type of the approximation space.integrator:TensorizedSampler#: An integrator for tensorized sampling.ndof:int#: Number of degrees of freedom.def__init__(self,nb_unknowns:int,**kwargs):self.nb_unknowns:int=(nb_unknowns#: Number of unknowns in the approximation space.)#: dictionary to store the best approximation state.self.best_approx:dict={}
[docs]defgrad(self,w:torch.Tensor|MultiLabelTensor,y:torch.Tensor|LabelTensor,)->torch.Tensor|Generator[torch.Tensor,None,None]:"""Computes the gradient of `w` with respect to `y`. Args: w: The tensor to differentiate. y: The tensor with respect to which the gradient is computed. Returns: torch.Tensor | Generator[torch.Tensor, None, None]: The gradient tensor. Raises: ValueError: If `w` and `y` are not compatible tensor types or shapes. """ifisinstance(w,MultiLabelTensor):w=w.wifnotisinstance(y,LabelTensor):raiseValueError("y must be a LabelTensor. You must differentiate with respect to ""all coordinates.")if(w.ndim>1)and(notw.shape[1]==1):raiseValueError("this function must call on a scalar unknown (shape =(batch,1)). ""Call get_component before to extract the component")ifisinstance(y,LabelTensor):y=y.xones=torch.ones_like(w)grad_output=torch.autograd.grad(w,y,ones,create_graph=True,allow_unused=True)[0]ify.size(1)>1:return(grad_output[:,i,None]foriinrange(y.size(1)))else:returngrad_output[:,0,None]
[docs]@abstractmethoddefevaluate(self,*args:LabelTensor,with_last_layer:bool=True)->MultiLabelTensor:"""Evaluates the approximation space. Args: *args: Input tensors for evaluation. with_last_layer: Whether to include the last layer in evaluation. (Default value = True) Returns: The result of the evaluation. """pass
[docs]@abstractmethoddefjacobian(self,*args:LabelTensor)->torch.Tensor:"""Computes the Jacobian of the approximation space. Args: *args: Input tensors for Jacobian computation. Returns: The Jacobian tensor. """pass
[docs]@abstractmethoddefset_dof(self,theta:torch.Tensor,flag_scope:str)->None:"""Sets the degrees of freedom for the approximation space. Args: theta: Tensor representing the degrees of freedom. flag_scope: Scope flag for setting degrees of freedom. """pass
[docs]@abstractmethoddefget_dof(self,flag_scope:str,flag_format:str)->torch.Tensor|list:"""Gets the degrees of freedom for the approximation space. Args: flag_scope: Scope flag for getting degrees of freedom. flag_format: Format flag for the degrees of freedom. Returns: The degrees of freedom. """pass
[docs]defdict_for_save(self)->dict:"""Returns a dictionary representing the space that can be stored/saved. Returns: A dictionary representing the space. """assertisinstance(self,torch.nn.Module)state_dict={"current_state_dict":deepcopy(self.state_dict())}if"model_state_dict"inself.best_approx:state_dict["best_state_dict"]=deepcopy(self.best_approx["model_state_dict"])returnstate_dict
[docs]defload_from_dict(self,checkpoint:dict)->None:"""Restores the space from a dictionary. Args: checkpoint: dictionary containing the state to restore. """assertisinstance(self,torch.nn.Module)self.load_state_dict(checkpoint["current_state_dict"])if"best_state_dict"incheckpoint:self.best_approx["model_state_dict"]=checkpoint["best_state_dict"]else:self.best_approx={}self.eval()
[docs]defupdate_best_approx(self)->None:"""Updates the best approximation state to the current approximation state."""assertisinstance(self,torch.nn.Module)self.best_approx["model_state_dict"]=deepcopy(self.state_dict())
[docs]defload_from_best_approx(self)->None:"""Loads the current approximation state from the best approximation state. Notes: If no best approximation has been saved with `update_best_approx()` yet, raises a warning and does nothing. """assertisinstance(self,torch.nn.Module)if"model_state_dict"inself.best_approx:self.load_state_dict(self.best_approx["model_state_dict"])self.eval()else:warn("self.best_approx has no key model_state_dict; nothing will happen; ""perhaps update_best_approx has not been called",RuntimeWarning,)