Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,32 @@
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)


@register_canonicalize
@node_rewriter([BlockDiagonal])
def fuse_blockdiagonal(fgraph, node):
"""Fuse nested BlockDiagonal ops into a single BlockDiagonal."""

if not isinstance(node.op, BlockDiagonal):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check is useless, the tracks already enforces it

return None

new_inputs = []
changed = False

for inp in node.inputs:
if inp.owner and isinstance(inp.owner.op, BlockDiagonal):
new_inputs.extend(inp.owner.inputs)
changed = True
else:
new_inputs.append(inp)

if changed:
fused_op = BlockDiagonal(len(new_inputs))
new_output = fused_op(*new_inputs)
return [new_output]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a copy_stack_trace at the end


return None


def is_matrix_transpose(x: TensorVariable) -> bool:
"""Check if a variable corresponds to a transpose of the last two axes"""
node = x.owner
Expand Down
65 changes: 65 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,71 @@
from tests.test_rop import break_op


def test_nested_blockdiag_fusion():
x = pt.tensor("x", shape=(3, 3))
y = pt.tensor("y", shape=(3, 3))
z = pt.tensor("z", shape=(3, 3))

inner = BlockDiagonal(2)(x, y)
outer = BlockDiagonal(2)(inner, z)

nodes_before = ancestors([outer])
initial_count = sum(
1
for node in nodes_before
Copy link
Member

@ricardoV94 ricardoV94 Dec 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we tend to call these variables and their owners nodes. you can also use ancestor_nodes or ancestor_applys (don't remember the name right now) to iterate over the nodes directly

if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops"

f = pytensor.function([x, y, z], outer)
Copy link
Member

@ricardoV94 ricardoV94 Dec 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you bother to create a function evaluate it. But don't bother to create one, just use rewrite_graph and then check what you checked here

fgraph = f.maker.fgraph

nodes_after = fgraph.apply_nodes
fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)]
assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused"

fused_op = fused_nodes[0].op

assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}"

out_shape = fgraph.outputs[0].type.shape
assert out_shape == (9, 9), f"Unexpected fused output shape: {out_shape}"


def test_deeply_nested_blockdiag_fusion():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments apply to this test

x = pt.tensor("x", shape=(3, 3))
y = pt.tensor("y", shape=(3, 3))
z = pt.tensor("z", shape=(3, 3))
w = pt.tensor("w", shape=(3, 3))

inner1 = BlockDiagonal(2)(x, y)
inner2 = BlockDiagonal(2)(inner1, z)
outer = BlockDiagonal(2)(inner2, w)

f = pytensor.function([x, y, z, w], outer)
fgraph = f.maker.fgraph

fused_nodes = [
node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal)
]

assert len(fused_nodes) == 1, (
f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}"
)

fused_op = fused_nodes[0].op

assert fused_op.n_inputs == 4, (
f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}"
)

out_shape = fgraph.outputs[0].type.shape
expected_shape = (12, 12) # 4 blocks of (3x3)
assert out_shape == expected_shape, (
f"Unexpected fused output shape: expected {expected_shape}, got {out_shape}"
)


def test_matrix_inverse_rop_lop():
rtol = 1e-7 if config.floatX == "float64" else 1e-5
mx = matrix("mx")
Expand Down