diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 4e57811fb..c104e0ef9 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -5,13 +5,7 @@ "ConditionInterface", "DomainEquationCondition", "InputTargetCondition", - "TensorInputTensorTargetCondition", - "TensorInputGraphTargetCondition", - "GraphInputTensorTargetCondition", - "GraphInputGraphTargetCondition", "InputEquationCondition", - "InputTensorEquationCondition", - "InputGraphEquationCondition", "DataCondition", "GraphDataCondition", "TensorDataCondition", @@ -20,20 +14,6 @@ from .condition_interface import ConditionInterface from .condition import Condition from .domain_equation_condition import DomainEquationCondition -from .input_target_condition import ( - InputTargetCondition, - TensorInputTensorTargetCondition, - TensorInputGraphTargetCondition, - GraphInputTensorTargetCondition, - GraphInputGraphTargetCondition, -) -from .input_equation_condition import ( - InputEquationCondition, - InputTensorEquationCondition, - InputGraphEquationCondition, -) -from .data_condition import ( - DataCondition, - GraphDataCondition, - TensorDataCondition, -) +from .input_target_condition import InputTargetCondition +from .input_equation_condition import InputEquationCondition +from .data_condition import DataCondition diff --git a/pina/condition/condition.py b/pina/condition/condition.py index ad8764c9f..3c43f7176 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -86,12 +86,12 @@ class Condition: """ # Combine all possible keyword arguments from the different Condition types - __slots__ = list( + available_kwargs = list( set( - InputTargetCondition.__slots__ - + InputEquationCondition.__slots__ - + DomainEquationCondition.__slots__ - + DataCondition.__slots__ + InputTargetCondition.__fields__ + + InputEquationCondition.__fields__ + + DomainEquationCondition.__fields__ + + DataCondition.__fields__ ) ) @@ -112,28 +112,28 @@ def __new__(cls, *args, **kwargs): if len(args) != 0: raise ValueError( "Condition takes only the following keyword " - f"arguments: {Condition.__slots__}." + f"arguments: {Condition.available_kwargs}." ) # Class specialization based on keyword arguments sorted_keys = sorted(kwargs.keys()) # Input - Target Condition - if sorted_keys == sorted(InputTargetCondition.__slots__): + if sorted_keys == sorted(InputTargetCondition.__fields__): return InputTargetCondition(**kwargs) # Input - Equation Condition - if sorted_keys == sorted(InputEquationCondition.__slots__): + if sorted_keys == sorted(InputEquationCondition.__fields__): return InputEquationCondition(**kwargs) # Domain - Equation Condition - if sorted_keys == sorted(DomainEquationCondition.__slots__): + if sorted_keys == sorted(DomainEquationCondition.__fields__): return DomainEquationCondition(**kwargs) # Data Condition if ( - sorted_keys == sorted(DataCondition.__slots__) - or sorted_keys[0] == DataCondition.__slots__[0] + sorted_keys == sorted(DataCondition.__fields__) + or sorted_keys[0] == DataCondition.__fields__[0] ): return DataCondition(**kwargs) diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py new file mode 100644 index 000000000..388e7141d --- /dev/null +++ b/pina/condition/condition_base.py @@ -0,0 +1,231 @@ +""" +Base class for conditions. +""" + +from copy import deepcopy +from functools import partial +import torch +from torch_geometric.data import Data, Batch +from torch.utils.data import DataLoader +from .condition_interface import ConditionInterface +from ..graph import Graph, LabelBatch +from ..label_tensor import LabelTensor + + +class TensorCondition: + """ + Base class for tensor conditions. + """ + + def store_data(self, **kwargs): + """ + Store data for standard tensor condition + + :param kwargs: Keyword arguments representing the data to be stored. + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = {} + for key, value in kwargs.items(): + data[key] = value + return data + + +class GraphCondition: + """ + Base class for graph conditions. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + example = kwargs.get(self.graph_field)[0] + self.batch_fn = ( + LabelBatch.from_data_list + if isinstance(example, Graph) + else Batch.from_data_list + ) + + def store_data(self, **kwargs): + """ + Store data for graph condition + + :param graphs: List of graphs to store data in. + :type graphs: list[Graph] | list[Data] + :param tensors: List of tensors to store in the graphs. + :type tensors: list[torch.Tensor] | list[LabelTensor] + :param key: Key under which to store the tensors in the graphs. + :type key: str + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = [] + graphs = kwargs.get(self.graph_field) + for i, graph in enumerate(graphs): + new_graph = deepcopy(graph) + for key in self.tensor_fields: + tensor = kwargs[key][i] + mapping_key = self.keys_map.get(key) + setattr(new_graph, mapping_key, tensor) + data.append(new_graph) + return {"data": data} + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.get_multiple_data(idx) + return {"data": self.data["data"][idx]} + + def get_multiple_data(self, indices): + """ + Get multiple data items based on the provided indices. + + :param List[int] indices: List of indices to retrieve. + :return: Dictionary containing 'input' and 'target' data. + :rtype: dict + """ + to_return_dict = {} + data = self.batch_fn([self.data["data"][i] for i in indices]) + to_return_dict[self.graph_field] = data + for key in self.tensor_fields: + mapping_key = self.keys_map.get(key) + y = getattr(data, mapping_key) + delattr(data, mapping_key) # Avoid duplication of y on GPU memory + to_return_dict[key] = y + return to_return_dict + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + collated_graphs = super().automatic_batching_collate_fn(batch)["data"] + to_return_dict = {} + for key in cls.tensor_fields: + mapping_key = cls.keys_map.get(key) + tensor = getattr(collated_graphs, mapping_key) + to_return_dict[key] = tensor + delattr(collated_graphs, mapping_key) + to_return_dict[cls.graph_field] = collated_graphs + return to_return_dict + + +class ConditionBase(ConditionInterface): + """ + Base abstract class for all conditions in PINA. + This class provides common functionality for handling data storage, + batching, and interaction with the associated problem. + """ + + collate_fn_dict = { + "tensor": torch.stack, + "label_tensor": LabelTensor.stack, + "graph": LabelBatch.from_data_list, + "data": Batch.from_data_list, + } + + def __init__(self, **kwargs): + """ + Initialization of the :class:`ConditionBase` class. + + :param kwargs: Keyword arguments representing the data to be stored. + """ + super().__init__() + self.data = self.store_data(**kwargs) + + @property + def problem(self): + """ + Return the problem associated with this condition. + + :return: Problem associated with this condition. + :rtype: ~pina.problem.abstract_problem.AbstractProblem + """ + return self._problem + + @problem.setter + def problem(self, value): + """ + Set the problem associated with this condition. + + :param pina.problem.abstract_problem.AbstractProblem value: The problem + to associate with this condition + """ + self._problem = value + + def __len__(self): + """ + Return the number of data points in the condition. + + :return: Number of data points. + :rtype: int + """ + return len(self.data) + + def __getitem__(self, idx): + """ + Return the data point(s) at the specified index. + + :param idx: Index(es) of the data point(s) to retrieve. + :type idx: int | list[int] + :return: Data point(s) at the specified index. + """ + return self.data[idx] + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function for automatic batching to be used in DataLoader. + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + if not batch: + return {} + instance_class = batch[0].__class__ + return instance_class._create_batch(batch) + + @staticmethod + def collate_fn(batch, condition): + """ + Collate function for custom batching to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :param condition: The condition instance. + :type condition: ConditionBase + :return: A collated batch. + :rtype: dict + """ + print("Custom collate_fn called") + print("batch:", batch) + data = condition.data[batch] + return data + + def create_dataloader( + self, dataset, batch_size, shuffle, automatic_batching + ): + """ + Create a DataLoader for the condition. + + :param int batch_size: The batch size for the DataLoader. + :param bool shuffle: Whether to shuffle the data. Default is ``False``. + :return: The DataLoader for the condition. + :rtype: torch.utils.data.DataLoader + """ + if batch_size == len(dataset): + pass # will be updated in the near future + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=( + partial(self.collate_fn, condition=self) + if not automatic_batching + else self.automatic_batching_collate_fn + ), + ) diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index b0264517c..229b9a025 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -1,9 +1,6 @@ """Module for the Condition interface.""" -from abc import ABCMeta -from torch_geometric.data import Data -from ..label_tensor import LabelTensor -from ..graph import Graph +from abc import ABCMeta, abstractmethod class ConditionInterface(metaclass=ABCMeta): @@ -15,13 +12,14 @@ class ConditionInterface(metaclass=ABCMeta): description of all available conditions and how to instantiate them. """ - def __init__(self): + @abstractmethod + def __init__(self, **kwargs): """ Initialization of the :class:`ConditionInterface` class. """ - self._problem = None @property + @abstractmethod def problem(self): """ Return the problem associated with this condition. @@ -29,9 +27,9 @@ def problem(self): :return: Problem associated with this condition. :rtype: ~pina.problem.abstract_problem.AbstractProblem """ - return self._problem @problem.setter + @abstractmethod def problem(self, value): """ Set the problem associated with this condition. @@ -39,88 +37,21 @@ def problem(self, value): :param pina.problem.abstract_problem.AbstractProblem value: The problem to associate with this condition """ - self._problem = value - @staticmethod - def _check_graph_list_consistency(data_list): + @abstractmethod + def __len__(self): """ - Check the consistency of the list of Data | Graph objects. - The following checks are performed: + Return the number of data points in the condition. - - All elements in the list must be of the same type (either - :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). - - - All elements in the list must have the same keys. - - - The data type of each tensor must be consistent across all elements. - - - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels - must also be consistent across all elements. - - :param data_list: The list of Data | Graph objects to check. - :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] - :raises ValueError: If the input types are invalid. - :raises ValueError: If all elements in the list do not have the same - keys. - :raises ValueError: If the type of each tensor is not consistent across - all elements in the list. - :raises ValueError: If the labels of the LabelTensors are not consistent - across all elements in the list. + :return: Number of data points. + :rtype: int """ - # If the data is a Graph or Data object, perform no checks - if isinstance(data_list, (Graph, Data)): - return - - # Check all elements in the list are of the same type - if not all(isinstance(i, (Graph, Data)) for i in data_list): - raise ValueError( - "Invalid input. Please, provide either Data or Graph objects." - ) - - # Store the keys, data types and labels of the first element - data = data_list[0] - keys = sorted(list(data.keys())) - data_types = {name: tensor.__class__ for name, tensor in data.items()} - labels = { - name: tensor.labels - for name, tensor in data.items() - if isinstance(tensor, LabelTensor) - } - - # Iterate over the list of Data | Graph objects - for data in data_list[1:]: - - # Check that all elements in the list have the same keys - if sorted(list(data.keys())) != keys: - raise ValueError( - "All elements in the list must have the same keys." - ) - - # Iterate over the tensors in the current element - for name, tensor in data.items(): - # Check that the type of each tensor is consistent - if tensor.__class__ is not data_types[name]: - raise ValueError( - f"Data {name} must be a {data_types[name]}, got " - f"{tensor.__class__}" - ) - - # Check that the labels of each LabelTensor are consistent - if isinstance(tensor, LabelTensor): - if tensor.labels != labels[name]: - raise ValueError( - "LabelTensor must have the same labels" - ) - def __getattribute__(self, name): + @abstractmethod + def __getitem__(self, idx): """ - Get an attribute from the object. + Return the data point(s) at the specified index. - :param str name: The name of the attribute to get. - :return: The requested attribute. - :rtype: Any + :param int idx: Index of the data point(s) to retrieve. + :return: Data point(s) at the specified index. """ - to_return = super().__getattribute__(name) - if isinstance(to_return, (Graph, Data)): - to_return = [to_return] - return to_return diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 5f5e7d36b..debb71bad 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -2,12 +2,13 @@ import torch from torch_geometric.data import Data -from .condition_interface import ConditionInterface +from .condition_base import ConditionBase from ..label_tensor import LabelTensor from ..graph import Graph +from ..condition.data_manager import _DataManager -class DataCondition(ConditionInterface): +class DataCondition(ConditionBase): """ The class :class:`DataCondition` defines an unsupervised condition based on ``input`` data. This condition is typically used in data-driven problems, @@ -16,17 +17,6 @@ class DataCondition(ConditionInterface): the provided data during training. Optional ``conditional_variables`` can be specified when the model depends on additional parameters. - The class automatically selects the appropriate implementation based on the - type of the ``input`` data. Depending on whether the ``input`` is a tensor - or graph-based data, one of the following specialized subclasses is - instantiated: - - - :class:`TensorDataCondition`: For cases where the ``input`` is either a - :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. - - - :class:`GraphDataCondition`: For cases where the ``input`` is either a - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` object. - :Example: >>> from pina import Condition, LabelTensor @@ -38,14 +28,14 @@ class DataCondition(ConditionInterface): """ # Available input data types - __slots__ = ["input", "conditional_variables"] + __fields__ = ["input", "conditional_variables"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_conditional_variables_cls = (torch.Tensor, LabelTensor) def __new__(cls, input, conditional_variables=None): """ - Instantiate the appropriate subclass of :class:`DataCondition` based on - the type of the ``input``. + Check the types of ``input`` and ``conditional_variables`` and + instantiate a class of :class:`DataCondition` accordingly. :param input: The input data for the condition. :type input: torch.Tensor | LabelTensor | Graph | @@ -63,58 +53,71 @@ def __new__(cls, input, conditional_variables=None): if cls != DataCondition: return super().__new__(cls) - # If the input is a tensor - if isinstance(input, (torch.Tensor, LabelTensor)): - subclass = TensorDataCondition - return subclass.__new__(subclass, input, conditional_variables) - - # If the input is a graph - if isinstance(input, (Graph, Data, list, tuple)): - cls._check_graph_list_consistency(input) - subclass = GraphDataCondition - return subclass.__new__(subclass, input, conditional_variables) - - # If the input is not of the correct type raise an error - raise ValueError( - "Invalid input type. Expected one of the following: " - "torch.Tensor, LabelTensor, Graph, Data or " - "an iterable of the previous types." - ) - - def __init__(self, input, conditional_variables=None): + # Check input type + if not isinstance(input, cls._avail_input_cls): + raise ValueError( + "Invalid input type. Expected one of the following: " + "torch.Tensor, LabelTensor, Graph, Data or " + "an iterable of the previous types." + ) + if isinstance(input, (list, tuple)): + for item in input: + if not isinstance(item, (Data, Graph)): + raise ValueError( + "if input is a list or tuple, all its elements must" + " be of type Graph or Data." + ) + + # Check conditional_variables type + if conditional_variables is not None: + if not isinstance( + conditional_variables, cls._avail_conditional_variables_cls + ): + raise ValueError( + "Invalid conditional_variables type. Expected one of the " + "following: torch.Tensor, LabelTensor." + ) + + return super().__new__(cls) + + def store_data(self, **kwargs): """ - Initialization of the :class:`DataCondition` class. + Store the input data and conditional variables in a dictionary. :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] + :type input: torch.Tensor | LabelTensor | Graph | + Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] :param conditional_variables: The conditional variables for the - condition. Default is ``None``. + condition. :type conditional_variables: torch.Tensor | LabelTensor - - .. note:: - - If ``input`` is a list of :class:`~pina.graph.Graph` or - :class:`~torch_geometric.data.Data`, all elements in - the list must share the same structure, with matching keys and - consistent data types. + :return: A dictionary containing the stored data. + :rtype: dict """ - super().__init__() - self.input = input - self.conditional_variables = conditional_variables - + data_dict = {"input": kwargs.get("input")} + cond_vars = kwargs.get("conditional_variables", None) + if cond_vars is not None: + data_dict["conditional_variables"] = cond_vars + return _DataManager(**data_dict) + + @property + def conditional_variables(self): + """ + Return the conditional variables for the condition. -class TensorDataCondition(DataCondition): - """ - Specialization of the :class:`DataCondition` class for the case where - ``input`` is either a :class:`~pina.label_tensor.LabelTensor` object or a - :class:`torch.Tensor` object. - """ + :return: The conditional variables. + :rtype: torch.Tensor | LabelTensor | None + """ + if hasattr(self.data, "conditional_variables"): + return self.data.conditional_variables + return None + @property + def input(self): + """ + Return the input data for the condition. -class GraphDataCondition(DataCondition): - """ - Specialization of the :class:`DataCondition` class for the case where - ``input`` is either a :class:`~pina.graph.Graph` object or a - :class:`~torch_geometric.data.Data` object. - """ + :return: The input data. + :rtype: torch.Tensor | LabelTensor | Graph | Data | + list[Graph] | list[Data] | tuple[Graph] | tuple[Data] + """ + return self.data.input diff --git a/pina/condition/data_manager.py b/pina/condition/data_manager.py new file mode 100644 index 000000000..d26cd1941 --- /dev/null +++ b/pina/condition/data_manager.py @@ -0,0 +1,358 @@ +""" +Module for managing data in conditions. +""" + +from abc import ABC, abstractmethod +import torch +from torch_geometric.data import Data +from torch_geometric.data.batch import Batch +from pina import LabelTensor +from pina.graph import Graph +from pina.graph import LabelBatch +from pina.equation.equation_interface import EquationInterface + + +class _BatchManager: + """ + Class for managing batches of data. + """ + + def __init__(self, **dict): + """ + Store the batch data from the provided dictionary. + + :param dict dict: The dictionary containing the batch data. + """ + self.keys = list(dict.keys()) + for k, v in dict.items(): + setattr(self, k, v) + + def to(self, device): + """ + Move all data in the batch to the specified device. + :param device: The device to move the data to. + :type device: torch.device | str + :return: The batch manager with data moved to the specified device. + :rtype: _BatchManager + """ + for k in self.keys: + val = getattr(self, k) + setattr(self, k, val.to(device)) + return self + + +class _DataManager(ABC): + """ + Abstract base class for data managers. + + This class dynamically selects between :class:`_TensorDataManager` and + :class:`_GraphDataManager` based on the types of the input data. + """ + + def __new__(cls, **kwargs): + """ + Dynamically instantiate the appropriate subclass based on the types + of the input data. + - If all values in ``kwargs`` are instances of + :class:`torch.Tensor`, :class:`LabelTensor` then + :class:`_TensorDataManager` is instantiated. + - Otherwise, :class:`_GraphDataManager` is instantiated. + + :param dict kwargs: The keyword arguments containing the data. + :return: An instance of :class:`_TensorDataManager` or + :class:`_GraphDataManager`. + :rtype: _TensorDataManager | _GraphDataManager + """ + # If not called directly, proceed with normal instantiation + if cls is not _DataManager: + return super().__new__(cls) + + # Does the data contain only tensors/LabelTensors/Equations? + is_tensor_only = all( + isinstance(v, (torch.Tensor, LabelTensor, EquationInterface)) + for v in kwargs.values() + ) + # Choose the appropriate subclass, GraphDataManager or TensorDataManager + subclass = _TensorDataManager if is_tensor_only else _GraphDataManager + return super().__new__(subclass) + + def __init__(self, **kwargs): + """ + Initialize the data manager with the provided keyword arguments. + + :param dict kwargs: The keyword arguments containing the data. + """ + self.keys = list(kwargs.keys()) + + @abstractmethod + def __len__(self): + """ + Return the number of samples in the data manager. + """ + + @abstractmethod + def __getitem__(self, idx): + """ + Retrieve a data item or a subset of data items by index. + """ + + def to_dict(self): + """ + Convert the data manager to a dictionary. + """ + return {k: getattr(self, k) for k in self.keys} + + @staticmethod + @abstractmethod + def create_batch(items): + """ + Create a batch from a list of data manager items. + """ + + +class _TensorDataManager(_DataManager): + """ + Data manager for tensor data. Handles data stored as `torch.Tensor` or + `LabelTensor`. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.data = kwargs + + for k, v in kwargs.items(): + setattr(self, k, v) + + def __len__(self): + """ + Return the number of samples in the tensor data manager. + + :return: Number of samples. + :rtype: int + """ + return self.data[self.keys[0]].shape[0] + + def __getitem__(self, idx): + """ + Return a data item or a subset of data items by index. + + :param idx: Index or indices of the data items to retrieve. + :type idx: int | slice | list[int] | torch.Tensor + :return: A new :class:`_TensorDataManager` instance containing the + selected data items. + :rtype: _TensorDataManager + """ + # Mapping efficiente degli elementi + new_data = { + k: (self.data[k][idx] if k in self.keys else self.data[k]) + for k in self.keys + } + return _TensorDataManager(**new_data) + + @staticmethod + def _create_batch(items): + """ + Create a batch from a list of :class:`_TensorDataManager` items. + + :param list items: List of :class:`_TensorDataManager` items to batch. + :return: A new :class:`_BatchManager` instance containing the batched + data. + :rtype: _BatchManager + """ + if not items: + return None + first = items[0] + batch_data = {} + + for k in first.keys: + vals = [it.data[k] for it in items] + sample = vals[0] + + if isinstance(sample, (torch.Tensor, LabelTensor)): + batch_fn = ( + LabelTensor.stack + if isinstance(sample, LabelTensor) + else torch.stack + ) + batch_data[k] = batch_fn(vals, dim=0) + else: + batch_data[k] = sample + + return _BatchManager(**batch_data) + + +class _GraphDataManager(_DataManager): + """ + Data manager for graph data. Handles data stored as :class:`Graph`, + :class:`Data`, or lists/tuples of these types. Moreover , it can also manage + associated tensors stored as :class:`torch.Tensor` or :class:`LabelTensor`. + """ + + def __init__(self, **kwargs): + """ + Initialize the graph data manager with the provided keyword arguments. + + :param dict kwargs: The keyword arguments containing the data. + """ + super().__init__(**kwargs) + self.graph_key = next( + k + for k, v in kwargs.items() + if isinstance(v, (Graph, Data, list, tuple)) + ) + + self.keys = [ + k + for k in self.keys + if k != self.graph_key + and isinstance(kwargs[k], (torch.Tensor, LabelTensor)) + ] + + # Prepare graphs and assign tensors + self.data = self._prepare_graphs(kwargs) + + def _prepare_graphs(self, kwargs): + """ + Store tensors in the corresponding graphs. + + :param dict kwargs: The keyword arguments containing the graphs and + associated tensors. + :return: A list of graphs with tensors assigned. + :rtype: list[Graph] | list[Data] + """ + graphs = kwargs.pop(self.graph_key) + if not isinstance(graphs, (list, tuple)): + graphs = [graphs] + + n_graphs = len(graphs) + for name, tensor in kwargs.items(): + # Verify consistency between number of graphs and tensor samples + if n_graphs != tensor.shape[0]: + raise ValueError( + f"Number of graphs ({n_graphs}) does not match " + f"number of samples for key '{name}' " + f"({kwargs[name].shape[0]})." + ) + # Assign tensors to graphs + for i, g in enumerate(graphs): + setattr(g, name, tensor[i]) + + return graphs + + def __len__(self): + """ + Return the number of graphs in the graph data manager. + + :return: Number of graphs. + :rtype: int + """ + return len(self.data) + + def __getattr__(self, name): + """ + Override attribute access to retrieve tensors or graphs. If the graph + key is requested, return the list of graphs. If a tensor key is + requested, stack the tensors from all graphs and return the result. + + :param str name: The name of the attribute to retrieve. + :return: The requested tensor or graph. + :rtype: torch.Tensor | LabelTensor | Graph | list[Graph] | Data | + """ + # If the requested attribute is a tensor key, stack the tensors from + # all graphs + if name in self.keys: + tensors = [getattr(g, name) for g in self.data] + batch_fn = ( + LabelTensor.stack + if isinstance(tensors[0], LabelTensor) + else torch.stack + ) + return batch_fn(tensors) + + # If the requested attribute is the graph key, return the graphs + if name == self.graph_key: + return self.data if len(self.data) > 1 else self.data[0] + + super().__getattribute__(name) + + @classmethod + def _init_from_graphs_list(cls, graphs, graph_key, keys): + """ + Initialize a :class:`_GraphDataManager` instance from a list of graphs. + This is used internally to create subsets of the data manager, without + going through the full initialization process. + + :param list graphs: List of graphs to initialize the data manager with. + :param str graph_key: Key under which the graphs are stored. + :param list keys: List of tensor keys associated with the graphs. + :return: A new :class:`_GraphDataManager` instance. + :rtype: _GraphDataManager + """ + # Create a new instance without calling __init__ + obj = _GraphDataManager.__new__(_GraphDataManager) + obj.graph_key = graph_key + obj.keys = keys + obj.data = graphs + return obj + + def __getitem__(self, idx): + """ + Retrieve a graph or a subset of graphs by index. + + :param idx: Index or indices of the graphs to retrieve. + :type idx: int | slice | list[int] | torch.Tensor + :return: A new :class:`_GraphDataManager` instance containing the + selected graphs. + :rtype: _GraphDataManager + """ + # Manage int and slice directly + if isinstance(idx, (int, slice)): + selected = self.data[idx] + # Manage list or tensor of indices + elif isinstance(idx, (list, torch.Tensor)): + selected = [self.data[i] for i in idx] + else: + raise TypeError(f"Invalid index type: {type(idx)}") + + # Ensure selected is a list + if not isinstance(selected, list): + selected = [selected] + + # Return a new _GraphDataManager instance with the selected graphs + return _GraphDataManager._init_from_graphs_list( + selected, + # tensor_keys=self._tensor_keys, + graph_key=self.graph_key, + keys=self.keys, + ) + + @staticmethod + def _create_batch(items): + """ + Create a batch from a list of :class:`_GraphDataManager` items. + + :param list items: List of :class:`_GraphDataManager` items to batch. + :return: A new :class:`_BatchManager` instance containing the batched + data. + :rtype: _BatchManager + """ + if not items: + return None + first = items[0] + batching_fn = ( + LabelBatch.from_data_list + if isinstance(first.data[0], Graph) + else Batch.from_data_list + ) + + graphs_to_batch = [item.data[0] for item in items] + batch_graph = batching_fn(graphs_to_batch) + + batch_data = {first.graph_key: batch_graph} + + for k in first.keys: + if k == first.graph_key: + continue + batch_data[k] = getattr(batch_graph, k) + delattr(batch_graph, k) + return _BatchManager(**batch_data) diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 3565c0b41..b8d465581 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -1,12 +1,11 @@ """Module for the DomainEquationCondition class.""" -from .condition_interface import ConditionInterface -from ..utils import check_consistency +from .condition_base import ConditionBase from ..domain import DomainInterface from ..equation.equation_interface import EquationInterface -class DomainEquationCondition(ConditionInterface): +class DomainEquationCondition(ConditionBase): """ The class :class:`DomainEquationCondition` defines a condition based on a ``domain`` and an ``equation``. This condition is typically used in @@ -30,35 +29,66 @@ class DomainEquationCondition(ConditionInterface): """ # Available slots - __slots__ = ["domain", "equation"] + __fields__ = ["domain", "equation"] - def __init__(self, domain, equation): + _avail_domain_cls = DomainInterface + _avail_equation_cls = EquationInterface + + def __new__(cls, domain, equation): """ - Initialization of the :class:`DomainEquationCondition` class. + Check the types of ``domain`` and ``equation`` and instantiate an + instance of :class:`DomainEquationCondition`. - :param DomainInterface domain: The domain over which the equation is - defined. - :param EquationInterface equation: The equation to be satisfied over the - specified domain. + :return: An instance of :class:`DomainEquationCondition`. + :rtype: pina.condition.domain_equation_condition.DomainEquationCondition + :raises ValueError: If ``domain`` is not of type :class:`DomainInterface` or + ``equation`` is not of type :class:` """ - super().__init__() - self.domain = domain - self.equation = equation + if not isinstance(domain, cls._avail_domain_cls): + raise ValueError( + "The domain must be an instance of DomainInterface." + ) + + if not isinstance(equation, cls._avail_equation_cls): + raise ValueError( + "The equation must be an instance of EquationInterface." + ) + + return super().__new__(cls) - def __setattr__(self, key, value): + def __len__(self): """ - Set the attribute value with type checking. + Raise NotImplementedError since the number of points is determined by + the domain sampling strategy. - :param str key: The attribute name. - :param any value: The value to set for the attribute. + :raises NotImplementedError: Always raised since the number of points is + determined by the domain sampling strategy. """ - if key == "domain": - check_consistency(value, (DomainInterface, str)) - DomainEquationCondition.__dict__[key].__set__(self, value) + raise NotImplementedError( + "`__len__` method is not implemented for " + "`DomainEquationCondition` since the number of points is " + "determined by the domain sampling strategy." + ) - elif key == "equation": - check_consistency(value, (EquationInterface)) - DomainEquationCondition.__dict__[key].__set__(self, value) + def __getitem__(self, idx): + """ + Raise NotImplementedError since data retrieval is not applicable. - elif key in ("_problem"): - super().__setattr__(key, value) + :param int idx: Index of the data point(s) to retrieve. + :raises NotImplementedError: Always raised since data retrieval is not + applicable for this condition. + """ + raise NotImplementedError( + "`__getitem__` method is not implemented for " + "`DomainEquationCondition`" + ) + + def store_data(self, **kwargs): + """ + Store data for the condition. No data is stored for this condition. + + :return: An empty dictionary since no data is stored. + :rtype: dict + """ + self.domain = kwargs.get("domain") + self.equation = kwargs.get("equation") diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index d32597894..b8ef0209b 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -1,13 +1,13 @@ """Module for the InputEquationCondition class and its subclasses.""" -from .condition_interface import ConditionInterface +from .condition_base import ConditionBase from ..label_tensor import LabelTensor from ..graph import Graph -from ..utils import check_consistency from ..equation.equation_interface import EquationInterface +from ..condition.data_manager import _DataManager -class InputEquationCondition(ConditionInterface): +class InputEquationCondition(ConditionBase): """ The class :class:`InputEquationCondition` defines a condition based on ``input`` data and an ``equation``. This condition is typically used in @@ -15,17 +15,6 @@ class InputEquationCondition(ConditionInterface): ``equation`` through the evaluation of the residual performed at the provided ``input``. - The class automatically selects the appropriate implementation based on - the type of the ``input`` data. Depending on whether the ``input`` is a - tensor or graph-based data, one of the following specialized subclasses is - instantiated: - - - :class:`InputTensorEquationCondition`: For cases where the ``input`` - data is a :class:`~pina.label_tensor.LabelTensor` object. - - - :class:`InputGraphEquationCondition`: For cases where the ``input`` data - is a :class:`~pina.graph.Graph` object. - :Example: >>> from pina import Condition, LabelTensor @@ -41,14 +30,14 @@ class InputEquationCondition(ConditionInterface): """ # Available input data types - __slots__ = ["input", "equation"] - _avail_input_cls = (LabelTensor, Graph, list, tuple) + __fields__ = ["input", "equation"] + _avail_input_cls = (LabelTensor, Graph) _avail_equation_cls = EquationInterface def __new__(cls, input, equation): """ - Instantiate the appropriate subclass of :class:`InputEquationCondition` - based on the type of ``input`` data. + Check the types of ``input`` and ``equation`` and instantiate a class + of :class:`InputEquationCondition` accordingly. :param input: The input data for the condition. :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] @@ -62,96 +51,59 @@ def __new__(cls, input, equation): :raises ValueError: If input is not of type :class:`~pina.graph.Graph` or :class:`~pina.label_tensor.LabelTensor`. """ - if cls != InputEquationCondition: - return super().__new__(cls) - - # If the input is a Graph object - if isinstance(input, (Graph, list, tuple)): - subclass = InputGraphEquationCondition - cls._check_graph_list_consistency(input) - subclass._check_label_tensor(input) - return subclass.__new__(subclass, input, equation) - - # If the input is a LabelTensor - if isinstance(input, LabelTensor): - subclass = InputTensorEquationCondition - return subclass.__new__(subclass, input, equation) - - # If the input is not a LabelTensor or a Graph object raise an error - raise ValueError( - "The input data object must be a LabelTensor or a Graph object." - ) - - def __init__(self, input, equation): - """ - Initialization of the :class:`InputEquationCondition` class. - :param input: The input data for the condition. - :type input: LabelTensor | Graph | list[Graph] | tuple[Graph] - :param EquationInterface equation: The equation to be satisfied over the - specified input points. + # CHeck input type + if not isinstance(input, cls._avail_input_cls): + raise ValueError( + "The input data object must be a LabelTensor or a Graph object." + ) - .. note:: + # Check equation type + if not isinstance(equation, cls._avail_equation_cls): + raise ValueError( + "The equation must be an instance of EquationInterface." + ) - If ``input`` is a list of :class:`~pina.graph.Graph` all elements in - the list must share the same structure, with matching keys and - consistent data types. - """ - super().__init__() - self.input = input - self.equation = equation + return super().__new__(cls) - def __setattr__(self, key, value): + def store_data(self, **kwargs): """ - Set the attribute value with type checking. - - :param str key: The attribute name. - :param any value: The value to set for the attribute. + Store the input data in a :class:`_DataManager` object. + :param dict kwargs: The keyword arguments containing the input data. """ - if key == "input": - check_consistency(value, self._avail_input_cls) - InputEquationCondition.__dict__[key].__set__(self, value) - - elif key == "equation": - check_consistency(value, self._avail_equation_cls) - InputEquationCondition.__dict__[key].__set__(self, value) - - elif key in ("_problem"): - super().__setattr__(key, value) - + self.__setattr__("_equation", kwargs.pop("equation")) + return _DataManager(**kwargs) -class InputTensorEquationCondition(InputEquationCondition): - """ - Specialization of the :class:`InputEquationCondition` class for the case - where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object. - """ - - -class InputGraphEquationCondition(InputEquationCondition): - """ - Specialization of the :class:`InputEquationCondition` class for the case - where ``input`` is a :class:`~pina.graph.Graph` object. - """ + @property + def input(self): + """ + Return the input data for the condition. - @staticmethod - def _check_label_tensor(input): + :return: The input data. + :rtype: LabelTensor | Graph | list[Graph] | tuple[Graph] """ - Check if at least one :class:`~pina.label_tensor.LabelTensor` is present - in the ``input`` object. + return self.data.input - :param input: The input data. - :type input: torch.Tensor | Graph | list[Graph] | tuple[Graph] - :raises ValueError: If the input data object does not contain at least - one LabelTensor. + @property + def equation(self): """ + Return the equation associated with this condition. - # Store the first element: it is sufficient to check this since all - # elements must have the same type and structure (already checked). - data = input[0] if isinstance(input, (list, tuple)) else input + :return: Equation associated with this condition. + :rtype: EquationInterface + """ + return self._equation - # Check if the input data contains at least one LabelTensor - for v in data.values(): - if isinstance(v, LabelTensor): - return + @equation.setter + def equation(self, value): + """ + Set the equation associated with this condition. - raise ValueError("The input must contain at least one LabelTensor.") + :param EquationInterface value: The equation to associate with this + condition + """ + if not isinstance(value, EquationInterface): + raise TypeError( + "The equation must be an instance of EquationInterface." + ) + self._equation = value diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 07b07bb7b..e5939c498 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -6,10 +6,11 @@ from torch_geometric.data import Data from ..label_tensor import LabelTensor from ..graph import Graph -from .condition_interface import ConditionInterface +from .condition_base import ConditionBase +from .data_manager import _DataManager -class InputTargetCondition(ConditionInterface): +class InputTargetCondition(ConditionBase): """ The :class:`InputTargetCondition` class represents a supervised condition defined by both ``input`` and ``target`` data. The model is trained to @@ -17,29 +18,6 @@ class InputTargetCondition(ConditionInterface): include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. - The class automatically selects the appropriate implementation based on - the types of ``input`` and ``target``. Depending on whether the ``input`` - and ``target`` are tensors or graph-based data, one of the following - specialized subclasses is instantiated: - - - :class:`TensorInputTensorTargetCondition`: For cases where both ``input`` - and ``target`` data are either :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor`. - - - :class:`TensorInputGraphTargetCondition`: For cases where ``input`` is - either a :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor` - and ``target`` is either a :class:`~pina.graph.Graph` or a - :class:`torch_geometric.data.Data`. - - - :class:`GraphInputTensorTargetCondition`: For cases where ``input`` is - either a :class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data` - and ``target`` is either a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor`. - - - :class:`GraphInputGraphTargetCondition`: For cases where both ``input`` - and ``target`` are either :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data`. - :Example: >>> from pina import Condition, LabelTensor @@ -55,154 +33,82 @@ class InputTargetCondition(ConditionInterface): """ # Available input and target data types - __slots__ = ["input", "target"] + __fields__ = ["input", "target"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) def __new__(cls, input, target): """ - Instantiate the appropriate subclass of :class:`InputTargetCondition` - based on the types of both ``input`` and ``target`` data. + Check the types of ``input`` and ``target`` data and instantiate the + :class:`InputTargetCondition`. :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] + :type input: torch.Tensor | LabelTensor | Graph | + Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] :param target: The target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :return: The subclass of InputTargetCondition. - :rtype: pina.condition.input_target_condition. - TensorInputTensorTargetCondition | - pina.condition.input_target_condition. - TensorInputGraphTargetCondition | - pina.condition.input_target_condition. - GraphInputTensorTargetCondition | - pina.condition.input_target_condition.GraphInputGraphTargetCondition - - :raises ValueError: If ``input`` and/or ``target`` are not of type - :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, - :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. - """ - if cls != InputTargetCondition: - return super().__new__(cls) - - # Tensor - Tensor - if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( - target, (torch.Tensor, LabelTensor) - ): - subclass = TensorInputTensorTargetCondition - return subclass.__new__(subclass, input, target) - - # Tensor - Graph - if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( - target, (Graph, Data, list, tuple) - ): - cls._check_graph_list_consistency(target) - subclass = TensorInputGraphTargetCondition - return subclass.__new__(subclass, input, target) - - # Graph - Tensor - if isinstance(input, (Graph, Data, list, tuple)) and isinstance( - target, (torch.Tensor, LabelTensor) - ): - cls._check_graph_list_consistency(input) - subclass = GraphInputTensorTargetCondition - return subclass.__new__(subclass, input, target) - - # Graph - Graph - if isinstance(input, (Graph, Data, list, tuple)) and isinstance( - target, (Graph, Data, list, tuple) - ): - cls._check_graph_list_consistency(input) - cls._check_graph_list_consistency(target) - subclass = GraphInputGraphTargetCondition - return subclass.__new__(subclass, input, target) - - # If the input and/or target are not of the correct type raise an error - raise ValueError( - "Invalid input | target types." - "Please provide either torch_geometric.data.Data, Graph, " - "LabelTensor or torch.Tensor objects." - ) - - def __init__(self, input, target): + :type target: torch.Tensor | LabelTensor | Graph | + Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] + :return: An instance of :class:`InputTargetCondition`. + :rtype: pina.condition.input_target_condition.InputTargetCondition + :raises ValueError: If ``input`` or ``target`` are not of supported types. """ - Initialization of the :class:`InputTargetCondition` class. - :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] + if not isinstance(input, cls._avail_input_cls): + raise ValueError( + "Invalid input type. Expected one of the following: " + "torch.Tensor, LabelTensor, Graph, Data or " + "list/tuple of Graph/Data objects." + ) + elif isinstance(input, (list, tuple)): + for item in input: + if not isinstance(item, (Graph, Data)): + raise ValueError( + "If target is a list or tuple, all its elements " + "must be of type Graph or Data." + ) + + if not isinstance(target, cls._avail_output_cls): + raise ValueError( + "Invalid target type. Expected one of the following: " + "torch.Tensor, LabelTensor, Graph, Data or " + "list/tuple of Graph/Data objects." + ) + elif isinstance(target, (list, tuple)): + for item in target: + if not isinstance(item, (Graph, Data)): + raise ValueError( + "If target is a list or tuple, all its elements " + "must be of type Graph or Data." + ) - .. note:: + return super().__new__(cls) - If either ``input`` or ``target`` is a list of - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` - objects, all elements in the list must share the same structure, - with matching keys and consistent data types. + def store_data(self, **kwargs): + """ + Store the input and target data in a :class:`_DataManager` object. + :param dict kwargs: The keyword arguments containing the input and + target data. """ - super().__init__() - self._check_input_target_len(input, target) - self.input = input - self.target = target + return _DataManager(**kwargs) - @staticmethod - def _check_input_target_len(input, target): + @property + def input(self): """ - Check that the length of the input and target lists are the same. + Return the input data for the condition. - :param input: The input data. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + :return: The input data. + :rtype: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :raises ValueError: If the lengths of the input and target lists do not - match. """ - if isinstance(input, (Graph, Data)) or isinstance( - target, (Graph, Data) - ): - return - - # Raise an error if the lengths of the input and target do not match - if len(input) != len(target): - raise ValueError( - "The input and target lists must have the same length." - ) - - -class TensorInputTensorTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - both ``input`` and ``target`` are :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` objects. - """ - - -class TensorInputGraphTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - ``input`` is either a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor` object and ``target`` is either a - :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object. - """ - - -class GraphInputTensorTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - ``input`` is either a :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data` object and ``target`` is either a - :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. - """ + return self.data.input + @property + def target(self): + """ + Return the target data for the condition. -class GraphInputGraphTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - both ``input`` and ``target`` are either :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data` objects. - """ + :return: The target data. + :rtype: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + list[Data] | tuple[Graph] | tuple[Data] + """ + return self.data.target diff --git a/tests/test_condition.py b/tests/test_condition.py deleted file mode 100644 index 9199f2bd9..000000000 --- a/tests/test_condition.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import pytest - -from pina import LabelTensor, Condition -from pina.condition import ( - TensorInputGraphTargetCondition, - TensorInputTensorTargetCondition, - GraphInputGraphTargetCondition, - GraphInputTensorTargetCondition, -) -from pina.condition import ( - InputTensorEquationCondition, - InputGraphEquationCondition, - DomainEquationCondition, -) -from pina.condition import ( - TensorDataCondition, - GraphDataCondition, -) -from pina.domain import CartesianDomain -from pina.equation.equation_factory import FixedValue -from pina.graph import RadiusGraph - -example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) - -input_tensor = torch.rand((10, 3)) -target_tensor = torch.rand((10, 2)) -input_lt = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) -target_lt = LabelTensor(torch.rand((10, 2)), ["a", "b"]) - -x = torch.rand(10, 20, 2) -pos = torch.rand(10, 20, 2) -radius = 0.1 -input_graph = [ - RadiusGraph( - x=x_, - pos=pos_, - radius=radius, - ) - for x_, pos_ in zip(x, pos) -] -target_graph = [ - RadiusGraph( - x=x_, - pos=pos_, - radius=radius, - ) - for x_, pos_ in zip(x, pos) -] - -x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) -pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) -radius = 0.1 -input_graph_lt = [ - RadiusGraph( - x=x[i], - pos=pos[i], - radius=radius, - ) - for i in range(len(x)) -] -target_graph_lt = [ - RadiusGraph( - x=x[i], - pos=pos[i], - radius=radius, - ) - for i in range(len(x)) -] - -input_single_graph = input_graph[0] -target_single_graph = target_graph[0] - - -def test_init_input_target(): - cond = Condition(input=input_tensor, target=target_tensor) - assert isinstance(cond, TensorInputTensorTargetCondition) - cond = Condition(input=input_tensor, target=target_tensor) - assert isinstance(cond, TensorInputTensorTargetCondition) - cond = Condition(input=input_tensor, target=target_graph) - assert isinstance(cond, TensorInputGraphTargetCondition) - cond = Condition(input=input_graph, target=target_tensor) - assert isinstance(cond, GraphInputTensorTargetCondition) - cond = Condition(input=input_graph, target=target_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - - cond = Condition(input=input_lt, target=input_single_graph) - assert isinstance(cond, TensorInputGraphTargetCondition) - cond = Condition(input=input_single_graph, target=target_lt) - assert isinstance(cond, GraphInputTensorTargetCondition) - cond = Condition(input=input_graph, target=target_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - cond = Condition(input=input_single_graph, target=target_single_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - - with pytest.raises(ValueError): - Condition(input_tensor, input_tensor) - with pytest.raises(ValueError): - Condition(input=3.0, target="example") - with pytest.raises(ValueError): - Condition(input=example_domain, target=example_domain) - - # Test wrong graph condition initialisation - input = [input_graph[0], input_graph_lt[0]] - target = [target_graph[0], target_graph_lt[0]] - with pytest.raises(ValueError): - Condition(input=input, target=target) - - input_graph_lt[0].x.labels = ["a", "b"] - with pytest.raises(ValueError): - Condition(input=input_graph_lt, target=target_graph_lt) - input_graph_lt[0].x.labels = ["u", "v"] - - -def test_init_domain_equation(): - cond = Condition(domain=example_domain, equation=FixedValue(0.0)) - assert isinstance(cond, DomainEquationCondition) - with pytest.raises(ValueError): - Condition(example_domain, FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(domain=3.0, equation="example") - with pytest.raises(ValueError): - Condition(domain=input_tensor, equation=input_graph) - - -def test_init_input_equation(): - cond = Condition(input=input_lt, equation=FixedValue(0.0)) - assert isinstance(cond, InputTensorEquationCondition) - cond = Condition(input=input_graph_lt, equation=FixedValue(0.0)) - assert isinstance(cond, InputGraphEquationCondition) - with pytest.raises(ValueError): - cond = Condition(input=input_tensor, equation=FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(example_domain, FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(input=3.0, equation="example") - with pytest.raises(ValueError): - Condition(input=example_domain, equation=input_graph) - - -test_init_input_equation() - - -def test_init_data_condition(): - cond = Condition(input=input_lt) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_tensor) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1)) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_graph) - assert isinstance(cond, GraphDataCondition) - cond = Condition(input=input_graph, conditional_variables=torch.tensor(1)) - assert isinstance(cond, GraphDataCondition) diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py new file mode 100644 index 000000000..08762a9f8 --- /dev/null +++ b/tests/test_condition/test_data_condition.py @@ -0,0 +1,227 @@ +import pytest +import torch +from pina import Condition, LabelTensor +from pina.condition import DataCondition +from pina.graph import RadiusGraph +from torch_geometric.data import Data, Batch +from pina.graph import Graph, LabelBatch +from pina.condition.data_manager import _DataManager + + +def _create_tensor_data(use_lt=False, conditional_variables=False): + input_tensor = torch.rand((10, 3)) + if use_lt: + input_tensor = LabelTensor(input_tensor, ["x", "y", "z"]) + if conditional_variables: + cond_vars = torch.rand((10, 2)) + if use_lt: + cond_vars = LabelTensor(cond_vars, ["a", "b"]) + else: + cond_vars = None + return input_tensor, cond_vars + + +def _create_graph_data(use_lt=False, conditional_variables=False): + if use_lt: + x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) + pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + else: + x = torch.rand(10, 20, 2) + pos = torch.rand(10, 20, 2) + radius = 0.1 + input_graph = [ + RadiusGraph(pos=pos[i], radius=radius, x=x[i]) for i in range(len(x)) + ] + if conditional_variables: + if use_lt: + cond_vars = LabelTensor(torch.rand(10, 20, 1), ["f"]) + else: + cond_vars = torch.rand(10, 20, 1) + else: + cond_vars = None + return input_graph, cond_vars + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + print(condition) + assert isinstance(condition, DataCondition) + + type_ = LabelTensor if use_lt else torch.Tensor + if conditional_variables: + assert condition.conditional_variables is not None + assert isinstance(condition.conditional_variables, type_) + if use_lt: + assert condition.conditional_variables.labels == ["a", "b"] + else: + assert condition.conditional_variables is None + assert isinstance(condition.input, type_) + if use_lt: + assert condition.input.labels == ["x", "y", "z"] + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + assert isinstance(condition, DataCondition) + type_ = LabelTensor if use_lt else torch.Tensor + if conditional_variables: + assert condition.conditional_variables is not None + assert isinstance(condition.conditional_variables, type_) + if use_lt: + assert condition.conditional_variables.labels == ["f"] + else: + assert condition.conditional_variables is None + # assert "conditional_variables" not in condition.data.keys() + assert isinstance(condition.input, list) + for graph in condition.input: + assert isinstance(graph, Data) + assert isinstance(graph.x, type_) + if use_lt: + assert graph.x.labels == ["u", "v"] + assert isinstance(graph.pos, type_) + if use_lt: + assert graph.pos.labels == ["x", "y"] + + +def test_wrong_init_data_condition(): + input_tensor, cond_vars = _create_tensor_data() + # Wrong input type + with pytest.raises(ValueError): + Condition(input="invalid_input", conditional_variables=cond_vars) + # Wrong conditional_variables type + with pytest.raises(ValueError): + Condition(input=input_tensor, conditional_variables="invalid_cond_vars") + # Wrong input type (list with wrong elements) + with pytest.raises(ValueError): + Condition(input=[input_tensor], conditional_variables=cond_vars) + # Wrong conditional_variables type (list) + with pytest.raises(ValueError): + Condition(input=input_tensor, conditional_variables=[cond_vars]) + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + item = condition[0] + assert isinstance(item, _DataManager) + assert hasattr(item, "input") + type_ = LabelTensor if use_lt else torch.Tensor + assert isinstance(item.input, type_) + assert item.input.shape == (3,) + if type_ is LabelTensor: + assert item.input.labels == ["x", "y", "z"] + if conditional_variables: + assert hasattr(item, "conditional_variables") + assert isinstance(item.conditional_variables, type_) + assert item.conditional_variables.shape == (2,) + if type_ is LabelTensor: + assert item.conditional_variables.labels == ["a", "b"] + else: + assert not hasattr(item, "conditional_variables") + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + item = condition[0] + assert isinstance(item, _DataManager) + assert hasattr(item, "input") + graph = item.input + assert isinstance(graph, Data) + type_ = LabelTensor if use_lt else torch.Tensor + assert isinstance(graph.x, type_) + assert graph.x.shape == (20, 2) + if use_lt: + assert graph.x.labels == ["u", "v"] + assert isinstance(graph.pos, type_) + assert graph.pos.shape == (20, 2) + if use_lt: + assert graph.pos.labels == ["x", "y"] + if conditional_variables: + assert hasattr(item, "conditional_variables") + cond_var = item.conditional_variables + assert isinstance(cond_var, type_) + assert cond_var.shape == (1, 20, 1) + if use_lt: + assert cond_var.labels == ["f"] + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + idxs = [0, 1, 3] + items = condition[idxs] + assert isinstance(items, _DataManager) + assert hasattr(items, "input") + type_ = LabelTensor if use_lt else torch.Tensor + inputs = items.input + assert isinstance(inputs, type_) + assert inputs.shape == (3, 3) + if use_lt: + assert inputs.labels == ["x", "y", "z"] + if conditional_variables: + assert hasattr(items, "conditional_variables") + cond_vars_items = items.conditional_variables + assert isinstance(cond_vars_items, type_) + assert cond_vars_items.shape == (3, 2) + if use_lt: + assert cond_vars_items.labels == ["a", "b"] + else: + assert not hasattr(items, "conditional_variables") + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + idxs = [0, 1, 3] + items = condition[idxs] + assert isinstance(items, _DataManager) + assert hasattr(items, "input") + graphs = items.input + assert isinstance(graphs, list) + assert len(graphs) == 3 + for graph in graphs: + assert isinstance(graph, Data) + type_ = LabelTensor if use_lt else torch.Tensor + assert isinstance(graph.x, type_) + assert graph.x.shape == (20, 2) + if use_lt: + assert graph.x.labels == ["u", "v"] + assert isinstance(graph.pos, type_) + assert graph.pos.shape == (20, 2) + if use_lt: + assert graph.pos.labels == ["x", "y"] + if conditional_variables: + type_ = LabelTensor if use_lt else torch.Tensor + assert hasattr(items, "conditional_variables") + cond_vars_batch = items.conditional_variables + assert isinstance(cond_vars_batch, type_) + assert cond_vars_batch.shape == (3, 20, 1) + if use_lt: + assert cond_vars_batch.labels == ["f"] diff --git a/tests/test_condition/test_domain_equation_condition.py b/tests/test_condition/test_domain_equation_condition.py new file mode 100644 index 000000000..f871a2a99 --- /dev/null +++ b/tests/test_condition/test_domain_equation_condition.py @@ -0,0 +1,29 @@ +import pytest +from pina import Condition +from pina.domain import CartesianDomain +from pina.equation.equation_factory import FixedValue +from pina.condition import DomainEquationCondition + +example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) +example_equation = FixedValue(0.0) + + +def test_init_domain_equation(): + cond = Condition(domain=example_domain, equation=example_equation) + assert isinstance(cond, DomainEquationCondition) + assert cond.domain is example_domain + assert cond.equation is example_equation + assert hasattr(cond, "data") + assert cond.data is None + + +def test_len_not_implemented(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + with pytest.raises(NotImplementedError): + len(cond) + + +def test_getitem_not_implemented(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + with pytest.raises(NotImplementedError): + cond[0] diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py new file mode 100644 index 000000000..b6bf62296 --- /dev/null +++ b/tests/test_condition/test_input_equation_condition.py @@ -0,0 +1,79 @@ +import torch +import pytest +from pina import Condition +from pina.condition.input_equation_condition import InputEquationCondition +from pina.equation import Equation +from pina import LabelTensor +from pina.graph import Graph +from pina.condition.data_manager import _DataManager + + +def _create_pts_and_equation(): + def dummy_equation(pts): + return pts["x"] ** 2 + pts["y"] ** 2 - 1 + + pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + equation = Equation(dummy_equation) + return pts, equation + + +def _create_graph_and_equation(): + from pina.graph import KNNGraph + + def dummy_equation(pts): + return pts.x[:, 0] ** 2 + pts.x[:, 1] ** 2 - 1 + + x = LabelTensor(torch.randn(100, 2), labels=["u", "v"]) + pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + graph = KNNGraph(x=x, pos=pos, neighbours=5, edge_attr=True) + equation = Equation(dummy_equation) + return graph, equation + + +def test_init_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + assert isinstance(condition, InputEquationCondition) + assert condition.input.shape == (100, 2) + assert condition.equation is equation + + +def test_init_graph_equation_condition(): + graph, equation = _create_graph_and_equation() + condition = Condition(input=graph, equation=equation) + assert isinstance(condition, InputEquationCondition) + assert isinstance(condition.input, Graph) + assert condition.input.x.shape == (100, 2) + assert condition.equation is equation + + +def test_wrong_init_equation_condition(): + pts, equation = _create_pts_and_equation() + # Wrong input type + with pytest.raises(ValueError): + Condition(input=torch.randn(10, 2), equation=equation) + # Wrong equation type + with pytest.raises(ValueError): + Condition(input=pts, equation="not_an_equation") + # Wrong input type (list with wrong elements) + with pytest.raises(ValueError): + Condition(input=[torch.randn(10, 2)], equation=equation) + + +def test_getitem_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + item = condition[0] + assert isinstance(item, _DataManager) + assert hasattr(item, "input") + assert item.input.shape == (2,) + + +def test_getitems_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + idxs = [0, 1, 3] + item = condition[idxs] + assert isinstance(item, _DataManager) + assert hasattr(item, "input") + assert item.input.shape == (3, 2) diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py new file mode 100644 index 000000000..f0dd2cfcf --- /dev/null +++ b/tests/test_condition/test_input_target_condition.py @@ -0,0 +1,290 @@ +import torch +import pytest +from torch_geometric.data import Batch +from pina import LabelTensor, Condition + +# from pina.condition import ( +# TensorInputGraphTargetCondition, +# TensorInputTensorTargetCondition, +# GraphInputTensorTargetCondition, +# ) +from pina.graph import RadiusGraph, LabelBatch + + +def _create_tensor_data(use_lt=False): + if use_lt: + input_tensor = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) + target_tensor = LabelTensor(torch.rand((10, 2)), ["a", "b"]) + return input_tensor, target_tensor + input_tensor = torch.rand((10, 3)) + target_tensor = torch.rand((10, 2)) + return input_tensor, target_tensor + + +def _create_graph_data(tensor_input=True, use_lt=False): + if use_lt: + x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) + pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + else: + x = torch.rand(10, 20, 2) + pos = torch.rand(10, 20, 2) + radius = 0.1 + graph = [ + RadiusGraph( + pos=pos[i], + radius=radius, + x=x[i] if not tensor_input else None, + y=x[i] if tensor_input else None, + ) + for i in range(len(x)) + ] + if use_lt: + tensor = LabelTensor(torch.rand(10, 20, 1), ["f"]) + else: + tensor = torch.rand(10, 20, 1) + return graph, tensor + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_tensor_input_tensor_target_condition(use_lt): + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + # assert isinstance(condition, TensorInputTensorTargetCondition) + assert torch.allclose( + condition.input, input_tensor + ), "TensorInputTensorTargetCondition input failed" + assert torch.allclose( + condition.target, target_tensor + ), "TensorInputTensorTargetCondition target failed" + if use_lt: + assert isinstance( + condition.input, LabelTensor + ), "TensorInputTensorTargetCondition input type failed" + assert condition.input.labels == [ + "x", + "y", + "z", + ], "TensorInputTensorTargetCondition input labels failed" + assert isinstance( + condition.target, LabelTensor + ), "TensorInputTensorTargetCondition target type failed" + assert condition.target.labels == [ + "a", + "b", + ], "TensorInputTensorTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + # assert isinstance(condition, TensorInputGraphTargetCondition) + assert torch.allclose( + condition.input, input_tensor + ), "TensorInputGraphTargetCondition input failed" + for i in range(len(target_graph)): + assert torch.allclose( + condition.target[i].y, target_graph[i].y + ), "TensorInputGraphTargetCondition target failed" + if use_lt: + assert isinstance( + condition.target[i].y, LabelTensor + ), "TensorInputGraphTargetCondition target type failed" + assert condition.target[i].y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition target labels failed" + if use_lt: + assert isinstance( + condition.input, LabelTensor + ), "TensorInputGraphTargetCondition target type failed" + assert condition.input.labels == [ + "f" + ], "TensorInputGraphTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_graph_input_tensor_target_condition(use_lt): + input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) + condition = Condition(input=input_graph, target=target_tensor) + # assert isinstance(condition, GraphInputTensorTargetCondition) + for i in range(len(input_graph)): + assert torch.allclose( + condition.input[i].x, input_graph[i].x + ), "GraphInputTensorTargetCondition input failed" + if use_lt: + assert isinstance( + condition.input[i].x, LabelTensor + ), "GraphInputTensorTargetCondition input type failed" + assert ( + condition.input[i].x.labels == input_graph[i].x.labels + ), "GraphInputTensorTargetCondition labels failed" + + assert torch.allclose( + condition.target[i], target_tensor[i] + ), "GraphInputTensorTargetCondition target failed" + if use_lt: + assert isinstance( + condition.target, LabelTensor + ), "GraphInputTensorTargetCondition target type failed" + assert condition.target.labels == [ + "f" + ], "GraphInputTensorTargetCondition target labels failed" + + +def test_wrong_init(): + input_tensor, target_tensor = _create_tensor_data() + with pytest.raises(ValueError): + Condition(input="invalid_input", target=target_tensor) + with pytest.raises(ValueError): + Condition(input=input_tensor, target="invalid_target") + with pytest.raises(ValueError): + Condition(input=[input_tensor], target=target_tensor) + with pytest.raises(ValueError): + Condition(input=input_tensor, target=[target_tensor]) + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_tensor_input_tensor_target_condition(use_lt): + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + for i in range(len(input_tensor)): + item = condition[i] + assert torch.allclose( + item.input, input_tensor[i] + ), "TensorInputTensorTargetCondition __getitem__ input failed" + assert torch.allclose( + item.target, target_tensor[i] + ), "TensorInputTensorTargetCondition __getitem__ target failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + for i in range(len(input_tensor)): + item = condition[i] + assert torch.allclose( + item.input, input_tensor[i] + ), "TensorInputGraphTargetCondition __getitem__ input failed" + assert torch.allclose( + item.target.y, target_graph[i].y + ), "TensorInputGraphTargetCondition __getitem__ target failed" + if use_lt: + assert isinstance( + item.target.y, LabelTensor + ), "TensorInputGraphTargetCondition __getitem__ target type failed" + assert item.target.y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition __getitem__ target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_graph_input_tensor_target_condition(use_lt): + input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) + condition = Condition(input=input_graph, target=target_tensor) + assert len(condition) == len(input_graph) + for i in range(len(input_graph)): + item = condition[i] + assert torch.allclose( + item.input.x, input_graph[i].x + ), "GraphInputTensorTargetCondition __getitem__ input failed" + assert torch.allclose( + item.target, target_tensor[i] + ), "GraphInputTensorTargetCondition __getitem__ target failed" + if use_lt: + assert isinstance( + item.input.x, LabelTensor + ), "GraphInputTensorTargetCondition __getitem__ input type failed" + assert ( + item.input.x.labels == input_graph[i].x.labels + ), "GraphInputTensorTargetCondition __getitem__ input labels failed" + assert isinstance( + item.target, LabelTensor + ), "GraphInputTensorTargetCondition __getitem__ target type failed" + assert item.target.labels == [ + "f" + ], "GraphInputTensorTargetCondition __getitem__ target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_tensor_input_tensor_target_condition(use_lt): + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + indices = [1, 3, 5, 7] + items = condition[indices] + candidate_input = items.input + candidate_target = items.target + + if use_lt: + input_ = LabelTensor.stack([input_tensor[i] for i in indices]) + target_ = LabelTensor.stack([target_tensor[i] for i in indices]) + else: + input_ = torch.stack([input_tensor[i] for i in indices]) + target_ = torch.stack([target_tensor[i] for i in indices]) + assert torch.allclose( + candidate_input, input_ + ), "TensorInputTensorTargetCondition __getitems__ input failed" + assert torch.allclose( + candidate_target, target_ + ), "TensorInputTensorTargetCondition __getitems__ target failed" + if use_lt: + assert isinstance( + candidate_input, LabelTensor + ), "TensorInputTensorTargetCondition __getitems__ input type failed" + assert candidate_input.labels == [ + "x", + "y", + "z", + ], "TensorInputTensorTargetCondition __getitems__ input labels failed" + assert isinstance( + candidate_target, LabelTensor + ), "TensorInputTensorTargetCondition __getitems__ target type failed" + assert candidate_target.labels == [ + "a", + "b", + ], "TensorInputTensorTargetCondition __getitems__ target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(True, use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + indices = [0, 2, 4] + items = condition[indices] + candidate_input = items.input + candidate_target = items.target + if use_lt: + input_ = LabelTensor.stack([input_tensor[i] for i in indices]) + # target_ = LabelBatch.from_data_list([target_graph[i] for i in indices]) + else: + input_ = torch.stack([input_tensor[i] for i in indices]) + # target_ = Batch.from_data_list([target_graph[i] for i in indices]) + assert torch.allclose( + candidate_input, input_ + ), "TensorInputGraphTargetCondition __getitems__ input failed" + + assert len(candidate_target) == len( + indices + ), "TensorInputGraphTargetCondition __getitems__ target length failed" + for idx, graph_idx in enumerate(indices): + assert torch.allclose( + candidate_target[idx].y, target_graph[graph_idx].y + ), "TensorInputGraphTargetCondition __getitems__ target failed" + + if use_lt: + assert isinstance( + candidate_input, LabelTensor + ), "TensorInputGraphTargetCondition __getitems__ input type failed" + assert candidate_input.labels == [ + "f" + ], "TensorInputGraphTargetCondition __getitems__ input labels failed" + for g in candidate_target: + assert isinstance( + g.y, LabelTensor + ), "TensorInputGraphTargetCondition __getitems__ target type failed" + assert g.y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition __getitems__ target labels failed" diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py new file mode 100644 index 000000000..ade91413c --- /dev/null +++ b/tests/test_data_manager.py @@ -0,0 +1,137 @@ +import torch +from pina.condition.data_manager import ( + _DataManager, + _TensorDataManager, + _GraphDataManager, +) +from pina.graph import Graph +from pina.equation import Equation + + +def test_tensor_data_manager_init(): + pippo = torch.rand((10, 5)) + pluto = torch.rand((10, 7)) + paperino = torch.rand((10, 11)) + data_manager = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) + assert isinstance(data_manager, _TensorDataManager) + assert hasattr(data_manager, "pippo") + assert hasattr(data_manager, "pluto") + assert hasattr(data_manager, "paperino") + assert torch.equal(data_manager.pippo, pippo) + assert torch.equal(data_manager.pluto, pluto) + assert torch.equal(data_manager.paperino, paperino) + + paperino = Equation(lambda x: x**2) + data_manager3 = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) + assert isinstance(data_manager3, _TensorDataManager) + assert hasattr(data_manager3, "pippo") + assert hasattr(data_manager3, "pluto") + assert hasattr(data_manager3, "paperino") + assert torch.equal(data_manager3.pippo, pippo) + assert torch.equal(data_manager3.pluto, pluto) + assert isinstance(data_manager3.paperino, Equation) + + +def test_graph_data_manager_init(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + assert hasattr(data_manager, "graph_key") + assert data_manager.graph_key == "graph" + assert hasattr(data_manager, "graph") + assert len(data_manager.data) == 3 + for i in range(3): + g = data_manager.graph[i] + assert torch.equal(g.x, x[i]) + assert torch.equal(g.pos, pos[i]) + assert torch.equal(g.edge_index, edge_index[i]) + assert torch.equal(g.target, target[i]) + + +def test_graph_data_manager_getattribute(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + target_retrieved = data_manager.target + assert torch.equal(target_retrieved, target) + + +def test_graph_data_manager_getitem(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + item = data_manager[1] + assert isinstance(item, _DataManager) + assert hasattr(item, "graph_key") + assert item.graph_key == "graph" + assert hasattr(item, "graph") + assert torch.equal(item.graph.x, x[1]) + assert torch.equal(item.graph.pos, pos[1]) + assert torch.equal(item.graph.edge_index, edge_index[1]) + assert torch.equal(item.target, target[1].unsqueeze(0)) + + +def test_graph_data_create_batch(): + x = [torch.rand((10, 5)) for _ in range(3)] + pos = [torch.rand((10, 3)) for _ in range(3)] + edge_index = [torch.randint(0, 10, (2, 20)) for _ in range(3)] + graph = [ + Graph(x=x_, pos=pos_, edge_index=edge_index_) + for x_, pos_, edge_index_ in zip(x, pos, edge_index) + ] + target = torch.rand((3, 10, 1)) + data_manager = _DataManager(graph=graph, target=target) + item1 = data_manager[0] + item2 = data_manager[1] + batch_data = _GraphDataManager._create_batch([item1, item2]) + assert hasattr(batch_data, "graph") + assert hasattr(batch_data, "target") + batched_graphs = batch_data.graph + batched_target = batch_data.target + assert batched_graphs.num_graphs == 2 + assert batched_target.shape == (20, 1) + assert torch.equal(batched_target, torch.cat([target[0], target[1]], dim=0)) + mps_data = batch_data.to("mps") + assert mps_data.graph.num_graphs == 2 + assert torch.equal(mps_data.target, batched_target.to("mps")) + assert torch.equal(mps_data.graph.x, batched_graphs.x.to("mps")) + + +def test_tensor_data_create_batch(): + pippo = torch.rand((10, 5)) + pluto = torch.rand((10, 7)) + paperino = torch.rand((10, 11)) + data_manager = _DataManager(pippo=pippo, pluto=pluto, paperino=paperino) + item1 = data_manager[0] + item2 = data_manager[1] + batch_data = _TensorDataManager._create_batch([item1, item2]) + assert hasattr(batch_data, "pippo") + assert hasattr(batch_data, "pluto") + assert hasattr(batch_data, "paperino") + assert torch.equal( + batch_data.pippo, torch.stack([pippo[0], pippo[1]], dim=0) + ) + assert torch.equal( + batch_data.pluto, torch.stack([pluto[0], pluto[1]], dim=0) + ) + assert torch.equal( + batch_data.paperino, torch.stack([paperino[0], paperino[1]], dim=0) + ) diff --git a/tests/test_solver/test_ensemble_supervised_solver.py b/tests/test_solver/test_ensemble_supervised_solver.py index c5f0b9e52..4be2897d9 100644 --- a/tests/test_solver/test_ensemble_supervised_solver.py +++ b/tests/test_solver/test_ensemble_supervised_solver.py @@ -83,7 +83,8 @@ def forward(self, batch): y = self.conv(y, edge_index) y = self.activation(y) y = self.output(y) - return to_dense_batch(y, batch.batch)[0] + return y + # return to_dense_batch(y, batch.batch)[0] graph_models = [Models() for i in range(10)] diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 6f7d1ab4d..461130a6b 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -83,7 +83,8 @@ def forward(self, batch): y = self.conv(y, edge_index) y = self.activation(y) y = self.output(y) - return to_dense_batch(y, batch.batch)[0] + return y + # return to_dense_batch(y, batch.batch)[0] graph_model = Model()