From a169969d652bc933892b74e5daaa83ac7093187c Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 26 Dec 2025 12:26:39 +0100 Subject: [PATCH 01/17] refact base and input-target condition --- pina/condition/condition_base.py | 210 +++++++++++++++++++++++ pina/condition/condition_interface.py | 94 +--------- pina/condition/input_target_condition.py | 186 +++++++++++++++++--- 3 files changed, 378 insertions(+), 112 deletions(-) create mode 100644 pina/condition/condition_base.py diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py new file mode 100644 index 000000000..361352d0f --- /dev/null +++ b/pina/condition/condition_base.py @@ -0,0 +1,210 @@ +import torch +from copy import deepcopy +from .condition_interface import ConditionInterface +from ..graph import Graph, LabelBatch +from ..label_tensor import LabelTensor +from ..data.dummy_dataloader import DummyDataloader +from torch_geometric.data import Data, Batch +from torch.utils.data import DataLoader +from functools import partial + + +class ConditionBase(ConditionInterface): + collate_fn_dict = { + "tensor": torch.stack, + "label_tensor": LabelTensor.stack, + "graph": LabelBatch.from_data_list, + "data": Batch.from_data_list, + } + + def __init__(self, **kwargs): + super().__init__() + self.data = self._store_data(**kwargs) + + @property + def problem(self): + return self._problem + + @problem.setter + def problem(self, value): + self._problem = value + + @staticmethod + def _check_graph_list_consistency(data_list): + """ + Check the consistency of the list of Data | Graph objects. + The following checks are performed: + + - All elements in the list must be of the same type (either + :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). + + - All elements in the list must have the same keys. + + - The data type of each tensor must be consistent across all elements. + + - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels + must also be consistent across all elements. + + :param data_list: The list of Data | Graph objects to check. + :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] + :raises ValueError: If the input types are invalid. + :raises ValueError: If all elements in the list do not have the same + keys. + :raises ValueError: If the type of each tensor is not consistent across + all elements in the list. + :raises ValueError: If the labels of the LabelTensors are not consistent + across all elements in the list. + """ + # If the data is a Graph or Data object, perform no checks + if isinstance(data_list, (Graph, Data)): + return + + # Check all elements in the list are of the same type + if not all(isinstance(i, (Graph, Data)) for i in data_list): + raise ValueError( + "Invalid input. Please, provide either Data or Graph objects." + ) + + # Store the keys, data types and labels of the first element + data = data_list[0] + keys = sorted(list(data.keys())) + data_types = {name: tensor.__class__ for name, tensor in data.items()} + labels = { + name: tensor.labels + for name, tensor in data.items() + if isinstance(tensor, LabelTensor) + } + + # Iterate over the list of Data | Graph objects + for data in data_list[1:]: + + # Check that all elements in the list have the same keys + if sorted(list(data.keys())) != keys: + raise ValueError( + "All elements in the list must have the same keys." + ) + + # Iterate over the tensors in the current element + for name, tensor in data.items(): + # Check that the type of each tensor is consistent + if tensor.__class__ is not data_types[name]: + raise ValueError( + f"Data {name} must be a {data_types[name]}, got " + f"{tensor.__class__}" + ) + + # Check that the labels of each LabelTensor are consistent + if isinstance(tensor, LabelTensor): + if tensor.labels != labels[name]: + raise ValueError( + "LabelTensor must have the same labels" + ) + + def _store_tensor_data(self, **kwargs): + """ + Store data for standard tensor condition + + :param kwargs: Keyword arguments representing the data to be stored. + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = {} + for key, value in kwargs.items(): + data[key] = value + return data + + def _store_graph_data(self, graphs, tensors=None, key=None): + """ + Store data for graph condition + + :param graphs: List of graphs to store data in. + :type graphs: list[Graph] | list[Data] + :param tensors: List of tensors to store in the graphs. + :type tensors: list[torch.Tensor] | list[LabelTensor] + :param key: Key under which to store the tensors in the graphs. + :type key: str + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = [] + for i, graph in enumerate(graphs): + new_graph = deepcopy(graph) + tensor = tensors[i] + setattr(new_graph, key, tensor) + data.append(new_graph) + return {"data": data} + + def _store_data(self, **kwargs): + return self._store_tensor_data(**kwargs) + + def __len__(self): + return len(next(iter(self.data.values()))) + + def __getitem__(self, idx): + return {key: self.data[key][idx] for key in self.data} + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + + to_return = {key: [] for key in batch[0].keys()} + for item in batch: + for key, value in item.items(): + to_return[key].append(value) + for key, values in to_return.items(): + collate_function = cls.collate_fn_dict.get( + "label_tensor" + if isinstance(values[0], LabelTensor) + else ( + "label_tensor" + if isinstance(values[0], torch.Tensor) + else "graph" if isinstance(values[0], Graph) else "data" + ) + ) + to_return[key] = collate_function(values) + return to_return + + @staticmethod + def collate_fn(batch, condition): + """ + Collate function for automatic batching to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: list + """ + data = condition[batch] + return data + + def create_dataloader( + self, dataset, batch_size, shuffle, automatic_batching + ): + """ + Create a DataLoader for the condition. + + :param int batch_size: The batch size for the DataLoader. + :param bool shuffle: Whether to shuffle the data. Default is ``False``. + :return: The DataLoader for the condition. + :rtype: torch.utils.data.DataLoader + """ + if batch_size == len(dataset): + return DummyDataloader(dataset) + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=( + partial(self.collate_fn, condition=self) + if not automatic_batching + else self.automatic_batching_collate_fn + ), + # collate_fn = self.automatic_batching_collate_fn + ) diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index b0264517c..427b85502 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -1,6 +1,6 @@ """Module for the Condition interface.""" -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from torch_geometric.data import Data from ..label_tensor import LabelTensor from ..graph import Graph @@ -15,13 +15,14 @@ class ConditionInterface(metaclass=ABCMeta): description of all available conditions and how to instantiate them. """ - def __init__(self): + @abstractmethod + def __init__(self, **kwargs): """ Initialization of the :class:`ConditionInterface` class. """ - self._problem = None @property + @abstractmethod def problem(self): """ Return the problem associated with this condition. @@ -29,9 +30,9 @@ def problem(self): :return: Problem associated with this condition. :rtype: ~pina.problem.abstract_problem.AbstractProblem """ - return self._problem @problem.setter + @abstractmethod def problem(self, value): """ Set the problem associated with this condition. @@ -39,88 +40,3 @@ def problem(self, value): :param pina.problem.abstract_problem.AbstractProblem value: The problem to associate with this condition """ - self._problem = value - - @staticmethod - def _check_graph_list_consistency(data_list): - """ - Check the consistency of the list of Data | Graph objects. - The following checks are performed: - - - All elements in the list must be of the same type (either - :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). - - - All elements in the list must have the same keys. - - - The data type of each tensor must be consistent across all elements. - - - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels - must also be consistent across all elements. - - :param data_list: The list of Data | Graph objects to check. - :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] - :raises ValueError: If the input types are invalid. - :raises ValueError: If all elements in the list do not have the same - keys. - :raises ValueError: If the type of each tensor is not consistent across - all elements in the list. - :raises ValueError: If the labels of the LabelTensors are not consistent - across all elements in the list. - """ - # If the data is a Graph or Data object, perform no checks - if isinstance(data_list, (Graph, Data)): - return - - # Check all elements in the list are of the same type - if not all(isinstance(i, (Graph, Data)) for i in data_list): - raise ValueError( - "Invalid input. Please, provide either Data or Graph objects." - ) - - # Store the keys, data types and labels of the first element - data = data_list[0] - keys = sorted(list(data.keys())) - data_types = {name: tensor.__class__ for name, tensor in data.items()} - labels = { - name: tensor.labels - for name, tensor in data.items() - if isinstance(tensor, LabelTensor) - } - - # Iterate over the list of Data | Graph objects - for data in data_list[1:]: - - # Check that all elements in the list have the same keys - if sorted(list(data.keys())) != keys: - raise ValueError( - "All elements in the list must have the same keys." - ) - - # Iterate over the tensors in the current element - for name, tensor in data.items(): - # Check that the type of each tensor is consistent - if tensor.__class__ is not data_types[name]: - raise ValueError( - f"Data {name} must be a {data_types[name]}, got " - f"{tensor.__class__}" - ) - - # Check that the labels of each LabelTensor are consistent - if isinstance(tensor, LabelTensor): - if tensor.labels != labels[name]: - raise ValueError( - "LabelTensor must have the same labels" - ) - - def __getattribute__(self, name): - """ - Get an attribute from the object. - - :param str name: The name of the attribute to get. - :return: The requested attribute. - :rtype: Any - """ - to_return = super().__getattribute__(name) - if isinstance(to_return, (Graph, Data)): - to_return = [to_return] - return to_return diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 07b07bb7b..965eeecfc 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -3,13 +3,15 @@ """ import torch +from copy import deepcopy from torch_geometric.data import Data from ..label_tensor import LabelTensor -from ..graph import Graph -from .condition_interface import ConditionInterface +from ..graph import Graph, LabelBatch +from .condition_base import ConditionBase +from torch_geometric.data import Batch -class InputTargetCondition(ConditionInterface): +class InputTargetCondition(ConditionBase): """ The :class:`InputTargetCondition` class represents a supervised condition defined by both ``input`` and ``target`` data. The model is trained to @@ -55,7 +57,7 @@ class InputTargetCondition(ConditionInterface): """ # Available input and target data types - __slots__ = ["input", "target"] + __fields__ = ["input", "target"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) @@ -109,16 +111,6 @@ def __new__(cls, input, target): subclass = GraphInputTensorTargetCondition return subclass.__new__(subclass, input, target) - # Graph - Graph - if isinstance(input, (Graph, Data, list, tuple)) and isinstance( - target, (Graph, Data, list, tuple) - ): - cls._check_graph_list_consistency(input) - cls._check_graph_list_consistency(target) - subclass = GraphInputGraphTargetCondition - return subclass.__new__(subclass, input, target) - - # If the input and/or target are not of the correct type raise an error raise ValueError( "Invalid input | target types." "Please provide either torch_geometric.data.Data, Graph, " @@ -143,10 +135,8 @@ def __init__(self, input, target): objects, all elements in the list must share the same structure, with matching keys and consistent data types. """ - super().__init__() self._check_input_target_len(input, target) - self.input = input - self.target = target + super().__init__(input=input, target=target) @staticmethod def _check_input_target_len(input, target): @@ -181,6 +171,26 @@ class TensorInputTensorTargetCondition(InputTargetCondition): :class:`~pina.label_tensor.LabelTensor` objects. """ + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: torch.Tensor | LabelTensor + """ + return self.data["input"] + + @property + def target(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: torch.Tensor | LabelTensor + """ + return self.data["target"] + class TensorInputGraphTargetCondition(InputTargetCondition): """ @@ -190,6 +200,65 @@ class TensorInputGraphTargetCondition(InputTargetCondition): :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object. """ + def _store_data(self, **kwargs): + return self._store_graph_data( + kwargs["target"], kwargs["input"], key="x" + ) + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + targets = [] + is_lt = isinstance(self.data["data"][0].x, LabelTensor) + for graph in self.data["data"]: + targets.append(graph.x) + return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) + + @property + def target(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.get_multiple_data(idx) + return {"data": self.data["data"][idx]} + + def get_multiple_data(self, indices): + data = self.batch_fn([self.data["data"][i] for i in indices]) + x = data.x + del data.x # Avoid duplication of y on GPU memory + return { + "input": x, + "target": data, + } + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + collated_graphs = super().automatic_batching_collate_fn(batch) + x = collated_graphs["data"].x + del collated_graphs["data"].x # Avoid duplication of y on GPU memory + to_return = {"input": x, "input": collated_graphs["data"]} + return to_return + class GraphInputTensorTargetCondition(InputTargetCondition): """ @@ -199,10 +268,81 @@ class GraphInputTensorTargetCondition(InputTargetCondition): :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. """ + def __init__(self, input, target): + """ + Initialization of the :class:`GraphInputTensorTargetCondition` class. -class GraphInputGraphTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - both ``input`` and ``target`` are either :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data` objects. - """ + :param input: The input data for the condition. + :type input: Graph | Data | list[Graph] | list[Data] | + tuple[Graph] | tuple[Data] + :param target: The target data for the condition. + :type target: torch.Tensor | LabelTensor + """ + super().__init__(input=input, target=target) + self.batch_fn = ( + LabelBatch.from_data_list + if isinstance(input[0], Graph) + else Batch.from_data_list + ) + + def _store_data(self, **kwargs): + return self._store_graph_data( + kwargs["input"], kwargs["target"], key="y" + ) + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] + + @property + def target(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + targets = [] + is_lt = isinstance(self.data["data"][0].y, LabelTensor) + for graph in self.data["data"]: + targets.append(graph.y) + + return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.get_multiple_data(idx) + return {"data": self.data["data"][idx]} + + def get_multiple_data(self, indices): + data = self.batch_fn([self.data["data"][i] for i in indices]) + y = data.y + del data.y # Avoid duplication of y on GPU memory + return { + "input": data, + "target": y, + } + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + collated_graphs = super().automatic_batching_collate_fn(batch) + y = collated_graphs["data"].y + del collated_graphs["data"].y # Avoid duplication of y on GPU memory + print("y shape:", y.shape) + print(y.labels) + to_return = {"target": y, "input": collated_graphs["data"]} + return to_return From 52ee3e78a687b688aa7944274a871e92f47a5122 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 08:58:01 +0100 Subject: [PATCH 02/17] refact condition factory --- pina/condition/condition.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pina/condition/condition.py b/pina/condition/condition.py index ad8764c9f..3c43f7176 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -86,12 +86,12 @@ class Condition: """ # Combine all possible keyword arguments from the different Condition types - __slots__ = list( + available_kwargs = list( set( - InputTargetCondition.__slots__ - + InputEquationCondition.__slots__ - + DomainEquationCondition.__slots__ - + DataCondition.__slots__ + InputTargetCondition.__fields__ + + InputEquationCondition.__fields__ + + DomainEquationCondition.__fields__ + + DataCondition.__fields__ ) ) @@ -112,28 +112,28 @@ def __new__(cls, *args, **kwargs): if len(args) != 0: raise ValueError( "Condition takes only the following keyword " - f"arguments: {Condition.__slots__}." + f"arguments: {Condition.available_kwargs}." ) # Class specialization based on keyword arguments sorted_keys = sorted(kwargs.keys()) # Input - Target Condition - if sorted_keys == sorted(InputTargetCondition.__slots__): + if sorted_keys == sorted(InputTargetCondition.__fields__): return InputTargetCondition(**kwargs) # Input - Equation Condition - if sorted_keys == sorted(InputEquationCondition.__slots__): + if sorted_keys == sorted(InputEquationCondition.__fields__): return InputEquationCondition(**kwargs) # Domain - Equation Condition - if sorted_keys == sorted(DomainEquationCondition.__slots__): + if sorted_keys == sorted(DomainEquationCondition.__fields__): return DomainEquationCondition(**kwargs) # Data Condition if ( - sorted_keys == sorted(DataCondition.__slots__) - or sorted_keys[0] == DataCondition.__slots__[0] + sorted_keys == sorted(DataCondition.__fields__) + or sorted_keys[0] == DataCondition.__fields__[0] ): return DataCondition(**kwargs) From 8ae31089022ffe2e3cadc6b07ad4c298d9202014 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 10:00:44 +0100 Subject: [PATCH 03/17] fix TensorInputGraphTargetCondition --- pina/condition/input_target_condition.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 965eeecfc..ece2f22e6 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -200,6 +200,23 @@ class TensorInputGraphTargetCondition(InputTargetCondition): :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object. """ + def __init__(self, input, target): + """ + Initialization of the :class:`TensorInputGraphTargetCondition` class. + + :param input: The input data for the condition. + :type input: torch.Tensor | LabelTensor + :param target: The target data for the condition. + :type target: Graph | Data | list[Graph] | list[Data] | + tuple[Graph] | tuple[Data] + """ + super().__init__(input=input, target=target) + self.batch_fn = ( + LabelBatch.from_data_list + if isinstance(target[0], Graph) + else Batch.from_data_list + ) + def _store_data(self, **kwargs): return self._store_graph_data( kwargs["target"], kwargs["input"], key="x" From 5ed425dc5fc62ee7fcbea806f254540feb41a1b0 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 10:01:20 +0100 Subject: [PATCH 04/17] Implement test for InputTargetCondition --- .../test_input_target_condition.py | 294 ++++++++++++++++++ 1 file changed, 294 insertions(+) create mode 100644 tests/test_condition/test_input_target_condition.py diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py new file mode 100644 index 000000000..033f7094a --- /dev/null +++ b/tests/test_condition/test_input_target_condition.py @@ -0,0 +1,294 @@ +import torch +import pytest +from torch_geometric.data import Batch +from pina import LabelTensor, Condition +from pina.condition import ( + TensorInputGraphTargetCondition, + TensorInputTensorTargetCondition, + GraphInputTensorTargetCondition, +) +from pina.graph import RadiusGraph, LabelBatch + + +def _create_tensor_data(use_lt=False): + if use_lt: + input_tensor = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) + target_tensor = LabelTensor(torch.rand((10, 2)), ["a", "b"]) + return input_tensor, target_tensor + input_tensor = torch.rand((10, 3)) + target_tensor = torch.rand((10, 2)) + return input_tensor, target_tensor + + +def _create_graph_data(tensor_input=True, use_lt=False): + if use_lt: + x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) + pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + else: + x = torch.rand(10, 20, 2) + pos = torch.rand(10, 20, 2) + radius = 0.1 + graph = [ + RadiusGraph( + pos=pos[i], + radius=radius, + x=x[i] if not tensor_input else None, + y=x[i] if tensor_input else None, + ) + for i in range(len(x)) + ] + if use_lt: + tensor = LabelTensor(torch.rand(10, 20, 1), ["f"]) + else: + tensor = torch.rand(10, 20, 1) + return graph, tensor + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_tensor_input_tensor_target_condition(use_lt): + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + assert isinstance(condition, TensorInputTensorTargetCondition) + assert torch.allclose( + condition.input, input_tensor + ), "TensorInputTensorTargetCondition input failed" + assert torch.allclose( + condition.target, target_tensor + ), "TensorInputTensorTargetCondition target failed" + if use_lt: + assert isinstance( + condition.input, LabelTensor + ), "TensorInputTensorTargetCondition input type failed" + assert condition.input.labels == [ + "x", + "y", + "z", + ], "TensorInputTensorTargetCondition input labels failed" + assert isinstance( + condition.target, LabelTensor + ), "TensorInputTensorTargetCondition target type failed" + assert condition.target.labels == [ + "a", + "b", + ], "TensorInputTensorTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + assert isinstance(condition, TensorInputGraphTargetCondition) + assert torch.allclose( + condition.input, input_tensor + ), "TensorInputGraphTargetCondition input failed" + for i in range(len(target_graph)): + assert torch.allclose( + condition.target[i].y, target_graph[i].y + ), "TensorInputGraphTargetCondition target failed" + if use_lt: + assert isinstance( + condition.target[i].y, LabelTensor + ), "TensorInputGraphTargetCondition target type failed" + assert condition.target[i].y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition target labels failed" + if use_lt: + assert isinstance( + condition.input, LabelTensor + ), "TensorInputGraphTargetCondition target type failed" + assert condition.input.labels == [ + "f" + ], "TensorInputGraphTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_graph_input_tensor_target_condition(use_lt): + input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) + condition = Condition(input=input_graph, target=target_tensor) + assert isinstance(condition, GraphInputTensorTargetCondition) + for i in range(len(input_graph)): + assert torch.allclose( + condition.input[i].x, input_graph[i].x + ), "GraphInputTensorTargetCondition input failed" + if use_lt: + assert isinstance( + condition.input[i].x, LabelTensor + ), "GraphInputTensorTargetCondition input type failed" + assert ( + condition.input[i].x.labels == input_graph[i].x.labels + ), "GraphInputTensorTargetCondition labels failed" + + assert torch.allclose( + condition.target[i], target_tensor[i] + ), "GraphInputTensorTargetCondition target failed" + if use_lt: + assert isinstance( + condition.target, LabelTensor + ), "GraphInputTensorTargetCondition target type failed" + assert condition.target.labels == [ + "f" + ], "GraphInputTensorTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_tensor_input_tensor_target_condition(use_lt): + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + for i in range(len(input_tensor)): + item = condition[i] + assert torch.allclose( + item["input"], input_tensor[i] + ), "TensorInputTensorTargetCondition __getitem__ input failed" + assert torch.allclose( + item["target"], target_tensor[i] + ), "TensorInputTensorTargetCondition __getitem__ target failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + for i in range(len(input_tensor)): + item = condition[i]["data"] + assert torch.allclose( + item.x, input_tensor[i] + ), "TensorInputGraphTargetCondition __getitem__ input failed" + assert torch.allclose( + item.y, target_graph[i].y + ), "TensorInputGraphTargetCondition __getitem__ target failed" + if use_lt: + assert isinstance( + item.y, LabelTensor + ), "TensorInputGraphTargetCondition __getitem__ target type failed" + assert item.y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition __getitem__ target labels failed" + + +def test_getitem_graph_input_tensor_target_condition(): + input_graph, target_tensor = _create_graph_data(False) + condition = Condition(input=input_graph, target=target_tensor) + for i in range(len(input_graph)): + item = condition[i]["data"] + print(item) + assert torch.allclose( + item.x, input_graph[i].x + ), "GraphInputTensorTargetCondition __getitem__ input failed" + assert torch.allclose( + item.y, target_tensor[i] + ), "GraphInputTensorTargetCondition __getitem__ target failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_graph_input_tensor_target_condition(use_lt): + input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) + condition = Condition(input=input_graph, target=target_tensor) + indices = [0, 2, 4] + items = condition[indices] + candidate_input = items["input"] + candidate_target = items["target"] + + if use_lt: + input_ = LabelBatch.from_data_list([input_graph[i] for i in indices]) + target_ = LabelTensor.cat([target_tensor[i] for i in indices], dim=0) + else: + input_ = Batch.from_data_list([input_graph[i] for i in indices]) + target_ = torch.cat([target_tensor[i] for i in indices], dim=0) + assert torch.allclose( + candidate_input.x, input_.x + ), "GraphInputTensorTargetCondition __geitemsem__ input failed" + assert torch.allclose( + candidate_target, target_ + ), "GraphInputTensorTargetCondition __geitemsem__ input failed" + if use_lt: + assert isinstance( + candidate_target, LabelTensor + ), "GraphInputTensorTargetCondition __getitems__ target type failed" + assert candidate_target.labels == [ + "f" + ], "GraphInputTensorTargetCondition __getitems__ target labels failed" + + assert isinstance( + candidate_input.x, LabelTensor + ), "GraphInputTensorTargetCondition __getitems__ input type failed" + assert ( + candidate_input.x.labels == input_graph[0].x.labels + ), "GraphInputTensorTargetCondition __getitems__ input labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_tensor_input_tensor_target_condition(use_lt): + + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + indices = [1, 3, 5, 7] + items = condition[indices] + candidate_input = items["input"] + candidate_target = items["target"] + + if use_lt: + input_ = LabelTensor.stack([input_tensor[i] for i in indices]) + target_ = LabelTensor.stack([target_tensor[i] for i in indices]) + else: + input_ = torch.stack([input_tensor[i] for i in indices]) + target_ = torch.stack([target_tensor[i] for i in indices]) + assert torch.allclose( + candidate_input, input_ + ), "TensorInputTensorTargetCondition __getitems__ input failed" + assert torch.allclose( + candidate_target, target_ + ), "TensorInputTensorTargetCondition __getitems__ target failed" + if use_lt: + assert isinstance( + candidate_input, LabelTensor + ), "TensorInputTensorTargetCondition __getitems__ input type failed" + assert candidate_input.labels == [ + "x", + "y", + "z", + ], "TensorInputTensorTargetCondition __getitems__ input labels failed" + assert isinstance( + candidate_target, LabelTensor + ), "TensorInputTensorTargetCondition __getitems__ target type failed" + assert candidate_target.labels == [ + "a", + "b", + ], "TensorInputTensorTargetCondition __getitems__ target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(True, use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + indices = [0, 2, 4] + items = condition[indices] + candidate_input = items["input"] + candidate_target = items["target"] + if use_lt: + input_ = LabelTensor.cat([input_tensor[i] for i in indices], dim=0) + target_ = LabelBatch.from_data_list([target_graph[i] for i in indices]) + else: + input_ = torch.cat([input_tensor[i] for i in indices], dim=0) + target_ = Batch.from_data_list([target_graph[i] for i in indices]) + assert torch.allclose( + candidate_input, input_ + ), "TensorInputGraphTargetCondition __getitems__ input failed" + assert torch.allclose( + candidate_target.y, target_.y + ), "TensorInputGraphTargetCondition __getitems__ target failed" + if use_lt: + assert isinstance( + candidate_input, LabelTensor + ), "TensorInputGraphTargetCondition __getitems__ input type failed" + assert candidate_input.labels == [ + "f" + ], "TensorInputGraphTargetCondition __getitems__ input labels failed" + assert isinstance( + candidate_target.y, LabelTensor + ), "TensorInputGraphTargetCondition __getitems__ target type failed" + assert candidate_target.y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition __getitems__ target labels failed" From ea26f3dae4f6cc3794d4edfc1b04686b352b9afe Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 14:50:46 +0100 Subject: [PATCH 05/17] fix codacy --- pina/condition/condition_base.py | 37 ++++++++++++++++++++---- pina/condition/condition_interface.py | 21 ++++++++++++-- pina/condition/input_target_condition.py | 29 +++++++++++++++---- 3 files changed, 74 insertions(+), 13 deletions(-) diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py index 361352d0f..0232375e3 100644 --- a/pina/condition/condition_base.py +++ b/pina/condition/condition_base.py @@ -1,15 +1,25 @@ -import torch +""" +Base class for conditions. +""" + from copy import deepcopy +from functools import partial +import torch +from torch_geometric.data import Data, Batch +from torch.utils.data import DataLoader from .condition_interface import ConditionInterface from ..graph import Graph, LabelBatch from ..label_tensor import LabelTensor from ..data.dummy_dataloader import DummyDataloader -from torch_geometric.data import Data, Batch -from torch.utils.data import DataLoader -from functools import partial class ConditionBase(ConditionInterface): + """ + Base abstract class for all conditions in PINA. + This class provides common functionality for handling data storage, + batching, and interaction with the associated problem. + """ + collate_fn_dict = { "tensor": torch.stack, "label_tensor": LabelTensor.stack, @@ -18,15 +28,32 @@ class ConditionBase(ConditionInterface): } def __init__(self, **kwargs): + """ + Initialization of the :class:`ConditionBase` class. + + :param kwargs: Keyword arguments representing the data to be stored. + """ super().__init__() self.data = self._store_data(**kwargs) @property def problem(self): + """ + Return the problem associated with this condition. + + :return: Problem associated with this condition. + :rtype: ~pina.problem.abstract_problem.AbstractProblem + """ return self._problem @problem.setter def problem(self, value): + """ + Set the problem associated with this condition. + + :param pina.problem.abstract_problem.AbstractProblem value: The problem + to associate with this condition + """ self._problem = value @staticmethod @@ -141,7 +168,7 @@ def __len__(self): return len(next(iter(self.data.values()))) def __getitem__(self, idx): - return {key: self.data[key][idx] for key in self.data} + return {name: data[idx] for name, data in self.data.items()} @classmethod def automatic_batching_collate_fn(cls, batch): diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index 427b85502..229b9a025 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -1,9 +1,6 @@ """Module for the Condition interface.""" from abc import ABCMeta, abstractmethod -from torch_geometric.data import Data -from ..label_tensor import LabelTensor -from ..graph import Graph class ConditionInterface(metaclass=ABCMeta): @@ -40,3 +37,21 @@ def problem(self, value): :param pina.problem.abstract_problem.AbstractProblem value: The problem to associate with this condition """ + + @abstractmethod + def __len__(self): + """ + Return the number of data points in the condition. + + :return: Number of data points. + :rtype: int + """ + + @abstractmethod + def __getitem__(self, idx): + """ + Return the data point(s) at the specified index. + + :param int idx: Index of the data point(s) to retrieve. + :return: Data point(s) at the specified index. + """ diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index ece2f22e6..c90fcc8e3 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -3,12 +3,10 @@ """ import torch -from copy import deepcopy -from torch_geometric.data import Data +from torch_geometric.data import Data, Batch from ..label_tensor import LabelTensor from ..graph import Graph, LabelBatch from .condition_base import ConditionBase -from torch_geometric.data import Batch class InputTargetCondition(ConditionBase): @@ -218,6 +216,13 @@ def __init__(self, input, target): ) def _store_data(self, **kwargs): + """ + Store the input and target data for the condition. + + :param kwargs: Keyword arguments containing 'input' and 'target'. + :return: Stored data dictionary. + :rtype: dict + """ return self._store_graph_data( kwargs["target"], kwargs["input"], key="x" ) @@ -252,6 +257,13 @@ def __getitem__(self, idx): return {"data": self.data["data"][idx]} def get_multiple_data(self, indices): + """ + Get multiple data items based on the provided indices. + + :param List[int] indices: List of indices to retrieve. + :return: Dictionary containing 'input' and 'target' data. + :rtype: dict + """ data = self.batch_fn([self.data["data"][i] for i in indices]) x = data.x del data.x # Avoid duplication of y on GPU memory @@ -266,14 +278,14 @@ def automatic_batching_collate_fn(cls, batch): Collate function to be used in DataLoader. :param batch: A list of items from the dataset. - :type batch: list + :type batch: List[dict] :return: A collated batch. :rtype: dict """ collated_graphs = super().automatic_batching_collate_fn(batch) x = collated_graphs["data"].x del collated_graphs["data"].x # Avoid duplication of y on GPU memory - to_return = {"input": x, "input": collated_graphs["data"]} + to_return = {"input": x, "target": collated_graphs["data"]} return to_return @@ -338,6 +350,13 @@ def __getitem__(self, idx): return {"data": self.data["data"][idx]} def get_multiple_data(self, indices): + """ + Get multiple data items based on the provided indices. + + :param List[int] indices: List of indices to retrieve. + :return: Dictionary containing 'input' and 'target' data. + :rtype: dict + """ data = self.batch_fn([self.data["data"][i] for i in indices]) y = data.y del data.y # Avoid duplication of y on GPU memory From 74749fc40a0434a6e59b0bd747e81aea421658b8 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 15:13:38 +0100 Subject: [PATCH 06/17] refact InputEquationCondition --- pina/condition/input_equation_condition.py | 30 +++------ .../test_input_equation_condition.py | 65 +++++++++++++++++++ 2 files changed, 75 insertions(+), 20 deletions(-) create mode 100644 tests/test_condition/test_input_equation_condition.py diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index d32597894..fa41f79e2 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -1,13 +1,12 @@ """Module for the InputEquationCondition class and its subclasses.""" -from .condition_interface import ConditionInterface +from .condition_base import ConditionBase from ..label_tensor import LabelTensor from ..graph import Graph -from ..utils import check_consistency from ..equation.equation_interface import EquationInterface -class InputEquationCondition(ConditionInterface): +class InputEquationCondition(ConditionBase): """ The class :class:`InputEquationCondition` defines a condition based on ``input`` data and an ``equation``. This condition is typically used in @@ -41,7 +40,7 @@ class InputEquationCondition(ConditionInterface): """ # Available input data types - __slots__ = ["input", "equation"] + __fields__ = ["input", "equation"] _avail_input_cls = (LabelTensor, Graph, list, tuple) _avail_equation_cls = EquationInterface @@ -97,27 +96,18 @@ def __init__(self, input, equation): the list must share the same structure, with matching keys and consistent data types. """ - super().__init__() - self.input = input + super().__init__(input=input) self.equation = equation - def __setattr__(self, key, value): + @property + def input(self): """ - Set the attribute value with type checking. + Return the input data for the condition. - :param str key: The attribute name. - :param any value: The value to set for the attribute. + :return: The input data. + :rtype: LabelTensor | Graph | list[Graph] | tuple[Graph] """ - if key == "input": - check_consistency(value, self._avail_input_cls) - InputEquationCondition.__dict__[key].__set__(self, value) - - elif key == "equation": - check_consistency(value, self._avail_equation_cls) - InputEquationCondition.__dict__[key].__set__(self, value) - - elif key in ("_problem"): - super().__setattr__(key, value) + return self.data["input"] class InputTensorEquationCondition(InputEquationCondition): diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py new file mode 100644 index 000000000..b6a687e2a --- /dev/null +++ b/tests/test_condition/test_input_equation_condition.py @@ -0,0 +1,65 @@ +import torch +from pina import Condition +from pina.condition.input_equation_condition import ( + InputTensorEquationCondition, + InputGraphEquationCondition, +) +from pina.equation import Equation +from pina import LabelTensor + + +def _create_pts_and_equation(): + def dummy_equation(pts): + return pts["x"] ** 2 + pts["y"] ** 2 - 1 + + pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + equation = Equation(dummy_equation) + return pts, equation + + +def _create_graph_and_equation(): + from pina.graph import KNNGraph + + def dummy_equation(pts): + return pts.x[:, 0] ** 2 + pts.x[:, 1] ** 2 - 1 + + x = LabelTensor(torch.randn(100, 2), labels=["u", "v"]) + pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + graph = KNNGraph(x=x, pos=pos, neighbours=5, edge_attr=True) + equation = Equation(dummy_equation) + return graph, equation + + +def test_init_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + assert isinstance(condition, InputTensorEquationCondition) + assert condition.input.shape == (100, 2) + assert condition.equation is equation + + +def test_init_graph_equation_condition(): + graph, equation = _create_graph_and_equation() + condition = Condition(input=graph, equation=equation) + assert isinstance(condition, InputGraphEquationCondition) + assert condition.input is graph + assert condition.equation is equation + + +def test_getitem_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + item = condition[0] + assert isinstance(item, dict) + assert "input" in item + assert item["input"].shape == (2,) + + +def test_getitems_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + idxs = [0, 1, 3] + item = condition[idxs] + assert isinstance(item, dict) + assert "input" in item + assert item["input"].shape == (3, 2) From 8fbfde94af38e39164de7ea157c34f15d3de251b Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 15:22:11 +0100 Subject: [PATCH 07/17] refact DomainEquationCondition --- pina/condition/domain_equation_condition.py | 45 ++++++++++--------- .../test_domain_equation_condition.py | 27 +++++++++++ 2 files changed, 52 insertions(+), 20 deletions(-) create mode 100644 tests/test_condition/test_domain_equation_condition.py diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 3565c0b41..3e4adbaee 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -1,12 +1,11 @@ """Module for the DomainEquationCondition class.""" -from .condition_interface import ConditionInterface -from ..utils import check_consistency +from .condition_base import ConditionBase from ..domain import DomainInterface from ..equation.equation_interface import EquationInterface -class DomainEquationCondition(ConditionInterface): +class DomainEquationCondition(ConditionBase): """ The class :class:`DomainEquationCondition` defines a condition based on a ``domain`` and an ``equation``. This condition is typically used in @@ -30,7 +29,7 @@ class DomainEquationCondition(ConditionInterface): """ # Available slots - __slots__ = ["domain", "equation"] + __fields__ = ["domain", "equation"] def __init__(self, domain, equation): """ @@ -41,24 +40,30 @@ def __init__(self, domain, equation): :param EquationInterface equation: The equation to be satisfied over the specified domain. """ + if not isinstance(domain, (DomainInterface, str)): + raise ValueError( + f"`domain` must be an instance of DomainInterface, " + f"got {type(domain)} instead." + ) + if not isinstance(equation, EquationInterface): + raise ValueError( + f"`equation` must be an instance of EquationInterface, " + f"got {type(equation)} instead." + ) super().__init__() self.domain = domain self.equation = equation - def __setattr__(self, key, value): - """ - Set the attribute value with type checking. - - :param str key: The attribute name. - :param any value: The value to set for the attribute. - """ - if key == "domain": - check_consistency(value, (DomainInterface, str)) - DomainEquationCondition.__dict__[key].__set__(self, value) - - elif key == "equation": - check_consistency(value, (EquationInterface)) - DomainEquationCondition.__dict__[key].__set__(self, value) + def __len__(self): + raise NotImplementedError( + "`__len__` method is not implemented for " + "`DomainEquationCondition` since the number of points is " + "determined by the domain sampling strategy." + ) - elif key in ("_problem"): - super().__setattr__(key, value) + def __getitem__(self, idx): + """ """ + raise NotImplementedError( + "`__getitem__` method is not implemented for " + "`DomainEquationCondition`" + ) diff --git a/tests/test_condition/test_domain_equation_condition.py b/tests/test_condition/test_domain_equation_condition.py new file mode 100644 index 000000000..2b7c78b00 --- /dev/null +++ b/tests/test_condition/test_domain_equation_condition.py @@ -0,0 +1,27 @@ +import pytest +from pina import Condition +from pina.domain import CartesianDomain +from pina.equation.equation_factory import FixedValue +from pina.condition import DomainEquationCondition + +example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) +example_equation = FixedValue(0.0) + + +def test_init_domain_equation(): + cond = Condition(domain=example_domain, equation=example_equation) + assert isinstance(cond, DomainEquationCondition) + assert cond.domain is example_domain + assert cond.equation is example_equation + + +def test_len_not_implemented(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + with pytest.raises(NotImplementedError): + len(cond) + + +def test_getitem_not_implemented(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + with pytest.raises(NotImplementedError): + cond[0] From c941444007947efa95f6a106a3d6d5cc7470ddaa Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 30 Dec 2025 12:27:09 +0100 Subject: [PATCH 08/17] implement TensorCondition and GraphCondition --- pina/condition/condition_base.py | 112 ++++++++++++++++++++----------- 1 file changed, 74 insertions(+), 38 deletions(-) diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py index 0232375e3..b8b828767 100644 --- a/pina/condition/condition_base.py +++ b/pina/condition/condition_base.py @@ -13,6 +13,79 @@ from ..data.dummy_dataloader import DummyDataloader +class TensorCondition: + def store_data(self, **kwargs): + """ + Store data for standard tensor condition + + :param kwargs: Keyword arguments representing the data to be stored. + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = {} + for key, value in kwargs.items(): + data[key] = value + return data + + +class GraphCondition: + def __init__(self, **kwargs): + super().__init__(**kwargs) + example = kwargs.get(self.graph_field)[0] + self.batch_fn = ( + LabelBatch.from_data_list + if isinstance(example, Graph) + else Batch.from_data_list + ) + + def store_data(self, **kwargs): + """ + Store data for graph condition + + :param graphs: List of graphs to store data in. + :type graphs: list[Graph] | list[Data] + :param tensors: List of tensors to store in the graphs. + :type tensors: list[torch.Tensor] | list[LabelTensor] + :param key: Key under which to store the tensors in the graphs. + :type key: str + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = [] + graphs = kwargs.get(self.graph_field) + for i, graph in enumerate(graphs): + new_graph = deepcopy(graph) + for key in self.tensor_fields: + tensor = kwargs[key][i] + mapping_key = self.keys_map.get(key) + setattr(new_graph, mapping_key, tensor) + data.append(new_graph) + return {"data": data} + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.get_multiple_data(idx) + return {"data": self.data["data"][idx]} + + def get_multiple_data(self, indices): + """ + Get multiple data items based on the provided indices. + + :param List[int] indices: List of indices to retrieve. + :return: Dictionary containing 'input' and 'target' data. + :rtype: dict + """ + to_return_dict = {} + data = self.batch_fn([self.data["data"][i] for i in indices]) + to_return_dict[self.graph_field] = data + for key in self.tensor_fields: + mapping_key = self.keys_map.get(key) + y = getattr(data, mapping_key) + delattr(data, mapping_key) # Avoid duplication of y on GPU memory + to_return_dict[key] = y + return to_return_dict + + class ConditionBase(ConditionInterface): """ Base abstract class for all conditions in PINA. @@ -34,7 +107,7 @@ def __init__(self, **kwargs): :param kwargs: Keyword arguments representing the data to be stored. """ super().__init__() - self.data = self._store_data(**kwargs) + self.data = self.store_data(**kwargs) @property def problem(self): @@ -127,43 +200,6 @@ def _check_graph_list_consistency(data_list): "LabelTensor must have the same labels" ) - def _store_tensor_data(self, **kwargs): - """ - Store data for standard tensor condition - - :param kwargs: Keyword arguments representing the data to be stored. - :return: A dictionary containing the stored data. - :rtype: dict - """ - data = {} - for key, value in kwargs.items(): - data[key] = value - return data - - def _store_graph_data(self, graphs, tensors=None, key=None): - """ - Store data for graph condition - - :param graphs: List of graphs to store data in. - :type graphs: list[Graph] | list[Data] - :param tensors: List of tensors to store in the graphs. - :type tensors: list[torch.Tensor] | list[LabelTensor] - :param key: Key under which to store the tensors in the graphs. - :type key: str - :return: A dictionary containing the stored data. - :rtype: dict - """ - data = [] - for i, graph in enumerate(graphs): - new_graph = deepcopy(graph) - tensor = tensors[i] - setattr(new_graph, key, tensor) - data.append(new_graph) - return {"data": data} - - def _store_data(self, **kwargs): - return self._store_tensor_data(**kwargs) - def __len__(self): return len(next(iter(self.data.values()))) From 4fda9a7875ae69f3b11c9ad514bb6603dd4238fa Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 30 Dec 2025 12:27:49 +0100 Subject: [PATCH 09/17] use GraphCondition and Tensor condition classes --- pina/condition/__init__.py | 4 +- pina/condition/data_condition.py | 113 +++++++++++++++++-- pina/condition/domain_equation_condition.py | 9 ++ pina/condition/input_equation_condition.py | 46 ++++++-- pina/condition/input_target_condition.py | 117 +++----------------- 5 files changed, 166 insertions(+), 123 deletions(-) diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 4e57811fb..13429a829 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -8,7 +8,7 @@ "TensorInputTensorTargetCondition", "TensorInputGraphTargetCondition", "GraphInputTensorTargetCondition", - "GraphInputGraphTargetCondition", + # "GraphInputGraphTargetCondition", "InputEquationCondition", "InputTensorEquationCondition", "InputGraphEquationCondition", @@ -25,7 +25,7 @@ TensorInputTensorTargetCondition, TensorInputGraphTargetCondition, GraphInputTensorTargetCondition, - GraphInputGraphTargetCondition, + # GraphInputGraphTargetCondition, ) from .input_equation_condition import ( InputEquationCondition, diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 5f5e7d36b..b04166b51 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -1,13 +1,13 @@ """Module for the DataCondition class.""" import torch -from torch_geometric.data import Data -from .condition_interface import ConditionInterface +from torch_geometric.data import Data, Batch +from .condition_base import ConditionBase, GraphCondition, TensorCondition from ..label_tensor import LabelTensor -from ..graph import Graph +from ..graph import Graph, LabelBatch -class DataCondition(ConditionInterface): +class DataCondition(ConditionBase): """ The class :class:`DataCondition` defines an unsupervised condition based on ``input`` data. This condition is typically used in data-driven problems, @@ -38,7 +38,7 @@ class DataCondition(ConditionInterface): """ # Available input data types - __slots__ = ["input", "conditional_variables"] + __fields__ = ["input", "conditional_variables"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_conditional_variables_cls = (torch.Tensor, LabelTensor) @@ -99,22 +99,115 @@ def __init__(self, input, conditional_variables=None): the list must share the same structure, with matching keys and consistent data types. """ - super().__init__() - self.input = input - self.conditional_variables = conditional_variables + if conditional_variables is None: + super().__init__(input=input) + else: + super().__init__( + input=input, conditional_variables=conditional_variables + ) + + @property + def conditional_variables(self): + """ + Return the conditional variables for the condition. + + :return: The conditional variables. + :rtype: torch.Tensor | LabelTensor | None + """ + return self.data.get("conditional_variables", None) -class TensorDataCondition(DataCondition): +class TensorDataCondition(TensorCondition, DataCondition): """ Specialization of the :class:`DataCondition` class for the case where ``input`` is either a :class:`~pina.label_tensor.LabelTensor` object or a :class:`torch.Tensor` object. """ + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: torch.Tensor | LabelTensor + """ + return self.data["input"] + -class GraphDataCondition(DataCondition): +class GraphDataCondition(GraphCondition, DataCondition): """ Specialization of the :class:`DataCondition` class for the case where ``input`` is either a :class:`~pina.graph.Graph` object or a :class:`~torch_geometric.data.Data` object. """ + + def __init__(self, input, conditional_variables=None): + """ + Initialization of the :class:`GraphDataCondition` class. + + :param input: The input data for the condition. + :type input: Graph | Data | list[Graph] | list[Data] | + tuple[Graph] | tuple[Data] + :param conditional_variables: The conditional variables for the + condition. Default is ``None``. + :type conditional_variables: torch.Tensor | LabelTensor + + .. note:: + + If ``input`` is a list of :class:`~pina.graph.Graph` or + :class:`~torch_geometric.data.Data`, all elements in + the list must share the same structure, with matching keys and + consistent data types. + """ + self.graph_field = "input" + self.tensor_fields = [] + self.keys_map = {} + if conditional_variables is not None: + self.tensor_fields.append("conditional_variables") + self.keys_map["conditional_variables"] = "cond_vars" + super().__init__( + input=input, conditional_variables=conditional_variables + ) + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: Graph | Data | list[Graph] | list[Data] | tuple[Graph] | + tuple[Data] + """ + return self.data["data"] + + @property + def conditional_variables(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + + if not hasattr(self.data["data"][0], "cond_vars"): + return None + cond_vars = [] + is_lt = isinstance(self.data["data"][0].cond_vars, LabelTensor) + for graph in self.data["data"]: + cond_vars.append(graph.cond_vars) + return ( + torch.stack(cond_vars) + if not is_lt + else LabelTensor.stack(cond_vars) + ) + + def __getitem__(self, idx): + """ + Get item by index from the input data. + + :param int index: The index of the item to retrieve. + :return: The item at the specified index. + :rtype: Graph | Data + """ + input_ = self.batch_fn(self.data["input"][idx]) diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 3e4adbaee..0ce05eeab 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -67,3 +67,12 @@ def __getitem__(self, idx): "`__getitem__` method is not implemented for " "`DomainEquationCondition`" ) + + def store_data(self): + """ + Store the data for the condition by sampling points from the domain. + + :return: Sampled points from the domain. + :rtype: dict + """ + return {} diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index fa41f79e2..913cdc4d2 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -1,6 +1,6 @@ """Module for the InputEquationCondition class and its subclasses.""" -from .condition_base import ConditionBase +from .condition_base import ConditionBase, TensorCondition, GraphCondition from ..label_tensor import LabelTensor from ..graph import Graph from ..equation.equation_interface import EquationInterface @@ -99,6 +99,13 @@ def __init__(self, input, equation): super().__init__(input=input) self.equation = equation + +class InputTensorEquationCondition(TensorCondition, InputEquationCondition): + """ + Specialization of the :class:`InputEquationCondition` class for the case + where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object. + """ + @property def input(self): """ @@ -110,18 +117,31 @@ def input(self): return self.data["input"] -class InputTensorEquationCondition(InputEquationCondition): +class InputGraphEquationCondition(GraphCondition, InputEquationCondition): """ Specialization of the :class:`InputEquationCondition` class for the case - where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object. + where ``input`` is a :class:`~pina.graph.Graph` object. """ + def __init__(self, input, equation): + """ + Initialization of the :class:`InputGraphEquationCondition` class. -class InputGraphEquationCondition(InputEquationCondition): - """ - Specialization of the :class:`InputEquationCondition` class for the case - where ``input`` is a :class:`~pina.graph.Graph` object. - """ + :param input: The input data for the condition. + :type input: Graph | list[Graph] | tuple[Graph] + :param EquationInterface equation: The equation to be satisfied over the + specified input points. + + .. note:: + + If ``input`` is a list of :class:`~pina.graph.Graph` all elements in + the list must share the same structure, with matching keys and + consistent data types. + """ + self.graph_field = "input" + self.tensor_fields = [] + self.keys_map = {} + super().__init__(input=[input], equation=equation) @staticmethod def _check_label_tensor(input): @@ -145,3 +165,13 @@ def _check_label_tensor(input): return raise ValueError("The input must contain at least one LabelTensor.") + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index c90fcc8e3..3e041bf90 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -3,10 +3,10 @@ """ import torch -from torch_geometric.data import Data, Batch +from torch_geometric.data import Data from ..label_tensor import LabelTensor -from ..graph import Graph, LabelBatch -from .condition_base import ConditionBase +from ..graph import Graph +from .condition_base import ConditionBase, GraphCondition, TensorCondition class InputTargetCondition(ConditionBase): @@ -115,7 +115,7 @@ def __new__(cls, input, target): "LabelTensor or torch.Tensor objects." ) - def __init__(self, input, target): + def __init__(self, **kwargs): """ Initialization of the :class:`InputTargetCondition` class. @@ -133,36 +133,10 @@ def __init__(self, input, target): objects, all elements in the list must share the same structure, with matching keys and consistent data types. """ - self._check_input_target_len(input, target) - super().__init__(input=input, target=target) + super().__init__(**kwargs) - @staticmethod - def _check_input_target_len(input, target): - """ - Check that the length of the input and target lists are the same. - :param input: The input data. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :raises ValueError: If the lengths of the input and target lists do not - match. - """ - if isinstance(input, (Graph, Data)) or isinstance( - target, (Graph, Data) - ): - return - - # Raise an error if the lengths of the input and target do not match - if len(input) != len(target): - raise ValueError( - "The input and target lists must have the same length." - ) - - -class TensorInputTensorTargetCondition(InputTargetCondition): +class TensorInputTensorTargetCondition(InputTargetCondition, TensorCondition): """ Specialization of the :class:`InputTargetCondition` class for the case where both ``input`` and ``target`` are :class:`torch.Tensor` or @@ -190,7 +164,7 @@ def target(self): return self.data["target"] -class TensorInputGraphTargetCondition(InputTargetCondition): +class TensorInputGraphTargetCondition(GraphCondition, InputTargetCondition): """ Specialization of the :class:`InputTargetCondition` class for the case where ``input`` is either a :class:`torch.Tensor` or a @@ -208,24 +182,10 @@ def __init__(self, input, target): :type target: Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] """ + self.graph_field = "target" + self.tensor_fields = ["input"] + self.keys_map = {"input": "x"} super().__init__(input=input, target=target) - self.batch_fn = ( - LabelBatch.from_data_list - if isinstance(target[0], Graph) - else Batch.from_data_list - ) - - def _store_data(self, **kwargs): - """ - Store the input and target data for the condition. - - :param kwargs: Keyword arguments containing 'input' and 'target'. - :return: Stored data dictionary. - :rtype: dict - """ - return self._store_graph_data( - kwargs["target"], kwargs["input"], key="x" - ) @property def input(self): @@ -251,27 +211,6 @@ def target(self): """ return self.data["data"] - def __getitem__(self, idx): - if isinstance(idx, list): - return self.get_multiple_data(idx) - return {"data": self.data["data"][idx]} - - def get_multiple_data(self, indices): - """ - Get multiple data items based on the provided indices. - - :param List[int] indices: List of indices to retrieve. - :return: Dictionary containing 'input' and 'target' data. - :rtype: dict - """ - data = self.batch_fn([self.data["data"][i] for i in indices]) - x = data.x - del data.x # Avoid duplication of y on GPU memory - return { - "input": x, - "target": data, - } - @classmethod def automatic_batching_collate_fn(cls, batch): """ @@ -289,7 +228,7 @@ def automatic_batching_collate_fn(cls, batch): return to_return -class GraphInputTensorTargetCondition(InputTargetCondition): +class GraphInputTensorTargetCondition(GraphCondition, InputTargetCondition): """ Specialization of the :class:`InputTargetCondition` class for the case where ``input`` is either a :class:`~pina.graph.Graph` or @@ -307,17 +246,10 @@ def __init__(self, input, target): :param target: The target data for the condition. :type target: torch.Tensor | LabelTensor """ + self.graph_field = "input" + self.tensor_fields = ["target"] + self.keys_map = {"target": "y"} super().__init__(input=input, target=target) - self.batch_fn = ( - LabelBatch.from_data_list - if isinstance(input[0], Graph) - else Batch.from_data_list - ) - - def _store_data(self, **kwargs): - return self._store_graph_data( - kwargs["input"], kwargs["target"], key="y" - ) @property def input(self): @@ -344,27 +276,6 @@ def target(self): return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) - def __getitem__(self, idx): - if isinstance(idx, list): - return self.get_multiple_data(idx) - return {"data": self.data["data"][idx]} - - def get_multiple_data(self, indices): - """ - Get multiple data items based on the provided indices. - - :param List[int] indices: List of indices to retrieve. - :return: Dictionary containing 'input' and 'target' data. - :rtype: dict - """ - data = self.batch_fn([self.data["data"][i] for i in indices]) - y = data.y - del data.y # Avoid duplication of y on GPU memory - return { - "input": data, - "target": y, - } - @classmethod def automatic_batching_collate_fn(cls, batch): """ From 598bce42cf09b62366926e51dec59ac3a49b826e Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 30 Dec 2025 12:28:04 +0100 Subject: [PATCH 10/17] fix tests --- tests/test_condition.py | 325 +++++++++--------- tests/test_condition/test_data_condition.py | 100 ++++++ .../test_input_equation_condition.py | 4 +- .../test_input_target_condition.py | 3 + .../test_ensemble_supervised_solver.py | 3 +- tests/test_solver/test_supervised_solver.py | 3 +- 6 files changed, 281 insertions(+), 157 deletions(-) create mode 100644 tests/test_condition/test_data_condition.py diff --git a/tests/test_condition.py b/tests/test_condition.py index 9199f2bd9..8a5480499 100644 --- a/tests/test_condition.py +++ b/tests/test_condition.py @@ -1,154 +1,171 @@ -import torch -import pytest - -from pina import LabelTensor, Condition -from pina.condition import ( - TensorInputGraphTargetCondition, - TensorInputTensorTargetCondition, - GraphInputGraphTargetCondition, - GraphInputTensorTargetCondition, -) -from pina.condition import ( - InputTensorEquationCondition, - InputGraphEquationCondition, - DomainEquationCondition, -) -from pina.condition import ( - TensorDataCondition, - GraphDataCondition, -) -from pina.domain import CartesianDomain -from pina.equation.equation_factory import FixedValue -from pina.graph import RadiusGraph - -example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) - -input_tensor = torch.rand((10, 3)) -target_tensor = torch.rand((10, 2)) -input_lt = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) -target_lt = LabelTensor(torch.rand((10, 2)), ["a", "b"]) - -x = torch.rand(10, 20, 2) -pos = torch.rand(10, 20, 2) -radius = 0.1 -input_graph = [ - RadiusGraph( - x=x_, - pos=pos_, - radius=radius, - ) - for x_, pos_ in zip(x, pos) -] -target_graph = [ - RadiusGraph( - x=x_, - pos=pos_, - radius=radius, - ) - for x_, pos_ in zip(x, pos) -] - -x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) -pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) -radius = 0.1 -input_graph_lt = [ - RadiusGraph( - x=x[i], - pos=pos[i], - radius=radius, - ) - for i in range(len(x)) -] -target_graph_lt = [ - RadiusGraph( - x=x[i], - pos=pos[i], - radius=radius, - ) - for i in range(len(x)) -] - -input_single_graph = input_graph[0] -target_single_graph = target_graph[0] - - -def test_init_input_target(): - cond = Condition(input=input_tensor, target=target_tensor) - assert isinstance(cond, TensorInputTensorTargetCondition) - cond = Condition(input=input_tensor, target=target_tensor) - assert isinstance(cond, TensorInputTensorTargetCondition) - cond = Condition(input=input_tensor, target=target_graph) - assert isinstance(cond, TensorInputGraphTargetCondition) - cond = Condition(input=input_graph, target=target_tensor) - assert isinstance(cond, GraphInputTensorTargetCondition) - cond = Condition(input=input_graph, target=target_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - - cond = Condition(input=input_lt, target=input_single_graph) - assert isinstance(cond, TensorInputGraphTargetCondition) - cond = Condition(input=input_single_graph, target=target_lt) - assert isinstance(cond, GraphInputTensorTargetCondition) - cond = Condition(input=input_graph, target=target_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - cond = Condition(input=input_single_graph, target=target_single_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - - with pytest.raises(ValueError): - Condition(input_tensor, input_tensor) - with pytest.raises(ValueError): - Condition(input=3.0, target="example") - with pytest.raises(ValueError): - Condition(input=example_domain, target=example_domain) - - # Test wrong graph condition initialisation - input = [input_graph[0], input_graph_lt[0]] - target = [target_graph[0], target_graph_lt[0]] - with pytest.raises(ValueError): - Condition(input=input, target=target) - - input_graph_lt[0].x.labels = ["a", "b"] - with pytest.raises(ValueError): - Condition(input=input_graph_lt, target=target_graph_lt) - input_graph_lt[0].x.labels = ["u", "v"] - - -def test_init_domain_equation(): - cond = Condition(domain=example_domain, equation=FixedValue(0.0)) - assert isinstance(cond, DomainEquationCondition) - with pytest.raises(ValueError): - Condition(example_domain, FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(domain=3.0, equation="example") - with pytest.raises(ValueError): - Condition(domain=input_tensor, equation=input_graph) - - -def test_init_input_equation(): - cond = Condition(input=input_lt, equation=FixedValue(0.0)) - assert isinstance(cond, InputTensorEquationCondition) - cond = Condition(input=input_graph_lt, equation=FixedValue(0.0)) - assert isinstance(cond, InputGraphEquationCondition) - with pytest.raises(ValueError): - cond = Condition(input=input_tensor, equation=FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(example_domain, FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(input=3.0, equation="example") - with pytest.raises(ValueError): - Condition(input=example_domain, equation=input_graph) - - -test_init_input_equation() - - -def test_init_data_condition(): - cond = Condition(input=input_lt) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_tensor) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1)) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_graph) - assert isinstance(cond, GraphDataCondition) - cond = Condition(input=input_graph, conditional_variables=torch.tensor(1)) - assert isinstance(cond, GraphDataCondition) +# import torch +# import pytest + +# from pina import LabelTensor, Condition +# from pina.condition import ( +# TensorInputGraphTargetCondition, +# TensorInputTensorTargetCondition, +# # GraphInputGraphTargetCondition, +# GraphInputTensorTargetCondition, +# ) +# from pina.condition import ( +# InputTensorEquationCondition, +# InputGraphEquationCondition, +# DomainEquationCondition, +# ) +# from pina.condition import ( +# TensorDataCondition, +# GraphDataCondition, +# ) +# from pina.domain import CartesianDomain +# from pina.equation.equation_factory import FixedValue +# from pina.graph import RadiusGraph + +# def _create_tensor_data(): +# input_tensor = torch.rand((10, 3)) +# target_tensor = torch.rand((10, 2)) +# return input_tensor, target_tensor + +# def _create_graph_data(): +# x = torch.rand(10, 20, 2) +# pos = torch.rand(10, 20, 2) +# radius = 0.1 +# input_graph = [ +# RadiusGraph( +# x=x_, +# pos=pos_, +# radius=radius, +# ) +# for x_, pos_ in zip(x, pos) +# ] +# target_graph = [ +# RadiusGraph( +# y=x_, +# pos=pos_, +# radius=radius, +# ) +# for x_, pos_ in zip(x, pos) +# ] +# return input_graph, target_graph + +# def _create_lt_data(): +# input_lt = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) +# target_lt = LabelTensor(torch.rand((10, 2)), ["a", "b"]) +# return input_lt, target_lt + +# def _create_graph_lt_data(): +# x = LabelTensor(torch.rand((10, 20, 2)), ["u", "v"]) +# pos = LabelTensor(torch.rand((10, 20, 2)), ["x", "y"]) +# radius = 0.1 +# input_graph = [ +# RadiusGraph( +# x=x[i], +# pos=pos[i], +# radius=radius, +# ) +# for i in range(len(x)) +# ] +# target_graph = [ +# RadiusGraph( +# y=x[i], +# pos=pos[i], +# radius=radius, +# ) +# for i in range(len(x)) +# ] +# return input_graph, target_graph + +# example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) + +# def test_init_input_target(): +# input_tensor, target_tensor = _create_tensor_data() +# cond = Condition(input=input_tensor, target=target_tensor) +# assert isinstance(cond, TensorInputTensorTargetCondition) + +# input_lt, target_lt = _create_lt_data() +# cond = Condition(input=input_lt, target=target_lt) +# assert isinstance(cond, TensorInputTensorTargetCondition) + +# input_graph, target_graph = _create_graph_data() +# cond = Condition(input=input_tensor, target=target_graph) +# assert isinstance(cond, TensorInputGraphTargetCondition) + +# cond = Condition(input=input_graph, target=target_tensor) +# assert isinstance(cond, GraphInputTensorTargetCondition) + +# input_single_graph = input_graph[0] +# target_single_graph = target_graph[0] +# cond = Condition(input=input_lt, target=input_single_graph) +# assert isinstance(cond, TensorInputGraphTargetCondition) +# cond = Condition(input=input_single_graph, target=target_lt) +# assert isinstance(cond, GraphInputTensorTargetCondition) +# # cond = Condition(input=input_graph, target=target_graph) +# # assert isinstance(cond, GraphInputGraphTargetCondition) +# # cond = Condition(input=input_single_graph, target=target_single_graph) +# # assert isinstance(cond, GraphInputGraphTargetCondition) +# input_graph_lt, target_graph_lt = _create_graph_lt_data() + +# with pytest.raises(ValueError): +# Condition(input_tensor, input_tensor) +# with pytest.raises(ValueError): +# Condition(input=3.0, target="example") +# with pytest.raises(ValueError): +# Condition(input=example_domain, target=example_domain) + +# # Test wrong graph condition initialisation +# input = [input_graph[0], input_graph_lt[0]] +# target = [target_graph[0], target_graph_lt[0]] +# with pytest.raises(ValueError): +# Condition(input=input, target=target) + +# input_graph_lt[0].x.labels = ["a", "b"] +# with pytest.raises(ValueError): +# Condition(input=input_graph_lt, target=target_graph_lt) +# input_graph_lt[0].x.labels = ["u", "v"] + + +# def test_init_domain_equation(): +# input_tensor, _ = _create_tensor_data() +# input_graph, _ = _create_graph_data() +# cond = Condition(domain=example_domain, equation=FixedValue(0.0)) +# assert isinstance(cond, DomainEquationCondition) +# with pytest.raises(ValueError): +# Condition(example_domain, FixedValue(0.0)) +# with pytest.raises(ValueError): +# Condition(domain=3.0, equation="example") +# with pytest.raises(ValueError): +# Condition(domain=input_tensor, equation=input_graph) + + +# def test_init_input_equation(): +# input_lt, _ = _create_lt_data() +# input_graph_lt, _ = _create_graph_lt_data() +# input_tensor, _ = _create_tensor_data() +# input_graph, _ = _create_graph_data() +# cond = Condition(input=input_lt, equation=FixedValue(0.0)) +# assert isinstance(cond, InputTensorEquationCondition) +# cond = Condition(input=input_graph_lt, equation=FixedValue(0.0)) +# assert isinstance(cond, InputGraphEquationCondition) +# with pytest.raises(ValueError): +# cond = Condition(input=input_tensor, equation=FixedValue(0.0)) +# with pytest.raises(ValueError): +# Condition(example_domain, FixedValue(0.0)) +# with pytest.raises(ValueError): +# Condition(input=3.0, equation="example") +# with pytest.raises(ValueError): +# Condition(input=example_domain, equation=input_graph) + +# def test_init_data_condition(): +# input_lt, _ = _create_lt_data() +# input_tensor, _ = _create_tensor_data() +# input_graph, _ = _create_graph_data() +# cond = Condition(input=input_lt) +# assert isinstance(cond, TensorDataCondition) +# cond = Condition(input=input_tensor) +# assert isinstance(cond, TensorDataCondition) +# cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1)) +# assert isinstance(cond, TensorDataCondition) +# cond = Condition(input=input_graph) +# assert isinstance(cond, GraphDataCondition) +# cond = Condition(input=input_graph, conditional_variables=torch.tensor(1)) +# assert isinstance(cond, GraphDataCondition) diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py new file mode 100644 index 000000000..954e8f777 --- /dev/null +++ b/tests/test_condition/test_data_condition.py @@ -0,0 +1,100 @@ +import pytest +import torch +from pina import Condition, LabelTensor +from pina.condition import ( + TensorDataCondition, + GraphDataCondition, +) +from pina.graph import RadiusGraph +from torch_geometric.data import Data + + +def _create_tensor_data(use_lt=False, conditional_variables=False): + input_tensor = torch.rand((10, 3)) + if use_lt: + input_tensor = LabelTensor(input_tensor, ["x", "y", "z"]) + if conditional_variables: + cond_vars = torch.rand((10, 2)) + if use_lt: + cond_vars = LabelTensor(cond_vars, ["a", "b"]) + else: + cond_vars = None + return input_tensor, cond_vars + + +def _create_graph_data(use_lt=False, conditional_variables=False): + if use_lt: + x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) + pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + else: + x = torch.rand(10, 20, 2) + pos = torch.rand(10, 20, 2) + radius = 0.1 + input_graph = [ + RadiusGraph(pos=pos[i], radius=radius, x=x[i]) for i in range(len(x)) + ] + if conditional_variables: + if use_lt: + cond_vars = LabelTensor(torch.rand(10, 20, 1), ["f"]) + else: + cond_vars = torch.rand(10, 20, 1) + else: + cond_vars = None + return input_graph, cond_vars + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = TensorDataCondition( + input=input_tensor, conditional_variables=cond_vars + ) + type_ = LabelTensor if use_lt else torch.Tensor + if conditional_variables: + assert condition.conditional_variables is not None + assert isinstance(condition.conditional_variables, type_) + if use_lt: + assert condition.conditional_variables.labels == ["a", "b"] + else: + assert condition.conditional_variables is None + assert isinstance(condition.input, type_) + if use_lt: + assert condition.input.labels == ["x", "y", "z"] + + +test_init_tensor_data_condition(False, False) + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = GraphDataCondition( + input=input_graph, conditional_variables=cond_vars + ) + type_ = LabelTensor if use_lt else torch.Tensor + if conditional_variables: + assert condition.conditional_variables is not None + assert isinstance(condition.conditional_variables, type_) + if use_lt: + assert condition.conditional_variables.labels == ["f"] + else: + assert condition.conditional_variables is None + # assert "conditional_variables" not in condition.data.keys() + assert isinstance(condition.input, list) + for graph in condition.input: + assert isinstance(graph, Data) + assert isinstance(graph.x, type_) + if use_lt: + assert graph.x.labels == ["u", "v"] + assert isinstance(graph.pos, type_) + if use_lt: + assert graph.pos.labels == ["x", "y"] + + +test_init_graph_data_condition(False, False) diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py index b6a687e2a..af11d382e 100644 --- a/tests/test_condition/test_input_equation_condition.py +++ b/tests/test_condition/test_input_equation_condition.py @@ -42,7 +42,9 @@ def test_init_graph_equation_condition(): graph, equation = _create_graph_and_equation() condition = Condition(input=graph, equation=equation) assert isinstance(condition, InputGraphEquationCondition) - assert condition.input is graph + assert isinstance(condition.input, list) + assert len(condition.input) == 1 + assert condition.input[0].x.shape == (100, 2) assert condition.equation is equation diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py index 033f7094a..81c3a9b24 100644 --- a/tests/test_condition/test_input_target_condition.py +++ b/tests/test_condition/test_input_target_condition.py @@ -292,3 +292,6 @@ def test_getitems_tensor_input_graph_target_condition(use_lt): "u", "v", ], "TensorInputGraphTargetCondition __getitems__ target labels failed" + + +test_init_graph_input_tensor_target_condition(use_lt=True) diff --git a/tests/test_solver/test_ensemble_supervised_solver.py b/tests/test_solver/test_ensemble_supervised_solver.py index c5f0b9e52..4be2897d9 100644 --- a/tests/test_solver/test_ensemble_supervised_solver.py +++ b/tests/test_solver/test_ensemble_supervised_solver.py @@ -83,7 +83,8 @@ def forward(self, batch): y = self.conv(y, edge_index) y = self.activation(y) y = self.output(y) - return to_dense_batch(y, batch.batch)[0] + return y + # return to_dense_batch(y, batch.batch)[0] graph_models = [Models() for i in range(10)] diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 6f7d1ab4d..461130a6b 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -83,7 +83,8 @@ def forward(self, batch): y = self.conv(y, edge_index) y = self.activation(y) y = self.output(y) - return to_dense_batch(y, batch.batch)[0] + return y + # return to_dense_batch(y, batch.batch)[0] graph_model = Model() From f6619e101d2e0769ee5a4149968513d05c7fe11a Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 31 Dec 2025 08:38:20 +0100 Subject: [PATCH 11/17] fix DataCondition and relative tests --- pina/condition/condition_base.py | 29 ++++- pina/condition/data_condition.py | 14 +-- pina/condition/input_target_condition.py | 54 --------- tests/test_condition/test_data_condition.py | 122 ++++++++++++++++++-- 4 files changed, 141 insertions(+), 78 deletions(-) diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py index b8b828767..8365c51d4 100644 --- a/pina/condition/condition_base.py +++ b/pina/condition/condition_base.py @@ -14,6 +14,10 @@ class TensorCondition: + """ + Base class for tensor conditions. + """ + def store_data(self, **kwargs): """ Store data for standard tensor condition @@ -29,6 +33,10 @@ def store_data(self, **kwargs): class GraphCondition: + """ + Base class for graph conditions. + """ + def __init__(self, **kwargs): super().__init__(**kwargs) example = kwargs.get(self.graph_field)[0] @@ -85,6 +93,26 @@ def get_multiple_data(self, indices): to_return_dict[key] = y return to_return_dict + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + collated_graphs = super().automatic_batching_collate_fn(batch)["data"] + to_return_dict = {} + for key in cls.tensor_fields: + mapping_key = cls.keys_map.get(key) + tensor = getattr(collated_graphs, mapping_key) + to_return_dict[key] = tensor + delattr(collated_graphs, mapping_key) + to_return_dict[cls.graph_field] = collated_graphs + return to_return_dict + class ConditionBase(ConditionInterface): """ @@ -269,5 +297,4 @@ def create_dataloader( if not automatic_batching else self.automatic_batching_collate_fn ), - # collate_fn = self.automatic_batching_collate_fn ) diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index b04166b51..e78dc800f 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -1,10 +1,10 @@ """Module for the DataCondition class.""" import torch -from torch_geometric.data import Data, Batch +from torch_geometric.data import Data from .condition_base import ConditionBase, GraphCondition, TensorCondition from ..label_tensor import LabelTensor -from ..graph import Graph, LabelBatch +from ..graph import Graph class DataCondition(ConditionBase): @@ -201,13 +201,3 @@ def conditional_variables(self): if not is_lt else LabelTensor.stack(cond_vars) ) - - def __getitem__(self, idx): - """ - Get item by index from the input data. - - :param int index: The index of the item to retrieve. - :return: The item at the specified index. - :rtype: Graph | Data - """ - input_ = self.batch_fn(self.data["input"][idx]) diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 3e041bf90..064f4b5eb 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -115,26 +115,6 @@ def __new__(cls, input, target): "LabelTensor or torch.Tensor objects." ) - def __init__(self, **kwargs): - """ - Initialization of the :class:`InputTargetCondition` class. - - :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - - .. note:: - - If either ``input`` or ``target`` is a list of - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` - objects, all elements in the list must share the same structure, - with matching keys and consistent data types. - """ - super().__init__(**kwargs) - class TensorInputTensorTargetCondition(InputTargetCondition, TensorCondition): """ @@ -211,22 +191,6 @@ def target(self): """ return self.data["data"] - @classmethod - def automatic_batching_collate_fn(cls, batch): - """ - Collate function to be used in DataLoader. - - :param batch: A list of items from the dataset. - :type batch: List[dict] - :return: A collated batch. - :rtype: dict - """ - collated_graphs = super().automatic_batching_collate_fn(batch) - x = collated_graphs["data"].x - del collated_graphs["data"].x # Avoid duplication of y on GPU memory - to_return = {"input": x, "target": collated_graphs["data"]} - return to_return - class GraphInputTensorTargetCondition(GraphCondition, InputTargetCondition): """ @@ -275,21 +239,3 @@ def target(self): targets.append(graph.y) return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) - - @classmethod - def automatic_batching_collate_fn(cls, batch): - """ - Collate function to be used in DataLoader. - - :param batch: A list of items from the dataset. - :type batch: list - :return: A collated batch. - :rtype: dict - """ - collated_graphs = super().automatic_batching_collate_fn(batch) - y = collated_graphs["data"].y - del collated_graphs["data"].y # Avoid duplication of y on GPU memory - print("y shape:", y.shape) - print(y.labels) - to_return = {"target": y, "input": collated_graphs["data"]} - return to_return diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py index 954e8f777..e922bdbcd 100644 --- a/tests/test_condition/test_data_condition.py +++ b/tests/test_condition/test_data_condition.py @@ -6,7 +6,8 @@ GraphDataCondition, ) from pina.graph import RadiusGraph -from torch_geometric.data import Data +from torch_geometric.data import Data, Batch +from pina.graph import Graph, LabelBatch def _create_tensor_data(use_lt=False, conditional_variables=False): @@ -49,9 +50,8 @@ def test_init_tensor_data_condition(use_lt, conditional_variables): input_tensor, cond_vars = _create_tensor_data( use_lt=use_lt, conditional_variables=conditional_variables ) - condition = TensorDataCondition( - input=input_tensor, conditional_variables=cond_vars - ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + assert isinstance(condition, TensorDataCondition) type_ = LabelTensor if use_lt else torch.Tensor if conditional_variables: assert condition.conditional_variables is not None @@ -65,18 +65,14 @@ def test_init_tensor_data_condition(use_lt, conditional_variables): assert condition.input.labels == ["x", "y", "z"] -test_init_tensor_data_condition(False, False) - - @pytest.mark.parametrize("use_lt", [False, True]) @pytest.mark.parametrize("conditional_variables", [False, True]) def test_init_graph_data_condition(use_lt, conditional_variables): input_graph, cond_vars = _create_graph_data( use_lt=use_lt, conditional_variables=conditional_variables ) - condition = GraphDataCondition( - input=input_graph, conditional_variables=cond_vars - ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + assert isinstance(condition, GraphDataCondition) type_ = LabelTensor if use_lt else torch.Tensor if conditional_variables: assert condition.conditional_variables is not None @@ -97,4 +93,108 @@ def test_init_graph_data_condition(use_lt, conditional_variables): assert graph.pos.labels == ["x", "y"] -test_init_graph_data_condition(False, False) +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + item = condition[0] + assert isinstance(item, dict) + assert "input" in item + type_ = LabelTensor if use_lt else torch.Tensor + assert isinstance(item["input"], type_) + assert item["input"].shape == (3,) + if type_ is LabelTensor: + assert item["input"].labels == ["x", "y", "z"] + if conditional_variables: + assert "conditional_variables" in item + assert isinstance(item["conditional_variables"], type_) + assert item["conditional_variables"].shape == (2,) + if type_ is LabelTensor: + assert item["conditional_variables"].labels == ["a", "b"] + else: + assert "conditional_variables" not in item + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + item = condition[0] + assert isinstance(item, dict) + assert "data" in item + graph = item["data"] + assert isinstance(graph, Data) + type_ = LabelTensor if use_lt else torch.Tensor + assert isinstance(graph.x, type_) + assert graph.x.shape == (20, 2) + if use_lt: + assert graph.x.labels == ["u", "v"] + assert isinstance(graph.pos, type_) + assert graph.pos.shape == (20, 2) + if use_lt: + assert graph.pos.labels == ["x", "y"] + if conditional_variables: + assert hasattr(graph, "cond_vars") + cond_var = graph.cond_vars + assert isinstance(cond_var, type_) + assert cond_var.shape == (20, 1) + if use_lt: + assert cond_var.labels == ["f"] + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + idxs = [0, 1, 3] + items = condition[idxs] + assert isinstance(items, dict) + assert "input" in items + graphs = items["input"] + assert isinstance(graphs, LabelBatch) + assert graphs.num_graphs == 3 + if conditional_variables: + type_ = LabelTensor if use_lt else torch.Tensor + assert "conditional_variables" in items + cond_vars_batch = items["conditional_variables"] + assert isinstance(cond_vars_batch, type_) + assert cond_vars_batch.shape == (60, 1) + if use_lt: + assert cond_vars_batch.labels == ["f"] + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + idxs = [0, 1, 3] + items = condition[idxs] + assert isinstance(items, dict) + assert "input" in items + type_ = LabelTensor if use_lt else torch.Tensor + inputs = items["input"] + assert isinstance(inputs, type_) + assert inputs.shape == (3, 3) + if use_lt: + assert inputs.labels == ["x", "y", "z"] + if conditional_variables: + assert "conditional_variables" in items + cond_vars_items = items["conditional_variables"] + assert isinstance(cond_vars_items, type_) + assert cond_vars_items.shape == (3, 2) + if use_lt: + assert cond_vars_items.labels == ["a", "b"] + else: + assert "conditional_variables" not in items From 7b9096ef3431ceb759cc91c67ffddfbc5d184655 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 12 Jan 2026 15:37:52 +0100 Subject: [PATCH 12/17] fixes --- pina/condition/condition_base.py | 61 ++++++++++++++------- pina/condition/domain_equation_condition.py | 19 ++++++- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py index 8365c51d4..6283a78bb 100644 --- a/pina/condition/condition_base.py +++ b/pina/condition/condition_base.py @@ -10,7 +10,6 @@ from .condition_interface import ConditionInterface from ..graph import Graph, LabelBatch from ..label_tensor import LabelTensor -from ..data.dummy_dataloader import DummyDataloader class TensorCondition: @@ -229,48 +228,68 @@ def _check_graph_list_consistency(data_list): ) def __len__(self): + """ + Return the number of data points in the condition. + + :return: Number of data points. + :rtype: int + """ return len(next(iter(self.data.values()))) def __getitem__(self, idx): + """ + Return the data point(s) at the specified index. + + :param idx: Index(es) of the data point(s) to retrieve. + :type idx: int | list[int] + :return: Data point(s) at the specified index. + """ return {name: data[idx] for name, data in self.data.items()} @classmethod def automatic_batching_collate_fn(cls, batch): """ - Collate function to be used in DataLoader. - + Collate function for automatic batching to be used in DataLoader. :param batch: A list of items from the dataset. :type batch: list :return: A collated batch. :rtype: dict """ + if not batch: + return {} + keys = batch[0].keys() + columns = zip(*[item.values() for item in batch]) + + to_return = {} + + # 2. Process each column + for key, values in zip(keys, columns): + # Determine type based on the first sample only + first_val = values[0] + + if isinstance(first_val, (LabelTensor, torch.Tensor)): + lookup_key = "label_tensor" + elif isinstance(first_val, Graph): + lookup_key = "graph" + else: + lookup_key = "data" + + # Execute the specific collate function + to_return[key] = cls.collate_fn_dict[lookup_key](list(values)) - to_return = {key: [] for key in batch[0].keys()} - for item in batch: - for key, value in item.items(): - to_return[key].append(value) - for key, values in to_return.items(): - collate_function = cls.collate_fn_dict.get( - "label_tensor" - if isinstance(values[0], LabelTensor) - else ( - "label_tensor" - if isinstance(values[0], torch.Tensor) - else "graph" if isinstance(values[0], Graph) else "data" - ) - ) - to_return[key] = collate_function(values) return to_return @staticmethod def collate_fn(batch, condition): """ - Collate function for automatic batching to be used in DataLoader. + Collate function for custom batching to be used in DataLoader. :param batch: A list of items from the dataset. :type batch: list + :param condition: The condition instance. + :type condition: ConditionBase :return: A collated batch. - :rtype: list + :rtype: dict """ data = condition[batch] return data @@ -287,7 +306,7 @@ def create_dataloader( :rtype: torch.utils.data.DataLoader """ if batch_size == len(dataset): - return DummyDataloader(dataset) + pass # will be updated in the near future return DataLoader( dataset=dataset, batch_size=batch_size, diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 0ce05eeab..673bf1612 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -55,6 +55,13 @@ def __init__(self, domain, equation): self.equation = equation def __len__(self): + """ + Raise NotImplementedError since the number of points is determined by + the domain sampling strategy. + + :raises NotImplementedError: Always raised since the number of points is + determined by the domain sampling strategy. + """ raise NotImplementedError( "`__len__` method is not implemented for " "`DomainEquationCondition` since the number of points is " @@ -62,7 +69,13 @@ def __len__(self): ) def __getitem__(self, idx): - """ """ + """ + Raise NotImplementedError since data retrieval is not applicable. + + :param int idx: Index of the data point(s) to retrieve. + :raises NotImplementedError: Always raised since data retrieval is not + applicable for this condition. + """ raise NotImplementedError( "`__getitem__` method is not implemented for " "`DomainEquationCondition`" @@ -70,9 +83,9 @@ def __getitem__(self, idx): def store_data(self): """ - Store the data for the condition by sampling points from the domain. + Store data for the condition. No data is stored for this condition. - :return: Sampled points from the domain. + :return: An empty dictionary since no data is stored. :rtype: dict """ return {} From a6cbc1f7e0644156dea1752b3a669ab6571c7f25 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 19 Jan 2026 12:00:25 +0100 Subject: [PATCH 13/17] Add _DataManager and refact conditions --- pina/condition/__init__.py | 26 +- pina/condition/condition_base.py | 102 +------- pina/condition/data_condition.py | 178 ++++--------- pina/condition/data_manager.py | 222 ++++++++++++++++ pina/condition/domain_equation_condition.py | 37 +-- pina/condition/input_equation_condition.py | 148 +++-------- pina/condition/input_target_condition.py | 239 ++++-------------- tests/test_condition/test_data_condition.py | 191 ++++++++++---- .../test_domain_equation_condition.py | 11 + .../test_input_equation_condition.py | 39 +-- .../test_input_target_condition.py | 161 ++++++------ tests/test_data_manager.py | 156 ++++++++++++ 12 files changed, 810 insertions(+), 700 deletions(-) create mode 100644 pina/condition/data_manager.py create mode 100644 tests/test_data_manager.py diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 13429a829..c104e0ef9 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -5,13 +5,7 @@ "ConditionInterface", "DomainEquationCondition", "InputTargetCondition", - "TensorInputTensorTargetCondition", - "TensorInputGraphTargetCondition", - "GraphInputTensorTargetCondition", - # "GraphInputGraphTargetCondition", "InputEquationCondition", - "InputTensorEquationCondition", - "InputGraphEquationCondition", "DataCondition", "GraphDataCondition", "TensorDataCondition", @@ -20,20 +14,6 @@ from .condition_interface import ConditionInterface from .condition import Condition from .domain_equation_condition import DomainEquationCondition -from .input_target_condition import ( - InputTargetCondition, - TensorInputTensorTargetCondition, - TensorInputGraphTargetCondition, - GraphInputTensorTargetCondition, - # GraphInputGraphTargetCondition, -) -from .input_equation_condition import ( - InputEquationCondition, - InputTensorEquationCondition, - InputGraphEquationCondition, -) -from .data_condition import ( - DataCondition, - GraphDataCondition, - TensorDataCondition, -) +from .input_target_condition import InputTargetCondition +from .input_equation_condition import InputEquationCondition +from .data_condition import DataCondition diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py index 6283a78bb..388e7141d 100644 --- a/pina/condition/condition_base.py +++ b/pina/condition/condition_base.py @@ -156,77 +156,6 @@ def problem(self, value): """ self._problem = value - @staticmethod - def _check_graph_list_consistency(data_list): - """ - Check the consistency of the list of Data | Graph objects. - The following checks are performed: - - - All elements in the list must be of the same type (either - :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). - - - All elements in the list must have the same keys. - - - The data type of each tensor must be consistent across all elements. - - - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels - must also be consistent across all elements. - - :param data_list: The list of Data | Graph objects to check. - :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] - :raises ValueError: If the input types are invalid. - :raises ValueError: If all elements in the list do not have the same - keys. - :raises ValueError: If the type of each tensor is not consistent across - all elements in the list. - :raises ValueError: If the labels of the LabelTensors are not consistent - across all elements in the list. - """ - # If the data is a Graph or Data object, perform no checks - if isinstance(data_list, (Graph, Data)): - return - - # Check all elements in the list are of the same type - if not all(isinstance(i, (Graph, Data)) for i in data_list): - raise ValueError( - "Invalid input. Please, provide either Data or Graph objects." - ) - - # Store the keys, data types and labels of the first element - data = data_list[0] - keys = sorted(list(data.keys())) - data_types = {name: tensor.__class__ for name, tensor in data.items()} - labels = { - name: tensor.labels - for name, tensor in data.items() - if isinstance(tensor, LabelTensor) - } - - # Iterate over the list of Data | Graph objects - for data in data_list[1:]: - - # Check that all elements in the list have the same keys - if sorted(list(data.keys())) != keys: - raise ValueError( - "All elements in the list must have the same keys." - ) - - # Iterate over the tensors in the current element - for name, tensor in data.items(): - # Check that the type of each tensor is consistent - if tensor.__class__ is not data_types[name]: - raise ValueError( - f"Data {name} must be a {data_types[name]}, got " - f"{tensor.__class__}" - ) - - # Check that the labels of each LabelTensor are consistent - if isinstance(tensor, LabelTensor): - if tensor.labels != labels[name]: - raise ValueError( - "LabelTensor must have the same labels" - ) - def __len__(self): """ Return the number of data points in the condition. @@ -234,7 +163,7 @@ def __len__(self): :return: Number of data points. :rtype: int """ - return len(next(iter(self.data.values()))) + return len(self.data) def __getitem__(self, idx): """ @@ -244,7 +173,7 @@ def __getitem__(self, idx): :type idx: int | list[int] :return: Data point(s) at the specified index. """ - return {name: data[idx] for name, data in self.data.items()} + return self.data[idx] @classmethod def automatic_batching_collate_fn(cls, batch): @@ -257,27 +186,8 @@ def automatic_batching_collate_fn(cls, batch): """ if not batch: return {} - keys = batch[0].keys() - columns = zip(*[item.values() for item in batch]) - - to_return = {} - - # 2. Process each column - for key, values in zip(keys, columns): - # Determine type based on the first sample only - first_val = values[0] - - if isinstance(first_val, (LabelTensor, torch.Tensor)): - lookup_key = "label_tensor" - elif isinstance(first_val, Graph): - lookup_key = "graph" - else: - lookup_key = "data" - - # Execute the specific collate function - to_return[key] = cls.collate_fn_dict[lookup_key](list(values)) - - return to_return + instance_class = batch[0].__class__ + return instance_class._create_batch(batch) @staticmethod def collate_fn(batch, condition): @@ -291,7 +201,9 @@ def collate_fn(batch, condition): :return: A collated batch. :rtype: dict """ - data = condition[batch] + print("Custom collate_fn called") + print("batch:", batch) + data = condition.data[batch] return data def create_dataloader( diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index e78dc800f..debb71bad 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -2,9 +2,10 @@ import torch from torch_geometric.data import Data -from .condition_base import ConditionBase, GraphCondition, TensorCondition +from .condition_base import ConditionBase from ..label_tensor import LabelTensor from ..graph import Graph +from ..condition.data_manager import _DataManager class DataCondition(ConditionBase): @@ -16,17 +17,6 @@ class DataCondition(ConditionBase): the provided data during training. Optional ``conditional_variables`` can be specified when the model depends on additional parameters. - The class automatically selects the appropriate implementation based on the - type of the ``input`` data. Depending on whether the ``input`` is a tensor - or graph-based data, one of the following specialized subclasses is - instantiated: - - - :class:`TensorDataCondition`: For cases where the ``input`` is either a - :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. - - - :class:`GraphDataCondition`: For cases where the ``input`` is either a - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` object. - :Example: >>> from pina import Condition, LabelTensor @@ -44,8 +34,8 @@ class DataCondition(ConditionBase): def __new__(cls, input, conditional_variables=None): """ - Instantiate the appropriate subclass of :class:`DataCondition` based on - the type of the ``input``. + Check the types of ``input`` and ``conditional_variables`` and + instantiate a class of :class:`DataCondition` accordingly. :param input: The input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | @@ -63,48 +53,51 @@ def __new__(cls, input, conditional_variables=None): if cls != DataCondition: return super().__new__(cls) - # If the input is a tensor - if isinstance(input, (torch.Tensor, LabelTensor)): - subclass = TensorDataCondition - return subclass.__new__(subclass, input, conditional_variables) - - # If the input is a graph - if isinstance(input, (Graph, Data, list, tuple)): - cls._check_graph_list_consistency(input) - subclass = GraphDataCondition - return subclass.__new__(subclass, input, conditional_variables) - - # If the input is not of the correct type raise an error - raise ValueError( - "Invalid input type. Expected one of the following: " - "torch.Tensor, LabelTensor, Graph, Data or " - "an iterable of the previous types." - ) - - def __init__(self, input, conditional_variables=None): + # Check input type + if not isinstance(input, cls._avail_input_cls): + raise ValueError( + "Invalid input type. Expected one of the following: " + "torch.Tensor, LabelTensor, Graph, Data or " + "an iterable of the previous types." + ) + if isinstance(input, (list, tuple)): + for item in input: + if not isinstance(item, (Data, Graph)): + raise ValueError( + "if input is a list or tuple, all its elements must" + " be of type Graph or Data." + ) + + # Check conditional_variables type + if conditional_variables is not None: + if not isinstance( + conditional_variables, cls._avail_conditional_variables_cls + ): + raise ValueError( + "Invalid conditional_variables type. Expected one of the " + "following: torch.Tensor, LabelTensor." + ) + + return super().__new__(cls) + + def store_data(self, **kwargs): """ - Initialization of the :class:`DataCondition` class. + Store the input data and conditional variables in a dictionary. :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] + :type input: torch.Tensor | LabelTensor | Graph | + Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] :param conditional_variables: The conditional variables for the - condition. Default is ``None``. + condition. :type conditional_variables: torch.Tensor | LabelTensor - - .. note:: - - If ``input`` is a list of :class:`~pina.graph.Graph` or - :class:`~torch_geometric.data.Data`, all elements in - the list must share the same structure, with matching keys and - consistent data types. + :return: A dictionary containing the stored data. + :rtype: dict """ - if conditional_variables is None: - super().__init__(input=input) - else: - super().__init__( - input=input, conditional_variables=conditional_variables - ) + data_dict = {"input": kwargs.get("input")} + cond_vars = kwargs.get("conditional_variables", None) + if cond_vars is not None: + data_dict["conditional_variables"] = cond_vars + return _DataManager(**data_dict) @property def conditional_variables(self): @@ -114,15 +107,9 @@ def conditional_variables(self): :return: The conditional variables. :rtype: torch.Tensor | LabelTensor | None """ - return self.data.get("conditional_variables", None) - - -class TensorDataCondition(TensorCondition, DataCondition): - """ - Specialization of the :class:`DataCondition` class for the case where - ``input`` is either a :class:`~pina.label_tensor.LabelTensor` object or a - :class:`torch.Tensor` object. - """ + if hasattr(self.data, "conditional_variables"): + return self.data.conditional_variables + return None @property def input(self): @@ -130,74 +117,7 @@ def input(self): Return the input data for the condition. :return: The input data. - :rtype: torch.Tensor | LabelTensor - """ - return self.data["input"] - - -class GraphDataCondition(GraphCondition, DataCondition): - """ - Specialization of the :class:`DataCondition` class for the case where - ``input`` is either a :class:`~pina.graph.Graph` object or a - :class:`~torch_geometric.data.Data` object. - """ - - def __init__(self, input, conditional_variables=None): + :rtype: torch.Tensor | LabelTensor | Graph | Data | + list[Graph] | list[Data] | tuple[Graph] | tuple[Data] """ - Initialization of the :class:`GraphDataCondition` class. - - :param input: The input data for the condition. - :type input: Graph | Data | list[Graph] | list[Data] | - tuple[Graph] | tuple[Data] - :param conditional_variables: The conditional variables for the - condition. Default is ``None``. - :type conditional_variables: torch.Tensor | LabelTensor - - .. note:: - - If ``input`` is a list of :class:`~pina.graph.Graph` or - :class:`~torch_geometric.data.Data`, all elements in - the list must share the same structure, with matching keys and - consistent data types. - """ - self.graph_field = "input" - self.tensor_fields = [] - self.keys_map = {} - if conditional_variables is not None: - self.tensor_fields.append("conditional_variables") - self.keys_map["conditional_variables"] = "cond_vars" - super().__init__( - input=input, conditional_variables=conditional_variables - ) - - @property - def input(self): - """ - Return the input data for the condition. - - :return: The input data. - :rtype: Graph | Data | list[Graph] | list[Data] | tuple[Graph] | - tuple[Data] - """ - return self.data["data"] - - @property - def conditional_variables(self): - """ - Return the target data for the condition. - - :return: The target data. - :rtype: list[torch.Tensor] | list[LabelTensor] - """ - - if not hasattr(self.data["data"][0], "cond_vars"): - return None - cond_vars = [] - is_lt = isinstance(self.data["data"][0].cond_vars, LabelTensor) - for graph in self.data["data"]: - cond_vars.append(graph.cond_vars) - return ( - torch.stack(cond_vars) - if not is_lt - else LabelTensor.stack(cond_vars) - ) + return self.data.input diff --git a/pina/condition/data_manager.py b/pina/condition/data_manager.py new file mode 100644 index 000000000..3a33cf99d --- /dev/null +++ b/pina/condition/data_manager.py @@ -0,0 +1,222 @@ +import torch +from pina import LabelTensor +from pina.graph import Graph +from torch_geometric.data import Data +from torch_geometric.data.batch import Batch +from pina.graph import LabelBatch +from pina.equation.equation_interface import EquationInterface +from abc import ABC, abstractmethod + + +class _BatchManager: + def __init__(self, **dict): + self.keys = list(dict.keys()) + for k, v in dict.items(): + setattr(self, k, v) + + def to(self, device): + for k in self.keys: + val = getattr(self, k) + setattr(self, k, val.to(device)) + return self + + +class _DataManager(ABC): + """Interfaccia base ottimizzata per la gestione dei dati.""" + + def __new__(cls, **kwargs): + # Dispatching Factory + if cls is not _DataManager: + return super().__new__(cls) + + # Determina se usare il gestore Tensori o Grafi + # (Controllo ottimizzato: evita cicli se possibile) + is_tensor_only = all( + isinstance(v, (torch.Tensor, LabelTensor, EquationInterface)) + for v in kwargs.values() + ) + + subclass = _TensorDataManager if is_tensor_only else _GraphDataManager + return super().__new__(subclass) + + @abstractmethod + def __len__(self) -> int: + pass + + @abstractmethod + def __getitem__(self, idx): + pass + + def to_dict(self): + return {k: getattr(self, k) for k in self.keys} + + +# --- GESTORE TENSORI --- + + +class _TensorDataManager(_DataManager): + def __init__(self, **kwargs): + self.keys = list(kwargs.keys()) + self._data = kwargs # Memorizzazione in dizionario per accesso O(1) + + # # Identifica i tensori una sola volta + # self._tensor_keys = [ + # k for k, v in kwargs.items() + # if isinstance(v, (torch.Tensor, LabelTensor)) + # ] + + # Espone le chiavi come attributi (facoltativo, ma mantiene compatibilità) + for k, v in kwargs.items(): + setattr(self, k, v) + + def __len__(self) -> int: + # Prende la lunghezza dal primo tensore disponibile + return self._data[self.keys[0]].shape[0] + + def __getitem__(self, idx): + # Mapping efficiente degli elementi + new_data = { + k: (self._data[k][idx] if k in self.keys else self._data[k]) + for k in self.keys + } + return _TensorDataManager(**new_data) + + @staticmethod + def _create_batch(items): + if not items: + return None + first = items[0] + batch_data = {} + + for k in first.keys: + vals = [it._data[k] for it in items] + sample = vals[0] + + if isinstance(sample, (torch.Tensor, LabelTensor)): + batch_fn = ( + LabelTensor.stack + if isinstance(sample, LabelTensor) + else torch.stack + ) + batch_data[k] = batch_fn(vals, dim=0) + else: + batch_data[k] = sample + + return _BatchManager(**batch_data) + + +class _GraphDataManager(_DataManager): + def __init__(self, **kwargs): + self.keys = list(kwargs.keys()) + + self.graph_key = next( + k + for k, v in kwargs.items() + if isinstance(v, (Graph, Data, list, tuple)) + ) + + self.keys = [ + k + for k in self.keys + if k != self.graph_key + and isinstance(kwargs[k], (torch.Tensor, LabelTensor)) + ] + + # Prepara la lista di grafi internamente + self.data = self._prepare_graphs(kwargs) + + def _prepare_graphs(self, kwargs): + graphs = kwargs[self.graph_key] + if not isinstance(graphs, (list, tuple)): + graphs = [graphs] + + # Iniezione attributi nei grafi + for k in self.keys: + val_source = kwargs[k] + # Ottimizzazione: se la lunghezza coincide, distribuiamo i tensori, + # altrimenti trattiamo il tensore come costante per tutti. + use_idx = ( + len(val_source) == len(graphs) + if hasattr(val_source, "__len__") + else False + ) + + for i, g in enumerate(graphs): + setattr(g, k, val_source[i] if use_idx else val_source) + return graphs + + def __len__(self) -> int: + return len(self.data) + + def __getattr__(self, name): + + # If the requested attribute is a tensor key, stack the tensors from + # all graphs + if name in self.keys: + tensors = [getattr(g, name) for g in self.data] + batch_fn = ( + LabelTensor.stack + if isinstance(tensors[0], LabelTensor) + else torch.stack + ) + return batch_fn(tensors) + + # If the requested attribute is the graph key, return the graphs + if name == self.graph_key: + return self.data if len(self.data) > 1 else self.data[0] + + super().__getattribute__(name) + + @classmethod + def _init_from_graphs_list(cls, graphs, graph_key, keys): + # Create a new instance without calling __init__ + obj = _GraphDataManager.__new__(_GraphDataManager) + obj.graph_key = graph_key + obj.keys = keys + # obj._tensor_keys = tensor_keys + obj.data = graphs + return obj + + def __getitem__(self, idx): + # Manage int and slice directly + if isinstance(idx, (int, slice)): + selected = self.data[idx] + # Manage list or tensor of indices + elif isinstance(idx, (list, torch.Tensor)): + selected = [self.data[i] for i in idx] + else: + raise TypeError(f"Invalid index type: {type(idx)}") + + # Ensure selected is a list + if not isinstance(selected, list): + selected = [selected] + + # Return a new _GraphDataManager instance with the selected graphs + return _GraphDataManager._init_from_graphs_list( + selected, + # tensor_keys=self._tensor_keys, + graph_key=self.graph_key, + keys=self.keys, + ) + + def _create_batch(items): + if not items: + return None + first = items[0] + batching_fn = ( + LabelBatch.from_data_list + if isinstance(first.data[0], Graph) + else Batch.from_data_list + ) + + graphs_to_batch = [item.data[0] for item in items] + batch_graph = batching_fn(graphs_to_batch) + + batch_data = {first.graph_key: batch_graph} + + for k in first.keys: + if k == first.graph_key: + continue + batch_data[k] = getattr(batch_graph, k) + delattr(batch_graph, k) + return _BatchManager(**batch_data) diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 673bf1612..b8d465581 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -31,28 +31,30 @@ class DomainEquationCondition(ConditionBase): # Available slots __fields__ = ["domain", "equation"] - def __init__(self, domain, equation): + _avail_domain_cls = DomainInterface + _avail_equation_cls = EquationInterface + + def __new__(cls, domain, equation): """ - Initialization of the :class:`DomainEquationCondition` class. + Check the types of ``domain`` and ``equation`` and instantiate an + instance of :class:`DomainEquationCondition`. - :param DomainInterface domain: The domain over which the equation is - defined. - :param EquationInterface equation: The equation to be satisfied over the - specified domain. + :return: An instance of :class:`DomainEquationCondition`. + :rtype: pina.condition.domain_equation_condition.DomainEquationCondition + :raises ValueError: If ``domain`` is not of type :class:`DomainInterface` or + ``equation`` is not of type :class:` """ - if not isinstance(domain, (DomainInterface, str)): + if not isinstance(domain, cls._avail_domain_cls): raise ValueError( - f"`domain` must be an instance of DomainInterface, " - f"got {type(domain)} instead." + "The domain must be an instance of DomainInterface." ) - if not isinstance(equation, EquationInterface): + + if not isinstance(equation, cls._avail_equation_cls): raise ValueError( - f"`equation` must be an instance of EquationInterface, " - f"got {type(equation)} instead." + "The equation must be an instance of EquationInterface." ) - super().__init__() - self.domain = domain - self.equation = equation + + return super().__new__(cls) def __len__(self): """ @@ -81,11 +83,12 @@ def __getitem__(self, idx): "`DomainEquationCondition`" ) - def store_data(self): + def store_data(self, **kwargs): """ Store data for the condition. No data is stored for this condition. :return: An empty dictionary since no data is stored. :rtype: dict """ - return {} + self.domain = kwargs.get("domain") + self.equation = kwargs.get("equation") diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 913cdc4d2..b8ef0209b 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -1,9 +1,10 @@ """Module for the InputEquationCondition class and its subclasses.""" -from .condition_base import ConditionBase, TensorCondition, GraphCondition +from .condition_base import ConditionBase from ..label_tensor import LabelTensor from ..graph import Graph from ..equation.equation_interface import EquationInterface +from ..condition.data_manager import _DataManager class InputEquationCondition(ConditionBase): @@ -14,17 +15,6 @@ class InputEquationCondition(ConditionBase): ``equation`` through the evaluation of the residual performed at the provided ``input``. - The class automatically selects the appropriate implementation based on - the type of the ``input`` data. Depending on whether the ``input`` is a - tensor or graph-based data, one of the following specialized subclasses is - instantiated: - - - :class:`InputTensorEquationCondition`: For cases where the ``input`` - data is a :class:`~pina.label_tensor.LabelTensor` object. - - - :class:`InputGraphEquationCondition`: For cases where the ``input`` data - is a :class:`~pina.graph.Graph` object. - :Example: >>> from pina import Condition, LabelTensor @@ -41,13 +31,13 @@ class InputEquationCondition(ConditionBase): # Available input data types __fields__ = ["input", "equation"] - _avail_input_cls = (LabelTensor, Graph, list, tuple) + _avail_input_cls = (LabelTensor, Graph) _avail_equation_cls = EquationInterface def __new__(cls, input, equation): """ - Instantiate the appropriate subclass of :class:`InputEquationCondition` - based on the type of ``input`` data. + Check the types of ``input`` and ``equation`` and instantiate a class + of :class:`InputEquationCondition` accordingly. :param input: The input data for the condition. :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] @@ -61,50 +51,28 @@ def __new__(cls, input, equation): :raises ValueError: If input is not of type :class:`~pina.graph.Graph` or :class:`~pina.label_tensor.LabelTensor`. """ - if cls != InputEquationCondition: - return super().__new__(cls) - - # If the input is a Graph object - if isinstance(input, (Graph, list, tuple)): - subclass = InputGraphEquationCondition - cls._check_graph_list_consistency(input) - subclass._check_label_tensor(input) - return subclass.__new__(subclass, input, equation) - - # If the input is a LabelTensor - if isinstance(input, LabelTensor): - subclass = InputTensorEquationCondition - return subclass.__new__(subclass, input, equation) - - # If the input is not a LabelTensor or a Graph object raise an error - raise ValueError( - "The input data object must be a LabelTensor or a Graph object." - ) - - def __init__(self, input, equation): - """ - Initialization of the :class:`InputEquationCondition` class. - :param input: The input data for the condition. - :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] - :param EquationInterface equation: The equation to be satisfied over the - specified input points. + # CHeck input type + if not isinstance(input, cls._avail_input_cls): + raise ValueError( + "The input data object must be a LabelTensor or a Graph object." + ) - .. note:: - - If ``input`` is a list of :class:`~pina.graph.Graph` all elements in - the list must share the same structure, with matching keys and - consistent data types. - """ - super().__init__(input=input) - self.equation = equation + # Check equation type + if not isinstance(equation, cls._avail_equation_cls): + raise ValueError( + "The equation must be an instance of EquationInterface." + ) + return super().__new__(cls) -class InputTensorEquationCondition(TensorCondition, InputEquationCondition): - """ - Specialization of the :class:`InputEquationCondition` class for the case - where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object. - """ + def store_data(self, **kwargs): + """ + Store the input data in a :class:`_DataManager` object. + :param dict kwargs: The keyword arguments containing the input data. + """ + self.__setattr__("_equation", kwargs.pop("equation")) + return _DataManager(**kwargs) @property def input(self): @@ -114,64 +82,28 @@ def input(self): :return: The input data. :rtype: LabelTensor | Graph | list[Graph] | tuple[Graph] """ - return self.data["input"] - - -class InputGraphEquationCondition(GraphCondition, InputEquationCondition): - """ - Specialization of the :class:`InputEquationCondition` class for the case - where ``input`` is a :class:`~pina.graph.Graph` object. - """ - - def __init__(self, input, equation): - """ - Initialization of the :class:`InputGraphEquationCondition` class. - - :param input: The input data for the condition. - :type input: Graph | list[Graph] | tuple[Graph] - :param EquationInterface equation: The equation to be satisfied over the - specified input points. - - .. note:: - - If ``input`` is a list of :class:`~pina.graph.Graph` all elements in - the list must share the same structure, with matching keys and - consistent data types. - """ - self.graph_field = "input" - self.tensor_fields = [] - self.keys_map = {} - super().__init__(input=[input], equation=equation) + return self.data.input - @staticmethod - def _check_label_tensor(input): + @property + def equation(self): """ - Check if at least one :class:`~pina.label_tensor.LabelTensor` is present - in the ``input`` object. + Return the equation associated with this condition. - :param input: The input data. - :type input: torch.Tensor | Graph | list[Graph] | tuple[Graph] - :raises ValueError: If the input data object does not contain at least - one LabelTensor. + :return: Equation associated with this condition. + :rtype: EquationInterface """ + return self._equation - # Store the first element: it is sufficient to check this since all - # elements must have the same type and structure (already checked). - data = input[0] if isinstance(input, (list, tuple)) else input - - # Check if the input data contains at least one LabelTensor - for v in data.values(): - if isinstance(v, LabelTensor): - return - - raise ValueError("The input must contain at least one LabelTensor.") - - @property - def input(self): + @equation.setter + def equation(self, value): """ - Return the input data for the condition. + Set the equation associated with this condition. - :return: The input data. - :rtype: list[Graph] | list[Data] + :param EquationInterface value: The equation to associate with this + condition """ - return self.data["data"] + if not isinstance(value, EquationInterface): + raise TypeError( + "The equation must be an instance of EquationInterface." + ) + self._equation = value diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 064f4b5eb..e5939c498 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -6,7 +6,8 @@ from torch_geometric.data import Data from ..label_tensor import LabelTensor from ..graph import Graph -from .condition_base import ConditionBase, GraphCondition, TensorCondition +from .condition_base import ConditionBase +from .data_manager import _DataManager class InputTargetCondition(ConditionBase): @@ -17,29 +18,6 @@ class InputTargetCondition(ConditionBase): include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. - The class automatically selects the appropriate implementation based on - the types of ``input`` and ``target``. Depending on whether the ``input`` - and ``target`` are tensors or graph-based data, one of the following - specialized subclasses is instantiated: - - - :class:`TensorInputTensorTargetCondition`: For cases where both ``input`` - and ``target`` data are either :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor`. - - - :class:`TensorInputGraphTargetCondition`: For cases where ``input`` is - either a :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor` - and ``target`` is either a :class:`~pina.graph.Graph` or a - :class:`torch_geometric.data.Data`. - - - :class:`GraphInputTensorTargetCondition`: For cases where ``input`` is - either a :class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data` - and ``target`` is either a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor`. - - - :class:`GraphInputGraphTargetCondition`: For cases where both ``input`` - and ``target`` are either :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data`. - :Example: >>> from pina import Condition, LabelTensor @@ -61,67 +39,57 @@ class InputTargetCondition(ConditionBase): def __new__(cls, input, target): """ - Instantiate the appropriate subclass of :class:`InputTargetCondition` - based on the types of both ``input`` and ``target`` data. + Check the types of ``input`` and ``target`` data and instantiate the + :class:`InputTargetCondition`. :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] + :type input: torch.Tensor | LabelTensor | Graph | + Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] :param target: The target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :return: The subclass of InputTargetCondition. - :rtype: pina.condition.input_target_condition. - TensorInputTensorTargetCondition | - pina.condition.input_target_condition. - TensorInputGraphTargetCondition | - pina.condition.input_target_condition. - GraphInputTensorTargetCondition | - pina.condition.input_target_condition.GraphInputGraphTargetCondition - - :raises ValueError: If ``input`` and/or ``target`` are not of type - :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, - :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. - """ - if cls != InputTargetCondition: - return super().__new__(cls) - - # Tensor - Tensor - if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( - target, (torch.Tensor, LabelTensor) - ): - subclass = TensorInputTensorTargetCondition - return subclass.__new__(subclass, input, target) - - # Tensor - Graph - if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( - target, (Graph, Data, list, tuple) - ): - cls._check_graph_list_consistency(target) - subclass = TensorInputGraphTargetCondition - return subclass.__new__(subclass, input, target) - - # Graph - Tensor - if isinstance(input, (Graph, Data, list, tuple)) and isinstance( - target, (torch.Tensor, LabelTensor) - ): - cls._check_graph_list_consistency(input) - subclass = GraphInputTensorTargetCondition - return subclass.__new__(subclass, input, target) - - raise ValueError( - "Invalid input | target types." - "Please provide either torch_geometric.data.Data, Graph, " - "LabelTensor or torch.Tensor objects." - ) - - -class TensorInputTensorTargetCondition(InputTargetCondition, TensorCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - both ``input`` and ``target`` are :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` objects. - """ + :type target: torch.Tensor | LabelTensor | Graph | + Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] + :return: An instance of :class:`InputTargetCondition`. + :rtype: pina.condition.input_target_condition.InputTargetCondition + :raises ValueError: If ``input`` or ``target`` are not of supported types. + """ + + if not isinstance(input, cls._avail_input_cls): + raise ValueError( + "Invalid input type. Expected one of the following: " + "torch.Tensor, LabelTensor, Graph, Data or " + "list/tuple of Graph/Data objects." + ) + elif isinstance(input, (list, tuple)): + for item in input: + if not isinstance(item, (Graph, Data)): + raise ValueError( + "If target is a list or tuple, all its elements " + "must be of type Graph or Data." + ) + + if not isinstance(target, cls._avail_output_cls): + raise ValueError( + "Invalid target type. Expected one of the following: " + "torch.Tensor, LabelTensor, Graph, Data or " + "list/tuple of Graph/Data objects." + ) + elif isinstance(target, (list, tuple)): + for item in target: + if not isinstance(item, (Graph, Data)): + raise ValueError( + "If target is a list or tuple, all its elements " + "must be of type Graph or Data." + ) + + return super().__new__(cls) + + def store_data(self, **kwargs): + """ + Store the input and target data in a :class:`_DataManager` object. + :param dict kwargs: The keyword arguments containing the input and + target data. + """ + return _DataManager(**kwargs) @property def input(self): @@ -129,101 +97,10 @@ def input(self): Return the input data for the condition. :return: The input data. - :rtype: torch.Tensor | LabelTensor - """ - return self.data["input"] - - @property - def target(self): - """ - Return the target data for the condition. - - :return: The target data. - :rtype: torch.Tensor | LabelTensor - """ - return self.data["target"] - - -class TensorInputGraphTargetCondition(GraphCondition, InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - ``input`` is either a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor` object and ``target`` is either a - :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object. - """ - - def __init__(self, input, target): - """ - Initialization of the :class:`TensorInputGraphTargetCondition` class. - - :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor - :param target: The target data for the condition. - :type target: Graph | Data | list[Graph] | list[Data] | - tuple[Graph] | tuple[Data] - """ - self.graph_field = "target" - self.tensor_fields = ["input"] - self.keys_map = {"input": "x"} - super().__init__(input=input, target=target) - - @property - def input(self): - """ - Return the input data for the condition. - - :return: The input data. - :rtype: list[torch.Tensor] | list[LabelTensor] - """ - targets = [] - is_lt = isinstance(self.data["data"][0].x, LabelTensor) - for graph in self.data["data"]: - targets.append(graph.x) - return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) - - @property - def target(self): - """ - Return the target data for the condition. - - :return: The target data. - :rtype: list[Graph] | list[Data] - """ - return self.data["data"] - - -class GraphInputTensorTargetCondition(GraphCondition, InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - ``input`` is either a :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data` object and ``target`` is either a - :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. - """ - - def __init__(self, input, target): - """ - Initialization of the :class:`GraphInputTensorTargetCondition` class. - - :param input: The input data for the condition. - :type input: Graph | Data | list[Graph] | list[Data] | - tuple[Graph] | tuple[Data] - :param target: The target data for the condition. - :type target: torch.Tensor | LabelTensor - """ - self.graph_field = "input" - self.tensor_fields = ["target"] - self.keys_map = {"target": "y"} - super().__init__(input=input, target=target) - - @property - def input(self): - """ - Return the input data for the condition. - - :return: The input data. - :rtype: list[Graph] | list[Data] + :rtype: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + list[Data] | tuple[Graph] | tuple[Data] """ - return self.data["data"] + return self.data.input @property def target(self): @@ -231,11 +108,7 @@ def target(self): Return the target data for the condition. :return: The target data. - :rtype: list[torch.Tensor] | list[LabelTensor] + :rtype: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + list[Data] | tuple[Graph] | tuple[Data] """ - targets = [] - is_lt = isinstance(self.data["data"][0].y, LabelTensor) - for graph in self.data["data"]: - targets.append(graph.y) - - return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) + return self.data.target diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py index e922bdbcd..289830abc 100644 --- a/tests/test_condition/test_data_condition.py +++ b/tests/test_condition/test_data_condition.py @@ -1,13 +1,11 @@ import pytest import torch from pina import Condition, LabelTensor -from pina.condition import ( - TensorDataCondition, - GraphDataCondition, -) +from pina.condition import DataCondition from pina.graph import RadiusGraph from torch_geometric.data import Data, Batch from pina.graph import Graph, LabelBatch +from pina.condition.data_manager import _DataManager def _create_tensor_data(use_lt=False, conditional_variables=False): @@ -51,7 +49,9 @@ def test_init_tensor_data_condition(use_lt, conditional_variables): use_lt=use_lt, conditional_variables=conditional_variables ) condition = Condition(input=input_tensor, conditional_variables=cond_vars) - assert isinstance(condition, TensorDataCondition) + print(condition) + assert isinstance(condition, DataCondition) + type_ = LabelTensor if use_lt else torch.Tensor if conditional_variables: assert condition.conditional_variables is not None @@ -72,7 +72,7 @@ def test_init_graph_data_condition(use_lt, conditional_variables): use_lt=use_lt, conditional_variables=conditional_variables ) condition = Condition(input=input_graph, conditional_variables=cond_vars) - assert isinstance(condition, GraphDataCondition) + assert isinstance(condition, DataCondition) type_ = LabelTensor if use_lt else torch.Tensor if conditional_variables: assert condition.conditional_variables is not None @@ -101,21 +101,21 @@ def test_getitem_tensor_data_condition(use_lt, conditional_variables): ) condition = Condition(input=input_tensor, conditional_variables=cond_vars) item = condition[0] - assert isinstance(item, dict) - assert "input" in item + assert isinstance(item, _DataManager) + assert hasattr(item, "input") type_ = LabelTensor if use_lt else torch.Tensor - assert isinstance(item["input"], type_) - assert item["input"].shape == (3,) + assert isinstance(item.input, type_) + assert item.input.shape == (3,) if type_ is LabelTensor: - assert item["input"].labels == ["x", "y", "z"] + assert item.input.labels == ["x", "y", "z"] if conditional_variables: - assert "conditional_variables" in item - assert isinstance(item["conditional_variables"], type_) - assert item["conditional_variables"].shape == (2,) + assert hasattr(item, "conditional_variables") + assert isinstance(item.conditional_variables, type_) + assert item.conditional_variables.shape == (2,) if type_ is LabelTensor: - assert item["conditional_variables"].labels == ["a", "b"] + assert item.conditional_variables.labels == ["a", "b"] else: - assert "conditional_variables" not in item + assert not hasattr(item, "conditional_variables") @pytest.mark.parametrize("use_lt", [False, True]) @@ -126,9 +126,9 @@ def test_getitem_graph_data_condition(use_lt, conditional_variables): ) condition = Condition(input=input_graph, conditional_variables=cond_vars) item = condition[0] - assert isinstance(item, dict) - assert "data" in item - graph = item["data"] + assert isinstance(item, _DataManager) + assert hasattr(item, "input") + graph = item.input assert isinstance(graph, Data) type_ = LabelTensor if use_lt else torch.Tensor assert isinstance(graph.x, type_) @@ -140,38 +140,14 @@ def test_getitem_graph_data_condition(use_lt, conditional_variables): if use_lt: assert graph.pos.labels == ["x", "y"] if conditional_variables: - assert hasattr(graph, "cond_vars") - cond_var = graph.cond_vars + assert hasattr(item, "conditional_variables") + cond_var = item.conditional_variables assert isinstance(cond_var, type_) - assert cond_var.shape == (20, 1) + assert cond_var.shape == (1, 20, 1) if use_lt: assert cond_var.labels == ["f"] -@pytest.mark.parametrize("use_lt", [False, True]) -@pytest.mark.parametrize("conditional_variables", [False, True]) -def test_getitems_graph_data_condition(use_lt, conditional_variables): - input_graph, cond_vars = _create_graph_data( - use_lt=use_lt, conditional_variables=conditional_variables - ) - condition = Condition(input=input_graph, conditional_variables=cond_vars) - idxs = [0, 1, 3] - items = condition[idxs] - assert isinstance(items, dict) - assert "input" in items - graphs = items["input"] - assert isinstance(graphs, LabelBatch) - assert graphs.num_graphs == 3 - if conditional_variables: - type_ = LabelTensor if use_lt else torch.Tensor - assert "conditional_variables" in items - cond_vars_batch = items["conditional_variables"] - assert isinstance(cond_vars_batch, type_) - assert cond_vars_batch.shape == (60, 1) - if use_lt: - assert cond_vars_batch.labels == ["f"] - - @pytest.mark.parametrize("use_lt", [False, True]) @pytest.mark.parametrize("conditional_variables", [False, True]) def test_getitems_tensor_data_condition(use_lt, conditional_variables): @@ -181,20 +157,131 @@ def test_getitems_tensor_data_condition(use_lt, conditional_variables): condition = Condition(input=input_tensor, conditional_variables=cond_vars) idxs = [0, 1, 3] items = condition[idxs] - assert isinstance(items, dict) - assert "input" in items + assert isinstance(items, _DataManager) + assert hasattr(items, "input") type_ = LabelTensor if use_lt else torch.Tensor - inputs = items["input"] + inputs = items.input assert isinstance(inputs, type_) assert inputs.shape == (3, 3) if use_lt: assert inputs.labels == ["x", "y", "z"] if conditional_variables: - assert "conditional_variables" in items - cond_vars_items = items["conditional_variables"] + assert hasattr(items, "conditional_variables") + cond_vars_items = items.conditional_variables assert isinstance(cond_vars_items, type_) assert cond_vars_items.shape == (3, 2) if use_lt: assert cond_vars_items.labels == ["a", "b"] else: - assert "conditional_variables" not in items + assert not hasattr(items, "conditional_variables") + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + idxs = [0, 1, 3] + items = condition[idxs] + assert isinstance(items, _DataManager) + assert hasattr(items, "input") + graphs = items.input + assert isinstance(graphs, list) + assert len(graphs) == 3 + for graph in graphs: + assert isinstance(graph, Data) + type_ = LabelTensor if use_lt else torch.Tensor + assert isinstance(graph.x, type_) + assert graph.x.shape == (20, 2) + if use_lt: + assert graph.x.labels == ["u", "v"] + assert isinstance(graph.pos, type_) + assert graph.pos.shape == (20, 2) + if use_lt: + assert graph.pos.labels == ["x", "y"] + if conditional_variables: + type_ = LabelTensor if use_lt else torch.Tensor + assert hasattr(items, "conditional_variables") + cond_vars_batch = items.conditional_variables + assert isinstance(cond_vars_batch, type_) + assert cond_vars_batch.shape == (3, 20, 1) + if use_lt: + assert cond_vars_batch.labels == ["f"] + + +if __name__ == "__main__": + test_init_tensor_data_condition(use_lt=False, conditional_variables=False) + print("Passed tensor data condition init test without LT and cond vars.") + test_init_tensor_data_condition(use_lt=True, conditional_variables=False) + print( + "Passed tensor data condition init test with LT and without cond vars." + ) + test_init_tensor_data_condition(use_lt=False, conditional_variables=True) + print( + "Passed tensor data condition init test without LT and with cond vars." + ) + test_init_tensor_data_condition(use_lt=True, conditional_variables=True) + print("Passed tensor data condition init test with LT and cond vars.") + test_init_graph_data_condition(use_lt=False, conditional_variables=False) + print("Passed graph data condition init test without LT and cond vars.") + test_init_graph_data_condition(use_lt=True, conditional_variables=False) + print( + "Passed graph data condition init test with LT and without cond vars." + ) + test_init_graph_data_condition(use_lt=False, conditional_variables=True) + print( + "Passed graph data condition init test without LT and with cond vars." + ) + test_init_graph_data_condition(use_lt=True, conditional_variables=True) + print("Passed graph data condition init test with LT and cond vars.") + + test_getitem_tensor_data_condition( + use_lt=False, conditional_variables=False + ) + print("Passed tensor data condition getitem test without LT and cond vars.") + test_getitem_tensor_data_condition(use_lt=True, conditional_variables=False) + print( + "Passed tensor data condition getitem test with LT and without cond vars." + ) + test_getitem_tensor_data_condition(use_lt=False, conditional_variables=True) + print( + "Passed tensor data condition getitem test without LT and with cond vars." + ) + test_getitem_tensor_data_condition(use_lt=True, conditional_variables=True) + print("Passed tensor data condition getitem test with LT and cond vars.") + + test_getitem_graph_data_condition(use_lt=False, conditional_variables=False) + print("Passed graph data condition getitem test without LT and cond vars.") + test_getitem_graph_data_condition(use_lt=True, conditional_variables=False) + print( + "Passed graph data condition getitem test with LT and without cond vars." + ) + test_getitem_graph_data_condition(use_lt=False, conditional_variables=True) + print( + "Passed graph data condition getitem test without LT and with cond vars." + ) + test_getitem_graph_data_condition(use_lt=True, conditional_variables=True) + print("Passed graph data condition getitem test with LT and cond vars.") + + test_getitems_tensor_data_condition( + use_lt=False, conditional_variables=False + ) + print( + "Passed tensor data condition getitems test without LT and cond vars." + ) + test_getitems_tensor_data_condition( + use_lt=True, conditional_variables=False + ) + print( + "Passed tensor data condition getitems test with LT and without cond vars." + ) + test_getitems_tensor_data_condition( + use_lt=False, conditional_variables=True + ) + print( + "Passed tensor data condition getitems test without LT and with cond vars." + ) + test_getitems_tensor_data_condition(use_lt=True, conditional_variables=True) + print("Passed tensor data condition getitems test with LT and cond vars.") diff --git a/tests/test_condition/test_domain_equation_condition.py b/tests/test_condition/test_domain_equation_condition.py index 2b7c78b00..24817f9b4 100644 --- a/tests/test_condition/test_domain_equation_condition.py +++ b/tests/test_condition/test_domain_equation_condition.py @@ -13,6 +13,8 @@ def test_init_domain_equation(): assert isinstance(cond, DomainEquationCondition) assert cond.domain is example_domain assert cond.equation is example_equation + assert hasattr(cond, "data") + assert cond.data is None def test_len_not_implemented(): @@ -25,3 +27,12 @@ def test_getitem_not_implemented(): cond = Condition(domain=example_domain, equation=FixedValue(0.0)) with pytest.raises(NotImplementedError): cond[0] + + +if __name__ == "__main__": + test_init_domain_equation() + print("Passed domain equation condition init test.") + test_len_not_implemented() + print("Passed domain equation condition len test.") + test_getitem_not_implemented() + print("Passed domain equation condition getitem test.") diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py index af11d382e..f02848a23 100644 --- a/tests/test_condition/test_input_equation_condition.py +++ b/tests/test_condition/test_input_equation_condition.py @@ -1,11 +1,10 @@ import torch from pina import Condition -from pina.condition.input_equation_condition import ( - InputTensorEquationCondition, - InputGraphEquationCondition, -) +from pina.condition.input_equation_condition import InputEquationCondition from pina.equation import Equation from pina import LabelTensor +from pina.graph import Graph +from pina.condition.data_manager import _DataManager def _create_pts_and_equation(): @@ -33,7 +32,7 @@ def dummy_equation(pts): def test_init_tensor_equation_condition(): pts, equation = _create_pts_and_equation() condition = Condition(input=pts, equation=equation) - assert isinstance(condition, InputTensorEquationCondition) + assert isinstance(condition, InputEquationCondition) assert condition.input.shape == (100, 2) assert condition.equation is equation @@ -41,10 +40,9 @@ def test_init_tensor_equation_condition(): def test_init_graph_equation_condition(): graph, equation = _create_graph_and_equation() condition = Condition(input=graph, equation=equation) - assert isinstance(condition, InputGraphEquationCondition) - assert isinstance(condition.input, list) - assert len(condition.input) == 1 - assert condition.input[0].x.shape == (100, 2) + assert isinstance(condition, InputEquationCondition) + assert isinstance(condition.input, Graph) + assert condition.input.x.shape == (100, 2) assert condition.equation is equation @@ -52,9 +50,9 @@ def test_getitem_tensor_equation_condition(): pts, equation = _create_pts_and_equation() condition = Condition(input=pts, equation=equation) item = condition[0] - assert isinstance(item, dict) - assert "input" in item - assert item["input"].shape == (2,) + assert isinstance(item, _DataManager) + assert hasattr(item, "input") + assert item.input.shape == (2,) def test_getitems_tensor_equation_condition(): @@ -62,6 +60,17 @@ def test_getitems_tensor_equation_condition(): condition = Condition(input=pts, equation=equation) idxs = [0, 1, 3] item = condition[idxs] - assert isinstance(item, dict) - assert "input" in item - assert item["input"].shape == (3, 2) + assert isinstance(item, _DataManager) + assert hasattr(item, "input") + assert item.input.shape == (3, 2) + + +if __name__ == "__main__": + test_init_tensor_equation_condition() + print("Passed tensor equation condition init test.") + test_init_graph_equation_condition() + print("Passed graph equation condition init test.") + test_getitem_tensor_equation_condition() + print("Passed tensor equation condition getitem test.") + test_getitems_tensor_equation_condition() + print("Passed tensor equation condition getitems test.") diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py index 81c3a9b24..95e3aac6a 100644 --- a/tests/test_condition/test_input_target_condition.py +++ b/tests/test_condition/test_input_target_condition.py @@ -2,11 +2,12 @@ import pytest from torch_geometric.data import Batch from pina import LabelTensor, Condition -from pina.condition import ( - TensorInputGraphTargetCondition, - TensorInputTensorTargetCondition, - GraphInputTensorTargetCondition, -) + +# from pina.condition import ( +# TensorInputGraphTargetCondition, +# TensorInputTensorTargetCondition, +# GraphInputTensorTargetCondition, +# ) from pina.graph import RadiusGraph, LabelBatch @@ -48,7 +49,7 @@ def _create_graph_data(tensor_input=True, use_lt=False): def test_init_tensor_input_tensor_target_condition(use_lt): input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) condition = Condition(input=input_tensor, target=target_tensor) - assert isinstance(condition, TensorInputTensorTargetCondition) + # assert isinstance(condition, TensorInputTensorTargetCondition) assert torch.allclose( condition.input, input_tensor ), "TensorInputTensorTargetCondition input failed" @@ -77,7 +78,7 @@ def test_init_tensor_input_tensor_target_condition(use_lt): def test_init_tensor_input_graph_target_condition(use_lt): target_graph, input_tensor = _create_graph_data(use_lt=use_lt) condition = Condition(input=input_tensor, target=target_graph) - assert isinstance(condition, TensorInputGraphTargetCondition) + # assert isinstance(condition, TensorInputGraphTargetCondition) assert torch.allclose( condition.input, input_tensor ), "TensorInputGraphTargetCondition input failed" @@ -106,7 +107,7 @@ def test_init_tensor_input_graph_target_condition(use_lt): def test_init_graph_input_tensor_target_condition(use_lt): input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) condition = Condition(input=input_graph, target=target_tensor) - assert isinstance(condition, GraphInputTensorTargetCondition) + # assert isinstance(condition, GraphInputTensorTargetCondition) for i in range(len(input_graph)): assert torch.allclose( condition.input[i].x, input_graph[i].x @@ -138,10 +139,10 @@ def test_getitem_tensor_input_tensor_target_condition(use_lt): for i in range(len(input_tensor)): item = condition[i] assert torch.allclose( - item["input"], input_tensor[i] + item.input, input_tensor[i] ), "TensorInputTensorTargetCondition __getitem__ input failed" assert torch.allclose( - item["target"], target_tensor[i] + item.target, target_tensor[i] ), "TensorInputTensorTargetCondition __getitem__ target failed" @@ -150,83 +151,59 @@ def test_getitem_tensor_input_graph_target_condition(use_lt): target_graph, input_tensor = _create_graph_data(use_lt=use_lt) condition = Condition(input=input_tensor, target=target_graph) for i in range(len(input_tensor)): - item = condition[i]["data"] + item = condition[i] assert torch.allclose( - item.x, input_tensor[i] + item.input, input_tensor[i] ), "TensorInputGraphTargetCondition __getitem__ input failed" assert torch.allclose( - item.y, target_graph[i].y + item.target.y, target_graph[i].y ), "TensorInputGraphTargetCondition __getitem__ target failed" if use_lt: assert isinstance( - item.y, LabelTensor + item.target.y, LabelTensor ), "TensorInputGraphTargetCondition __getitem__ target type failed" - assert item.y.labels == [ + assert item.target.y.labels == [ "u", "v", ], "TensorInputGraphTargetCondition __getitem__ target labels failed" -def test_getitem_graph_input_tensor_target_condition(): - input_graph, target_tensor = _create_graph_data(False) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_graph_input_tensor_target_condition(use_lt): + input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) condition = Condition(input=input_graph, target=target_tensor) + assert len(condition) == len(input_graph) for i in range(len(input_graph)): - item = condition[i]["data"] - print(item) + item = condition[i] assert torch.allclose( - item.x, input_graph[i].x + item.input.x, input_graph[i].x ), "GraphInputTensorTargetCondition __getitem__ input failed" assert torch.allclose( - item.y, target_tensor[i] + item.target, target_tensor[i] ), "GraphInputTensorTargetCondition __getitem__ target failed" - - -@pytest.mark.parametrize("use_lt", [True, False]) -def test_getitems_graph_input_tensor_target_condition(use_lt): - input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) - condition = Condition(input=input_graph, target=target_tensor) - indices = [0, 2, 4] - items = condition[indices] - candidate_input = items["input"] - candidate_target = items["target"] - - if use_lt: - input_ = LabelBatch.from_data_list([input_graph[i] for i in indices]) - target_ = LabelTensor.cat([target_tensor[i] for i in indices], dim=0) - else: - input_ = Batch.from_data_list([input_graph[i] for i in indices]) - target_ = torch.cat([target_tensor[i] for i in indices], dim=0) - assert torch.allclose( - candidate_input.x, input_.x - ), "GraphInputTensorTargetCondition __geitemsem__ input failed" - assert torch.allclose( - candidate_target, target_ - ), "GraphInputTensorTargetCondition __geitemsem__ input failed" - if use_lt: - assert isinstance( - candidate_target, LabelTensor - ), "GraphInputTensorTargetCondition __getitems__ target type failed" - assert candidate_target.labels == [ - "f" - ], "GraphInputTensorTargetCondition __getitems__ target labels failed" - - assert isinstance( - candidate_input.x, LabelTensor - ), "GraphInputTensorTargetCondition __getitems__ input type failed" - assert ( - candidate_input.x.labels == input_graph[0].x.labels - ), "GraphInputTensorTargetCondition __getitems__ input labels failed" + if use_lt: + assert isinstance( + item.input.x, LabelTensor + ), "GraphInputTensorTargetCondition __getitem__ input type failed" + assert ( + item.input.x.labels == input_graph[i].x.labels + ), "GraphInputTensorTargetCondition __getitem__ input labels failed" + assert isinstance( + item.target, LabelTensor + ), "GraphInputTensorTargetCondition __getitem__ target type failed" + assert item.target.labels == [ + "f" + ], "GraphInputTensorTargetCondition __getitem__ target labels failed" @pytest.mark.parametrize("use_lt", [True, False]) def test_getitems_tensor_input_tensor_target_condition(use_lt): - input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) condition = Condition(input=input_tensor, target=target_tensor) indices = [1, 3, 5, 7] items = condition[indices] - candidate_input = items["input"] - candidate_target = items["target"] + candidate_input = items.input + candidate_target = items.target if use_lt: input_ = LabelTensor.stack([input_tensor[i] for i in indices]) @@ -264,20 +241,26 @@ def test_getitems_tensor_input_graph_target_condition(use_lt): condition = Condition(input=input_tensor, target=target_graph) indices = [0, 2, 4] items = condition[indices] - candidate_input = items["input"] - candidate_target = items["target"] + candidate_input = items.input + candidate_target = items.target if use_lt: - input_ = LabelTensor.cat([input_tensor[i] for i in indices], dim=0) - target_ = LabelBatch.from_data_list([target_graph[i] for i in indices]) + input_ = LabelTensor.stack([input_tensor[i] for i in indices]) + # target_ = LabelBatch.from_data_list([target_graph[i] for i in indices]) else: - input_ = torch.cat([input_tensor[i] for i in indices], dim=0) - target_ = Batch.from_data_list([target_graph[i] for i in indices]) + input_ = torch.stack([input_tensor[i] for i in indices]) + # target_ = Batch.from_data_list([target_graph[i] for i in indices]) assert torch.allclose( candidate_input, input_ ), "TensorInputGraphTargetCondition __getitems__ input failed" - assert torch.allclose( - candidate_target.y, target_.y - ), "TensorInputGraphTargetCondition __getitems__ target failed" + + assert len(candidate_target) == len( + indices + ), "TensorInputGraphTargetCondition __getitems__ target length failed" + for idx, graph_idx in enumerate(indices): + assert torch.allclose( + candidate_target[idx].y, target_graph[graph_idx].y + ), "TensorInputGraphTargetCondition __getitems__ target failed" + if use_lt: assert isinstance( candidate_input, LabelTensor @@ -285,13 +268,35 @@ def test_getitems_tensor_input_graph_target_condition(use_lt): assert candidate_input.labels == [ "f" ], "TensorInputGraphTargetCondition __getitems__ input labels failed" - assert isinstance( - candidate_target.y, LabelTensor - ), "TensorInputGraphTargetCondition __getitems__ target type failed" - assert candidate_target.y.labels == [ - "u", - "v", - ], "TensorInputGraphTargetCondition __getitems__ target labels failed" + for g in candidate_target: + assert isinstance( + g.y, LabelTensor + ), "TensorInputGraphTargetCondition __getitems__ target type failed" + assert g.y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition __getitems__ target labels failed" + + +if __name__ == "__main__": + test_init_tensor_input_tensor_target_condition(use_lt=True) + test_init_tensor_input_tensor_target_condition(use_lt=False) + test_getitem_tensor_input_tensor_target_condition(use_lt=True) + test_getitem_tensor_input_tensor_target_condition(use_lt=False) + test_getitems_tensor_input_tensor_target_condition(use_lt=True) + test_getitems_tensor_input_tensor_target_condition(use_lt=False) + print("All tests for Tensor/Tensor conditions passed.") + test_init_tensor_input_graph_target_condition(use_lt=True) + test_init_tensor_input_graph_target_condition(use_lt=False) + test_init_graph_input_tensor_target_condition(use_lt=True) + test_init_graph_input_tensor_target_condition(use_lt=False) + print("All tests init for Tensor/Graph conditions passed.") -test_init_graph_input_tensor_target_condition(use_lt=True) + test_getitem_tensor_input_graph_target_condition(use_lt=True) + test_getitem_tensor_input_graph_target_condition(use_lt=False) + test_getitem_graph_input_tensor_target_condition(use_lt=True) + test_getitem_graph_input_tensor_target_condition(use_lt=False) + test_getitems_tensor_input_graph_target_condition(use_lt=True) + test_getitems_tensor_input_graph_target_condition(use_lt=False) + print("All tests getitem for Tensor/Graph conditions passed.") diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py new file mode 100644 index 000000000..335c8c567 --- /dev/null +++ b/tests/test_data_manager.py @@ -0,0 +1,156 @@ +import torch +from pina.condition.data_manager import ( + _DataManager, + _TensorDataManager, + _GraphDataManager, +) +from pina.graph import Graph +from pina.equation import Equation + + +def test_tensor_data_manager_init(): + pippo = torch.rand((10, 5)) + pluto = torch.rand((10, 7)) + paperino = torch.rand((10, 11)) + data_manager = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) + assert isinstance(data_manager, _TensorDataManager) + assert hasattr(data_manager, "pippo") + assert hasattr(data_manager, "pluto") + assert hasattr(data_manager, "paperino") + assert torch.equal(data_manager.pippo, pippo) + assert torch.equal(data_manager.pluto, pluto) + assert torch.equal(data_manager.paperino, paperino) + + paperino = Equation(lambda x: x**2) + data_manager3 = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) + assert isinstance(data_manager3, _TensorDataManager) + assert hasattr(data_manager3, "pippo") + assert hasattr(data_manager3, "pluto") + assert hasattr(data_manager3, "paperino") + assert torch.equal(data_manager3.pippo, pippo) + assert torch.equal(data_manager3.pluto, pluto) + assert isinstance(data_manager3.paperino, Equation) + + +def test_graph_data_manager_init(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + assert hasattr(data_manager, "graph_key") + assert data_manager.graph_key == "graph" + assert hasattr(data_manager, "graph") + assert len(data_manager.data) == 3 + for i in range(3): + g = data_manager.graph[i] + assert torch.equal(g.x, x[i]) + assert torch.equal(g.pos, pos[i]) + assert torch.equal(g.edge_index, edge_index[i]) + assert torch.equal(g.target, target[i]) + + +def test_graph_data_manager_getattribute(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + target_retrieved = data_manager.target + assert torch.equal(target_retrieved, target) + + +def test_graph_data_manager_getitem(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + item = data_manager[1] + assert isinstance(item, _DataManager) + assert hasattr(item, "graph_key") + assert item.graph_key == "graph" + assert hasattr(item, "graph") + assert torch.equal(item.graph.x, x[1]) + assert torch.equal(item.graph.pos, pos[1]) + assert torch.equal(item.graph.edge_index, edge_index[1]) + assert torch.equal(item.target, target[1].unsqueeze(0)) + + +def test_graph_data_create_batch(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + item1 = data_manager[0] + item2 = data_manager[1] + batch_data = _GraphDataManager._create_batch([item1, item2]) + assert hasattr(batch_data, "graph") + assert hasattr(batch_data, "target") + batched_graphs = batch_data.graph + batched_target = batch_data.target + assert batched_graphs.num_graphs == 2 + assert batched_target.shape == (20, 1) + assert torch.equal(batched_target, torch.cat([target[0], target[1]], dim=0)) + mps_data = batch_data.to("mps") + assert mps_data.graph.num_graphs == 2 + assert torch.equal(mps_data.target, batched_target.to("mps")) + assert torch.equal(mps_data.graph.x, batched_graphs.x.to("mps")) + + +def test_tensor_data_create_batch(): + pippo = torch.rand((10, 5)) + pluto = torch.rand((10, 7)) + paperino = torch.rand((10, 11)) + data_manager = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) + item1 = data_manager[0] + item2 = data_manager[1] + batch_data = _TensorDataManager._create_batch([item1, item2]) + assert hasattr(batch_data, "pippo") + assert hasattr(batch_data, "pluto") + assert hasattr(batch_data, "paperino") + assert torch.equal( + batch_data.pippo, torch.stack([pippo[0], pippo[1]], dim=0) + ) + assert torch.equal( + batch_data.pluto, torch.stack([pluto[0], pluto[1]], dim=0) + ) + assert torch.equal( + batch_data.paperino, torch.stack([paperino[0], paperino[1]], dim=0) + ) + mps_data = batch_data.to("mps") + assert torch.equal( + mps_data.pippo, torch.stack([pippo[0], pippo[1]], dim=0).to("mps") + ) + + +if __name__ == "__main__": + test_tensor_data_manager_init() + print("test_tensor_data_manager_init passed.") + test_graph_data_manager_init() + print("test_graph_data_manager_init passed.") + test_graph_data_manager_getattribute() + print("test_graph_data_manager_getattribute passed.") + test_graph_data_manager_getitem() + print("test_graph_data_manager_getitem passed.") + test_graph_data_create_batch() + print("test_graph_data_create_batch passed.") + test_tensor_data_create_batch() + print("test_tensor_data_create_batch passed.") From 717499029ce54a9b821cf9af034e4123abb6a8e8 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 19 Jan 2026 12:07:28 +0100 Subject: [PATCH 14/17] Remove old tests for conditions --- tests/test_condition.py | 171 ---------------------------------------- 1 file changed, 171 deletions(-) delete mode 100644 tests/test_condition.py diff --git a/tests/test_condition.py b/tests/test_condition.py deleted file mode 100644 index 8a5480499..000000000 --- a/tests/test_condition.py +++ /dev/null @@ -1,171 +0,0 @@ -# import torch -# import pytest - -# from pina import LabelTensor, Condition -# from pina.condition import ( -# TensorInputGraphTargetCondition, -# TensorInputTensorTargetCondition, -# # GraphInputGraphTargetCondition, -# GraphInputTensorTargetCondition, -# ) -# from pina.condition import ( -# InputTensorEquationCondition, -# InputGraphEquationCondition, -# DomainEquationCondition, -# ) -# from pina.condition import ( -# TensorDataCondition, -# GraphDataCondition, -# ) -# from pina.domain import CartesianDomain -# from pina.equation.equation_factory import FixedValue -# from pina.graph import RadiusGraph - -# def _create_tensor_data(): -# input_tensor = torch.rand((10, 3)) -# target_tensor = torch.rand((10, 2)) -# return input_tensor, target_tensor - -# def _create_graph_data(): -# x = torch.rand(10, 20, 2) -# pos = torch.rand(10, 20, 2) -# radius = 0.1 -# input_graph = [ -# RadiusGraph( -# x=x_, -# pos=pos_, -# radius=radius, -# ) -# for x_, pos_ in zip(x, pos) -# ] -# target_graph = [ -# RadiusGraph( -# y=x_, -# pos=pos_, -# radius=radius, -# ) -# for x_, pos_ in zip(x, pos) -# ] -# return input_graph, target_graph - -# def _create_lt_data(): -# input_lt = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) -# target_lt = LabelTensor(torch.rand((10, 2)), ["a", "b"]) -# return input_lt, target_lt - -# def _create_graph_lt_data(): -# x = LabelTensor(torch.rand((10, 20, 2)), ["u", "v"]) -# pos = LabelTensor(torch.rand((10, 20, 2)), ["x", "y"]) -# radius = 0.1 -# input_graph = [ -# RadiusGraph( -# x=x[i], -# pos=pos[i], -# radius=radius, -# ) -# for i in range(len(x)) -# ] -# target_graph = [ -# RadiusGraph( -# y=x[i], -# pos=pos[i], -# radius=radius, -# ) -# for i in range(len(x)) -# ] -# return input_graph, target_graph - -# example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) - -# def test_init_input_target(): -# input_tensor, target_tensor = _create_tensor_data() -# cond = Condition(input=input_tensor, target=target_tensor) -# assert isinstance(cond, TensorInputTensorTargetCondition) - -# input_lt, target_lt = _create_lt_data() -# cond = Condition(input=input_lt, target=target_lt) -# assert isinstance(cond, TensorInputTensorTargetCondition) - -# input_graph, target_graph = _create_graph_data() -# cond = Condition(input=input_tensor, target=target_graph) -# assert isinstance(cond, TensorInputGraphTargetCondition) - -# cond = Condition(input=input_graph, target=target_tensor) -# assert isinstance(cond, GraphInputTensorTargetCondition) - -# input_single_graph = input_graph[0] -# target_single_graph = target_graph[0] -# cond = Condition(input=input_lt, target=input_single_graph) -# assert isinstance(cond, TensorInputGraphTargetCondition) -# cond = Condition(input=input_single_graph, target=target_lt) -# assert isinstance(cond, GraphInputTensorTargetCondition) -# # cond = Condition(input=input_graph, target=target_graph) -# # assert isinstance(cond, GraphInputGraphTargetCondition) -# # cond = Condition(input=input_single_graph, target=target_single_graph) -# # assert isinstance(cond, GraphInputGraphTargetCondition) -# input_graph_lt, target_graph_lt = _create_graph_lt_data() - -# with pytest.raises(ValueError): -# Condition(input_tensor, input_tensor) -# with pytest.raises(ValueError): -# Condition(input=3.0, target="example") -# with pytest.raises(ValueError): -# Condition(input=example_domain, target=example_domain) - -# # Test wrong graph condition initialisation -# input = [input_graph[0], input_graph_lt[0]] -# target = [target_graph[0], target_graph_lt[0]] -# with pytest.raises(ValueError): -# Condition(input=input, target=target) - -# input_graph_lt[0].x.labels = ["a", "b"] -# with pytest.raises(ValueError): -# Condition(input=input_graph_lt, target=target_graph_lt) -# input_graph_lt[0].x.labels = ["u", "v"] - - -# def test_init_domain_equation(): -# input_tensor, _ = _create_tensor_data() -# input_graph, _ = _create_graph_data() -# cond = Condition(domain=example_domain, equation=FixedValue(0.0)) -# assert isinstance(cond, DomainEquationCondition) -# with pytest.raises(ValueError): -# Condition(example_domain, FixedValue(0.0)) -# with pytest.raises(ValueError): -# Condition(domain=3.0, equation="example") -# with pytest.raises(ValueError): -# Condition(domain=input_tensor, equation=input_graph) - - -# def test_init_input_equation(): -# input_lt, _ = _create_lt_data() -# input_graph_lt, _ = _create_graph_lt_data() -# input_tensor, _ = _create_tensor_data() -# input_graph, _ = _create_graph_data() -# cond = Condition(input=input_lt, equation=FixedValue(0.0)) -# assert isinstance(cond, InputTensorEquationCondition) -# cond = Condition(input=input_graph_lt, equation=FixedValue(0.0)) -# assert isinstance(cond, InputGraphEquationCondition) -# with pytest.raises(ValueError): -# cond = Condition(input=input_tensor, equation=FixedValue(0.0)) -# with pytest.raises(ValueError): -# Condition(example_domain, FixedValue(0.0)) -# with pytest.raises(ValueError): -# Condition(input=3.0, equation="example") -# with pytest.raises(ValueError): -# Condition(input=example_domain, equation=input_graph) - -# def test_init_data_condition(): -# input_lt, _ = _create_lt_data() -# input_tensor, _ = _create_tensor_data() -# input_graph, _ = _create_graph_data() -# cond = Condition(input=input_lt) -# assert isinstance(cond, TensorDataCondition) -# cond = Condition(input=input_tensor) -# assert isinstance(cond, TensorDataCondition) -# cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1)) -# assert isinstance(cond, TensorDataCondition) -# cond = Condition(input=input_graph) -# assert isinstance(cond, GraphDataCondition) -# cond = Condition(input=input_graph, conditional_variables=torch.tensor(1)) -# assert isinstance(cond, GraphDataCondition) From 1e47c108b986235bd2146175dcce2cd882d760d1 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 19 Jan 2026 12:09:10 +0100 Subject: [PATCH 15/17] Minor fixes in tests --- tests/test_condition/test_data_condition.py | 76 ------------------- .../test_domain_equation_condition.py | 9 --- .../test_input_equation_condition.py | 11 --- .../test_input_target_condition.py | 24 ------ tests/test_data_manager.py | 19 ----- 5 files changed, 139 deletions(-) diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py index 289830abc..63446a0bf 100644 --- a/tests/test_condition/test_data_condition.py +++ b/tests/test_condition/test_data_condition.py @@ -209,79 +209,3 @@ def test_getitems_graph_data_condition(use_lt, conditional_variables): assert cond_vars_batch.shape == (3, 20, 1) if use_lt: assert cond_vars_batch.labels == ["f"] - - -if __name__ == "__main__": - test_init_tensor_data_condition(use_lt=False, conditional_variables=False) - print("Passed tensor data condition init test without LT and cond vars.") - test_init_tensor_data_condition(use_lt=True, conditional_variables=False) - print( - "Passed tensor data condition init test with LT and without cond vars." - ) - test_init_tensor_data_condition(use_lt=False, conditional_variables=True) - print( - "Passed tensor data condition init test without LT and with cond vars." - ) - test_init_tensor_data_condition(use_lt=True, conditional_variables=True) - print("Passed tensor data condition init test with LT and cond vars.") - test_init_graph_data_condition(use_lt=False, conditional_variables=False) - print("Passed graph data condition init test without LT and cond vars.") - test_init_graph_data_condition(use_lt=True, conditional_variables=False) - print( - "Passed graph data condition init test with LT and without cond vars." - ) - test_init_graph_data_condition(use_lt=False, conditional_variables=True) - print( - "Passed graph data condition init test without LT and with cond vars." - ) - test_init_graph_data_condition(use_lt=True, conditional_variables=True) - print("Passed graph data condition init test with LT and cond vars.") - - test_getitem_tensor_data_condition( - use_lt=False, conditional_variables=False - ) - print("Passed tensor data condition getitem test without LT and cond vars.") - test_getitem_tensor_data_condition(use_lt=True, conditional_variables=False) - print( - "Passed tensor data condition getitem test with LT and without cond vars." - ) - test_getitem_tensor_data_condition(use_lt=False, conditional_variables=True) - print( - "Passed tensor data condition getitem test without LT and with cond vars." - ) - test_getitem_tensor_data_condition(use_lt=True, conditional_variables=True) - print("Passed tensor data condition getitem test with LT and cond vars.") - - test_getitem_graph_data_condition(use_lt=False, conditional_variables=False) - print("Passed graph data condition getitem test without LT and cond vars.") - test_getitem_graph_data_condition(use_lt=True, conditional_variables=False) - print( - "Passed graph data condition getitem test with LT and without cond vars." - ) - test_getitem_graph_data_condition(use_lt=False, conditional_variables=True) - print( - "Passed graph data condition getitem test without LT and with cond vars." - ) - test_getitem_graph_data_condition(use_lt=True, conditional_variables=True) - print("Passed graph data condition getitem test with LT and cond vars.") - - test_getitems_tensor_data_condition( - use_lt=False, conditional_variables=False - ) - print( - "Passed tensor data condition getitems test without LT and cond vars." - ) - test_getitems_tensor_data_condition( - use_lt=True, conditional_variables=False - ) - print( - "Passed tensor data condition getitems test with LT and without cond vars." - ) - test_getitems_tensor_data_condition( - use_lt=False, conditional_variables=True - ) - print( - "Passed tensor data condition getitems test without LT and with cond vars." - ) - test_getitems_tensor_data_condition(use_lt=True, conditional_variables=True) - print("Passed tensor data condition getitems test with LT and cond vars.") diff --git a/tests/test_condition/test_domain_equation_condition.py b/tests/test_condition/test_domain_equation_condition.py index 24817f9b4..f871a2a99 100644 --- a/tests/test_condition/test_domain_equation_condition.py +++ b/tests/test_condition/test_domain_equation_condition.py @@ -27,12 +27,3 @@ def test_getitem_not_implemented(): cond = Condition(domain=example_domain, equation=FixedValue(0.0)) with pytest.raises(NotImplementedError): cond[0] - - -if __name__ == "__main__": - test_init_domain_equation() - print("Passed domain equation condition init test.") - test_len_not_implemented() - print("Passed domain equation condition len test.") - test_getitem_not_implemented() - print("Passed domain equation condition getitem test.") diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py index f02848a23..bd77f96e9 100644 --- a/tests/test_condition/test_input_equation_condition.py +++ b/tests/test_condition/test_input_equation_condition.py @@ -63,14 +63,3 @@ def test_getitems_tensor_equation_condition(): assert isinstance(item, _DataManager) assert hasattr(item, "input") assert item.input.shape == (3, 2) - - -if __name__ == "__main__": - test_init_tensor_equation_condition() - print("Passed tensor equation condition init test.") - test_init_graph_equation_condition() - print("Passed graph equation condition init test.") - test_getitem_tensor_equation_condition() - print("Passed tensor equation condition getitem test.") - test_getitems_tensor_equation_condition() - print("Passed tensor equation condition getitems test.") diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py index 95e3aac6a..b1f661184 100644 --- a/tests/test_condition/test_input_target_condition.py +++ b/tests/test_condition/test_input_target_condition.py @@ -276,27 +276,3 @@ def test_getitems_tensor_input_graph_target_condition(use_lt): "u", "v", ], "TensorInputGraphTargetCondition __getitems__ target labels failed" - - -if __name__ == "__main__": - test_init_tensor_input_tensor_target_condition(use_lt=True) - test_init_tensor_input_tensor_target_condition(use_lt=False) - test_getitem_tensor_input_tensor_target_condition(use_lt=True) - test_getitem_tensor_input_tensor_target_condition(use_lt=False) - test_getitems_tensor_input_tensor_target_condition(use_lt=True) - test_getitems_tensor_input_tensor_target_condition(use_lt=False) - print("All tests for Tensor/Tensor conditions passed.") - - test_init_tensor_input_graph_target_condition(use_lt=True) - test_init_tensor_input_graph_target_condition(use_lt=False) - test_init_graph_input_tensor_target_condition(use_lt=True) - test_init_graph_input_tensor_target_condition(use_lt=False) - print("All tests init for Tensor/Graph conditions passed.") - - test_getitem_tensor_input_graph_target_condition(use_lt=True) - test_getitem_tensor_input_graph_target_condition(use_lt=False) - test_getitem_graph_input_tensor_target_condition(use_lt=True) - test_getitem_graph_input_tensor_target_condition(use_lt=False) - test_getitems_tensor_input_graph_target_condition(use_lt=True) - test_getitems_tensor_input_graph_target_condition(use_lt=False) - print("All tests getitem for Tensor/Graph conditions passed.") diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py index 335c8c567..ade91413c 100644 --- a/tests/test_data_manager.py +++ b/tests/test_data_manager.py @@ -135,22 +135,3 @@ def test_tensor_data_create_batch(): assert torch.equal( batch_data.paperino, torch.stack([paperino[0], paperino[1]], dim=0) ) - mps_data = batch_data.to("mps") - assert torch.equal( - mps_data.pippo, torch.stack([pippo[0], pippo[1]], dim=0).to("mps") - ) - - -if __name__ == "__main__": - test_tensor_data_manager_init() - print("test_tensor_data_manager_init passed.") - test_graph_data_manager_init() - print("test_graph_data_manager_init passed.") - test_graph_data_manager_getattribute() - print("test_graph_data_manager_getattribute passed.") - test_graph_data_manager_getitem() - print("test_graph_data_manager_getitem passed.") - test_graph_data_create_batch() - print("test_graph_data_create_batch passed.") - test_tensor_data_create_batch() - print("test_tensor_data_create_batch passed.") From 1350be767cc054292f5c59e3810738e0a4a96fa8 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 19 Jan 2026 14:07:09 +0100 Subject: [PATCH 16/17] add wrong tests --- tests/test_condition/test_data_condition.py | 16 ++++++++++++++++ .../test_input_equation_condition.py | 14 ++++++++++++++ .../test_input_target_condition.py | 12 ++++++++++++ 3 files changed, 42 insertions(+) diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py index 63446a0bf..08762a9f8 100644 --- a/tests/test_condition/test_data_condition.py +++ b/tests/test_condition/test_data_condition.py @@ -93,6 +93,22 @@ def test_init_graph_data_condition(use_lt, conditional_variables): assert graph.pos.labels == ["x", "y"] +def test_wrong_init_data_condition(): + input_tensor, cond_vars = _create_tensor_data() + # Wrong input type + with pytest.raises(ValueError): + Condition(input="invalid_input", conditional_variables=cond_vars) + # Wrong conditional_variables type + with pytest.raises(ValueError): + Condition(input=input_tensor, conditional_variables="invalid_cond_vars") + # Wrong input type (list with wrong elements) + with pytest.raises(ValueError): + Condition(input=[input_tensor], conditional_variables=cond_vars) + # Wrong conditional_variables type (list) + with pytest.raises(ValueError): + Condition(input=input_tensor, conditional_variables=[cond_vars]) + + @pytest.mark.parametrize("use_lt", [False, True]) @pytest.mark.parametrize("conditional_variables", [False, True]) def test_getitem_tensor_data_condition(use_lt, conditional_variables): diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py index bd77f96e9..b6bf62296 100644 --- a/tests/test_condition/test_input_equation_condition.py +++ b/tests/test_condition/test_input_equation_condition.py @@ -1,4 +1,5 @@ import torch +import pytest from pina import Condition from pina.condition.input_equation_condition import InputEquationCondition from pina.equation import Equation @@ -46,6 +47,19 @@ def test_init_graph_equation_condition(): assert condition.equation is equation +def test_wrong_init_equation_condition(): + pts, equation = _create_pts_and_equation() + # Wrong input type + with pytest.raises(ValueError): + Condition(input=torch.randn(10, 2), equation=equation) + # Wrong equation type + with pytest.raises(ValueError): + Condition(input=pts, equation="not_an_equation") + # Wrong input type (list with wrong elements) + with pytest.raises(ValueError): + Condition(input=[torch.randn(10, 2)], equation=equation) + + def test_getitem_tensor_equation_condition(): pts, equation = _create_pts_and_equation() condition = Condition(input=pts, equation=equation) diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py index b1f661184..f0dd2cfcf 100644 --- a/tests/test_condition/test_input_target_condition.py +++ b/tests/test_condition/test_input_target_condition.py @@ -132,6 +132,18 @@ def test_init_graph_input_tensor_target_condition(use_lt): ], "GraphInputTensorTargetCondition target labels failed" +def test_wrong_init(): + input_tensor, target_tensor = _create_tensor_data() + with pytest.raises(ValueError): + Condition(input="invalid_input", target=target_tensor) + with pytest.raises(ValueError): + Condition(input=input_tensor, target="invalid_target") + with pytest.raises(ValueError): + Condition(input=[input_tensor], target=target_tensor) + with pytest.raises(ValueError): + Condition(input=input_tensor, target=[target_tensor]) + + @pytest.mark.parametrize("use_lt", [True, False]) def test_getitem_tensor_input_tensor_target_condition(use_lt): input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) From 95f4773686a5713f904ea095b4fcc651e889fb4a Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 19 Jan 2026 16:05:45 +0100 Subject: [PATCH 17/17] revise docstring] --- pina/condition/data_manager.py | 226 ++++++++++++++++++++++++++------- 1 file changed, 181 insertions(+), 45 deletions(-) diff --git a/pina/condition/data_manager.py b/pina/condition/data_manager.py index 3a33cf99d..d26cd1941 100644 --- a/pina/condition/data_manager.py +++ b/pina/condition/data_manager.py @@ -1,20 +1,40 @@ +""" +Module for managing data in conditions. +""" + +from abc import ABC, abstractmethod import torch -from pina import LabelTensor -from pina.graph import Graph from torch_geometric.data import Data from torch_geometric.data.batch import Batch +from pina import LabelTensor +from pina.graph import Graph from pina.graph import LabelBatch from pina.equation.equation_interface import EquationInterface -from abc import ABC, abstractmethod class _BatchManager: + """ + Class for managing batches of data. + """ + def __init__(self, **dict): + """ + Store the batch data from the provided dictionary. + + :param dict dict: The dictionary containing the batch data. + """ self.keys = list(dict.keys()) for k, v in dict.items(): setattr(self, k, v) def to(self, device): + """ + Move all data in the batch to the specified device. + :param device: The device to move the data to. + :type device: torch.device | str + :return: The batch manager with data moved to the specified device. + :rtype: _BatchManager + """ for k in self.keys: val = getattr(self, k) setattr(self, k, val.to(device)) @@ -22,74 +42,130 @@ def to(self, device): class _DataManager(ABC): - """Interfaccia base ottimizzata per la gestione dei dati.""" + """ + Abstract base class for data managers. + + This class dynamically selects between :class:`_TensorDataManager` and + :class:`_GraphDataManager` based on the types of the input data. + """ def __new__(cls, **kwargs): - # Dispatching Factory + """ + Dynamically instantiate the appropriate subclass based on the types + of the input data. + - If all values in ``kwargs`` are instances of + :class:`torch.Tensor`, :class:`LabelTensor` then + :class:`_TensorDataManager` is instantiated. + - Otherwise, :class:`_GraphDataManager` is instantiated. + + :param dict kwargs: The keyword arguments containing the data. + :return: An instance of :class:`_TensorDataManager` or + :class:`_GraphDataManager`. + :rtype: _TensorDataManager | _GraphDataManager + """ + # If not called directly, proceed with normal instantiation if cls is not _DataManager: return super().__new__(cls) - # Determina se usare il gestore Tensori o Grafi - # (Controllo ottimizzato: evita cicli se possibile) + # Does the data contain only tensors/LabelTensors/Equations? is_tensor_only = all( isinstance(v, (torch.Tensor, LabelTensor, EquationInterface)) for v in kwargs.values() ) - + # Choose the appropriate subclass, GraphDataManager or TensorDataManager subclass = _TensorDataManager if is_tensor_only else _GraphDataManager return super().__new__(subclass) + def __init__(self, **kwargs): + """ + Initialize the data manager with the provided keyword arguments. + + :param dict kwargs: The keyword arguments containing the data. + """ + self.keys = list(kwargs.keys()) + @abstractmethod - def __len__(self) -> int: - pass + def __len__(self): + """ + Return the number of samples in the data manager. + """ @abstractmethod def __getitem__(self, idx): - pass + """ + Retrieve a data item or a subset of data items by index. + """ def to_dict(self): + """ + Convert the data manager to a dictionary. + """ return {k: getattr(self, k) for k in self.keys} - -# --- GESTORE TENSORI --- + @staticmethod + @abstractmethod + def create_batch(items): + """ + Create a batch from a list of data manager items. + """ class _TensorDataManager(_DataManager): - def __init__(self, **kwargs): - self.keys = list(kwargs.keys()) - self._data = kwargs # Memorizzazione in dizionario per accesso O(1) + """ + Data manager for tensor data. Handles data stored as `torch.Tensor` or + `LabelTensor`. + """ - # # Identifica i tensori una sola volta - # self._tensor_keys = [ - # k for k, v in kwargs.items() - # if isinstance(v, (torch.Tensor, LabelTensor)) - # ] + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.data = kwargs - # Espone le chiavi come attributi (facoltativo, ma mantiene compatibilità) for k, v in kwargs.items(): setattr(self, k, v) - def __len__(self) -> int: - # Prende la lunghezza dal primo tensore disponibile - return self._data[self.keys[0]].shape[0] + def __len__(self): + """ + Return the number of samples in the tensor data manager. + + :return: Number of samples. + :rtype: int + """ + return self.data[self.keys[0]].shape[0] def __getitem__(self, idx): + """ + Return a data item or a subset of data items by index. + + :param idx: Index or indices of the data items to retrieve. + :type idx: int | slice | list[int] | torch.Tensor + :return: A new :class:`_TensorDataManager` instance containing the + selected data items. + :rtype: _TensorDataManager + """ # Mapping efficiente degli elementi new_data = { - k: (self._data[k][idx] if k in self.keys else self._data[k]) + k: (self.data[k][idx] if k in self.keys else self.data[k]) for k in self.keys } return _TensorDataManager(**new_data) @staticmethod def _create_batch(items): + """ + Create a batch from a list of :class:`_TensorDataManager` items. + + :param list items: List of :class:`_TensorDataManager` items to batch. + :return: A new :class:`_BatchManager` instance containing the batched + data. + :rtype: _BatchManager + """ if not items: return None first = items[0] batch_data = {} for k in first.keys: - vals = [it._data[k] for it in items] + vals = [it.data[k] for it in items] sample = vals[0] if isinstance(sample, (torch.Tensor, LabelTensor)): @@ -106,9 +182,19 @@ def _create_batch(items): class _GraphDataManager(_DataManager): + """ + Data manager for graph data. Handles data stored as :class:`Graph`, + :class:`Data`, or lists/tuples of these types. Moreover , it can also manage + associated tensors stored as :class:`torch.Tensor` or :class:`LabelTensor`. + """ + def __init__(self, **kwargs): - self.keys = list(kwargs.keys()) + """ + Initialize the graph data manager with the provided keyword arguments. + :param dict kwargs: The keyword arguments containing the data. + """ + super().__init__(**kwargs) self.graph_key = next( k for k, v in kwargs.items() @@ -122,34 +208,56 @@ def __init__(self, **kwargs): and isinstance(kwargs[k], (torch.Tensor, LabelTensor)) ] - # Prepara la lista di grafi internamente + # Prepare graphs and assign tensors self.data = self._prepare_graphs(kwargs) def _prepare_graphs(self, kwargs): - graphs = kwargs[self.graph_key] + """ + Store tensors in the corresponding graphs. + + :param dict kwargs: The keyword arguments containing the graphs and + associated tensors. + :return: A list of graphs with tensors assigned. + :rtype: list[Graph] | list[Data] + """ + graphs = kwargs.pop(self.graph_key) if not isinstance(graphs, (list, tuple)): graphs = [graphs] - # Iniezione attributi nei grafi - for k in self.keys: - val_source = kwargs[k] - # Ottimizzazione: se la lunghezza coincide, distribuiamo i tensori, - # altrimenti trattiamo il tensore come costante per tutti. - use_idx = ( - len(val_source) == len(graphs) - if hasattr(val_source, "__len__") - else False - ) - + n_graphs = len(graphs) + for name, tensor in kwargs.items(): + # Verify consistency between number of graphs and tensor samples + if n_graphs != tensor.shape[0]: + raise ValueError( + f"Number of graphs ({n_graphs}) does not match " + f"number of samples for key '{name}' " + f"({kwargs[name].shape[0]})." + ) + # Assign tensors to graphs for i, g in enumerate(graphs): - setattr(g, k, val_source[i] if use_idx else val_source) + setattr(g, name, tensor[i]) + return graphs - def __len__(self) -> int: + def __len__(self): + """ + Return the number of graphs in the graph data manager. + + :return: Number of graphs. + :rtype: int + """ return len(self.data) def __getattr__(self, name): - + """ + Override attribute access to retrieve tensors or graphs. If the graph + key is requested, return the list of graphs. If a tensor key is + requested, stack the tensors from all graphs and return the result. + + :param str name: The name of the attribute to retrieve. + :return: The requested tensor or graph. + :rtype: torch.Tensor | LabelTensor | Graph | list[Graph] | Data | + """ # If the requested attribute is a tensor key, stack the tensors from # all graphs if name in self.keys: @@ -169,15 +277,34 @@ def __getattr__(self, name): @classmethod def _init_from_graphs_list(cls, graphs, graph_key, keys): + """ + Initialize a :class:`_GraphDataManager` instance from a list of graphs. + This is used internally to create subsets of the data manager, without + going through the full initialization process. + + :param list graphs: List of graphs to initialize the data manager with. + :param str graph_key: Key under which the graphs are stored. + :param list keys: List of tensor keys associated with the graphs. + :return: A new :class:`_GraphDataManager` instance. + :rtype: _GraphDataManager + """ # Create a new instance without calling __init__ obj = _GraphDataManager.__new__(_GraphDataManager) obj.graph_key = graph_key obj.keys = keys - # obj._tensor_keys = tensor_keys obj.data = graphs return obj def __getitem__(self, idx): + """ + Retrieve a graph or a subset of graphs by index. + + :param idx: Index or indices of the graphs to retrieve. + :type idx: int | slice | list[int] | torch.Tensor + :return: A new :class:`_GraphDataManager` instance containing the + selected graphs. + :rtype: _GraphDataManager + """ # Manage int and slice directly if isinstance(idx, (int, slice)): selected = self.data[idx] @@ -199,7 +326,16 @@ def __getitem__(self, idx): keys=self.keys, ) + @staticmethod def _create_batch(items): + """ + Create a batch from a list of :class:`_GraphDataManager` items. + + :param list items: List of :class:`_GraphDataManager` items to batch. + :return: A new :class:`_BatchManager` instance containing the batched + data. + :rtype: _BatchManager + """ if not items: return None first = items[0]