Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion ultraplot/axes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3148,12 +3148,39 @@ def _update_share_labels(self, axes=None, target="x"):
target : {'x', 'y'}, optional
Which axis labels to share ('x' for x-axis, 'y' for y-axis)
"""
if not axes:
if axes is False:
self.figure._clear_share_label_groups([self], target=target)
return
if axes is None or not len(list(axes)):
return

# Convert indices to actual axes objects
if isinstance(axes[0], int):
axes = [self.figure.axes[i] for i in axes]
axes = [
ax._get_topmost_axes() if hasattr(ax, "_get_topmost_axes") else ax
for ax in axes
if ax is not None
]
if len(axes) < 2:
return
# Preserve order while de-duplicating
seen = set()
unique = []
for ax in axes:
ax_id = id(ax)
if ax_id in seen:
continue
seen.add(ax_id)
unique.append(ax)
axes = unique
if len(axes) < 2:
return

# Prefer figure-managed spanning labels when possible
if all(isinstance(ax, maxes.SubplotBase) for ax in axes):
self.figure._register_share_label_group(axes, target=target, source=self)
return

# Get the center position of the axes group
if box := self.get_center_of_axes(axes):
Expand Down
28 changes: 19 additions & 9 deletions ultraplot/axes/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,27 @@
import copy
import inspect

import matplotlib.axis as maxis
import matplotlib.dates as mdates
import matplotlib.ticker as mticker
import numpy as np

from packaging import version

from .. import constructor
from .. import scale as pscale
from .. import ticker as pticker
from ..config import rc
from ..internals import ic # noqa: F401
from ..internals import _not_none, _pop_rc, _version_mpl, docstring, labels, warnings
from . import plot, shared
import matplotlib.axis as maxis

from ..internals import (
_not_none,
_pop_rc,
_version_mpl,
docstring,
ic, # noqa: F401
labels,
warnings,
)
from ..utils import units
from . import plot, shared

__all__ = ["CartesianAxes"]

Expand Down Expand Up @@ -432,9 +437,14 @@ def _apply_axis_sharing_for_axis(

# Handle axis label sharing (level > 0)
if level > 0:
shared_axis_obj = getattr(shared_axis, f"{axis_name}axis")
labels._transfer_label(axis.label, shared_axis_obj.label)
axis.label.set_visible(False)
if self.figure._is_share_label_group_member(self, axis_name):
pass
elif self.figure._is_share_label_group_member(shared_axis, axis_name):
axis.label.set_visible(False)
else:
shared_axis_obj = getattr(shared_axis, f"{axis_name}axis")
labels._transfer_label(axis.label, shared_axis_obj.label)
axis.label.set_visible(False)

# Handle tick label sharing (level > 2)
if level > 2:
Expand Down
19 changes: 19 additions & 0 deletions ultraplot/axes/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_version_cartopy,
docstring,
ic, # noqa: F401
labels,
warnings,
)
from ..utils import units
Expand Down Expand Up @@ -661,6 +662,24 @@ def _apply_axis_sharing(self):
the leftmost and bottommost is the *figure* sharing level.
"""

# Share axis labels
if self._sharex and self.figure._sharex >= 1:
if self.figure._is_share_label_group_member(self, "x"):
pass
elif self.figure._is_share_label_group_member(self._sharex, "x"):
self.xaxis.label.set_visible(False)
else:
labels._transfer_label(self.xaxis.label, self._sharex.xaxis.label)
self.xaxis.label.set_visible(False)
if self._sharey and self.figure._sharey >= 1:
if self.figure._is_share_label_group_member(self, "y"):
pass
elif self.figure._is_share_label_group_member(self._sharey, "y"):
self.yaxis.label.set_visible(False)
else:
labels._transfer_label(self.yaxis.label, self._sharey.yaxis.label)
self.yaxis.label.set_visible(False)

# Share interval x
if self._sharex and self.figure._sharex >= 2:
self._lonaxis.set_view_interval(*self._sharex._lonaxis.get_view_interval())
Expand Down
219 changes: 219 additions & 0 deletions ultraplot/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ def __init__(
self._supxlabel_dict = {} # an axes: label mapping
self._supylabel_dict = {} # an axes: label mapping
self._suplabel_dict = {"left": {}, "right": {}, "bottom": {}, "top": {}}
self._share_label_groups = {"x": {}, "y": {}} # explicit label-sharing groups
self._suptitle_pad = rc["suptitle.pad"]
d = self._suplabel_props = {} # store the super label props
d["left"] = {"va": "center", "ha": "right"}
Expand All @@ -840,6 +841,7 @@ def draw(self, renderer):
# we can use get_border_axes for the outermost plots and then collect their outermost panels that are not colorbars
self._share_ticklabels(axis="x")
self._share_ticklabels(axis="y")
self._apply_share_label_groups()
super().draw(renderer)

def _share_ticklabels(self, *, axis: str) -> None:
Expand Down Expand Up @@ -1889,6 +1891,223 @@ def _align_axis_label(self, x):
if span:
self._update_axis_label(pos, axs)

# Apply explicit label-sharing groups for this axis
self._apply_share_label_groups(axis=x)

def _register_share_label_group(self, axes, *, target, source=None):
"""
Register an explicit label-sharing group for a subset of axes.
"""
if not axes:
return
axes = list(axes)
axes = [ax for ax in axes if ax is not None and ax.figure is self]
if len(axes) < 2:
return

# Preserve order while de-duplicating
seen = set()
unique = []
for ax in axes:
ax_id = id(ax)
if ax_id in seen:
continue
seen.add(ax_id)
unique.append(ax)
axes = unique
if len(axes) < 2:
return

# Split by label side if mixed
axes_by_side = {}
if target == "x":
for ax in axes:
axes_by_side.setdefault(ax.xaxis.get_label_position(), []).append(ax)
else:
for ax in axes:
axes_by_side.setdefault(ax.yaxis.get_label_position(), []).append(ax)
if len(axes_by_side) > 1:
for side, side_axes in axes_by_side.items():
side_source = source if source in side_axes else None
self._register_share_label_group_for_side(
side_axes, target=target, side=side, source=side_source
)
return

side, side_axes = next(iter(axes_by_side.items()))
self._register_share_label_group_for_side(
side_axes, target=target, side=side, source=source
)

def _register_share_label_group_for_side(self, axes, *, target, side, source=None):
"""
Register a single label-sharing group for a given label side.
"""
if not axes:
return
axes = [ax for ax in axes if ax is not None and ax.figure is self]
if len(axes) < 2:
return

# Prefer label text from the source axes if available
label = None
if source in axes:
candidate = getattr(source, f"{target}axis").label
if candidate.get_text().strip():
label = candidate
if label is None:
for ax in axes:
candidate = getattr(ax, f"{target}axis").label
if candidate.get_text().strip():
label = candidate
break

text = label.get_text() if label else ""
props = None
if label is not None:
props = {
"color": label.get_color(),
"fontproperties": label.get_font_properties(),
"rotation": label.get_rotation(),
"rotation_mode": label.get_rotation_mode(),
"ha": label.get_ha(),
"va": label.get_va(),
}

group_key = tuple(sorted(id(ax) for ax in axes))
groups = self._share_label_groups[target]
group = groups.get(group_key)
if group is None:
groups[group_key] = {
"axes": axes,
"side": side,
"text": text if text.strip() else "",
"props": props,
}
else:
group["axes"] = axes
group["side"] = side
if text.strip():
group["text"] = text
group["props"] = props

def _is_share_label_group_member(self, ax, axis):
"""
Return True if the axes belongs to any explicit label-sharing group.
"""
groups = self._share_label_groups.get(axis, {})
return any(ax in group["axes"] for group in groups.values())

def _has_share_label_groups(self, axis):
"""
Return True if there are any explicit label-sharing groups for an axis.
"""
return bool(self._share_label_groups.get(axis, {}))

def _clear_share_label_groups(self, axes=None, *, target=None):
"""
Clear explicit label-sharing groups, optionally filtered by axes.
"""
targets = ("x", "y") if target is None else (target,)
for axis in targets:
groups = self._share_label_groups.get(axis, {})
if axes is None:
groups.clear()
continue
axes_set = {ax for ax in axes if ax is not None}
for key in list(groups):
if any(ax in axes_set for ax in groups[key]["axes"]):
del groups[key]
# Clear any existing spanning labels tied to these axes
if axis == "x":
for ax in axes_set:
if ax in self._supxlabel_dict:
self._supxlabel_dict[ax].set_text("")
else:
for ax in axes_set:
if ax in self._supylabel_dict:
self._supylabel_dict[ax].set_text("")

def _apply_share_label_groups(self, axis=None):
"""
Apply explicit label-sharing groups, overriding default label sharing.
"""

def _order_axes_for_side(axs, side):
if side in ("bottom", "top"):
key = (
(lambda ax: ax._range_subplotspec("y")[1])
if side == "bottom"
else (lambda ax: ax._range_subplotspec("y")[0])
)
reverse = side == "bottom"
else:
key = (
(lambda ax: ax._range_subplotspec("x")[1])
if side == "right"
else (lambda ax: ax._range_subplotspec("x")[0])
)
reverse = side == "right"
try:
return sorted(axs, key=key, reverse=reverse)
except Exception:
return list(axs)

axes = (axis,) if axis in ("x", "y") else ("x", "y")
for target in axes:
groups = self._share_label_groups.get(target, {})
for group in groups.values():
axs = [
ax for ax in group["axes"] if ax.figure is self and ax.get_visible()
]
if len(axs) < 2:
continue

side = group["side"]
ordered_axs = _order_axes_for_side(axs, side)

# Refresh label text from any axis with non-empty text
label = None
for ax in ordered_axs:
candidate = getattr(ax, f"{target}axis").label
if candidate.get_text().strip():
label = candidate
break
text = group["text"]
props = group["props"]
if label is not None:
text = label.get_text()
props = {
"color": label.get_color(),
"fontproperties": label.get_font_properties(),
"rotation": label.get_rotation(),
"rotation_mode": label.get_rotation_mode(),
"ha": label.get_ha(),
"va": label.get_va(),
}
group["text"] = text
group["props"] = props

if not text:
continue

try:
_, ax = self._get_align_coord(
side, ordered_axs, includepanels=self._includepanels
)
except Exception:
continue
axlab = getattr(ax, f"{target}axis").label
axlab.set_text(text)
if props is not None:
axlab.set_color(props["color"])
axlab.set_fontproperties(props["fontproperties"])
axlab.set_rotation(props["rotation"])
axlab.set_rotation_mode(props["rotation_mode"])
axlab.set_ha(props["ha"])
axlab.set_va(props["va"])
self._update_axis_label(side, ordered_axs)

def _align_super_labels(self, side, renderer):
"""
Adjust the position of super labels.
Expand Down
Loading