Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Types of changes:
### Fixed
- Fixed Complex value initialization error. ([#253](https://github.com/qBraid/pyqasm/pull/253))
- Fixed duplicate qubit argument check in function calls and Error in function call with aliased qubit. ([#260](https://github.com/qBraid/pyqasm/pull/260))
- Fixed Gate ordering across registers in `pyqasm.draw()` function. ([#268](https://github.com/qBraid/pyqasm/pull/268))
=======


### Dependencies
Expand Down
55 changes: 48 additions & 7 deletions src/pyqasm/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,21 @@ def _compute_line_nums(
return line_nums, sizes


# pylint: disable-next=too-many-branches
def _get_line_keys(
all_keys: list[tuple[str, int]], line_nums: dict[tuple[str, int], int]
) -> list[tuple[str, int]]:
if not all_keys:
return []

# Get line numbers for the given keys and find min/max
key_line_nums = [line_nums[key] for key in all_keys]
start_key, end_key = min(key_line_nums), max(key_line_nums)

# Return all keys within the range
return [key for key, line_num in line_nums.items() if start_key <= line_num <= end_key]


# pylint: disable-next=too-many-branches,too-many-locals
def _compute_moments(
statements: list[ast.Statement | ast.Pragma], line_nums: dict[tuple[str, int], int]
) -> tuple[list[list[QuantumStatement]], dict[tuple[str, int], int]]:
Expand All @@ -194,24 +208,44 @@ def _compute_moments(
depths[k] = -1

moments: list[list[QuantumStatement]] = []
# Find and remove final measurements (measurements at the end of the statements list)
final_measurements: list[ast.QuantumMeasurementStatement] = []
seen_keys = set[tuple[str, int]]()

for stmt in reversed(statements):
if not isinstance(stmt, ast.QuantumMeasurementStatement):
break
control_key = _identifier_to_key(stmt.measure.qubit)
if control_key not in seen_keys:
seen_keys.add(control_key)
final_measurements.append(stmt)

# Remove the final measurements from the end of statements list
statements = statements[: -len(final_measurements)] if final_measurements else statements

for statement in statements:
if isinstance(statement, Declaration):
continue
if not isinstance(statement, QuantumStatement):
raise ValueError(f"Unsupported statement: {statement}")
if isinstance(statement, ast.QuantumGate):
qubits = [_identifier_to_key(q) for q in statement.qubits]
depth = 1 + max(depths[q] for q in qubits)
for q in qubits:
depths[q] = depth
# Get line keys for multi-qubit gates, otherwise use qubits directly
target_keys = _get_line_keys(qubits, line_nums) if len(qubits) > 1 else qubits
# Calculate new depth and update all affected keys
depth = 1 + max(depths[key] for key in target_keys)
for key in target_keys:
depths[key] = depth
elif isinstance(statement, ast.QuantumMeasurementStatement):
keys = [_identifier_to_key(statement.measure.qubit)]
if statement.target:
target_key = _identifier_to_key(statement.target)[0], -1
keys.append(target_key)
depth = 1 + max(depths[k] for k in keys)
for k in keys:
depths[k] = depth
line_keys = _get_line_keys(keys, line_nums)
# Calculate new depth and update all affected keys
depth = 1 + max(depths[key] for key in line_keys)
for key in line_keys:
depths[key] = depth
elif isinstance(statement, ast.QuantumBarrier):
qubits = []
for expr in statement.qubits:
Expand All @@ -236,6 +270,13 @@ def _compute_moments(

moments[depth].append(statement)

depth = max(depths.values())
for measurement in final_measurements:
depth += 1
if depth >= len(moments):
moments.append([])
moments[depth].append(measurement)

return moments, depths


Expand Down
Binary file added tests/visualization/images/misc2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions tests/visualization/test_mpl_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,34 @@ def test_draw_misc_ops():
return fig


@pytest.mark.mpl_image_compare(baseline_dir="images", filename="misc2.png")
def test_draw_misc_ops_2():
"""Test drawing a circuit with random operations."""
qasm3 = """
OPENQASM 2.0;
include "qelib1.inc";
gate iswap q0,q1 { s q0; s q1; h q0; cx q0,q1; cx q1,q0; h q1; }
qreg q[4];
creg c[4];
h q[0];
h q[1];
h q[2];
h q[3];
measure q[3] -> c[3];
iswap q[2],q[3];
swap q[0],q[2];
measure q[2] -> c[2];
swap q[1],q[3];
swap q[0],q[2];
cx q[0],q[1];
measure q[3] -> c[1];
cp(pi/4) q[2],q[3];
measure q -> c;
"""
fig = mpl_draw(qasm3)
return fig


def test_draw_raises_unsupported_format_error():
"""Test that an error is raised for unsupported formats."""
qasm = """
Expand Down
Loading