r"""Solves a 2D reaction-diffusion equation. .. math:::: \partial_t u - D \Delta u - u = 0 The equation is solved using an explicit discrete PINN, and a Natural Gradient optimizer. """ import torch from scimba_torch.approximation_space.nn_space import NNxtSpace from scimba_torch.domain.meshless_domain.domain_2d import Square2D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.integration.monte_carlo_time import UniformTimeSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.temporal_pde.pinns import ( NaturalGradientTemporalPinns, ) from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import ( StrongFormEllipticPDE, ) from scimba_torch.physical_models.temporal_pde.abstract_temporal_pde import ( FirstOrderTemporalPDE, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces from scimba_torch.utils.scimba_tensors import LabelTensor def zeros_rhs(w, t, x, mu, nb_func: int = 1): return torch.zeros(x.shape[0], nb_func) def zeros_bc_rhs(w, t, x, n, mu, nb_func: int = 1): return torch.zeros(x.shape[0], nb_func) class FisherKPPSpaceComponent(StrongFormEllipticPDE): def __init__(self, space, f, g, final_time, **kwargs): super().__init__( space, linear=True, residual_size=1, bc_residual_size=1, **kwargs ) self.f = f self.g = g self.final_time = final_time def rhs(self, w, x, mu): return self.f(x, mu) def operator(self, w, x, mu): u = w.get_components() D, r = mu.get_components() grad_u = torch.cat(tuple(self.grad(u, x)), dim=-1) space_dim = grad_u.shape[1] laplacian_u = [tuple(self.grad(grad_u[:, i], x))[i] for i in range(space_dim)] laplacian_u = torch.sum(torch.stack(laplacian_u, dim=-1), dim=-1) return -self.final_time * (D * laplacian_u + r * u * (1 - u)) def bc_rhs(self, w, x, n, mu): return self.g(x, mu) def bc_operator(self, w, x, n, mu): u = w.get_components() grad_u = torch.cat(tuple(self.grad(u, x)), dim=-1) n_ = torch.cat(tuple(n.get_components()), dim=-1) return torch.sum(grad_u * n_, dim=-1, keepdim=True) class FisherKPP(FirstOrderTemporalPDE): def __init__(self, space, init, final_time, f=zeros_rhs, g=zeros_bc_rhs, **kwargs): super().__init__(space, linear=True, **kwargs) self.space_component = FisherKPPSpaceComponent(space, f, g, final_time) self.f = f self.g = g self.final_time = final_time self.init = init self.ic_residual_size = 1 def space_operator(self, w, t, x, mu): return self.space_component.operator(w, x, mu) def bc_operator(self, w, t, x, n, mu): return self.space_component.bc_operator(w, x, n, mu) def rhs(self, w, t, x, mu: LabelTensor): return self.f(w, t, x, mu) def bc_rhs(self, w, t, x, n, mu): return self.g(w, t, x, n, mu) def initial_condition(self, x, mu): return self.init(x, mu) def functional_operator(self, func, t, x, mu, theta): u = func(t, x, mu, theta) D, r = mu[0], mu[1] time_op = torch.func.jacrev(func, 0)(t, x, mu, theta) grad_u = torch.func.jacrev(func, 1) grad_grad_u = torch.func.jacrev(grad_u, 1)(t, x, mu, theta).squeeze() laplacian_u = torch.einsum("ii", grad_grad_u) return time_op[0] - self.final_time * (D * laplacian_u[None] + r * u * (1 - u)) def functional_operator_bc(self, func, t, x, n, mu, theta): grad_u = torch.func.jacrev(func, 1)(t, x, mu, theta) return grad_u @ n def functional_operator_ic(self, func, x, mu, theta): return func(torch.zeros(1), x, mu, theta) def f_exact(t: LabelTensor, x: LabelTensor, mu: LabelTensor): return f_ini(x, mu) def f_ini(x: LabelTensor, mu): center = [10, 10] x1, x2 = x.get_components() r2 = (x1 - center[0]) ** 2 + (x2 - center[1]) ** 2 return torch.exp(-r2 / 10.0) def pre_processing(t, x: LabelTensor, mu: LabelTensor): r2 = torch.sum((x.x - 10) ** 2, dim=-1, keepdim=True) / 10**2 return torch.cat([t.x, x.x, r2, mu.x], dim=-1) def functional_pre_processing(*args): r2 = torch.sum((args[1] - 10) ** 2, dim=-1, keepdim=True) / 10**2 return torch.cat([args[0], args[1], r2, args[2]], dim=-1) def create_sampler(domain_mu): domain_x = Square2D([(0, 50), (0, 50)], is_main_domain=True) t_min, t_max = 0.0, 1.0 domains = {"x": domain_x, "mu": domain_mu, "t": (t_min, t_max)} sampler = TensorizedSampler( [ UniformTimeSampler((t_min, t_max)), DomainSampler(domain_x), UniformParametricSampler(domain_mu), ] ) return domains, sampler def create_pinn(domains, sampler, final_time=50): space = NNxtSpace( 1, 2, GenericMLP, domains["x"], sampler, layer_sizes=[16, 16, 16], pre_processing=pre_processing, pre_processing_out_size=6, # t, x1, x2, r^2, D, r ) pde = FisherKPP(space, init=f_ini, final_time=final_time) return NaturalGradientTemporalPinns( pde, ic_type="weak", bc_type="weak", bc_weight=100, ic_weight=100, functional_pre_processing=functional_pre_processing, ) def solve( list_of_parameter_domains: list[list[tuple[float, float]]], plot_each_domain: bool = True, ) -> NaturalGradientTemporalPinns: """Solves the problem for the reaction-diffusion equation.""" n_colloc = { "n_collocation": 3_000, "n_bc_collocation": 2_000, "n_ic_collocation": 2_000, } domains, sampler = create_sampler(list_of_parameter_domains[0]) pinn = create_pinn(domains, sampler) for domain_mu in list_of_parameter_domains: print(f"Solving for domain_mu: {domain_mu}") domains["mu"] = domain_mu pinn.update_parameter_bounds(domain_mu) pinn.solve(epochs=200, **n_colloc) if plot_each_domain: plot_abstract_approx_spaces( pinn.space, domains["x"], domains["mu"], domains["t"], time_values=3, loss=pinn.losses, exact=f_exact, error=f_exact, ) return domains, pinn def evaluate_at_t_mu(pinn, t, D, r, nx=100): X = torch.linspace(0, 50, nx) x, y = torch.meshgrid(X, X, indexing="ij") x = torch.cat([x.reshape(-1, 1), y.reshape(-1, 1)], dim=-1) ones = torch.ones((x.shape[0], 1)) t = ones * t mu = torch.cat([ones * D, ones * r], dim=-1) u = pinn.evaluate(LabelTensor(t), LabelTensor(x), LabelTensor(mu)).w return u.reshape((nx, nx)) # %% if __name__ == "__main__": domains, pinn = solve( [ [(0.1, 0.1), (0.1, 0.1)], # D and r [(0.1, 0.2), (0.05, 0.15)], [(0.05, 0.5), (0.05, 0.15)], [(0.05, 1.0), (0.05, 0.15)], ] ) plot_abstract_approx_spaces( pinn.space, domains["x"], domains["mu"], domains["t"], time_values=3, loss=pinn.losses, exact=f_exact, error=f_exact, ) u = evaluate_at_t_mu(pinn, t=0.8, D=0.75, r=0.125) import matplotlib.pyplot as plt plt.imshow(u.detach().numpy(), extent=(0, 50, 0, 50), origin="lower", cmap="turbo") plt.colorbar() plt.show()