diff --git a/CHANGELOG.md b/CHANGELOG.md index 2049a577..f5e43a54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,8 @@ Types of changes: ### Fixed - Fixed Complex value initialization error. ([#253](https://github.com/qBraid/pyqasm/pull/253)) - Fixed duplicate qubit argument check in function calls and Error in function call with aliased qubit. ([#260](https://github.com/qBraid/pyqasm/pull/260)) +- Fixed Gate ordering across registers in `pyqasm.draw()` function. ([#268](https://github.com/qBraid/pyqasm/pull/268)) +======= ### Dependencies diff --git a/src/pyqasm/printer.py b/src/pyqasm/printer.py index 4d7ea4f0..7a7c52f3 100644 --- a/src/pyqasm/printer.py +++ b/src/pyqasm/printer.py @@ -185,7 +185,21 @@ def _compute_line_nums( return line_nums, sizes -# pylint: disable-next=too-many-branches +def _get_line_keys( + all_keys: list[tuple[str, int]], line_nums: dict[tuple[str, int], int] +) -> list[tuple[str, int]]: + if not all_keys: + return [] + + # Get line numbers for the given keys and find min/max + key_line_nums = [line_nums[key] for key in all_keys] + start_key, end_key = min(key_line_nums), max(key_line_nums) + + # Return all keys within the range + return [key for key, line_num in line_nums.items() if start_key <= line_num <= end_key] + + +# pylint: disable-next=too-many-branches,too-many-locals def _compute_moments( statements: list[ast.Statement | ast.Pragma], line_nums: dict[tuple[str, int], int] ) -> tuple[list[list[QuantumStatement]], dict[tuple[str, int], int]]: @@ -194,6 +208,21 @@ def _compute_moments( depths[k] = -1 moments: list[list[QuantumStatement]] = [] + # Find and remove final measurements (measurements at the end of the statements list) + final_measurements: list[ast.QuantumMeasurementStatement] = [] + seen_keys = set[tuple[str, int]]() + + for stmt in reversed(statements): + if not isinstance(stmt, ast.QuantumMeasurementStatement): + break + control_key = _identifier_to_key(stmt.measure.qubit) + if control_key not in seen_keys: + seen_keys.add(control_key) + final_measurements.append(stmt) + + # Remove the final measurements from the end of statements list + statements = statements[: -len(final_measurements)] if final_measurements else statements + for statement in statements: if isinstance(statement, Declaration): continue @@ -201,17 +230,22 @@ def _compute_moments( raise ValueError(f"Unsupported statement: {statement}") if isinstance(statement, ast.QuantumGate): qubits = [_identifier_to_key(q) for q in statement.qubits] - depth = 1 + max(depths[q] for q in qubits) - for q in qubits: - depths[q] = depth + # Get line keys for multi-qubit gates, otherwise use qubits directly + target_keys = _get_line_keys(qubits, line_nums) if len(qubits) > 1 else qubits + # Calculate new depth and update all affected keys + depth = 1 + max(depths[key] for key in target_keys) + for key in target_keys: + depths[key] = depth elif isinstance(statement, ast.QuantumMeasurementStatement): keys = [_identifier_to_key(statement.measure.qubit)] if statement.target: target_key = _identifier_to_key(statement.target)[0], -1 keys.append(target_key) - depth = 1 + max(depths[k] for k in keys) - for k in keys: - depths[k] = depth + line_keys = _get_line_keys(keys, line_nums) + # Calculate new depth and update all affected keys + depth = 1 + max(depths[key] for key in line_keys) + for key in line_keys: + depths[key] = depth elif isinstance(statement, ast.QuantumBarrier): qubits = [] for expr in statement.qubits: @@ -236,6 +270,13 @@ def _compute_moments( moments[depth].append(statement) + depth = max(depths.values()) + for measurement in final_measurements: + depth += 1 + if depth >= len(moments): + moments.append([]) + moments[depth].append(measurement) + return moments, depths diff --git a/tests/visualization/images/misc2.png b/tests/visualization/images/misc2.png new file mode 100644 index 00000000..51c6f0f9 Binary files /dev/null and b/tests/visualization/images/misc2.png differ diff --git a/tests/visualization/test_mpl_draw.py b/tests/visualization/test_mpl_draw.py index f80b78df..a76e6d57 100644 --- a/tests/visualization/test_mpl_draw.py +++ b/tests/visualization/test_mpl_draw.py @@ -156,6 +156,34 @@ def test_draw_misc_ops(): return fig +@pytest.mark.mpl_image_compare(baseline_dir="images", filename="misc2.png") +def test_draw_misc_ops_2(): + """Test drawing a circuit with random operations.""" + qasm3 = """ + OPENQASM 2.0; + include "qelib1.inc"; + gate iswap q0,q1 { s q0; s q1; h q0; cx q0,q1; cx q1,q0; h q1; } + qreg q[4]; + creg c[4]; + h q[0]; + h q[1]; + h q[2]; + h q[3]; + measure q[3] -> c[3]; + iswap q[2],q[3]; + swap q[0],q[2]; + measure q[2] -> c[2]; + swap q[1],q[3]; + swap q[0],q[2]; + cx q[0],q[1]; + measure q[3] -> c[1]; + cp(pi/4) q[2],q[3]; + measure q -> c; + """ + fig = mpl_draw(qasm3) + return fig + + def test_draw_raises_unsupported_format_error(): """Test that an error is raised for unsupported formats.""" qasm = """