Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ jobs:
| {
name: .key,
label: (if (.key | contains("pre")) then .key + " (PRE-RELEASE DEPENDENCIES)" else .key end),
python: .value.python
python: .value.python,
allow_failure: (.key | contains("pre"))
}
)')
echo "envs=${ENVS_JSON}" | tee $GITHUB_OUTPUT

# Run tests through hatch. Spawns a separate runner for each environment defined in the hatch matrix obtained above.
test:
needs: get-environments

strategy:
fail-fast: false
matrix:
Expand All @@ -61,6 +61,8 @@ jobs:
name: ${{ matrix.env.label }}
runs-on: ${{ matrix.os }}

continue-on-error: ${{ matrix.env.allow_failure }}

steps:
- uses: actions/checkout@v4
with:
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ and this project adheres to [Semantic Versioning][].
### Added

### Changed
- `pycea.tl.clades` now resets `tdata.uns["clade_colors"]` when number of clades differs from number of colors (#45)

### Fixed
- Legend placement now works with tight and constrained layouts (#45)

## [0.2.0] - 2025-11-14

Expand Down
57 changes: 20 additions & 37 deletions src/pycea/pl/_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def _place_legend(
shared_kwargs: dict[str, Any],
at_x: float,
at_y: float,
box_width: float | None = None,
expand: bool = False,
) -> mlegend.Legend:
"""Place a legend on the axes at the specified position.

Expand All @@ -131,30 +129,23 @@ def _place_legend(
shared_kwargs
A dictionary of shared keyword arguments for all legends.
at_x
The x-coordinate (in axes fraction) to place the legend.
The x offset in pixels from the top-right corner of the axes.
at_y
The y-coordinate (in axes fraction) to place the legend.
box_width
The width of the legend box (in axes fraction).
expand
Whether to expand the legend to the box_width.
The y offset in pixels from the top-right corner of the axes.
"""
handlelength = legend_kwargs.get("handlelength", 2.0)
fontsize = shared_kwargs.get("fontsize", mpl.rcParams["legend.fontsize"])
if isinstance(fontsize, str):
fontsize = FontProperties(size=fontsize).get_size_in_points()
if handlelength == "dynamic":
handlelength = 100 / fontsize
if box_width is not None:
handlelength = (box_width * 325) / fontsize
offset_trans = mtransforms.ScaledTranslation(at_x / ax.figure.dpi, at_y / ax.figure.dpi, ax.figure.dpi_scale_trans)
opts: dict[str, Any] = {
"handlelength": handlelength,
"loc": legend_kwargs.get("loc", "upper left"),
"bbox_to_anchor": (at_x, at_y),
"bbox_to_anchor": (1, 1),
"bbox_transform": ax.transAxes + offset_trans,
}
if expand and box_width is not None:
opts["bbox_to_anchor"] = (at_x, at_y, box_width + 0.03, 0)
opts["mode"] = "expand"
opts.update({k: v for k, v in legend_kwargs.items() if k not in ("loc", "handlelength")})
opts.update(shared_kwargs)
leg: mlegend.Legend = ax.legend(**opts)
Expand Down Expand Up @@ -188,11 +179,14 @@ def _render_legends(
shared_kwargs = {}
fig = ax.figure
fig.canvas.draw() # make sure transforms are current
ax_height = ax.bbox.height
ax_width = ax.bbox.width
spacing *= ax_height # convert to pixels

if not hasattr(ax, "_attrs"):
ax._attrs = {} # type: ignore
x_offset = ax._attrs.get("x_offset", anchor_x) # type: ignore
y_offset = ax._attrs.get("y_offset", 1.0) # type: ignore
x_offset = ax._attrs.get("x_offset", (anchor_x - 1) * ax_width) # type: ignore
y_offset = ax._attrs.get("y_offset", 0.0) # type: ignore
column_max_width = ax._attrs.get("column_max_width", 0.0) # type: ignore

for legend_kwargs in legends:
Expand All @@ -201,39 +195,28 @@ def _render_legends(
ax.add_artist(ax.get_legend())
# 2) place normally to measure its natural size
legend = _place_legend(ax, legend_kwargs, shared_kwargs, x_offset, y_offset)
# 3) measure in axes fraction
# 3) measure in pixels
renderer = fig.canvas.get_renderer() # type: ignore
bbox_disp = legend.get_window_extent(renderer=renderer)
bbox_axes = mtransforms.Bbox(ax.transAxes.inverted().transform(bbox_disp))
height = bbox_axes.height
width = bbox_axes.width
width = bbox_disp.width
height = bbox_disp.height
# 4) if first in column, initialize max width
if column_max_width == 0.0:
column_max_width = width
# 5) if it overflows vertically, start new column
if (height > y_offset) & (y_offset != 1.0):
if (height - y_offset > ax_height) & (y_offset != 0.0):
legend.remove()
x_offset += column_max_width + spacing
y_offset = 1.0
column_max_width = 0.0
y_offset = 0.0
# place again and re-measure
legend = _place_legend(ax, legend_kwargs, shared_kwargs, x_offset, y_offset)
bbox_disp = legend.get_window_extent(renderer=renderer)
bbox_axes = mtransforms.Bbox(ax.transAxes.inverted().transform(bbox_disp))
height = bbox_axes.height
width = bbox_axes.width
column_max_width = width
# 6) if this legend is narrower than the column max, re-place with expand
elif width < column_max_width:
legend.remove()
legend = _place_legend(
ax, legend_kwargs, shared_kwargs, x_offset, y_offset, box_width=column_max_width, expand=True
)
# 7) otherwise, update column max if this one is wider
else:
column_max_width = width
# 8) finalize: update y_offset and save to _attrs
column_max_width = bbox_disp.width
height = bbox_disp.height
# 6) update offsets for next legend
y_offset -= height + spacing
if width > column_max_width:
column_max_width = width
ax._attrs.update({"y_offset": y_offset}) # type: ignore
ax._attrs.update({"x_offset": x_offset}) # type: ignore
ax._attrs.update({"column_max_width": column_max_width}) # type: ignore
9 changes: 9 additions & 0 deletions src/pycea/tl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,12 @@ def _remove_attribute(tree, key, nodes=True, edges=True):
for u, v in tree.edges:
if key in tree.edges[u, v]:
del tree.edges[u, v][key]


def _check_colors_length(tdata, key: str):
"""Remove colors from uns if they do not match the number of unique entries in obs."""
if f"{key}_colors" not in tdata.uns.keys():
return
if tdata.obs[key].nunique() != len(tdata.uns[f"{key}_colors"]):
del tdata.uns[f"{key}_colors"]
return
8 changes: 7 additions & 1 deletion src/pycea/tl/clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_trees,
)

from ._utils import _remove_attribute
from ._utils import _check_colors_length, _remove_attribute


def _nodes_at_depth(tree, parent, nodes, depth, depth_key):
Expand Down Expand Up @@ -149,6 +149,11 @@ def clades(
* `tdata.obst[tree].nodes[key_added]` : `Object`
- Clade assignment for each node.

Modifies the following fields:

* `tdata.uns[f"{key_added}_colors"]` : `List`
- Removed if its length does not match the number of unique clades.

Examples
--------
Mark clades at specified depth
Expand Down Expand Up @@ -183,5 +188,6 @@ def clades(
node_to_clade = get_keyed_node_data(tdata, key_added, tree_keys, slot="obst")
node_to_clade.index = node_to_clade.index.droplevel(0)
tdata.obs[key_added] = tdata.obs.index.map(node_to_clade[key_added])
_check_colors_length(tdata, key_added)
if copy:
return pd.concat(lcas)
10 changes: 1 addition & 9 deletions tests/test_plot_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,13 @@ def test_size_legend():
assert len(legend2["labels"]) == 6


def test_place_legend_default_and_expand():
def test_place_legend_default():
fig, ax = plt.subplots()
l1 = mlines.Line2D([], [], color="red", label="a")
legend_kwargs = {"title": "t1", "handles": [l1], "labels": ["a"]}
shared_kwargs = {"fontsize": 10}
leg = _place_legend(ax, legend_kwargs, shared_kwargs, at_x=0.5, at_y=0.5)
assert isinstance(leg, mlegend.Legend)
# Expand case
fig2, ax2 = plt.subplots()
p1 = mpatches.Patch(color="blue", label="b")
legend_kwargs2 = {"title": "t2", "handles": [p1], "labels": ["b"], "handlelength": 2}
shared_kwargs2 = {"fontsize": 12}
leg2 = _place_legend(ax2, legend_kwargs2, shared_kwargs2, at_x=0.2, at_y=0.8, box_width=0.3, expand=True)
assert isinstance(leg2, mlegend.Legend)
assert hasattr(leg2, "_bbox_to_anchor")


def test_render_legends():
Expand Down
Loading