-
Notifications
You must be signed in to change notification settings - Fork 152
WIP: Add rewrite to fuse nested BlockDiag Ops #1671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,32 @@ | |
| MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) | ||
|
|
||
|
|
||
| @register_canonicalize | ||
| @node_rewriter([BlockDiagonal]) | ||
| def fuse_blockdiagonal(fgraph, node): | ||
| """Fuse nested BlockDiagonal ops into a single BlockDiagonal.""" | ||
|
|
||
| if not isinstance(node.op, BlockDiagonal): | ||
| return None | ||
|
|
||
| new_inputs = [] | ||
| changed = False | ||
|
|
||
| for inp in node.inputs: | ||
| if inp.owner and isinstance(inp.owner.op, BlockDiagonal): | ||
| new_inputs.extend(inp.owner.inputs) | ||
| changed = True | ||
| else: | ||
| new_inputs.append(inp) | ||
|
|
||
| if changed: | ||
| fused_op = BlockDiagonal(len(new_inputs)) | ||
| new_output = fused_op(*new_inputs) | ||
| return [new_output] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing a copy_stack_trace at the end |
||
|
|
||
| return None | ||
|
|
||
|
|
||
| def is_matrix_transpose(x: TensorVariable) -> bool: | ||
| """Check if a variable corresponds to a transpose of the last two axes""" | ||
| node = x.owner | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,71 @@ | |
| from tests.test_rop import break_op | ||
|
|
||
|
|
||
| def test_nested_blockdiag_fusion(): | ||
| x = pt.tensor("x", shape=(3, 3)) | ||
| y = pt.tensor("y", shape=(3, 3)) | ||
| z = pt.tensor("z", shape=(3, 3)) | ||
|
|
||
| inner = BlockDiagonal(2)(x, y) | ||
| outer = BlockDiagonal(2)(inner, z) | ||
|
|
||
| nodes_before = ancestors([outer]) | ||
| initial_count = sum( | ||
| 1 | ||
| for node in nodes_before | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we tend to call these variables and their owners nodes. you can also use ancestor_nodes or ancestor_applys (don't remember the name right now) to iterate over the nodes directly |
||
| if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) | ||
| ) | ||
| assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops" | ||
|
|
||
| f = pytensor.function([x, y, z], outer) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you bother to create a function evaluate it. But don't bother to create one, just use |
||
| fgraph = f.maker.fgraph | ||
|
|
||
| nodes_after = fgraph.apply_nodes | ||
| fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)] | ||
| assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused" | ||
|
|
||
| fused_op = fused_nodes[0].op | ||
|
|
||
| assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}" | ||
|
|
||
| out_shape = fgraph.outputs[0].type.shape | ||
| assert out_shape == (9, 9), f"Unexpected fused output shape: {out_shape}" | ||
|
|
||
|
|
||
| def test_deeply_nested_blockdiag_fusion(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comments apply to this test |
||
| x = pt.tensor("x", shape=(3, 3)) | ||
| y = pt.tensor("y", shape=(3, 3)) | ||
| z = pt.tensor("z", shape=(3, 3)) | ||
| w = pt.tensor("w", shape=(3, 3)) | ||
|
|
||
| inner1 = BlockDiagonal(2)(x, y) | ||
| inner2 = BlockDiagonal(2)(inner1, z) | ||
| outer = BlockDiagonal(2)(inner2, w) | ||
|
|
||
| f = pytensor.function([x, y, z, w], outer) | ||
| fgraph = f.maker.fgraph | ||
|
|
||
| fused_nodes = [ | ||
| node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal) | ||
| ] | ||
|
|
||
| assert len(fused_nodes) == 1, ( | ||
| f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}" | ||
| ) | ||
|
|
||
| fused_op = fused_nodes[0].op | ||
|
|
||
| assert fused_op.n_inputs == 4, ( | ||
| f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}" | ||
| ) | ||
|
|
||
| out_shape = fgraph.outputs[0].type.shape | ||
| expected_shape = (12, 12) # 4 blocks of (3x3) | ||
| assert out_shape == expected_shape, ( | ||
| f"Unexpected fused output shape: expected {expected_shape}, got {out_shape}" | ||
| ) | ||
|
|
||
|
|
||
| def test_matrix_inverse_rop_lop(): | ||
| rtol = 1e-7 if config.floatX == "float64" else 1e-5 | ||
| mx = matrix("mx") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check is useless, the tracks already enforces it