Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 3 additions & 23 deletions pina/condition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
"ConditionInterface",
"DomainEquationCondition",
"InputTargetCondition",
"TensorInputTensorTargetCondition",
"TensorInputGraphTargetCondition",
"GraphInputTensorTargetCondition",
"GraphInputGraphTargetCondition",
"InputEquationCondition",
"InputTensorEquationCondition",
"InputGraphEquationCondition",
"DataCondition",
"GraphDataCondition",
"TensorDataCondition",
Expand All @@ -20,20 +14,6 @@
from .condition_interface import ConditionInterface
from .condition import Condition
from .domain_equation_condition import DomainEquationCondition
from .input_target_condition import (
InputTargetCondition,
TensorInputTensorTargetCondition,
TensorInputGraphTargetCondition,
GraphInputTensorTargetCondition,
GraphInputGraphTargetCondition,
)
from .input_equation_condition import (
InputEquationCondition,
InputTensorEquationCondition,
InputGraphEquationCondition,
)
from .data_condition import (
DataCondition,
GraphDataCondition,
TensorDataCondition,
)
from .input_target_condition import InputTargetCondition
from .input_equation_condition import InputEquationCondition
from .data_condition import DataCondition
22 changes: 11 additions & 11 deletions pina/condition/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ class Condition:
"""

# Combine all possible keyword arguments from the different Condition types
__slots__ = list(
available_kwargs = list(
set(
InputTargetCondition.__slots__
+ InputEquationCondition.__slots__
+ DomainEquationCondition.__slots__
+ DataCondition.__slots__
InputTargetCondition.__fields__
+ InputEquationCondition.__fields__
+ DomainEquationCondition.__fields__
+ DataCondition.__fields__
)
)

Expand All @@ -112,28 +112,28 @@ def __new__(cls, *args, **kwargs):
if len(args) != 0:
raise ValueError(
"Condition takes only the following keyword "
f"arguments: {Condition.__slots__}."
f"arguments: {Condition.available_kwargs}."
)

# Class specialization based on keyword arguments
sorted_keys = sorted(kwargs.keys())

# Input - Target Condition
if sorted_keys == sorted(InputTargetCondition.__slots__):
if sorted_keys == sorted(InputTargetCondition.__fields__):
return InputTargetCondition(**kwargs)

# Input - Equation Condition
if sorted_keys == sorted(InputEquationCondition.__slots__):
if sorted_keys == sorted(InputEquationCondition.__fields__):
return InputEquationCondition(**kwargs)

# Domain - Equation Condition
if sorted_keys == sorted(DomainEquationCondition.__slots__):
if sorted_keys == sorted(DomainEquationCondition.__fields__):
return DomainEquationCondition(**kwargs)

# Data Condition
if (
sorted_keys == sorted(DataCondition.__slots__)
or sorted_keys[0] == DataCondition.__slots__[0]
sorted_keys == sorted(DataCondition.__fields__)
or sorted_keys[0] == DataCondition.__fields__[0]
):
return DataCondition(**kwargs)

Expand Down
231 changes: 231 additions & 0 deletions pina/condition/condition_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""
Base class for conditions.
"""

from copy import deepcopy
from functools import partial
import torch
from torch_geometric.data import Data, Batch
from torch.utils.data import DataLoader
from .condition_interface import ConditionInterface
from ..graph import Graph, LabelBatch
from ..label_tensor import LabelTensor


class TensorCondition:
"""
Base class for tensor conditions.
"""

def store_data(self, **kwargs):
"""
Store data for standard tensor condition

:param kwargs: Keyword arguments representing the data to be stored.
:return: A dictionary containing the stored data.
:rtype: dict
"""
data = {}
for key, value in kwargs.items():
data[key] = value
return data


class GraphCondition:
"""
Base class for graph conditions.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
example = kwargs.get(self.graph_field)[0]
self.batch_fn = (
LabelBatch.from_data_list
if isinstance(example, Graph)
else Batch.from_data_list
)

def store_data(self, **kwargs):
"""
Store data for graph condition

:param graphs: List of graphs to store data in.
:type graphs: list[Graph] | list[Data]
:param tensors: List of tensors to store in the graphs.
:type tensors: list[torch.Tensor] | list[LabelTensor]
:param key: Key under which to store the tensors in the graphs.
:type key: str
:return: A dictionary containing the stored data.
:rtype: dict
"""
data = []
graphs = kwargs.get(self.graph_field)
for i, graph in enumerate(graphs):
new_graph = deepcopy(graph)
for key in self.tensor_fields:
tensor = kwargs[key][i]
mapping_key = self.keys_map.get(key)
setattr(new_graph, mapping_key, tensor)
data.append(new_graph)
return {"data": data}

def __getitem__(self, idx):
if isinstance(idx, list):
return self.get_multiple_data(idx)
return {"data": self.data["data"][idx]}

def get_multiple_data(self, indices):
"""
Get multiple data items based on the provided indices.

:param List[int] indices: List of indices to retrieve.
:return: Dictionary containing 'input' and 'target' data.
:rtype: dict
"""
to_return_dict = {}
data = self.batch_fn([self.data["data"][i] for i in indices])
to_return_dict[self.graph_field] = data
for key in self.tensor_fields:
mapping_key = self.keys_map.get(key)
y = getattr(data, mapping_key)
delattr(data, mapping_key) # Avoid duplication of y on GPU memory
to_return_dict[key] = y
return to_return_dict

@classmethod
def automatic_batching_collate_fn(cls, batch):
"""
Collate function to be used in DataLoader.

:param batch: A list of items from the dataset.
:type batch: list
:return: A collated batch.
:rtype: dict
"""
collated_graphs = super().automatic_batching_collate_fn(batch)["data"]
to_return_dict = {}
for key in cls.tensor_fields:
mapping_key = cls.keys_map.get(key)
tensor = getattr(collated_graphs, mapping_key)
to_return_dict[key] = tensor
delattr(collated_graphs, mapping_key)
to_return_dict[cls.graph_field] = collated_graphs
return to_return_dict


class ConditionBase(ConditionInterface):
"""
Base abstract class for all conditions in PINA.
This class provides common functionality for handling data storage,
batching, and interaction with the associated problem.
"""

collate_fn_dict = {
"tensor": torch.stack,
"label_tensor": LabelTensor.stack,
"graph": LabelBatch.from_data_list,
"data": Batch.from_data_list,
}

def __init__(self, **kwargs):
"""
Initialization of the :class:`ConditionBase` class.

:param kwargs: Keyword arguments representing the data to be stored.
"""
super().__init__()
self.data = self.store_data(**kwargs)

@property
def problem(self):
"""
Return the problem associated with this condition.

:return: Problem associated with this condition.
:rtype: ~pina.problem.abstract_problem.AbstractProblem
"""
return self._problem

@problem.setter
def problem(self, value):
"""
Set the problem associated with this condition.

:param pina.problem.abstract_problem.AbstractProblem value: The problem
to associate with this condition
"""
self._problem = value

def __len__(self):
"""
Return the number of data points in the condition.

:return: Number of data points.
:rtype: int
"""
return len(self.data)

def __getitem__(self, idx):
"""
Return the data point(s) at the specified index.

:param idx: Index(es) of the data point(s) to retrieve.
:type idx: int | list[int]
:return: Data point(s) at the specified index.
"""
return self.data[idx]

@classmethod
def automatic_batching_collate_fn(cls, batch):
"""
Collate function for automatic batching to be used in DataLoader.
:param batch: A list of items from the dataset.
:type batch: list
:return: A collated batch.
:rtype: dict
"""
if not batch:
return {}
instance_class = batch[0].__class__
return instance_class._create_batch(batch)

@staticmethod
def collate_fn(batch, condition):
"""
Collate function for custom batching to be used in DataLoader.

:param batch: A list of items from the dataset.
:type batch: list
:param condition: The condition instance.
:type condition: ConditionBase
:return: A collated batch.
:rtype: dict
"""
print("Custom collate_fn called")
print("batch:", batch)
data = condition.data[batch]
return data

def create_dataloader(
self, dataset, batch_size, shuffle, automatic_batching
):
"""
Create a DataLoader for the condition.

:param int batch_size: The batch size for the DataLoader.
:param bool shuffle: Whether to shuffle the data. Default is ``False``.
:return: The DataLoader for the condition.
:rtype: torch.utils.data.DataLoader
"""
if batch_size == len(dataset):
pass # will be updated in the near future
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=(
partial(self.collate_fn, condition=self)
if not automatic_batching
else self.automatic_batching_collate_fn
),
)
Loading