diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index c379ab1f..2e6d7b81 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -417,11 +417,11 @@ def add_node( if validate_keys: self._validate_attributes(attrs, self.node_attr_keys(), "node") - if "t" not in attrs: - raise ValueError(f"Node attributes must have a 't' key. Got {attrs.keys()}") + if DEFAULT_ATTR_KEYS.T not in attrs: + raise ValueError(f"Node attributes must have a '{DEFAULT_ATTR_KEYS.T}' key. Got {attrs.keys()}") node_id = self.rx_graph.add_node(attrs) - self._time_to_nodes.setdefault(attrs["t"], []).append(node_id) + self._time_to_nodes.setdefault(attrs[DEFAULT_ATTR_KEYS.T], []).append(node_id) self.node_added.emit_fast(node_id) return node_id @@ -449,7 +449,7 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None node_indices = list(self.rx_graph.add_nodes_from(nodes)) for node, index in zip(nodes, node_indices, strict=True): - self._time_to_nodes.setdefault(node["t"], []).append(index) + self._time_to_nodes.setdefault(node[DEFAULT_ATTR_KEYS.T], []).append(index) # checking if it has connections to reduce overhead if is_signal_on(self.node_added): @@ -481,7 +481,7 @@ def remove_node(self, node_id: int) -> None: self.node_removed.emit_fast(node_id) # Get the time value before removing the node - t = self.rx_graph[node_id]["t"] + t = self.rx_graph[node_id][DEFAULT_ATTR_KEYS.T] # Remove the node from the graph (this also removes all connected edges) self.rx_graph.remove_node(node_id) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 079b1f33..d975b91d 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -507,29 +507,46 @@ class Base(DeclarativeBase): self.Base = Base return - class Node(Base): - __tablename__ = "Node" - - # Use node_id as sole primary key for simpler updates - node_id = sa.Column(sa.BigInteger, primary_key=True, unique=True) - - # Add t as a regular column - # NOTE might want to use as index for fast time-based queries - t = sa.Column(sa.Integer, nullable=False) + Node = type( + "Node", + (Base,), + { + "__tablename__": "Node", + # Use node_id as sole primary key for simpler updates + DEFAULT_ATTR_KEYS.NODE_ID: sa.Column(sa.BigInteger, primary_key=True, unique=True), + # Add t as a regular column + # NOTE might want to use as index for fast time-based queries + DEFAULT_ATTR_KEYS.T: sa.Column(sa.Integer, nullable=False), + }, + ) node_tb_name = Node.__tablename__ - class Edge(Base): - __tablename__ = "Edge" - edge_id = sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True) - source_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) - target_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) + Edge = type( + "Edge", + (Base,), + { + "__tablename__": "Edge", + DEFAULT_ATTR_KEYS.EDGE_ID: sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True), + DEFAULT_ATTR_KEYS.EDGE_SOURCE: sa.Column( + sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True + ), + DEFAULT_ATTR_KEYS.EDGE_TARGET: sa.Column( + sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True + ), + }, + ) class Overlap(Base): __tablename__ = "Overlap" + overlap_id = sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True) - source_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) - target_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) + source_id = sa.Column( + sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True + ) + target_id = sa.Column( + sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True + ) class Metadata(Base): __tablename__ = "Metadata" @@ -570,8 +587,9 @@ def _update_max_id_per_time(self) -> None: point and updates the internal cache to ensure newly created nodes have unique IDs. """ + t_column = getattr(self.Node, DEFAULT_ATTR_KEYS.T) with Session(self._engine) as session: - stmt = sa.select(self.Node.t, sa.func.max(self.Node.node_id)).group_by(self.Node.t) + stmt = sa.select(t_column, sa.func.max(getattr(self.Node, DEFAULT_ATTR_KEYS.NODE_ID))).group_by(t_column) self._max_id_per_time = {int(time): int(max_id) for time, max_id in session.execute(stmt).all()} def filter( @@ -637,10 +655,10 @@ def add_node( if validate_keys: self._validate_attributes(attrs, self.node_attr_keys(), "node") - if "t" not in attrs: - raise ValueError(f"Node attributes must have a 't' key. Got {attrs.keys()}") + if DEFAULT_ATTR_KEYS.T not in attrs: + raise ValueError(f"Node attributes must have a '{DEFAULT_ATTR_KEYS.T}' key. Got {attrs.keys()}") - time = attrs["t"] + time = attrs[DEFAULT_ATTR_KEYS.T] if index is None: default_node_id = (time * self.node_id_time_multiplier) - 1 @@ -712,7 +730,7 @@ def bulk_add_nodes( node_ids = [] for i, node in enumerate(nodes): - time = node["t"] + time = node[DEFAULT_ATTR_KEYS.T] if indices is None: default_node_id = (time * self.node_id_time_multiplier) - 1 @@ -1500,8 +1518,8 @@ def update_node_attrs( attrs: dict[str, Any], node_ids: Sequence[int] | None = None, ) -> None: - if "t" in attrs: - raise ValueError("Node attribute 't' cannot be updated.") + if DEFAULT_ATTR_KEYS.T in attrs: + raise ValueError(f"Node attribute '{DEFAULT_ATTR_KEYS.T}' cannot be updated.") self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs) diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 778ef5ca..5ffc55d8 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1641,6 +1641,24 @@ def test_summary(graph_backend: BaseGraph) -> None: assert "Number of edges" in summary +def test_changing_default_attr_keys(graph_backend: BaseGraph) -> None: + DEFAULT_ATTR_KEYS.T = "frame" + if isinstance(graph_backend, SQLGraph): + kwargs = {"drivername": "sqlite", "database": ":memory:"} + else: + kwargs = {} + + # must be a new graph because `graph_backend` already has a `t` attribute initialized + new_graph = type(graph_backend)(**kwargs) + new_graph.add_node({"frame": 0}) + node_attrs = new_graph.node_attrs() + assert "frame" in node_attrs.columns + + # this must be undone otherwise other tests will fail + DEFAULT_ATTR_KEYS.T = "t" + assert DEFAULT_ATTR_KEYS.T == "t" + + def test_spatial_filter_basic(graph_backend: BaseGraph) -> None: graph_backend.add_node_attr_key("x", 0.0) graph_backend.add_node_attr_key("y", 0.0) diff --git a/src/tracksdata/nodes/_random.py b/src/tracksdata/nodes/_random.py index 48abd995..d4a43b14 100644 --- a/src/tracksdata/nodes/_random.py +++ b/src/tracksdata/nodes/_random.py @@ -170,4 +170,4 @@ def _nodes_per_time( size=(n_nodes_at_t, len(self.spatial_cols)), ).tolist() - return [{"t": t, **dict(zip(self.spatial_cols, c, strict=True)), **kwargs} for c in coords] + return [{DEFAULT_ATTR_KEYS.T: t, **dict(zip(self.spatial_cols, c, strict=True)), **kwargs} for c in coords]