diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index 43f18078f..8494df8b0 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -18,6 +18,7 @@ "DeepEnsembleSupervisedSolver", "DeepEnsemblePINN", "GAROM", + "AutoregressiveSolver", ] from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface @@ -41,3 +42,8 @@ DeepEnsemblePINN, ) from .garom import GAROM +from .autoregressive_solver import ( + AutoregressiveSolver, + AutoregressiveSolverInterface, + UnrollInstructions, +) diff --git a/pina/solver/autoregressive_solver/__init__.py b/pina/solver/autoregressive_solver/__init__.py new file mode 100644 index 000000000..ac0d60a12 --- /dev/null +++ b/pina/solver/autoregressive_solver/__init__.py @@ -0,0 +1,5 @@ +__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"] + +from .autoregressive_solver import AutoregressiveSolver +from .autoregressive_solver_interface import AutoregressiveSolverInterface +from .autoregressive_solver_interface import UnrollInstructions diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py new file mode 100644 index 000000000..d91e1e254 --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -0,0 +1,225 @@ +import torch +from pina.utils import check_consistency +from pina.solver.solver import SingleSolverInterface +from pina.condition import DataCondition +from .autoregressive_solver_interface import AutoregressiveSolverInterface +from .autoregressive_solver_interface import UnrollInstructions +from typing import List + + +class AutoregressiveSolver( + AutoregressiveSolverInterface, SingleSolverInterface +): + r""" + Autoregressive Solver for learning dynamical systems. + + This solver learns a one-step transition function + :math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps + a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`. + + During training, the model is unrolled over multiple time steps to + learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`, + the model generates predictions recursively: + + .. math:: + \hat{\mathbf{y}}_{t+1} = \mathcal{M}(\hat{\mathbf{y}}_t), + \quad \hat{\mathbf{y}}_0 = \mathbf{y}_0 + + The loss is computed over the entire unroll window: + + .. math:: + \mathcal{L} = \sum_{t=1}^{T} w_t \|\hat{\mathbf{y}}_t - \mathbf{y}_t\|^2 + + where :math:`w_t` are exponential weights (if ``eps`` is specified) + that down-weight later predictions to stabilize training. + """ + + accepted_conditions_types = DataCondition + + def __init__( + self, + unroll_instructions_list: List[UnrollInstructions], + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=False, + ): + """ + Initialization of the :class:`AutoregressiveSolver` class. + + :param list[UnrollInstructions] unroll_instructions_list: List of + :class:`UnrollInstructions` specifying how to create training + windows for each condition. + :param AbstractProblem problem: The problem instance containing + the time series data conditions. + :param torch.nn.Module model: Neural network that predicts the + next state given the current state. + :param torch.nn.Module loss: Loss function to minimize. + If ``None``, :class:`torch.nn.MSELoss` is used. + Default is ``None``. + :param TorchOptimizer optimizer: Optimizer for training. + If ``None``, :class:`torch.optim.Adam` is used. + Default is ``None``. + :param TorchScheduler scheduler: Learning rate scheduler. + If ``None``, no scheduling is applied. Default is ``None``. + :param WeightingInterface weighting: Weighting scheme for + combining losses from multiple conditions. + If ``None``, uniform weighting is used. Default is ``None``. + :param bool use_lt: Whether to use LabelTensors. + Default is ``False``. + """ + + super().__init__( + unroll_instructions_list=unroll_instructions_list, + problem=problem, + model=model, + loss=loss, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + + def loss_data(self, data, unroll_instructions: UnrollInstructions): + """ + Compute the data loss for the recursive autoregressive solver. + + Creates unroll windows from the data, then iteratively predicts + each next state and computes the loss against the ground truth. + + :param torch.Tensor data: Time series with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + for window creation and loss weighting. + :return: Weighted sum of step losses. + :rtype: torch.Tensor + """ + + initial_data, unroll_data = self.create_unroll_windows( + data, unroll_instructions + ) + current_state = initial_data # [num_unrolls, features] + + losses = [] + for step in range(unroll_instructions.unroll_length): + + predicted_state = self.forward( + current_state + ) # [num_unrolls, features] + target_state = unroll_data[:, step, :] # [num_unrolls, features] + step_loss = self._loss_fn(predicted_state, target_state) + losses.append(step_loss) + current_state = predicted_state + + step_losses = torch.stack(losses) # [unroll_length] + + with torch.no_grad(): + eps = unroll_instructions.eps + if eps is None: + weights = torch.ones_like(step_losses) + else: + weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) + weights = weights / weights.sum() + + return (step_losses * weights).sum() + + def create_unroll_windows( + self, data, unroll_instructions: UnrollInstructions + ): + """ + Create unroll windows from time series data. + + Slices the input time series into overlapping windows, each + consisting of an initial state and subsequent target states. + + :param torch.Tensor data: Time series with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + specifying window length and count. + :return: Tuple of ``(initial_data, unroll_data)`` where: + + - ``initial_data``: Shape ``[num_unrolls, n_features]`` + - ``unroll_data``: Shape ``[num_unrolls, unroll_length, n_features]`` + + :rtype: tuple[torch.Tensor, torch.Tensor] + """ + + unroll_length = unroll_instructions.unroll_length + + start_list = [] + unroll_list = [] + for starting_index in self.decide_starting_indices( + data, unroll_instructions + ): + idx = starting_index.item() + start_list.append(data[idx]) + unroll_list.append(data[idx + 1 : idx + 1 + unroll_length, :]) + + initial_data = torch.stack(start_list) # [num_unrolls, features] + unroll_data = torch.stack( + unroll_list + ) # [num_unrolls, unroll_length, features] + return initial_data, unroll_data + + def decide_starting_indices( + self, data, unroll_instructions: UnrollInstructions + ): + """ + Determine starting indices for unroll windows. + + Computes valid starting positions ensuring each window has + enough subsequent time steps for the specified unroll length. + + :param torch.Tensor data: Time series with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + with ``unroll_length``, ``num_unrolls``, and ``randomize``. + :return: 1D tensor of starting indices. + :rtype: torch.Tensor + """ + n_step, n_features = data.shape + num_unrolls = unroll_instructions.num_unrolls + if unroll_instructions.unroll_length >= n_step: + return [] # TODO: decide if it is better to raise an error here + + max_start = n_step - unroll_instructions.unroll_length + indices = torch.arange(max_start, device=data.device) + + if num_unrolls is not None and num_unrolls < len(indices): + indices = indices[:num_unrolls] + + if unroll_instructions.randomize: + indices = indices[torch.randperm(len(indices), device=data.device)] + + return indices + + def predict(self, initial_state, num_steps): + """ + Generate predictions by recursively applying the model. + + Starting from the initial state, applies the model repeatedly + to generate a trajectory of predicted states. + + :param torch.Tensor initial_state: Starting state with shape + ``[n_features]``. + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory with shape + ``[num_steps + 1, n_features]``, where the first row is + the initial state. + :rtype: torch.Tensor + """ + self.eval() # Set model to evaluation mode + + current_state = initial_state + predictions = [current_state] + + with torch.no_grad(): + for step in range(num_steps): + next_state = self.forward(current_state) + predictions.append(next_state) + current_state = next_state + + return torch.stack(predictions) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py new file mode 100644 index 000000000..67f64b6cd --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -0,0 +1,157 @@ +"""Module for the Autoregressive solver interface.""" + +from abc import abstractmethod +import torch +from torch.nn.modules.loss import _Loss +from dataclasses import dataclass + +from ..solver import SolverInterface +from ...utils import check_consistency +from ...loss.loss_interface import LossInterface +from ...condition import DataCondition +from typing import Optional + + +@dataclass +class UnrollInstructions: + """ + Configuration for creating unroll windows from a time series condition. + + This dataclass specifies how to slice a time series into overlapping + windows for autoregressive training. Each window consists of an initial + state and a sequence of subsequent states used as targets. + + :param str condition_name: Name of the condition in the problem's + conditions dictionary. Must match a key in + ``problem.conditions``. + :param int unroll_length: The length of each unroll window. + :param Optional[int] num_unrolls: The number of unroll windows to create. + If ``None``, all possible windows are used. Default is None. + :param bool randomize: Whether to randomize the starting indices of the unroll windows. + Default is True. + :param Optional[float] eps: Epsilon parameter for exponential loss weighting. + If ``None``, uniform weighting is applied. Default is ``None``. + + :Example: + >>> instructions = UnrollInstructions( + ... condition_name="trajectory", + ... unroll_length=10, + ... num_unrolls=100, + ... randomize=True, + ... eps=0.1 + ... ) + """ + + condition_name: str + unroll_length: int + num_unrolls: Optional[int] = None + randomize: bool = True + eps: Optional[float] = None + + +class AutoregressiveSolverInterface(SolverInterface): + """ + Base class for autoregressive solvers. + + This interface defines solvers that learn to predict the next state + of a dynamical system given the current state. The solver uses an + unrolling strategy where predictions are made recursively over + multiple time steps during training. + The ``AutoregressiveSolverInterface`` is compatible with problems + containing :class:`~pina.condition.data_condition.DataCondition` + conditions, where the input represents a time series trajectory. + """ + + def __init__(self, unroll_instructions_list, loss=None, **kwargs): + """ + Initialization of the :class:`AutoregressiveSolverInterface` class. + + :param list[UnrollInstructions] unroll_instructions_list: List of + :class:`UnrollInstructions` objects, one for each condition + in the problem. Each instruction specifies how to create + unroll windows for training. + ::param torch.nn.Module loss: Loss function to minimize. + If ``None``, :class:`torch.nn.MSELoss` is used. + Default is ``None``. + :param kwargs: Additional keyword arguments passed to + :class:`~pina.solver.solver.SolverInterface`. + """ + + super().__init__(**kwargs) + + if loss is None: + loss = torch.nn.MSELoss() + + check_consistency(loss, (LossInterface, _Loss), subclass=False) + self._loss_fn = loss + self._unroll_instructions_list = unroll_instructions_list + + def optimization_cycle(self, batch): + """ + Optimization cycle for this family of solvers. + Iterates over each condition and each time applies the specialized loss_data function. + + :param list[tuple[str, dict]] batch: List of tuples where each + tuple contains a condition name and a dictionary with the + ``"input"`` key mapping to the time series tensor. + :return: Dictionary mapping condition names to computed loss values. + :rtype: dict[str, torch.Tensor] + """ + + condition_loss = {} + for condition_name, points in batch: + # find unroll instructions for this condition + unroll_instructions = next( + ui + for ui in self._unroll_instructions_list + if ui.condition_name == condition_name + ) + loss = self.loss_data( + points["input"], + unroll_instructions, + ) + condition_loss[condition_name] = loss + return condition_loss + + @abstractmethod + def loss_data(self, input, unroll_instructions: UnrollInstructions): + """ + Compute the data loss for each condition. + This method must be implemented by subclasses to define the + specific loss computation strategy. + + :param torch.Tensor input: Time series data with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + for creating unroll windows from the input data. + :return: Scalar loss value for this condition. + :rtype: torch.Tensor + """ + pass + + @abstractmethod + def predict(self, initial_state, num_steps): + """ + Generate recursive predictions from an initial state. + + Starting from the initial state, repeatedly applies the model + to predict subsequent states. + + :param torch.Tensor initial_state: Starting state with shape + ``[n_features]`` or ``[batch_size, n_features]``. + :param int num_steps: Number of future steps to predict. + :return: Tensor of predictions with shape + ``[num_steps + 1, n_features]``, including the initial state. + :rtype: torch.Tensor + """ + pass + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py new file mode 100644 index 000000000..467f90afe --- /dev/null +++ b/tests/test_solver/test_autoregressive_solver.py @@ -0,0 +1,213 @@ +import pytest +import torch + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver, UnrollInstructions + +NUM_TIMESTEPS = 10 +NUM_FEATURES = 3 + + +@pytest.fixture +def y_data(): + torch.manual_seed(42) + y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) + y[0] = torch.rand(NUM_FEATURES) + for t in range(NUM_TIMESTEPS - 1): + y[t + 1] = 0.95 * y[t] + return y + + +# crate a test Model +class ExactModel(torch.nn.Module): + """ + This model implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x + return next_state + 0.0 * self.dummy_param + + +@pytest.fixture +def solver(y_data): + """Create a minimal solver for testing internal methods.""" + + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + return AutoregressiveSolver( + unroll_instructions_list=[ + UnrollInstructions(condition_name="data", unroll_length=3) + ], + problem=Problem(), + model=ExactModel(), + ) + + +# Tests start here ============================================== + + +def test_exact_model(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = { + "data_condition": DataCondition(input=y_data), + } + + unroll_instruction = UnrollInstructions( + condition_name="data_condition", + unroll_length=5, + ) + + solver = AutoregressiveSolver( + unroll_instructions_list=[unroll_instruction], + problem=Problem(), + model=ExactModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.01), + ) + + loss = solver.loss_data(y_data, unroll_instruction) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + trainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + shuffle=False, + enable_model_summary=False, + ) + trainer.train() + + loss_after_training = solver.loss_data(y_data, unroll_instruction) + assert torch.isclose(loss_after_training, torch.tensor(0.0), atol=1e-6) + + predictions = solver.predict( + initial_state=y_data[0], num_steps=NUM_TIMESTEPS - 1 + ) + expected_predictions = y_data + assert torch.allclose(predictions, expected_predictions, atol=1e-6) + + +def test_indices_sequential_when_no_randomize(solver, y_data): + """Indices should be [0, 1, 2, ...] when randomize=False.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + # y_data has 10 timesteps, unroll_length=3 → max_start = 10 - 3 = 7 + expected = torch.arange(7) + assert torch.equal(indices, expected) + + +def test_indices_permuted_when_randomize(solver, y_data): + """Indices should contain same values but permuted when randomize=True.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + randomize=True, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + expected_values = set(range(7)) + actual_values = set(indices.tolist()) + assert actual_values == expected_values + + +def test_num_unrolls_parameter(solver, y_data): + """num_unrolls should limit the number of indices returned.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + num_unrolls=3, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + assert len(indices) == 3 + assert torch.equal(indices, torch.arange(3)) + + +def test_num_unrolls_greater_than_max_possible(solver, y_data): + """num_unrolls > max_possible should return all possible indices.""" + unroll_length = 3 + maximum_number_of_unrolls = ( + NUM_TIMESTEPS - unroll_length + ) # 10 - unroll_length(3) = 7 + instructions = UnrollInstructions( + condition_name="data", + unroll_length=unroll_length, + num_unrolls=100, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + assert len(indices) == maximum_number_of_unrolls + + +def test_no_valid_indices_when_unroll_too_long(solver, y_data): + """When unroll_length >= n_timesteps, no valid indices exist.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=NUM_TIMESTEPS + 1, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + print(indices) + assert len(indices) == 0 + + +def test_unroll_window_shape(solver, y_data): + """Unroll windows should have correct shapes.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=4, + num_unrolls=2, + randomize=False, + ) + initial_data, unroll_data = solver.create_unroll_windows( + y_data, instructions + ) + + assert initial_data.shape == (2, NUM_FEATURES) # [num_unrolls, features] + assert unroll_data.shape == ( + 2, + 4, + NUM_FEATURES, + ) # [num_unrolls, unroll_length, features] + + +def test_unroll_windows_content(solver, y_data): + """Verify actual content of unroll windows.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + num_unrolls=2, + randomize=False, + ) + initial_data, unroll_data = solver.create_unroll_windows( + y_data, instructions + ) + + # initial_data[i] should be y_data[i] + assert torch.equal(initial_data[0], y_data[0]) + assert torch.equal(initial_data[1], y_data[1]) + + # unroll_data[i] should be y_data[i+1 : i+1+unroll_length] + assert torch.equal(unroll_data[0], y_data[1:4]) + assert torch.equal(unroll_data[1], y_data[2:5])