Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pina/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"DeepEnsembleSupervisedSolver",
"DeepEnsemblePINN",
"GAROM",
"AutoregressiveSolver",
]

from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface
Expand All @@ -41,3 +42,8 @@
DeepEnsemblePINN,
)
from .garom import GAROM
from .autoregressive_solver import (
AutoregressiveSolver,
AutoregressiveSolverInterface,
UnrollInstructions,
)
5 changes: 5 additions & 0 deletions pina/solver/autoregressive_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"]

from .autoregressive_solver import AutoregressiveSolver
from .autoregressive_solver_interface import AutoregressiveSolverInterface
from .autoregressive_solver_interface import UnrollInstructions
225 changes: 225 additions & 0 deletions pina/solver/autoregressive_solver/autoregressive_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import torch
from pina.utils import check_consistency
from pina.solver.solver import SingleSolverInterface
from pina.condition import DataCondition
from .autoregressive_solver_interface import AutoregressiveSolverInterface
from .autoregressive_solver_interface import UnrollInstructions
from typing import List


class AutoregressiveSolver(
AutoregressiveSolverInterface, SingleSolverInterface
):
r"""
Autoregressive Solver for learning dynamical systems.

This solver learns a one-step transition function
:math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps
a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`.

During training, the model is unrolled over multiple time steps to
learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`,
the model generates predictions recursively:

.. math::
\hat{\mathbf{y}}_{t+1} = \mathcal{M}(\hat{\mathbf{y}}_t),
\quad \hat{\mathbf{y}}_0 = \mathbf{y}_0

The loss is computed over the entire unroll window:

.. math::
\mathcal{L} = \sum_{t=1}^{T} w_t \|\hat{\mathbf{y}}_t - \mathbf{y}_t\|^2

where :math:`w_t` are exponential weights (if ``eps`` is specified)
that down-weight later predictions to stabilize training.
"""

accepted_conditions_types = DataCondition

def __init__(
self,
unroll_instructions_list: List[UnrollInstructions],
problem,
model,
loss=None,
optimizer=None,
scheduler=None,
weighting=None,
use_lt=False,
):
"""
Initialization of the :class:`AutoregressiveSolver` class.

:param list[UnrollInstructions] unroll_instructions_list: List of
:class:`UnrollInstructions` specifying how to create training
windows for each condition.
:param AbstractProblem problem: The problem instance containing
the time series data conditions.
:param torch.nn.Module model: Neural network that predicts the
next state given the current state.
:param torch.nn.Module loss: Loss function to minimize.
If ``None``, :class:`torch.nn.MSELoss` is used.
Default is ``None``.
:param TorchOptimizer optimizer: Optimizer for training.
If ``None``, :class:`torch.optim.Adam` is used.
Default is ``None``.
:param TorchScheduler scheduler: Learning rate scheduler.
If ``None``, no scheduling is applied. Default is ``None``.
:param WeightingInterface weighting: Weighting scheme for
combining losses from multiple conditions.
If ``None``, uniform weighting is used. Default is ``None``.
:param bool use_lt: Whether to use LabelTensors.
Default is ``False``.
"""

super().__init__(
unroll_instructions_list=unroll_instructions_list,
problem=problem,
model=model,
loss=loss,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
use_lt=use_lt,
)

def loss_data(self, data, unroll_instructions: UnrollInstructions):
"""
Compute the data loss for the recursive autoregressive solver.

Creates unroll windows from the data, then iteratively predicts
each next state and computes the loss against the ground truth.

:param torch.Tensor data: Time series with shape
``[n_timesteps, n_features]``.
:param UnrollInstructions unroll_instructions: Configuration
for window creation and loss weighting.
:return: Weighted sum of step losses.
:rtype: torch.Tensor
"""

initial_data, unroll_data = self.create_unroll_windows(
data, unroll_instructions
)
current_state = initial_data # [num_unrolls, features]

losses = []
for step in range(unroll_instructions.unroll_length):

predicted_state = self.forward(
current_state
) # [num_unrolls, features]
target_state = unroll_data[:, step, :] # [num_unrolls, features]
step_loss = self._loss_fn(predicted_state, target_state)
losses.append(step_loss)
current_state = predicted_state

step_losses = torch.stack(losses) # [unroll_length]

with torch.no_grad():
eps = unroll_instructions.eps
if eps is None:
weights = torch.ones_like(step_losses)
else:
weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0))
weights = weights / weights.sum()

return (step_losses * weights).sum()

def create_unroll_windows(
self, data, unroll_instructions: UnrollInstructions
):
"""
Create unroll windows from time series data.

Slices the input time series into overlapping windows, each
consisting of an initial state and subsequent target states.

:param torch.Tensor data: Time series with shape
``[n_timesteps, n_features]``.
:param UnrollInstructions unroll_instructions: Configuration
specifying window length and count.
:return: Tuple of ``(initial_data, unroll_data)`` where:

- ``initial_data``: Shape ``[num_unrolls, n_features]``
- ``unroll_data``: Shape ``[num_unrolls, unroll_length, n_features]``

:rtype: tuple[torch.Tensor, torch.Tensor]
"""

unroll_length = unroll_instructions.unroll_length

start_list = []
unroll_list = []
for starting_index in self.decide_starting_indices(
data, unroll_instructions
):
idx = starting_index.item()
start_list.append(data[idx])
unroll_list.append(data[idx + 1 : idx + 1 + unroll_length, :])

initial_data = torch.stack(start_list) # [num_unrolls, features]
unroll_data = torch.stack(
unroll_list
) # [num_unrolls, unroll_length, features]
return initial_data, unroll_data

def decide_starting_indices(
self, data, unroll_instructions: UnrollInstructions
):
"""
Determine starting indices for unroll windows.

Computes valid starting positions ensuring each window has
enough subsequent time steps for the specified unroll length.

:param torch.Tensor data: Time series with shape
``[n_timesteps, n_features]``.
:param UnrollInstructions unroll_instructions: Configuration
with ``unroll_length``, ``num_unrolls``, and ``randomize``.
:return: 1D tensor of starting indices.
:rtype: torch.Tensor
"""
n_step, n_features = data.shape
num_unrolls = unroll_instructions.num_unrolls
if unroll_instructions.unroll_length >= n_step:
return [] # TODO: decide if it is better to raise an error here

max_start = n_step - unroll_instructions.unroll_length
indices = torch.arange(max_start, device=data.device)

if num_unrolls is not None and num_unrolls < len(indices):
indices = indices[:num_unrolls]

if unroll_instructions.randomize:
indices = indices[torch.randperm(len(indices), device=data.device)]

return indices

def predict(self, initial_state, num_steps):
"""
Generate predictions by recursively applying the model.

Starting from the initial state, applies the model repeatedly
to generate a trajectory of predicted states.

:param torch.Tensor initial_state: Starting state with shape
``[n_features]``.
:param int num_steps: Number of future time steps to predict.
:return: Predicted trajectory with shape
``[num_steps + 1, n_features]``, where the first row is
the initial state.
:rtype: torch.Tensor
"""
self.eval() # Set model to evaluation mode

current_state = initial_state
predictions = [current_state]

with torch.no_grad():
for step in range(num_steps):
next_state = self.forward(current_state)
predictions.append(next_state)
current_state = next_state

return torch.stack(predictions)
Loading