Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,11 @@ def add_node(
if validate_keys:
self._validate_attributes(attrs, self.node_attr_keys(), "node")

if "t" not in attrs:
raise ValueError(f"Node attributes must have a 't' key. Got {attrs.keys()}")
if DEFAULT_ATTR_KEYS.T not in attrs:
raise ValueError(f"Node attributes must have a '{DEFAULT_ATTR_KEYS.T}' key. Got {attrs.keys()}")

node_id = self.rx_graph.add_node(attrs)
self._time_to_nodes.setdefault(attrs["t"], []).append(node_id)
self._time_to_nodes.setdefault(attrs[DEFAULT_ATTR_KEYS.T], []).append(node_id)
self.node_added.emit_fast(node_id)
return node_id

Expand Down Expand Up @@ -449,7 +449,7 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None

node_indices = list(self.rx_graph.add_nodes_from(nodes))
for node, index in zip(nodes, node_indices, strict=True):
self._time_to_nodes.setdefault(node["t"], []).append(index)
self._time_to_nodes.setdefault(node[DEFAULT_ATTR_KEYS.T], []).append(index)

# checking if it has connections to reduce overhead
if is_signal_on(self.node_added):
Expand Down Expand Up @@ -481,7 +481,7 @@ def remove_node(self, node_id: int) -> None:
self.node_removed.emit_fast(node_id)

# Get the time value before removing the node
t = self.rx_graph[node_id]["t"]
t = self.rx_graph[node_id][DEFAULT_ATTR_KEYS.T]

# Remove the node from the graph (this also removes all connected edges)
self.rx_graph.remove_node(node_id)
Expand Down
64 changes: 41 additions & 23 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,29 +507,46 @@ class Base(DeclarativeBase):
self.Base = Base
return

class Node(Base):
__tablename__ = "Node"

# Use node_id as sole primary key for simpler updates
node_id = sa.Column(sa.BigInteger, primary_key=True, unique=True)

# Add t as a regular column
# NOTE might want to use as index for fast time-based queries
t = sa.Column(sa.Integer, nullable=False)
Node = type(
"Node",
(Base,),
{
"__tablename__": "Node",
# Use node_id as sole primary key for simpler updates
DEFAULT_ATTR_KEYS.NODE_ID: sa.Column(sa.BigInteger, primary_key=True, unique=True),
# Add t as a regular column
# NOTE might want to use as index for fast time-based queries
DEFAULT_ATTR_KEYS.T: sa.Column(sa.Integer, nullable=False),
},
)

node_tb_name = Node.__tablename__

class Edge(Base):
__tablename__ = "Edge"
edge_id = sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True)
source_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True)
target_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True)
Edge = type(
"Edge",
(Base,),
{
"__tablename__": "Edge",
DEFAULT_ATTR_KEYS.EDGE_ID: sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True),
DEFAULT_ATTR_KEYS.EDGE_SOURCE: sa.Column(
sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True
),
DEFAULT_ATTR_KEYS.EDGE_TARGET: sa.Column(
sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True
),
},
)

class Overlap(Base):
__tablename__ = "Overlap"

overlap_id = sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True)
source_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True)
target_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True)
source_id = sa.Column(
sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True
)
target_id = sa.Column(
sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True
)

class Metadata(Base):
__tablename__ = "Metadata"
Expand Down Expand Up @@ -570,8 +587,9 @@ def _update_max_id_per_time(self) -> None:
point and updates the internal cache to ensure newly created nodes
have unique IDs.
"""
t_column = getattr(self.Node, DEFAULT_ATTR_KEYS.T)
with Session(self._engine) as session:
stmt = sa.select(self.Node.t, sa.func.max(self.Node.node_id)).group_by(self.Node.t)
stmt = sa.select(t_column, sa.func.max(getattr(self.Node, DEFAULT_ATTR_KEYS.NODE_ID))).group_by(t_column)
self._max_id_per_time = {int(time): int(max_id) for time, max_id in session.execute(stmt).all()}

def filter(
Expand Down Expand Up @@ -637,10 +655,10 @@ def add_node(
if validate_keys:
self._validate_attributes(attrs, self.node_attr_keys(), "node")

if "t" not in attrs:
raise ValueError(f"Node attributes must have a 't' key. Got {attrs.keys()}")
if DEFAULT_ATTR_KEYS.T not in attrs:
raise ValueError(f"Node attributes must have a '{DEFAULT_ATTR_KEYS.T}' key. Got {attrs.keys()}")

time = attrs["t"]
time = attrs[DEFAULT_ATTR_KEYS.T]

if index is None:
default_node_id = (time * self.node_id_time_multiplier) - 1
Expand Down Expand Up @@ -712,7 +730,7 @@ def bulk_add_nodes(

node_ids = []
for i, node in enumerate(nodes):
time = node["t"]
time = node[DEFAULT_ATTR_KEYS.T]

if indices is None:
default_node_id = (time * self.node_id_time_multiplier) - 1
Expand Down Expand Up @@ -1500,8 +1518,8 @@ def update_node_attrs(
attrs: dict[str, Any],
node_ids: Sequence[int] | None = None,
) -> None:
if "t" in attrs:
raise ValueError("Node attribute 't' cannot be updated.")
if DEFAULT_ATTR_KEYS.T in attrs:
raise ValueError(f"Node attribute '{DEFAULT_ATTR_KEYS.T}' cannot be updated.")

self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs)

Expand Down
18 changes: 18 additions & 0 deletions src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,24 @@ def test_summary(graph_backend: BaseGraph) -> None:
assert "Number of edges" in summary


def test_changing_default_attr_keys(graph_backend: BaseGraph) -> None:
DEFAULT_ATTR_KEYS.T = "frame"
if isinstance(graph_backend, SQLGraph):
kwargs = {"drivername": "sqlite", "database": ":memory:"}
else:
kwargs = {}

# must be a new graph because `graph_backend` already has a `t` attribute initialized
new_graph = type(graph_backend)(**kwargs)
new_graph.add_node({"frame": 0})
node_attrs = new_graph.node_attrs()
assert "frame" in node_attrs.columns

# this must be undone otherwise other tests will fail
DEFAULT_ATTR_KEYS.T = "t"
assert DEFAULT_ATTR_KEYS.T == "t"


def test_spatial_filter_basic(graph_backend: BaseGraph) -> None:
graph_backend.add_node_attr_key("x", 0.0)
graph_backend.add_node_attr_key("y", 0.0)
Expand Down
2 changes: 1 addition & 1 deletion src/tracksdata/graph/filters/_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
attr_keys: list[str] | None = None,
) -> None:
if attr_keys is None:
attr_keys = ["t", "z", "y", "x"]
attr_keys = [DEFAULT_ATTR_KEYS.T, "z", "y", "x"]
valid_keys = set(graph.node_attr_keys())
attr_keys = list(filter(lambda x: x in valid_keys, attr_keys))

Expand Down
4 changes: 2 additions & 2 deletions src/tracksdata/metrics/_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def visualize_matches(
viewer = napari.Viewer()

if "z" in input_graph.node_attr_keys():
pos = ["t", "z", "y", "x"]
pos = [DEFAULT_ATTR_KEYS.T, "z", "y", "x"]
else:
pos = ["t", "y", "x"]
pos = [DEFAULT_ATTR_KEYS.T, "y", "x"]

node_attrs = input_graph.node_attrs()
ref_node_attrs = ref_graph.node_attrs()
Expand Down
3 changes: 2 additions & 1 deletion src/tracksdata/nodes/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.graph._base_graph import BaseGraph
from tracksdata.nodes._base_nodes import BaseNodesOperator
from tracksdata.utils._multiprocessing import multiprocessing_apply
Expand Down Expand Up @@ -169,4 +170,4 @@ def _nodes_per_time(
size=(n_nodes_at_t, len(self.spatial_cols)),
).tolist()

return [{"t": t, **dict(zip(self.spatial_cols, c, strict=True)), **kwargs} for c in coords]
return [{DEFAULT_ATTR_KEYS.T: t, **dict(zip(self.spatial_cols, c, strict=True)), **kwargs} for c in coords]
Loading