From 78128b10623a9cc51f4f7946bc30f0b5a8523035 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Wed, 3 Dec 2025 17:57:18 +0100 Subject: [PATCH 1/5] implement autoregressive condition, time_weighting, solver --- pina/condition/__init__.py | 3 + pina/condition/autoregressive_condition.py | 91 ++++++++++++++++++ pina/loss/__init__.py | 10 ++ pina/loss/time_weighting.py | 57 ++++++++++++ pina/loss/time_weighting_interface.py | 24 +++++ pina/solver/__init__.py | 5 + pina/solver/autoregressive_solver/__init__.py | 4 + .../autoregressive_solver.py | 88 ++++++++++++++++++ .../autoregressive_solver_interface.py | 93 +++++++++++++++++++ 9 files changed, 375 insertions(+) create mode 100644 pina/condition/autoregressive_condition.py create mode 100644 pina/loss/time_weighting.py create mode 100644 pina/loss/time_weighting_interface.py create mode 100644 pina/solver/autoregressive_solver/__init__.py create mode 100644 pina/solver/autoregressive_solver/autoregressive_solver.py create mode 100644 pina/solver/autoregressive_solver/autoregressive_solver_interface.py diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 4e57811fb..502c34ae9 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -15,6 +15,7 @@ "DataCondition", "GraphDataCondition", "TensorDataCondition", + "AutoregressiveCondition", ] from .condition_interface import ConditionInterface @@ -37,3 +38,5 @@ GraphDataCondition, TensorDataCondition, ) + +from .autoregressive_condition import AutoregressiveCondition diff --git a/pina/condition/autoregressive_condition.py b/pina/condition/autoregressive_condition.py new file mode 100644 index 000000000..1d48b638d --- /dev/null +++ b/pina/condition/autoregressive_condition.py @@ -0,0 +1,91 @@ +import torch +from .condition_interface import ConditionInterface +from ..loss import TimeWeightingInterface, ConstantTimeWeighting +from ..utils import check_consistency + + +class AutoregressiveCondition(ConditionInterface): + """ + A specialized condition for autoregressive tasks. + It generates input/unroll pairs from a single time-series tensor. + """ + + __slots__ = ["input", "unroll"] + + def __init__( + self, + data, + unroll_length, + num_unrolls=None, + randomize=True, + time_weighting=None, + ): + """ + Create an AutoregressiveCondition. + """ + super().__init__() + + self._n_timesteps, n_features = data.shape + self._unroll_length = unroll_length + self._requested_num_unrolls = num_unrolls + self._randomize = randomize + + # time weighting: weight the loss differently along the unroll + if time_weighting is None: + self._time_weighting = ConstantTimeWeighting() + else: + check_consistency(time_weighting, TimeWeightingInterface) + self._time_weighting = time_weighting + + # windows creation + initial_data = [] + unroll_data = [] + + for starting_index in self.starting_indices: + initial_data.append(data[starting_index]) + target_start = starting_index + 1 + unroll_data.append( + data[target_start : target_start + self._unroll_length, :] + ) + + self.input = torch.stack(initial_data) # [num_unrolls, features] + self.unroll = torch.stack( + unroll_data + ) # [num_unrolls, unroll_length, features] + + @property + def unroll_length(self): + return self._unroll_length + + @property + def time_weighting(self): + return self._time_weighting + + @property + def max_start_idx(self): + max_start_idx = self._n_timesteps - self._unroll_length + assert max_start_idx > 0, "Provided data sequence too short" + return max_start_idx + + @property + def num_unrolls(self): + if self._requested_num_unrolls is None: + return self.max_start_idx + else: + assert ( + self._requested_num_unrolls < self.max_start_idx + ), "too many samples requested" + return self._requested_num_unrolls + + @property + def starting_indices(self): + all_starting_indices = torch.arange(self.max_start_idx) + + if self._randomize: + perm = torch.randperm(len(all_starting_indices)) + return all_starting_indices[perm[: self.num_unrolls]] + else: + selected_indices = torch.linspace( + 0, len(all_starting_indices) - 1, self.num_unrolls + ).long() + return all_starting_indices[selected_indices] diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index d91cf7ab0..2d8ab288e 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -9,6 +9,10 @@ "NeuralTangentKernelWeighting", "SelfAdaptiveWeighting", "LinearWeighting", + "TimeWeightingInterface", + "ConstantTimeWeighting", + "ExponentialTimeWeighting", + "LinearTimeWeighting", ] from .loss_interface import LossInterface @@ -19,3 +23,9 @@ from .ntk_weighting import NeuralTangentKernelWeighting from .self_adaptive_weighting import SelfAdaptiveWeighting from .linear_weighting import LinearWeighting +from .time_weighting_interface import TimeWeightingInterface +from .time_weighting import ( + ConstantTimeWeighting, + ExponentialTimeWeighting, + LinearTimeWeighting, +) diff --git a/pina/loss/time_weighting.py b/pina/loss/time_weighting.py new file mode 100644 index 000000000..0b1d1ed65 --- /dev/null +++ b/pina/loss/time_weighting.py @@ -0,0 +1,57 @@ +"""Module for the Time Weighting.""" + +import torch +from .time_weighting_interface import TimeWeightingInterface + + +class ConstantTimeWeighting(TimeWeightingInterface): + """ + Weighting scheme that assigns equal weight to all time steps. + """ + + def __call__(self, num_steps, device): + return torch.ones(num_steps, device=device) / num_steps + + +class ExponentialTimeWeighting(TimeWeightingInterface): + """ + Weighting scheme change exponentially with time. + gamma > 1.0: increasing weights + 0 < gamma < 1.0: decreasing weights + weight at time t is gamma^t + """ + + def __init__(self, gamma=0.9): + """ + Initialization of the :class:`ExponentialTimeWeighting` class. + :param float gamma: The decay factor. Default is 0.9. + """ + self.gamma = gamma + + def __call__(self, num_steps, device): + steps = torch.arange(num_steps, device=device, dtype=torch.float32) + weights = self.gamma**steps + return weights / weights.sum() + + +class LinearTimeWeighting(TimeWeightingInterface): + """ + Weighting scheme that changes linearly from a start weight to an end weight. + """ + + def __init__(self, start=0.1, end=1.0): + """ + Initialization of the :class:`LinearDecayTimeWeighting` class. + + :param float start: The starting weight. Default is 0.1. + :param float end: The ending weight. Default is 1.0. + """ + self.start = start + self.end = end + + def __call__(self, num_steps, device): + if num_steps == 1: + return torch.ones(1, device=device) + + weights = torch.linspace(self.start, self.end, num_steps, device=device) + return weights / weights.sum() diff --git a/pina/loss/time_weighting_interface.py b/pina/loss/time_weighting_interface.py new file mode 100644 index 000000000..9d9781351 --- /dev/null +++ b/pina/loss/time_weighting_interface.py @@ -0,0 +1,24 @@ +"""Module for the Time Weighting Interface.""" + +from abc import ABCMeta, abstractmethod +import torch + + +class TimeWeightingInterface(metaclass=ABCMeta): + """ + Abstract base class for all time weighting schemas. All time weighting + schemas should inherit from this class. + """ + + @abstractmethod + def __call__(self, num_steps, device): + """ + Compute the weights for the time steps. + + :param int num_steps: The number of time steps. + :param torch.device device: The device on which the weights should be + created. + :return: The weights for the time steps. + :rtype: torch.Tensor + """ + pass diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index 43f18078f..e7d48e2b3 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -18,6 +18,7 @@ "DeepEnsembleSupervisedSolver", "DeepEnsemblePINN", "GAROM", + "AutoregressiveSolver", ] from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface @@ -41,3 +42,7 @@ DeepEnsemblePINN, ) from .garom import GAROM +from .autoregressive_solver import ( + AutoregressiveSolver, + AutoregressiveSolverInterface, +) diff --git a/pina/solver/autoregressive_solver/__init__.py b/pina/solver/autoregressive_solver/__init__.py new file mode 100644 index 000000000..9ef7c43e1 --- /dev/null +++ b/pina/solver/autoregressive_solver/__init__.py @@ -0,0 +1,4 @@ +__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"] + +from .autoregressive_solver import AutoregressiveSolver +from .autoregressive_solver_interface import AutoregressiveSolverInterface diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py new file mode 100644 index 000000000..d0a46c310 --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -0,0 +1,88 @@ +import torch +from torch.nn.modules.loss import _Loss + +from pina.utils import check_consistency +from pina.solver.solver import SingleSolverInterface +from pina.condition import AutoregressiveCondition +from pina.loss import ( + LossInterface, + TimeWeightingInterface, + ConstantTimeWeighting, +) +from .autoregressive_solver_interface import AutoregressiveSolverInterface + + +class AutoregressiveSolver( + AutoregressiveSolverInterface, SingleSolverInterface +): + """ + Autoregressive Solver class. + """ + + accepted_conditions_types = AutoregressiveCondition + + def __init__( + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=False, + ): + """ + Initialization of the :class:`AutoregressiveSolver` class. + """ + super().__init__( + problem=problem, + model=model, + loss=loss, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + + def loss_data(self, input, target, unroll_length, time_weighting): + """ + Compute the data loss for the recursive autoregressive solver. + This will be applied to each condition individually. + """ + steps_to_predict = unroll_length - 1 + # weights are passed from the condition + weights = time_weighting(steps_to_predict, device=input.device) + + total_loss = 0.0 + current_state = input + + for step in range(steps_to_predict): + + predicted_next_state = self.forward( + current_state + ) # [batch_size, features] + actual_next_state = target[:, step, :] # [batch_size, features] + + step_loss = self.loss(predicted_next_state, actual_next_state) + + total_loss += step_loss * weights[step] + + current_state = predicted_next_state.detach() + + return total_loss + + def predict(self, initial_state, num_steps): + """ + Make recursive predictions starting from an initial state. + """ + self.eval() # Set model to evaluation mode + + current_state = initial_state + predictions = [current_state] # Store initial state without batch dim + with torch.no_grad(): + for step in range(num_steps): + next_state = self.forward(current_state) + predictions.append(next_state) # Keep batch dim for storage + current_state = next_state + + return torch.stack(predictions) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py new file mode 100644 index 000000000..e895705fe --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -0,0 +1,93 @@ +"""Module for the Autoregressive solver interface.""" + +from abc import abstractmethod +import torch +from torch.nn.modules.loss import _Loss + +from ..solver import SolverInterface +from ...utils import check_consistency +from ...loss.loss_interface import LossInterface +from ...loss import TimeWeightingInterface, ConstantTimeWeighting +from ...condition import AutoregressiveCondition + + +class AutoregressiveSolverInterface(SolverInterface): + + accepted_conditions_types = AutoregressiveCondition + + def __init__(self, loss=None, **kwargs): + + if loss is None: + loss = torch.nn.MSELoss() + + super().__init__(**kwargs) + + check_consistency(loss, (LossInterface, _Loss), subclass=False) + self._loss_fn = loss + + def optimization_cycle(self, batch): + """ + Optimization cycle for this family of solvers. + Iterates over each conditions and each time applies the specialized loss_data function. + """ + + condition_loss = {} + for condition_name, points in batch: + condition = self.problem.conditions[condition_name] + + unroll_length = getattr(condition, "unroll_length", None) + time_weighting = getattr(condition, "time_weighting", None) + + if "unroll" in points: + loss = self.loss_data( + points["input"], + points["unroll"], + unroll_length, + time_weighting, + ) + condition_loss[condition_name] = loss + return condition_loss + + @abstractmethod + def loss_data(self, input, target, unroll_length, time_weighting): + """ + Computes the data loss for each condition. + N.B.: unroll_length and time_weighting are attributes of the condition. + + :param torch.Tensor input: Initial states. + :param torch.Tensor target: Target sequences. + :param int unroll_length: The number of steps to unroll (attribute of the condition). + :param TimeWeightingInterface time_weighting: The time weighting strategy (attribute of the condition). + :return: The average loss over all unroll steps. + :rtype: torch.Tensor + """ + pass + + @abstractmethod + def predict(self, initial_state, num_steps): + """ + Make recursive predictions starting from an initial state. + + :param torch.Tensor initial_state: Initial state tensor. + :param int num_steps: Number of steps to predict ahead. + :return: Tensor of predictions. + :rtype: torch.Tensor + """ + pass + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn + + @property + def time_weighting(self): + """ + The time weighting strategy. + """ + return self._time_weighting From 76e592567c65590ad3cc6356660ddd8a6b9804f7 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Tue, 16 Dec 2025 18:12:41 +0100 Subject: [PATCH 2/5] implement everything into solver --- autoregressive_prova_generic_condition.py | 149 ++++++++++++++++++ pina/condition/__init__.py | 3 - pina/condition/autoregressive_condition.py | 91 ----------- pina/loss/__init__.py | 10 -- pina/loss/time_weighting.py | 57 ------- pina/loss/time_weighting_interface.py | 24 --- .../autoregressive_solver.py | 142 +++++++++++++---- .../autoregressive_solver_interface.py | 54 +++---- 8 files changed, 285 insertions(+), 245 deletions(-) create mode 100644 autoregressive_prova_generic_condition.py delete mode 100644 pina/condition/autoregressive_condition.py delete mode 100644 pina/loss/time_weighting.py delete mode 100644 pina/loss/time_weighting_interface.py diff --git a/autoregressive_prova_generic_condition.py b/autoregressive_prova_generic_condition.py new file mode 100644 index 000000000..3c0796bbc --- /dev/null +++ b/autoregressive_prova_generic_condition.py @@ -0,0 +1,149 @@ +import torch +import matplotlib.pyplot as plt + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver + +NUM_TIMESTEPS = 100 +NUM_FEATURES = 15 +USE_TEST_MODEL = False + +# ============================================================================ +# DATA +# ============================================================================ + +torch.manual_seed(42) + +y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) +y[0] = torch.rand(NUM_FEATURES) # Random initial state + +for t in range(NUM_TIMESTEPS - 1): + y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum()) + +# ============================================================================ +# TRAINING +# ============================================================================ + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(y.shape[1], 20), + torch.nn.ReLU(), + torch.nn.Dropout(0.2), + torch.nn.Linear(20, y.shape[1]), + ) + + def forward(self, x): + return x + self.layers(x) + + +class TestModel(torch.nn.Module): + """ + Debug model that implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True)) + return next_state + 0.0 * self.dummy_param + + +class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = { + "data_condition_0":DataCondition(input=y), + "data_condition_1":DataCondition(input=y), + } + +problem = Problem() + +#for each condition, define unroll instructions with these keys: +# - unroll_length: length of each unroll window +# - num_unrolls: number of unroll windows to create (if None, use all possible) +# - randomize: whether to randomize the starting indices of the unroll windows +unroll_instructions = { + "data_condition_0": { + "unroll_length": 10, + "num_unrolls": 89, + "randomize": True, + "eps": 5.0 + }, + "data_condition_1": { + "unroll_length": 20, + "num_unrolls": 79, + "randomize": True, + "eps": 10.0 + }, +} + +solver = AutoregressiveSolver( + unroll_instructions=unroll_instructions, + problem=problem, + model=TestModel() if USE_TEST_MODEL else SimpleModel(), + optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01), + eps=10.0, +) + +trainer = Trainer( + solver, max_epochs=2000, accelerator="cpu", enable_model_summary=False, shuffle=False +) +trainer.train() + +# ============================================================================ +# VISUALIZATION +# ============================================================================ + +test_start_idx = 50 +num_prediction_steps = 30 + +initial_state = y[test_start_idx] # Shape: [features] +predictions = solver.predict(initial_state, num_prediction_steps) +actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1] + +total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:]) +print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}") + +# viauzlize single dof +dof_to_plot = [0, 3, 6, 9, 12] +colors = [ + "r", + "g", + "b", + "c", + "m", + "y", + "k", +] +plt.figure(figsize=(10, 6)) +for dof, color in zip(dof_to_plot, colors): + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + actual[:, dof].numpy(), + label="Actual", + marker="o", + color=color, + markerfacecolor="none", + ) + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + predictions[:, dof].numpy(), + label="Predicted", + marker="x", + color=color, + ) + +plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}") +plt.legend() +plt.xlabel("Timestep") +plt.savefig(f"autoregressive_predictions.png") +plt.close() diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 502c34ae9..4e57811fb 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -15,7 +15,6 @@ "DataCondition", "GraphDataCondition", "TensorDataCondition", - "AutoregressiveCondition", ] from .condition_interface import ConditionInterface @@ -38,5 +37,3 @@ GraphDataCondition, TensorDataCondition, ) - -from .autoregressive_condition import AutoregressiveCondition diff --git a/pina/condition/autoregressive_condition.py b/pina/condition/autoregressive_condition.py deleted file mode 100644 index 1d48b638d..000000000 --- a/pina/condition/autoregressive_condition.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -from .condition_interface import ConditionInterface -from ..loss import TimeWeightingInterface, ConstantTimeWeighting -from ..utils import check_consistency - - -class AutoregressiveCondition(ConditionInterface): - """ - A specialized condition for autoregressive tasks. - It generates input/unroll pairs from a single time-series tensor. - """ - - __slots__ = ["input", "unroll"] - - def __init__( - self, - data, - unroll_length, - num_unrolls=None, - randomize=True, - time_weighting=None, - ): - """ - Create an AutoregressiveCondition. - """ - super().__init__() - - self._n_timesteps, n_features = data.shape - self._unroll_length = unroll_length - self._requested_num_unrolls = num_unrolls - self._randomize = randomize - - # time weighting: weight the loss differently along the unroll - if time_weighting is None: - self._time_weighting = ConstantTimeWeighting() - else: - check_consistency(time_weighting, TimeWeightingInterface) - self._time_weighting = time_weighting - - # windows creation - initial_data = [] - unroll_data = [] - - for starting_index in self.starting_indices: - initial_data.append(data[starting_index]) - target_start = starting_index + 1 - unroll_data.append( - data[target_start : target_start + self._unroll_length, :] - ) - - self.input = torch.stack(initial_data) # [num_unrolls, features] - self.unroll = torch.stack( - unroll_data - ) # [num_unrolls, unroll_length, features] - - @property - def unroll_length(self): - return self._unroll_length - - @property - def time_weighting(self): - return self._time_weighting - - @property - def max_start_idx(self): - max_start_idx = self._n_timesteps - self._unroll_length - assert max_start_idx > 0, "Provided data sequence too short" - return max_start_idx - - @property - def num_unrolls(self): - if self._requested_num_unrolls is None: - return self.max_start_idx - else: - assert ( - self._requested_num_unrolls < self.max_start_idx - ), "too many samples requested" - return self._requested_num_unrolls - - @property - def starting_indices(self): - all_starting_indices = torch.arange(self.max_start_idx) - - if self._randomize: - perm = torch.randperm(len(all_starting_indices)) - return all_starting_indices[perm[: self.num_unrolls]] - else: - selected_indices = torch.linspace( - 0, len(all_starting_indices) - 1, self.num_unrolls - ).long() - return all_starting_indices[selected_indices] diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 2d8ab288e..d91cf7ab0 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -9,10 +9,6 @@ "NeuralTangentKernelWeighting", "SelfAdaptiveWeighting", "LinearWeighting", - "TimeWeightingInterface", - "ConstantTimeWeighting", - "ExponentialTimeWeighting", - "LinearTimeWeighting", ] from .loss_interface import LossInterface @@ -23,9 +19,3 @@ from .ntk_weighting import NeuralTangentKernelWeighting from .self_adaptive_weighting import SelfAdaptiveWeighting from .linear_weighting import LinearWeighting -from .time_weighting_interface import TimeWeightingInterface -from .time_weighting import ( - ConstantTimeWeighting, - ExponentialTimeWeighting, - LinearTimeWeighting, -) diff --git a/pina/loss/time_weighting.py b/pina/loss/time_weighting.py deleted file mode 100644 index 0b1d1ed65..000000000 --- a/pina/loss/time_weighting.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Module for the Time Weighting.""" - -import torch -from .time_weighting_interface import TimeWeightingInterface - - -class ConstantTimeWeighting(TimeWeightingInterface): - """ - Weighting scheme that assigns equal weight to all time steps. - """ - - def __call__(self, num_steps, device): - return torch.ones(num_steps, device=device) / num_steps - - -class ExponentialTimeWeighting(TimeWeightingInterface): - """ - Weighting scheme change exponentially with time. - gamma > 1.0: increasing weights - 0 < gamma < 1.0: decreasing weights - weight at time t is gamma^t - """ - - def __init__(self, gamma=0.9): - """ - Initialization of the :class:`ExponentialTimeWeighting` class. - :param float gamma: The decay factor. Default is 0.9. - """ - self.gamma = gamma - - def __call__(self, num_steps, device): - steps = torch.arange(num_steps, device=device, dtype=torch.float32) - weights = self.gamma**steps - return weights / weights.sum() - - -class LinearTimeWeighting(TimeWeightingInterface): - """ - Weighting scheme that changes linearly from a start weight to an end weight. - """ - - def __init__(self, start=0.1, end=1.0): - """ - Initialization of the :class:`LinearDecayTimeWeighting` class. - - :param float start: The starting weight. Default is 0.1. - :param float end: The ending weight. Default is 1.0. - """ - self.start = start - self.end = end - - def __call__(self, num_steps, device): - if num_steps == 1: - return torch.ones(1, device=device) - - weights = torch.linspace(self.start, self.end, num_steps, device=device) - return weights / weights.sum() diff --git a/pina/loss/time_weighting_interface.py b/pina/loss/time_weighting_interface.py deleted file mode 100644 index 9d9781351..000000000 --- a/pina/loss/time_weighting_interface.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Module for the Time Weighting Interface.""" - -from abc import ABCMeta, abstractmethod -import torch - - -class TimeWeightingInterface(metaclass=ABCMeta): - """ - Abstract base class for all time weighting schemas. All time weighting - schemas should inherit from this class. - """ - - @abstractmethod - def __call__(self, num_steps, device): - """ - Compute the weights for the time steps. - - :param int num_steps: The number of time steps. - :param torch.device device: The device on which the weights should be - created. - :return: The weights for the time steps. - :rtype: torch.Tensor - """ - pass diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index d0a46c310..0606a3fd6 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -1,14 +1,7 @@ import torch -from torch.nn.modules.loss import _Loss - from pina.utils import check_consistency from pina.solver.solver import SingleSolverInterface -from pina.condition import AutoregressiveCondition -from pina.loss import ( - LossInterface, - TimeWeightingInterface, - ConstantTimeWeighting, -) +from pina.condition import DataCondition from .autoregressive_solver_interface import AutoregressiveSolverInterface @@ -19,12 +12,14 @@ class AutoregressiveSolver( Autoregressive Solver class. """ - accepted_conditions_types = AutoregressiveCondition + accepted_conditions_types = DataCondition def __init__( self, + unroll_instructions, problem, model, + eps=None, loss=None, optimizer=None, scheduler=None, @@ -33,8 +28,19 @@ def __init__( ): """ Initialization of the :class:`AutoregressiveSolver` class. + :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + this is supposed to map condition names to dict objects with unroll instructions. + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module model: The model to be trained. + :param torch.nn.Module or LossInterface or None loss: The loss function to be minimized. If None, defaults to MSELoss. + :param TorchOptimizer or None optimizer: The optimizer to be used. If None, no optimization is performed. + :param TorchScheduler or None scheduler: The learning rate scheduler to be used. If None, no scheduling is performed. + :param Weighting or None weighting: The weighting scheme for combining losses from different conditions. If None, equal weighting is applied. + :param bool use_lt: Whether to use learning rate tuning. """ + super().__init__( + unroll_instructions=unroll_instructions, problem=problem, model=model, loss=loss, @@ -44,45 +50,123 @@ def __init__( use_lt=use_lt, ) - def loss_data(self, input, target, unroll_length, time_weighting): + def loss_data(self, data, condition_unroll_instructions): """ Compute the data loss for the recursive autoregressive solver. This will be applied to each condition individually. + :param torch.Tensor data: all training data. + :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :return: Computed loss value. + :rtype: torch.Tensor """ - steps_to_predict = unroll_length - 1 - # weights are passed from the condition - weights = time_weighting(steps_to_predict, device=input.device) - total_loss = 0.0 - current_state = input + initial_data, unroll_data = self.create_unroll_windows( + data, condition_unroll_instructions + ) + + unroll_length = condition_unroll_instructions["unroll_length"] + current_state = initial_data # [num_unrolls, features] + + losses = [] + for step in range(unroll_length): + + predicted_state = self.forward(current_state) # [num_unrolls, features] + target_state = unroll_data[:, step, :] # [num_unrolls, features] + step_loss = self._loss_fn(predicted_state, target_state) + losses.append(step_loss) + current_state = predicted_state + + step_losses = torch.stack(losses) # [unroll_length] + + with torch.no_grad(): + weights = self.compute_adaptive_weights(step_losses.detach(), condition_unroll_instructions) + + weighted_loss = (step_losses * weights).sum() + return weighted_loss - for step in range(steps_to_predict): + def create_unroll_windows(self, data, condition_unroll_instructions): + """ + Create unroll windows for each condition from the data based on the provided instructions. + :param torch.Tensor data: The full data tensor. + :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :return: Tuple of initial data and unroll data tensors. + :rtype: (torch.Tensor, torch.Tensor) + """ - predicted_next_state = self.forward( - current_state - ) # [batch_size, features] - actual_next_state = target[:, step, :] # [batch_size, features] + unroll_length = condition_unroll_instructions["unroll_length"] + + start_list = [] + unroll_list = [] + for starting_index in self.decide_starting_indices( + data, condition_unroll_instructions + ): + idx = starting_index.item() + start = data[idx] + target_start = idx + 1 + unroll = data[target_start : target_start + unroll_length, :] + start_list.append(start) + unroll_list.append(unroll) + initial_data = torch.stack(start_list) # [num_unrolls, features] + unroll_data = torch.stack(unroll_list) # [num_unrolls, unroll_length, features] + return initial_data, unroll_data + + def decide_starting_indices(self, data, condition_unroll_instructions): + """ + Decide the starting indices for unrolling based on the provided instructions. + :param torch.Tensor data: The full data tensor. + :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :return: Tensor of starting indices. + :rtype: torch.Tensor + """ + n_step, n_features = data.shape + num_unrolls = condition_unroll_instructions.get("num_unrolls", None) + unroll_length = condition_unroll_instructions["unroll_length"] + randomize = condition_unroll_instructions.get("randomize", True) - step_loss = self.loss(predicted_next_state, actual_next_state) + max_start = n_step - unroll_length + indices = torch.arange(max_start, device=data.device) - total_loss += step_loss * weights[step] + if num_unrolls is not None and num_unrolls < len(indices): + indices = indices[:num_unrolls] - current_state = predicted_next_state.detach() + if randomize: + indices = indices[torch.randperm(len(indices), device=data.device)] - return total_loss + return indices + + def compute_adaptive_weights(self, step_losses, condition_unroll_instructions): + """ + Compute adaptive weights for each time step based on cumulative losses. + :param torch.Tensor step_losses: Tensor of shape [unroll_length] containing losses at each time step. + :return: Tensor of shape [unroll_length] containing normalized weights. + :rtype: torch.Tensor + """ + num_steps = len(step_losses) + eps = condition_unroll_instructions.get("eps", None) + if eps is None: + weights = torch.ones_like(step_losses) + else: + weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) + + return weights / weights.sum() def predict(self, initial_state, num_steps): """ Make recursive predictions starting from an initial state. + :param torch.Tensor initial_state: Initial state tensor. + :param int num_steps: Number of steps to predict ahead. + :return: Tensor of predictions. + :rtype: torch.Tensor """ self.eval() # Set model to evaluation mode - + current_state = initial_state - predictions = [current_state] # Store initial state without batch dim + predictions = [current_state] + with torch.no_grad(): for step in range(num_steps): next_state = self.forward(current_state) - predictions.append(next_state) # Keep batch dim for storage + predictions.append(next_state) current_state = next_state - - return torch.stack(predictions) + + return torch.stack(predictions) \ No newline at end of file diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index e895705fe..d0a6f919a 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -7,58 +7,57 @@ from ..solver import SolverInterface from ...utils import check_consistency from ...loss.loss_interface import LossInterface -from ...loss import TimeWeightingInterface, ConstantTimeWeighting -from ...condition import AutoregressiveCondition +from ...condition import DataCondition class AutoregressiveSolverInterface(SolverInterface): - accepted_conditions_types = AutoregressiveCondition + def __init__(self, unroll_instructions, loss=None, **kwargs): + """ + Initialization of the :class:`AutoregressiveSolverInterface` class. + :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + this is supposed to map condition names to dict objects with unroll instructions. + :param loss: The loss function to be minimized. If None, defaults to MSELoss. + :type loss: torch.nn.Module or LossInterface, optional + """ - def __init__(self, loss=None, **kwargs): + super().__init__(**kwargs) if loss is None: loss = torch.nn.MSELoss() - super().__init__(**kwargs) - check_consistency(loss, (LossInterface, _Loss), subclass=False) self._loss_fn = loss + self._unroll_instructions = unroll_instructions def optimization_cycle(self, batch): """ Optimization cycle for this family of solvers. Iterates over each conditions and each time applies the specialized loss_data function. + :param dict batch: A dictionary mapping condition names to data batches. + :return: A dictionary mapping condition names to computed loss values. + :rtype: dict """ condition_loss = {} for condition_name, points in batch: - condition = self.problem.conditions[condition_name] - - unroll_length = getattr(condition, "unroll_length", None) - time_weighting = getattr(condition, "time_weighting", None) - - if "unroll" in points: - loss = self.loss_data( + condition_unroll_instructions = self._unroll_instructions[condition_name] + loss = self.loss_data( points["input"], - points["unroll"], - unroll_length, - time_weighting, + condition_unroll_instructions, ) condition_loss[condition_name] = loss return condition_loss @abstractmethod - def loss_data(self, input, target, unroll_length, time_weighting): + def loss_data(self, input, condition_unroll_instructions): """ Computes the data loss for each condition. - N.B.: unroll_length and time_weighting are attributes of the condition. + N.B.: This loss_data function must make use of unroll_instructions to know how to unroll the model. - :param torch.Tensor input: Initial states. - :param torch.Tensor target: Target sequences. - :param int unroll_length: The number of steps to unroll (attribute of the condition). - :param TimeWeightingInterface time_weighting: The time weighting strategy (attribute of the condition). - :return: The average loss over all unroll steps. + :param torch.Tensor input: all training data. + :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :return: Computed loss value. :rtype: torch.Tensor """ pass @@ -83,11 +82,4 @@ def loss(self): :return: The loss function to be minimized. :rtype: torch.nn.Module """ - return self._loss_fn - - @property - def time_weighting(self): - """ - The time weighting strategy. - """ - return self._time_weighting + return self._loss_fn \ No newline at end of file From bb2f925f7388dd89ef40a56a169196ded3699954 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Mon, 12 Jan 2026 18:21:42 +0100 Subject: [PATCH 3/5] add dataclass for managing unroll settings --- autoregressive_prova_generic_condition.py | 35 ++++----- pina/solver/__init__.py | 1 + pina/solver/autoregressive_solver/__init__.py | 1 + .../autoregressive_solver.py | 74 +++++++------------ .../autoregressive_solver_interface.py | 28 +++++-- 5 files changed, 69 insertions(+), 70 deletions(-) diff --git a/autoregressive_prova_generic_condition.py b/autoregressive_prova_generic_condition.py index 3c0796bbc..4812048fb 100644 --- a/autoregressive_prova_generic_condition.py +++ b/autoregressive_prova_generic_condition.py @@ -5,7 +5,7 @@ from pina.optim import TorchOptimizer from pina.problem import AbstractProblem from pina.condition.data_condition import DataCondition -from pina.solver import AutoregressiveSolver +from pina.solver import AutoregressiveSolver,UnrollInstructions NUM_TIMESTEPS = 100 NUM_FEATURES = 15 @@ -71,27 +71,28 @@ class Problem(AbstractProblem): # - unroll_length: length of each unroll window # - num_unrolls: number of unroll windows to create (if None, use all possible) # - randomize: whether to randomize the starting indices of the unroll windows -unroll_instructions = { - "data_condition_0": { - "unroll_length": 10, - "num_unrolls": 89, - "randomize": True, - "eps": 5.0 - }, - "data_condition_1": { - "unroll_length": 20, - "num_unrolls": 79, - "randomize": True, - "eps": 10.0 - }, -} +unroll_instructions_list = [ + UnrollInstructions( + condition_name="data_condition_0", + unroll_length=10, + num_unrolls=89, + randomize=True, + eps=5.0 + ), + UnrollInstructions( + condition_name="data_condition_1", + unroll_length=20, + num_unrolls=79, + randomize=True, + eps=10.0 + ), +] solver = AutoregressiveSolver( - unroll_instructions=unroll_instructions, + unroll_instructions_list=unroll_instructions_list, problem=problem, model=TestModel() if USE_TEST_MODEL else SimpleModel(), optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01), - eps=10.0, ) trainer = Trainer( diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index e7d48e2b3..8494df8b0 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -45,4 +45,5 @@ from .autoregressive_solver import ( AutoregressiveSolver, AutoregressiveSolverInterface, + UnrollInstructions, ) diff --git a/pina/solver/autoregressive_solver/__init__.py b/pina/solver/autoregressive_solver/__init__.py index 9ef7c43e1..ac0d60a12 100644 --- a/pina/solver/autoregressive_solver/__init__.py +++ b/pina/solver/autoregressive_solver/__init__.py @@ -2,3 +2,4 @@ from .autoregressive_solver import AutoregressiveSolver from .autoregressive_solver_interface import AutoregressiveSolverInterface +from .autoregressive_solver_interface import UnrollInstructions diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index 0606a3fd6..0a31d1ae2 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -3,7 +3,8 @@ from pina.solver.solver import SingleSolverInterface from pina.condition import DataCondition from .autoregressive_solver_interface import AutoregressiveSolverInterface - +from .autoregressive_solver_interface import UnrollInstructions +from typing import List class AutoregressiveSolver( AutoregressiveSolverInterface, SingleSolverInterface @@ -16,10 +17,9 @@ class AutoregressiveSolver( def __init__( self, - unroll_instructions, + unroll_instructions_list:List[UnrollInstructions], problem, model, - eps=None, loss=None, optimizer=None, scheduler=None, @@ -28,7 +28,7 @@ def __init__( ): """ Initialization of the :class:`AutoregressiveSolver` class. - :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + :param list unroll_instructions_list: A list of UnrollInstructions, one for each condition. this is supposed to map condition names to dict objects with unroll instructions. :param AbstractProblem problem: The problem to be solved. :param torch.nn.Module model: The model to be trained. @@ -40,7 +40,7 @@ def __init__( """ super().__init__( - unroll_instructions=unroll_instructions, + unroll_instructions_list=unroll_instructions_list, problem=problem, model=model, loss=loss, @@ -50,25 +50,23 @@ def __init__( use_lt=use_lt, ) - def loss_data(self, data, condition_unroll_instructions): + def loss_data(self, data, unroll_instructions:UnrollInstructions): """ Compute the data loss for the recursive autoregressive solver. This will be applied to each condition individually. :param torch.Tensor data: all training data. - :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :param UnrollInstructions unroll_instructions: instructions on how to unroll the model for this condition. :return: Computed loss value. :rtype: torch.Tensor """ initial_data, unroll_data = self.create_unroll_windows( - data, condition_unroll_instructions + data, unroll_instructions ) - - unroll_length = condition_unroll_instructions["unroll_length"] current_state = initial_data # [num_unrolls, features] losses = [] - for step in range(unroll_length): + for step in range(unroll_instructions.unroll_length): predicted_state = self.forward(current_state) # [num_unrolls, features] target_state = unroll_data[:, step, :] # [num_unrolls, features] @@ -79,76 +77,60 @@ def loss_data(self, data, condition_unroll_instructions): step_losses = torch.stack(losses) # [unroll_length] with torch.no_grad(): - weights = self.compute_adaptive_weights(step_losses.detach(), condition_unroll_instructions) + eps = unroll_instructions.eps + if eps is None: + weights = torch.ones_like(step_losses) + else: + weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) + weights = weights / weights.sum() - weighted_loss = (step_losses * weights).sum() - return weighted_loss + return (step_losses * weights).sum() - def create_unroll_windows(self, data, condition_unroll_instructions): + def create_unroll_windows(self, data, unroll_instructions:UnrollInstructions): """ Create unroll windows for each condition from the data based on the provided instructions. :param torch.Tensor data: The full data tensor. - :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :param UnrollInstructions unroll_instructions: Instructions on how to unroll the model for this condition. :return: Tuple of initial data and unroll data tensors. :rtype: (torch.Tensor, torch.Tensor) """ - unroll_length = condition_unroll_instructions["unroll_length"] + unroll_length = unroll_instructions.unroll_length start_list = [] unroll_list = [] for starting_index in self.decide_starting_indices( - data, condition_unroll_instructions + data, unroll_instructions ): idx = starting_index.item() - start = data[idx] - target_start = idx + 1 - unroll = data[target_start : target_start + unroll_length, :] - start_list.append(start) - unroll_list.append(unroll) + start_list.append(data[idx]) + unroll_list.append(data[idx+1 : idx+1+unroll_length, :]) + initial_data = torch.stack(start_list) # [num_unrolls, features] unroll_data = torch.stack(unroll_list) # [num_unrolls, unroll_length, features] return initial_data, unroll_data - def decide_starting_indices(self, data, condition_unroll_instructions): + def decide_starting_indices(self, data, unroll_instructions:UnrollInstructions): """ Decide the starting indices for unrolling based on the provided instructions. :param torch.Tensor data: The full data tensor. - :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :param UnrollInstructions unroll_instructions: Instructions on how to unroll the model for this condition. :return: Tensor of starting indices. :rtype: torch.Tensor """ n_step, n_features = data.shape - num_unrolls = condition_unroll_instructions.get("num_unrolls", None) - unroll_length = condition_unroll_instructions["unroll_length"] - randomize = condition_unroll_instructions.get("randomize", True) + num_unrolls = unroll_instructions.num_unrolls - max_start = n_step - unroll_length + max_start = n_step - unroll_instructions.unroll_length indices = torch.arange(max_start, device=data.device) if num_unrolls is not None and num_unrolls < len(indices): indices = indices[:num_unrolls] - if randomize: + if unroll_instructions.randomize: indices = indices[torch.randperm(len(indices), device=data.device)] return indices - - def compute_adaptive_weights(self, step_losses, condition_unroll_instructions): - """ - Compute adaptive weights for each time step based on cumulative losses. - :param torch.Tensor step_losses: Tensor of shape [unroll_length] containing losses at each time step. - :return: Tensor of shape [unroll_length] containing normalized weights. - :rtype: torch.Tensor - """ - num_steps = len(step_losses) - eps = condition_unroll_instructions.get("eps", None) - if eps is None: - weights = torch.ones_like(step_losses) - else: - weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) - - return weights / weights.sum() def predict(self, initial_state, num_steps): """ diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index d0a6f919a..bf6a67462 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -3,19 +3,29 @@ from abc import abstractmethod import torch from torch.nn.modules.loss import _Loss +from dataclasses import dataclass from ..solver import SolverInterface from ...utils import check_consistency from ...loss.loss_interface import LossInterface from ...condition import DataCondition +from typing import Optional +@dataclass +class UnrollInstructions: + """Instructions for unrolling a single condition.""" + condition_name: str + unroll_length: int + num_unrolls: Optional[int] = None + randomize: bool = True + eps: Optional[float] = None class AutoregressiveSolverInterface(SolverInterface): - def __init__(self, unroll_instructions, loss=None, **kwargs): + def __init__(self, unroll_instructions_list, loss=None, **kwargs): """ Initialization of the :class:`AutoregressiveSolverInterface` class. - :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + :param list unroll_instructions: A list of UnrollInstructions, one for each condition. this is supposed to map condition names to dict objects with unroll instructions. :param loss: The loss function to be minimized. If None, defaults to MSELoss. :type loss: torch.nn.Module or LossInterface, optional @@ -28,7 +38,7 @@ def __init__(self, unroll_instructions, loss=None, **kwargs): check_consistency(loss, (LossInterface, _Loss), subclass=False) self._loss_fn = loss - self._unroll_instructions = unroll_instructions + self._unroll_instructions_list = unroll_instructions_list def optimization_cycle(self, batch): """ @@ -41,22 +51,26 @@ def optimization_cycle(self, batch): condition_loss = {} for condition_name, points in batch: - condition_unroll_instructions = self._unroll_instructions[condition_name] + #find unroll instructions for this condition + unroll_instructions = next( + ui for ui in self._unroll_instructions_list + if ui.condition_name == condition_name + ) loss = self.loss_data( points["input"], - condition_unroll_instructions, + unroll_instructions, ) condition_loss[condition_name] = loss return condition_loss @abstractmethod - def loss_data(self, input, condition_unroll_instructions): + def loss_data(self, input, unroll_instructions:UnrollInstructions): """ Computes the data loss for each condition. N.B.: This loss_data function must make use of unroll_instructions to know how to unroll the model. :param torch.Tensor input: all training data. - :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :param UnrollInstruction unroll_instructions: instructions on how to unroll the model for this condition. :return: Computed loss value. :rtype: torch.Tensor """ From 1a3bfa8aea816de2f64e16b9fa0080124c9720a8 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Tue, 13 Jan 2026 14:52:06 +0100 Subject: [PATCH 4/5] add docstings and tests --- autoregressive_prova_generic_condition.py | 150 ------------ .../autoregressive_solver.py | 157 +++++++++---- .../autoregressive_solver_interface.py | 109 ++++++--- .../test_solver/test_autoregressive_solver.py | 213 ++++++++++++++++++ 4 files changed, 410 insertions(+), 219 deletions(-) delete mode 100644 autoregressive_prova_generic_condition.py create mode 100644 tests/test_solver/test_autoregressive_solver.py diff --git a/autoregressive_prova_generic_condition.py b/autoregressive_prova_generic_condition.py deleted file mode 100644 index 4812048fb..000000000 --- a/autoregressive_prova_generic_condition.py +++ /dev/null @@ -1,150 +0,0 @@ -import torch -import matplotlib.pyplot as plt - -from pina import Trainer -from pina.optim import TorchOptimizer -from pina.problem import AbstractProblem -from pina.condition.data_condition import DataCondition -from pina.solver import AutoregressiveSolver,UnrollInstructions - -NUM_TIMESTEPS = 100 -NUM_FEATURES = 15 -USE_TEST_MODEL = False - -# ============================================================================ -# DATA -# ============================================================================ - -torch.manual_seed(42) - -y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) -y[0] = torch.rand(NUM_FEATURES) # Random initial state - -for t in range(NUM_TIMESTEPS - 1): - y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum()) - -# ============================================================================ -# TRAINING -# ============================================================================ - -class SimpleModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = torch.nn.Sequential( - torch.nn.Linear(y.shape[1], 20), - torch.nn.ReLU(), - torch.nn.Dropout(0.2), - torch.nn.Linear(20, y.shape[1]), - ) - - def forward(self, x): - return x + self.layers(x) - - -class TestModel(torch.nn.Module): - """ - Debug model that implements the EXACT transformation rule. - y[t+1] = 0.95 * y[t] - Expected loss is zero - """ - - def __init__(self, data_series=None): - super().__init__() - self.dummy_param = torch.nn.Parameter(torch.zeros(1)) - - def forward(self, x): - next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True)) - return next_state + 0.0 * self.dummy_param - - -class Problem(AbstractProblem): - output_variables = None - input_variables = None - conditions = { - "data_condition_0":DataCondition(input=y), - "data_condition_1":DataCondition(input=y), - } - -problem = Problem() - -#for each condition, define unroll instructions with these keys: -# - unroll_length: length of each unroll window -# - num_unrolls: number of unroll windows to create (if None, use all possible) -# - randomize: whether to randomize the starting indices of the unroll windows -unroll_instructions_list = [ - UnrollInstructions( - condition_name="data_condition_0", - unroll_length=10, - num_unrolls=89, - randomize=True, - eps=5.0 - ), - UnrollInstructions( - condition_name="data_condition_1", - unroll_length=20, - num_unrolls=79, - randomize=True, - eps=10.0 - ), -] - -solver = AutoregressiveSolver( - unroll_instructions_list=unroll_instructions_list, - problem=problem, - model=TestModel() if USE_TEST_MODEL else SimpleModel(), - optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01), -) - -trainer = Trainer( - solver, max_epochs=2000, accelerator="cpu", enable_model_summary=False, shuffle=False -) -trainer.train() - -# ============================================================================ -# VISUALIZATION -# ============================================================================ - -test_start_idx = 50 -num_prediction_steps = 30 - -initial_state = y[test_start_idx] # Shape: [features] -predictions = solver.predict(initial_state, num_prediction_steps) -actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1] - -total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:]) -print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}") - -# viauzlize single dof -dof_to_plot = [0, 3, 6, 9, 12] -colors = [ - "r", - "g", - "b", - "c", - "m", - "y", - "k", -] -plt.figure(figsize=(10, 6)) -for dof, color in zip(dof_to_plot, colors): - plt.plot( - range(test_start_idx, test_start_idx + num_prediction_steps + 1), - actual[:, dof].numpy(), - label="Actual", - marker="o", - color=color, - markerfacecolor="none", - ) - plt.plot( - range(test_start_idx, test_start_idx + num_prediction_steps + 1), - predictions[:, dof].numpy(), - label="Predicted", - marker="x", - color=color, - ) - -plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}") -plt.legend() -plt.xlabel("Timestep") -plt.savefig(f"autoregressive_predictions.png") -plt.close() diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index 0a31d1ae2..c754aae4d 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -6,18 +6,39 @@ from .autoregressive_solver_interface import UnrollInstructions from typing import List + class AutoregressiveSolver( AutoregressiveSolverInterface, SingleSolverInterface ): - """ - Autoregressive Solver class. + r""" + Autoregressive Solver for learning dynamical systems. + + This solver learns a one-step transition function + :math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps + a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`. + + During training, the model is unrolled over multiple time steps to + learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`, + the model generates predictions recursively: + + .. math:: + \hat{\mathbf{y}}_{t+1} = \mathcal{M}(\hat{\mathbf{y}}_t), + \quad \hat{\mathbf{y}}_0 = \mathbf{y}_0 + + The loss is computed over the entire unroll window: + + .. math:: + \mathcal{L} = \sum_{t=1}^{T} w_t \|\hat{\mathbf{y}}_t - \mathbf{y}_t\|^2 + + where :math:`w_t` are exponential weights (if ``eps`` is specified) + that down-weight later predictions to stabilize training. """ accepted_conditions_types = DataCondition def __init__( self, - unroll_instructions_list:List[UnrollInstructions], + unroll_instructions_list: List[UnrollInstructions], problem, model, loss=None, @@ -28,15 +49,27 @@ def __init__( ): """ Initialization of the :class:`AutoregressiveSolver` class. - :param list unroll_instructions_list: A list of UnrollInstructions, one for each condition. - this is supposed to map condition names to dict objects with unroll instructions. - :param AbstractProblem problem: The problem to be solved. - :param torch.nn.Module model: The model to be trained. - :param torch.nn.Module or LossInterface or None loss: The loss function to be minimized. If None, defaults to MSELoss. - :param TorchOptimizer or None optimizer: The optimizer to be used. If None, no optimization is performed. - :param TorchScheduler or None scheduler: The learning rate scheduler to be used. If None, no scheduling is performed. - :param Weighting or None weighting: The weighting scheme for combining losses from different conditions. If None, equal weighting is applied. - :param bool use_lt: Whether to use learning rate tuning. + + :param list[UnrollInstructions] unroll_instructions_list: List of + :class:`UnrollInstructions` specifying how to create training + windows for each condition. + :param AbstractProblem problem: The problem instance containing + the time series data conditions. + :param torch.nn.Module model: Neural network that predicts the + next state given the current state. + :param torch.nn.Module loss: Loss function to minimize. + If ``None``, :class:`torch.nn.MSELoss` is used. + Default is ``None``. + :param TorchOptimizer optimizer: Optimizer for training. + If ``None``, :class:`torch.optim.Adam` is used. + Default is ``None``. + :param TorchScheduler scheduler: Learning rate scheduler. + If ``None``, no scheduling is applied. Default is ``None``. + :param WeightingInterface weighting: Weighting scheme for + combining losses from multiple conditions. + If ``None``, uniform weighting is used. Default is ``None``. + :param bool use_lt: Whether to use LabelTensors. + Default is ``False``. """ super().__init__( @@ -50,30 +83,37 @@ def __init__( use_lt=use_lt, ) - def loss_data(self, data, unroll_instructions:UnrollInstructions): + def loss_data(self, data, unroll_instructions: UnrollInstructions): """ Compute the data loss for the recursive autoregressive solver. - This will be applied to each condition individually. - :param torch.Tensor data: all training data. - :param UnrollInstructions unroll_instructions: instructions on how to unroll the model for this condition. - :return: Computed loss value. + + Creates unroll windows from the data, then iteratively predicts + each next state and computes the loss against the ground truth. + + :param torch.Tensor data: Time series with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + for window creation and loss weighting. + :return: Weighted sum of step losses. :rtype: torch.Tensor """ initial_data, unroll_data = self.create_unroll_windows( data, unroll_instructions ) - current_state = initial_data # [num_unrolls, features] + current_state = initial_data # [num_unrolls, features] losses = [] for step in range(unroll_instructions.unroll_length): - predicted_state = self.forward(current_state) # [num_unrolls, features] + predicted_state = self.forward( + current_state + ) # [num_unrolls, features] target_state = unroll_data[:, step, :] # [num_unrolls, features] step_loss = self._loss_fn(predicted_state, target_state) losses.append(step_loss) current_state = predicted_state - + step_losses = torch.stack(losses) # [unroll_length] with torch.no_grad(): @@ -83,16 +123,28 @@ def loss_data(self, data, unroll_instructions:UnrollInstructions): else: weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) weights = weights / weights.sum() - + return (step_losses * weights).sum() - def create_unroll_windows(self, data, unroll_instructions:UnrollInstructions): + def create_unroll_windows( + self, data, unroll_instructions: UnrollInstructions + ): """ - Create unroll windows for each condition from the data based on the provided instructions. - :param torch.Tensor data: The full data tensor. - :param UnrollInstructions unroll_instructions: Instructions on how to unroll the model for this condition. - :return: Tuple of initial data and unroll data tensors. - :rtype: (torch.Tensor, torch.Tensor) + Create unroll windows from time series data. + + Slices the input time series into overlapping windows, each + consisting of an initial state and subsequent target states. + + :param torch.Tensor data: Time series with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + specifying window length and count. + :return: Tuple of ``(initial_data, unroll_data)`` where: + + - ``initial_data``: Shape ``[num_unrolls, n_features]`` + - ``unroll_data``: Shape ``[num_unrolls, unroll_length, n_features]`` + + :rtype: tuple[torch.Tensor, torch.Tensor] """ unroll_length = unroll_instructions.unroll_length @@ -104,22 +156,34 @@ def create_unroll_windows(self, data, unroll_instructions:UnrollInstructions): ): idx = starting_index.item() start_list.append(data[idx]) - unroll_list.append(data[idx+1 : idx+1+unroll_length, :]) + unroll_list.append(data[idx + 1 : idx + 1 + unroll_length, :]) - initial_data = torch.stack(start_list) # [num_unrolls, features] - unroll_data = torch.stack(unroll_list) # [num_unrolls, unroll_length, features] + initial_data = torch.stack(start_list) # [num_unrolls, features] + unroll_data = torch.stack( + unroll_list + ) # [num_unrolls, unroll_length, features] return initial_data, unroll_data - def decide_starting_indices(self, data, unroll_instructions:UnrollInstructions): + def decide_starting_indices( + self, data, unroll_instructions: UnrollInstructions + ): """ - Decide the starting indices for unrolling based on the provided instructions. - :param torch.Tensor data: The full data tensor. - :param UnrollInstructions unroll_instructions: Instructions on how to unroll the model for this condition. - :return: Tensor of starting indices. + Determine starting indices for unroll windows. + + Computes valid starting positions ensuring each window has + enough subsequent time steps for the specified unroll length. + + :param torch.Tensor data: Time series with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + with ``unroll_length``, ``num_unrolls``, and ``randomize``. + :return: 1D tensor of starting indices. :rtype: torch.Tensor """ n_step, n_features = data.shape num_unrolls = unroll_instructions.num_unrolls + if unroll_instructions.unroll_length >= n_step: + return [] #TODO: decide if it is better to raise an error here max_start = n_step - unroll_instructions.unroll_length indices = torch.arange(max_start, device=data.device) @@ -134,21 +198,28 @@ def decide_starting_indices(self, data, unroll_instructions:UnrollInstructions): def predict(self, initial_state, num_steps): """ - Make recursive predictions starting from an initial state. - :param torch.Tensor initial_state: Initial state tensor. - :param int num_steps: Number of steps to predict ahead. - :return: Tensor of predictions. + Generate predictions by recursively applying the model. + + Starting from the initial state, applies the model repeatedly + to generate a trajectory of predicted states. + + :param torch.Tensor initial_state: Starting state with shape + ``[n_features]``. + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory with shape + ``[num_steps + 1, n_features]``, where the first row is + the initial state. :rtype: torch.Tensor """ self.eval() # Set model to evaluation mode - + current_state = initial_state predictions = [current_state] - + with torch.no_grad(): for step in range(num_steps): next_state = self.forward(current_state) predictions.append(next_state) current_state = next_state - - return torch.stack(predictions) \ No newline at end of file + + return torch.stack(predictions) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index bf6a67462..9d08b6b4e 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -13,22 +13,67 @@ @dataclass class UnrollInstructions: - """Instructions for unrolling a single condition.""" + """ + Configuration for creating unroll windows from a time series condition. + + This dataclass specifies how to slice a time series into overlapping + windows for autoregressive training. Each window consists of an initial + state and a sequence of subsequent states used as targets. + + :param str condition_name: Name of the condition in the problem's + conditions dictionary. Must match a key in + ``problem.conditions``. + :param int unroll_length: The length of each unroll window. + :param Optional[int] num_unrolls: The number of unroll windows to create. + If ``None``, all possible windows are used. Default is None. + :param bool randomize: Whether to randomize the starting indices of the unroll windows. + Default is True. + :param Optional[float] eps: Epsilon parameter for exponential loss weighting. + If ``None``, uniform weighting is applied. Default is ``None``. + + :Example: + >>> instructions = UnrollInstructions( + ... condition_name="trajectory", + ... unroll_length=10, + ... num_unrolls=100, + ... randomize=True, + ... eps=0.1 + ... ) + """ + condition_name: str unroll_length: int num_unrolls: Optional[int] = None randomize: bool = True eps: Optional[float] = None + class AutoregressiveSolverInterface(SolverInterface): + """ + Base class for autoregressive solvers. + + This interface defines solvers that learn to predict the next state + of a dynamical system given the current state. The solver uses an + unrolling strategy where predictions are made recursively over + multiple time steps during training. + The ``AutoregressiveSolverInterface`` is compatible with problems + containing :class:`~pina.condition.data_condition.DataCondition` + conditions, where the input represents a time series trajectory. + """ def __init__(self, unroll_instructions_list, loss=None, **kwargs): """ Initialization of the :class:`AutoregressiveSolverInterface` class. - :param list unroll_instructions: A list of UnrollInstructions, one for each condition. - this is supposed to map condition names to dict objects with unroll instructions. - :param loss: The loss function to be minimized. If None, defaults to MSELoss. - :type loss: torch.nn.Module or LossInterface, optional + + :param list[UnrollInstructions] unroll_instructions_list: List of + :class:`UnrollInstructions` objects, one for each condition + in the problem. Each instruction specifies how to create + unroll windows for training. + ::param torch.nn.Module loss: Loss function to minimize. + If ``None``, :class:`torch.nn.MSELoss` is used. + Default is ``None``. + :param kwargs: Additional keyword arguments passed to + :class:`~pina.solver.solver.SolverInterface`. """ super().__init__(**kwargs) @@ -43,35 +88,42 @@ def __init__(self, unroll_instructions_list, loss=None, **kwargs): def optimization_cycle(self, batch): """ Optimization cycle for this family of solvers. - Iterates over each conditions and each time applies the specialized loss_data function. - :param dict batch: A dictionary mapping condition names to data batches. - :return: A dictionary mapping condition names to computed loss values. - :rtype: dict + Iterates over each condition and each time applies the specialized loss_data function. + + :param list[tuple[str, dict]] batch: List of tuples where each + tuple contains a condition name and a dictionary with the + ``"input"`` key mapping to the time series tensor. + :return: Dictionary mapping condition names to computed loss values. + :rtype: dict[str, torch.Tensor] """ condition_loss = {} for condition_name, points in batch: - #find unroll instructions for this condition + # find unroll instructions for this condition unroll_instructions = next( - ui for ui in self._unroll_instructions_list + ui + for ui in self._unroll_instructions_list if ui.condition_name == condition_name ) loss = self.loss_data( - points["input"], - unroll_instructions, - ) + points["input"], + unroll_instructions, + ) condition_loss[condition_name] = loss return condition_loss @abstractmethod - def loss_data(self, input, unroll_instructions:UnrollInstructions): + def loss_data(self, input, unroll_instructions: UnrollInstructions): """ - Computes the data loss for each condition. - N.B.: This loss_data function must make use of unroll_instructions to know how to unroll the model. - - :param torch.Tensor input: all training data. - :param UnrollInstruction unroll_instructions: instructions on how to unroll the model for this condition. - :return: Computed loss value. + Compute the data loss for each condition. + This method must be implemented by subclasses to define the + specific loss computation strategy. + + :param torch.Tensor input: Time series data with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + for creating unroll windows from the input data. + :return: Scalar loss value for this condition. :rtype: torch.Tensor """ pass @@ -79,11 +131,16 @@ def loss_data(self, input, unroll_instructions:UnrollInstructions): @abstractmethod def predict(self, initial_state, num_steps): """ - Make recursive predictions starting from an initial state. + Generate recursive predictions from an initial state. + + Starting from the initial state, repeatedly applies the model + to predict subsequent states. - :param torch.Tensor initial_state: Initial state tensor. - :param int num_steps: Number of steps to predict ahead. - :return: Tensor of predictions. + :param torch.Tensor initial_state: Starting state with shape + ``[n_features]`` or ``[batch_size, n_features]``. + :param int num_steps: Number of future steps to predict. + :return: Tensor of predictions with shape + ``[num_steps + 1, n_features]``, including the initial state. :rtype: torch.Tensor """ pass @@ -96,4 +153,4 @@ def loss(self): :return: The loss function to be minimized. :rtype: torch.nn.Module """ - return self._loss_fn \ No newline at end of file + return self._loss_fn diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py new file mode 100644 index 000000000..467f90afe --- /dev/null +++ b/tests/test_solver/test_autoregressive_solver.py @@ -0,0 +1,213 @@ +import pytest +import torch + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver, UnrollInstructions + +NUM_TIMESTEPS = 10 +NUM_FEATURES = 3 + + +@pytest.fixture +def y_data(): + torch.manual_seed(42) + y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) + y[0] = torch.rand(NUM_FEATURES) + for t in range(NUM_TIMESTEPS - 1): + y[t + 1] = 0.95 * y[t] + return y + + +# crate a test Model +class ExactModel(torch.nn.Module): + """ + This model implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x + return next_state + 0.0 * self.dummy_param + + +@pytest.fixture +def solver(y_data): + """Create a minimal solver for testing internal methods.""" + + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + return AutoregressiveSolver( + unroll_instructions_list=[ + UnrollInstructions(condition_name="data", unroll_length=3) + ], + problem=Problem(), + model=ExactModel(), + ) + + +# Tests start here ============================================== + + +def test_exact_model(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = { + "data_condition": DataCondition(input=y_data), + } + + unroll_instruction = UnrollInstructions( + condition_name="data_condition", + unroll_length=5, + ) + + solver = AutoregressiveSolver( + unroll_instructions_list=[unroll_instruction], + problem=Problem(), + model=ExactModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.01), + ) + + loss = solver.loss_data(y_data, unroll_instruction) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + trainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + shuffle=False, + enable_model_summary=False, + ) + trainer.train() + + loss_after_training = solver.loss_data(y_data, unroll_instruction) + assert torch.isclose(loss_after_training, torch.tensor(0.0), atol=1e-6) + + predictions = solver.predict( + initial_state=y_data[0], num_steps=NUM_TIMESTEPS - 1 + ) + expected_predictions = y_data + assert torch.allclose(predictions, expected_predictions, atol=1e-6) + + +def test_indices_sequential_when_no_randomize(solver, y_data): + """Indices should be [0, 1, 2, ...] when randomize=False.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + # y_data has 10 timesteps, unroll_length=3 → max_start = 10 - 3 = 7 + expected = torch.arange(7) + assert torch.equal(indices, expected) + + +def test_indices_permuted_when_randomize(solver, y_data): + """Indices should contain same values but permuted when randomize=True.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + randomize=True, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + expected_values = set(range(7)) + actual_values = set(indices.tolist()) + assert actual_values == expected_values + + +def test_num_unrolls_parameter(solver, y_data): + """num_unrolls should limit the number of indices returned.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + num_unrolls=3, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + assert len(indices) == 3 + assert torch.equal(indices, torch.arange(3)) + + +def test_num_unrolls_greater_than_max_possible(solver, y_data): + """num_unrolls > max_possible should return all possible indices.""" + unroll_length = 3 + maximum_number_of_unrolls = ( + NUM_TIMESTEPS - unroll_length + ) # 10 - unroll_length(3) = 7 + instructions = UnrollInstructions( + condition_name="data", + unroll_length=unroll_length, + num_unrolls=100, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + assert len(indices) == maximum_number_of_unrolls + + +def test_no_valid_indices_when_unroll_too_long(solver, y_data): + """When unroll_length >= n_timesteps, no valid indices exist.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=NUM_TIMESTEPS + 1, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + print(indices) + assert len(indices) == 0 + + +def test_unroll_window_shape(solver, y_data): + """Unroll windows should have correct shapes.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=4, + num_unrolls=2, + randomize=False, + ) + initial_data, unroll_data = solver.create_unroll_windows( + y_data, instructions + ) + + assert initial_data.shape == (2, NUM_FEATURES) # [num_unrolls, features] + assert unroll_data.shape == ( + 2, + 4, + NUM_FEATURES, + ) # [num_unrolls, unroll_length, features] + + +def test_unroll_windows_content(solver, y_data): + """Verify actual content of unroll windows.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + num_unrolls=2, + randomize=False, + ) + initial_data, unroll_data = solver.create_unroll_windows( + y_data, instructions + ) + + # initial_data[i] should be y_data[i] + assert torch.equal(initial_data[0], y_data[0]) + assert torch.equal(initial_data[1], y_data[1]) + + # unroll_data[i] should be y_data[i+1 : i+1+unroll_length] + assert torch.equal(unroll_data[0], y_data[1:4]) + assert torch.equal(unroll_data[1], y_data[2:5]) From aa644c5b658782aacc9198c8b88baf4306b2cd3c Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Tue, 13 Jan 2026 15:10:45 +0100 Subject: [PATCH 5/5] fix formatting --- .../autoregressive_solver.py | 28 +++++++++---------- .../autoregressive_solver_interface.py | 7 +++-- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index c754aae4d..d91e1e254 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -16,7 +16,7 @@ class AutoregressiveSolver( This solver learns a one-step transition function :math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`. - + During training, the model is unrolled over multiple time steps to learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`, the model generates predictions recursively: @@ -183,7 +183,7 @@ def decide_starting_indices( n_step, n_features = data.shape num_unrolls = unroll_instructions.num_unrolls if unroll_instructions.unroll_length >= n_step: - return [] #TODO: decide if it is better to raise an error here + return [] # TODO: decide if it is better to raise an error here max_start = n_step - unroll_instructions.unroll_length indices = torch.arange(max_start, device=data.device) @@ -198,18 +198,18 @@ def decide_starting_indices( def predict(self, initial_state, num_steps): """ - Generate predictions by recursively applying the model. - - Starting from the initial state, applies the model repeatedly - to generate a trajectory of predicted states. - - :param torch.Tensor initial_state: Starting state with shape - ``[n_features]``. - :param int num_steps: Number of future time steps to predict. - :return: Predicted trajectory with shape - ``[num_steps + 1, n_features]``, where the first row is - the initial state. - :rtype: torch.Tensor + Generate predictions by recursively applying the model. + + Starting from the initial state, applies the model repeatedly + to generate a trajectory of predicted states. + + :param torch.Tensor initial_state: Starting state with shape + ``[n_features]``. + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory with shape + ``[num_steps + 1, n_features]``, where the first row is + the initial state. + :rtype: torch.Tensor """ self.eval() # Set model to evaluation mode diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index 9d08b6b4e..67f64b6cd 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -11,6 +11,7 @@ from ...condition import DataCondition from typing import Optional + @dataclass class UnrollInstructions: """ @@ -30,7 +31,7 @@ class UnrollInstructions: Default is True. :param Optional[float] eps: Epsilon parameter for exponential loss weighting. If ``None``, uniform weighting is applied. Default is ``None``. - + :Example: >>> instructions = UnrollInstructions( ... condition_name="trajectory", @@ -64,7 +65,7 @@ class AutoregressiveSolverInterface(SolverInterface): def __init__(self, unroll_instructions_list, loss=None, **kwargs): """ Initialization of the :class:`AutoregressiveSolverInterface` class. - + :param list[UnrollInstructions] unroll_instructions_list: List of :class:`UnrollInstructions` objects, one for each condition in the problem. Each instruction specifies how to create @@ -89,7 +90,7 @@ def optimization_cycle(self, batch): """ Optimization cycle for this family of solvers. Iterates over each condition and each time applies the specialized loss_data function. - + :param list[tuple[str, dict]] batch: List of tuples where each tuple contains a condition name and a dictionary with the ``"input"`` key mapping to the time series tensor.