diff --git a/pyproject.toml b/pyproject.toml
index df1f469a..c0888f86 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,6 +24,7 @@ dev = [
{include-group = "lint"},
{include-group = "test"},
"pre-commit>=4.3.0",
+ "viztracer>=1.0.4",
]
lint = [
"flake8>=7.3.0",
@@ -34,7 +35,8 @@ test = [
"pytest>=8.4.1",
"pytest-asyncio>=1.1.0",
"pytest-cov>=6.2.1",
- "xarray>=2023.1.0;python_version<'3.13'"
+ "xarray>=2023.1.0;python_version<'3.13'",
+ "psutil>=7.1.0",
]
docs = [
"ezmsg-sigproc>=2.2.0",
@@ -44,6 +46,7 @@ docs = [
[project.scripts]
ezmsg = "ezmsg.core.command:cmdline"
+ezmsg-perf = "ezmsg.util.perf.command:command"
[project.optional-dependencies]
axisarray = [
diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py
index 50e1aa0c..9a28d9ee 100644
--- a/src/ezmsg/core/subclient.py
+++ b/src/ezmsg/core/subclient.py
@@ -210,7 +210,7 @@ async def _handle_publisher(
self._incoming.put_nowait((id, msg_id))
- except (ConnectionResetError, BrokenPipeError):
+ except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError):
logger.debug(f"connection fail: sub:{self.id} -> pub:{id}")
finally:
diff --git a/src/ezmsg/util/perf/__init__.py b/src/ezmsg/util/perf/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py
new file mode 100644
index 00000000..2df4bebb
--- /dev/null
+++ b/src/ezmsg/util/perf/analysis.py
@@ -0,0 +1,454 @@
+import json
+import typing
+import dataclasses
+import argparse
+import html
+import math
+import webbrowser
+
+from pathlib import Path
+
+from ..messagecodec import MessageDecoder
+from .envinfo import TestEnvironmentInfo, format_env_diff
+from .run import get_datestamp
+from .impl import (
+ TestParameters,
+ Metrics,
+ TestLogEntry,
+)
+
+import ezmsg.core as ez
+
+try:
+ import xarray as xr
+ import pandas as pd # xarray depends on pandas
+except ImportError:
+ ez.logger.error('ezmsg perf analysis requires xarray')
+ raise
+
+try:
+ import numpy as np
+except ImportError:
+ ez.logger.error('ezmsg perf analysis requires numpy')
+ raise
+
+TEST_DESCRIPTION = """
+Configurations (config):
+- fanin: many publishers to one subscriber
+- fanout: one publisher to many subscribers
+- relay: one publisher to one subscriber through many relays
+
+Communication strategies (comms):
+- local: all subs, relays, and pubs are in the SAME process
+- shm / tcp: some clients move to a second process; comms via shared memory / TCP
+ * fanin: all publishers moved
+ * fanout: all subscribers moved
+ * relay: the publisher and all relay nodes moved
+- shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively
+
+Variables:
+- n_clients: pubs (fanin), subs (fanout), or relays (relay)
+- msg_size: nominal message size (bytes)
+
+Metrics:
+- sample_rate: messages/sec at the sink (higher = better)
+- data_rate: bytes/sec at the sink (higher = better)
+- latency_mean: average send -> receive latency in seconds (lower = better)
+"""
+
+
+def load_perf(perf: Path) -> xr.Dataset:
+
+ all_results: typing.Dict[TestParameters, typing.Dict[int, typing.List[Metrics]]] = dict()
+
+ run_idx = 0
+
+ with open(perf, 'r') as perf_f:
+ info: TestEnvironmentInfo = json.loads(next(perf_f), cls = MessageDecoder)
+ for line in perf_f:
+ obj = json.loads(line, cls = MessageDecoder)
+ if isinstance(obj, TestEnvironmentInfo):
+ run_idx += 1
+ elif isinstance(obj, TestLogEntry):
+ runs = all_results.get(obj.params, dict())
+ metrics = runs.get(run_idx, list())
+ metrics.append(obj.results)
+ runs[run_idx] = metrics
+ all_results[obj.params] = runs
+
+ n_clients_axis = list(sorted(set([p.n_clients for p in all_results.keys()])))
+ msg_size_axis = list(sorted(set([p.msg_size for p in all_results.keys()])))
+ comms_axis = list(sorted(set([p.comms for p in all_results.keys()])))
+ config_axis = list(sorted(set([p.config for p in all_results.keys()])))
+
+ dims = ['n_clients', 'msg_size', 'comms', 'config']
+ coords = {
+ 'n_clients': n_clients_axis,
+ 'msg_size': msg_size_axis,
+ 'comms': comms_axis,
+ 'config': config_axis
+ }
+
+ data_vars = {}
+ for field in dataclasses.fields(Metrics):
+ m = np.zeros((
+ len(n_clients_axis),
+ len(msg_size_axis),
+ len(comms_axis),
+ len(config_axis)
+ )) * np.nan
+ for p, a in all_results.items():
+ # tests are run multiple times; get the median of means
+ m[
+ n_clients_axis.index(p.n_clients),
+ msg_size_axis.index(p.msg_size),
+ comms_axis.index(p.comms),
+ config_axis.index(p.config)
+ ] = np.median([np.mean([getattr(v, field.name) for v in r]) for r in a.values()])
+ data_vars[field.name] = xr.DataArray(m, dims = dims, coords = coords)
+
+ dataset = xr.Dataset(data_vars, attrs = dict(info = info))
+ return dataset
+
+def _escape(s: str) -> str:
+ return html.escape(str(s), quote=True)
+
+def _env_block(title: str, body: str) -> str:
+ return f"""
+
+ {_escape(title)}
+ {_escape(body).strip()}
+
+ """
+
+def _legend_block() -> str:
+ return """
+
+ Legend
+
+ - Comparison mode: values are percentages (100 = no change).
+ - Green: improvement (↑ sample/data rate, ↓ latency).
+ - Red: regression (↓ sample/data rate, ↑ latency).
+
+
+ """
+
+def _base_css() -> str:
+ # Minimal, print-friendly CSS + color scales for cells.
+ return """
+
+ """
+
+def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 10.0) -> str:
+ """
+ Returns inline CSS background for a comparison % value.
+ value: e.g., 97.3, 104.8, etc.
+ For sample_rate/data_rate: improvement > 100 (good).
+ For latency_mean: improvement < 100 (good).
+ Noise band ±10% around 100 is neutral.
+ """
+ if not (isinstance(value, (int, float)) and math.isfinite(value)):
+ return ""
+
+ delta = value - 100.0
+ # Determine direction: + is good for sample/data; - is good for latency
+ if 'rate' in metric:
+ # positive delta good, negative bad
+ magnitude = abs(delta)
+ sign_good = delta > 0
+ elif 'latency' in metric:
+ # negative delta good (lower latency)
+ magnitude = abs(delta)
+ sign_good = delta < 0
+ else:
+ return ""
+
+ # Noise band: keep neutral
+ if magnitude <= noise_band_pct:
+ return ""
+
+ # Scale 5%..50% across 0..1; clamp
+ scale = max(0.0, min(1.0, (magnitude - noise_band_pct) / 45.0))
+
+ # Choose hue and lightness; use HSL with gentle saturation
+ hue = "var(--green)" if sign_good else "var(--red)"
+ # opacity via alpha blend on lightness via HSLa
+ # Use saturation ~70%, lightness around 40–50% blended with table bg
+ alpha = 0.15 + 0.35 * scale # 0.15..0.50
+ return f"background-color: hsla({hue}, 70%, 45%, {alpha});"
+
+def _format_number(x) -> str:
+ if isinstance(x, (int,)) and not isinstance(x, bool):
+ return f"{x:d}"
+ try:
+ xf = float(x)
+ except Exception:
+ return _escape(str(x))
+ # Heuristic: for comparison percentages, 1 decimal is nice; for absolute, 3 decimals for latency.
+ return f"{xf:.3f}"
+
+
+def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> None:
+ """ print perf test results and comparisons to the console """
+
+ output = ''
+
+ perf = load_perf(perf_path)
+ info: TestEnvironmentInfo = perf.attrs['info']
+ output += str(info) + '\n\n'
+
+ relative = False
+ env_diff = None
+ if baseline_path is not None:
+ relative = True
+ output += "PERFORMANCE COMPARISON\n\n"
+ baseline = load_perf(baseline_path)
+ perf = (perf / baseline) * 100.0
+ baseline_info: TestEnvironmentInfo = baseline.attrs['info']
+ env_diff = format_env_diff(info.diff(baseline_info))
+ output += env_diff + '\n\n'
+
+ # These raw stats are still valuable to have, but are confusing
+ # when making relative comparisons
+ perf = perf.drop_vars(['latency_total', 'num_msgs'])
+
+ perf = perf.stack(params = ['n_clients', 'msg_size']).dropna('params')
+ df = perf.squeeze().to_dataframe()
+ df = df.drop('n_clients', axis = 1)
+ df = df.drop('msg_size', axis = 1)
+
+ for _, config_ds in perf.groupby('config'):
+ for _, comms_ds in config_ds.groupby('comms'):
+ output += str(comms_ds.squeeze().to_dataframe()) + '\n\n'
+ output += '\n'
+
+ print(output)
+
+ if html:
+ # Ensure expected columns exist
+ expected_cols = {"sample_rate_mean", "sample_rate_median", "data_rate", "latency_mean", "latency_median"}
+ missing = expected_cols - set(df.columns)
+ if missing:
+ raise ValueError(f"Missing expected columns in dataset: {missing}")
+
+ # We'll render a table per (config, comms) group.
+ groups = df.reset_index().sort_values(
+ by=["config", "comms", "n_clients", "msg_size"]
+ ).groupby(["config", "comms"], sort=False)
+
+ # Build HTML
+ parts: list[str] = []
+ parts.append("
")
+ parts.append("")
+ parts.append("ezmsg perf report")
+ parts.append(_base_css())
+ parts.append("")
+
+ parts.append("
")
+ parts.append("ezmsg Performance Report
")
+ sub = str(perf_path)
+ if baseline_path is not None:
+ sub += f' relative to {str(baseline_path)}'
+ parts.append(f"{_escape(sub)}
")
+ parts.append("")
+
+ if info is not None:
+ parts.append(_env_block("Test Environment", str(info)))
+
+ parts.append(_env_block("Test Details", TEST_DESCRIPTION))
+
+ if env_diff is not None:
+ # Show diffs using your helper
+ parts.append("
")
+ parts.append("Environment Differences vs Baseline
")
+ parts.append(f"{_escape(env_diff)}")
+ parts.append("")
+ parts.append(_legend_block())
+
+ # Render each group
+ for (config, comms), g in groups:
+ # Keep only expected columns in order
+ cols = ["n_clients", "msg_size", "sample_rate_mean", "sample_rate_median", "data_rate", "latency_mean", "latency_median"]
+ g = g[cols].copy()
+
+ # String format some columns (msg_size with separators)
+ g["msg_size"] = g["msg_size"].map(lambda x: f"{int(x):,}" if pd.notna(x) else x)
+
+ # Build table manually so we can inject inline cell styles easily
+ # (pandas Styler is great but produces bulky HTML; manual keeps it clean)
+ header = f"""
+
+
+ | n_clients |
+ msg_size {'' if relative else '(b)'} |
+ sample_rate_mean {'' if relative else '(msgs/s)'} |
+ sample_rate_median {'' if relative else '(msgs/s)'} |
+ data_rate {'' if relative else '(MB/s)'} |
+ latency_mean {'' if relative else '(us)'} |
+ latency_median {'' if relative else '(us)'} |
+ |
+
+ """
+ body_rows: list[str] = []
+ for _, row in g.iterrows():
+ sr, srm, dr, lt, lm = row["sample_rate_mean"], row["sample_rate_median"], row["data_rate"], row["latency_mean"], row["latency_median"]
+ dr = dr if relative else dr / 2**20
+ lt = lt if relative else lt * 1e6
+ lm = lm if relative else lm * 1e6
+ sr_style = _color_for_comparison(sr, "sample_rate_mean") if relative else ""
+ srm_style = _color_for_comparison(srm, "sample_rate_median") if relative else ""
+ dr_style = _color_for_comparison(dr, "data_rate") if relative else ""
+ lt_style = _color_for_comparison(lt, "latency_mean") if relative else ""
+ lm_style = _color_for_comparison(lm, "latency_median") if relative else ""
+
+ body_rows.append(
+ "
"
+ f"| {_format_number(row['n_clients'])} | "
+ f"{_escape(row['msg_size'])} | "
+ f"{_format_number(sr)} | "
+ f"{_format_number(srm)} | "
+ f"{_format_number(dr)} | "
+ f"{_format_number(lt)} | "
+ f"{_format_number(lm)} | "
+ "
"
+ )
+ table_html = f"
{header}{''.join(body_rows)}
"
+
+ parts.append(
+ f"
"
+ f"{_escape(config)}"
+ f"{_escape(comms)}"
+ f"
{table_html}"
+ )
+
+ parts.append("
")
+ html_text = "".join(parts)
+
+ out_path = Path(f'report_{get_datestamp()}.html')
+ out_path.write_text(html_text, encoding="utf-8")
+ webbrowser.open(out_path.resolve().as_uri())
+
+
+def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None:
+ p_summary = subparsers.add_parser("summary", help = "summarize performance results")
+ p_summary.add_argument(
+ "perf",
+ type=Path,
+ help="perf test",
+ )
+ p_summary.add_argument(
+ "--baseline",
+ "-b",
+ type=Path,
+ default=None,
+ help="baseline perf test for comparison",
+ )
+ p_summary.add_argument(
+ "--html",
+ action = 'store_true',
+ help = "generate an html output file and render results in browser",
+ )
+
+ p_summary.set_defaults(_handler=lambda ns: summary(
+ perf_path = ns.perf,
+ baseline_path = ns.baseline,
+ html = ns.html
+ ))
\ No newline at end of file
diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py
new file mode 100644
index 00000000..ab919999
--- /dev/null
+++ b/src/ezmsg/util/perf/command.py
@@ -0,0 +1,17 @@
+import argparse
+
+from .analysis import setup_summary_cmdline
+from .run import setup_run_cmdline
+
+def command() -> None:
+ parser = argparse.ArgumentParser(description = 'ezmsg perf test utility')
+ subparsers = parser.add_subparsers(dest="command", required=True)
+
+ setup_run_cmdline(subparsers)
+ setup_summary_cmdline(subparsers)
+
+ ns = parser.parse_args()
+ ns._handler(ns)
+
+if __name__ == "__main__":
+ command()
diff --git a/src/ezmsg/util/perf/envinfo.py b/src/ezmsg/util/perf/envinfo.py
new file mode 100644
index 00000000..66fff18b
--- /dev/null
+++ b/src/ezmsg/util/perf/envinfo.py
@@ -0,0 +1,83 @@
+import dataclasses
+import datetime
+import platform
+import typing
+import sys
+import subprocess
+
+import ezmsg.core as ez
+
+try:
+ import numpy as np
+except ImportError:
+ ez.logger.error("ezmsg perf requires numpy")
+ raise
+
+try:
+ import psutil
+except ImportError:
+ ez.logger.error("ezmsg perf requires psutil")
+ raise
+
+
+def _git_commit() -> str:
+ try:
+ return subprocess.check_output(
+ ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL
+ ).decode().strip()
+ except:
+ return "unknown"
+
+def _git_branch() -> str:
+ try:
+ return subprocess.check_output(
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL
+ ).decode().strip()
+ except:
+ return "unknown"
+
+@dataclasses.dataclass
+class TestEnvironmentInfo:
+ ezmsg_version: str = dataclasses.field(default_factory=lambda: ez.__version__)
+ numpy_version: str = dataclasses.field(default_factory=lambda: np.__version__)
+ python_version: str = dataclasses.field(default_factory=lambda: sys.version.replace("\n", " "))
+ os: str = dataclasses.field(default_factory=lambda: platform.system())
+ os_version: str = dataclasses.field(default_factory=lambda: platform.version())
+ machine: str = dataclasses.field(default_factory=lambda: platform.machine())
+ processor: str = dataclasses.field(default_factory=lambda: platform.processor())
+ cpu_count_logical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=True))
+ cpu_count_physical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=False))
+ memory_gb: float = dataclasses.field(default_factory=lambda:round(psutil.virtual_memory().total / (1024**3), 2))
+ start_time: str = dataclasses.field(default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds"))
+ git_commit: str = dataclasses.field(default_factory=_git_commit)
+ git_branch: str = dataclasses.field(default_factory=_git_branch)
+
+ def __str__(self) -> str:
+ fields = dataclasses.asdict(self)
+ width = max(len(k) for k in fields)
+ lines = ["TestEnvironmentInfo:"]
+ for key, value in fields.items():
+ lines.append(f" {key.ljust(width)} : {value}")
+ return "\n".join(lines)
+
+ def diff(self, other: "TestEnvironmentInfo") -> typing.Dict[str, typing.Tuple[typing.Any, typing.Any]]:
+ """Return a structured diff: {field: (self_value, other_value)} for changed fields."""
+ a = dataclasses.asdict(self)
+ b = dataclasses.asdict(other)
+ keys = set(a) | set(b)
+ return {k: (a.get(k), b.get(k)) for k in keys if a.get(k) != b.get(k)}
+
+
+def format_env_diff(diffs: typing.Dict[str, typing.Tuple[typing.Any, typing.Any]]) -> str:
+ """Pretty-print the structured diff in the same aligned style."""
+ if not diffs:
+ return "No differences."
+ width = max(len(k) for k in diffs)
+ lines = ["Differences in TestEnvironmentInfo:"]
+ for k in sorted(diffs):
+ left, right = diffs[k]
+ lines.append(f" {k.ljust(width)} : {left} != {right}")
+ return "\n".join(lines)
+
+def diff_envs(a: TestEnvironmentInfo, b: TestEnvironmentInfo) -> str:
+ return format_env_diff(a.diff(b))
\ No newline at end of file
diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py
new file mode 100644
index 00000000..51ae600b
--- /dev/null
+++ b/src/ezmsg/util/perf/impl.py
@@ -0,0 +1,362 @@
+import asyncio
+import dataclasses
+import os
+import time
+import typing
+import enum
+
+import ezmsg.core as ez
+
+from ezmsg.util.messages.util import replace
+from ezmsg.core.netprotocol import Address
+
+try:
+ import numpy as np
+except ImportError:
+ ez.logger.error("ezmsg perf requires numpy")
+ raise
+
+from .util import stable_perf
+
+
+def collect(
+ components: typing.Optional[typing.Mapping[str, ez.Component]] = None,
+ network: ez.NetworkDefinition = (),
+ process_components: typing.Collection[ez.Component] | None = None,
+ **components_kwargs: ez.Component,
+) -> ez.Collection:
+ """ collect a grouping of pre-configured components into a new "Collection" """
+ from ezmsg.core.util import either_dict_or_kwargs
+
+ components = either_dict_or_kwargs(components, components_kwargs, "collect")
+ if components is None:
+ raise ValueError("Must supply at least one component to run")
+
+ out = ez.Collection()
+ for name, comp in components.items():
+ comp._set_name(name)
+ out._components = components # FIXME: Component._components should be typehinted as a Mapping
+ out.network = lambda: network
+ out.process_components = lambda: (out,) if process_components is None else process_components
+ return out
+
+
+@dataclasses.dataclass
+class Metrics:
+ num_msgs: int
+ sample_rate_mean: float
+ sample_rate_median: float
+ latency_mean: float
+ latency_median: float
+ latency_total: float
+ data_rate: float
+
+
+class LoadTestSettings(ez.Settings):
+ max_duration: float
+ num_msgs: int
+ dynamic_size: int
+ buffers: int
+ force_tcp: bool
+
+
+@dataclasses.dataclass
+class LoadTestSample:
+ _timestamp: float
+ counter: int
+ dynamic_data: np.ndarray
+ key: str
+
+class LoadTestSourceState(ez.State):
+ counter: int = 0
+
+class LoadTestSource(ez.Unit):
+ OUTPUT = ez.OutputStream(LoadTestSample)
+ SETTINGS = LoadTestSettings
+ STATE = LoadTestSourceState
+
+ async def initialize(self) -> None:
+ self.OUTPUT.num_buffers = self.SETTINGS.buffers
+ self.OUTPUT.force_tcp = self.SETTINGS.force_tcp
+
+ @ez.publisher(OUTPUT)
+ async def publish(self) -> typing.AsyncGenerator:
+ ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})")
+ start_time = time.time()
+ for _ in range(self.SETTINGS.num_msgs):
+
+ current_time = time.time()
+ if current_time - start_time >= self.SETTINGS.max_duration:
+ break
+
+ yield (
+ self.OUTPUT,
+ LoadTestSample(
+ _timestamp=time.time(),
+ counter=self.STATE.counter,
+ dynamic_data=np.zeros(
+ int(self.SETTINGS.dynamic_size // 4), dtype=np.float32
+ ),
+ key = self.name,
+ ),
+ )
+ self.STATE.counter += 1
+
+ ez.logger.info("Exiting publish")
+ raise ez.Complete
+
+ async def shutdown(self) -> None:
+ ez.logger.info(f"Samples sent: {self.STATE.counter}")
+
+
+
+class LoadTestRelay(ez.Unit):
+ INPUT = ez.InputStream(LoadTestSample)
+ OUTPUT = ez.OutputStream(LoadTestSample)
+
+ @ez.subscriber(INPUT, zero_copy = True)
+ @ez.publisher(OUTPUT)
+ async def on_msg(self, msg: LoadTestSample) -> typing.AsyncGenerator:
+ yield self.OUTPUT, msg
+
+
+class LoadTestReceiverState(ez.State):
+ # Tuples of sent timestamp, received timestamp, counter, dynamic size
+ received_data: typing.List[typing.Tuple[float, float, int]] = dataclasses.field(
+ default_factory=list
+ )
+ counters: typing.Dict[str, int] = dataclasses.field(default_factory=dict)
+
+
+class LoadTestReceiver(ez.Unit):
+ INPUT = ez.InputStream(LoadTestSample)
+ SETTINGS = LoadTestSettings
+ STATE = LoadTestReceiverState
+
+ async def initialize(self) -> None:
+ ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})")
+
+ @ez.subscriber(INPUT, zero_copy=True)
+ async def receive(self, sample: LoadTestSample) -> None:
+ counter = self.STATE.counters.get(sample.key, -1)
+ if sample.counter != counter + 1:
+ ez.logger.warning(
+ f"{sample.counter - counter - 1} samples skipped!"
+ )
+ self.STATE.received_data.append(
+ (sample._timestamp, time.time(), sample.counter)
+ )
+ self.STATE.counters[sample.key] = sample.counter
+
+
+class LoadTestSink(LoadTestReceiver):
+
+ INPUT = ez.InputStream(LoadTestSample)
+
+ @ez.subscriber(INPUT, zero_copy=True)
+ async def receive(self, sample: LoadTestSample) -> None:
+ await super().receive(sample)
+ if len(self.STATE.received_data) == self.SETTINGS.num_msgs:
+ raise ez.NormalTermination
+
+ @ez.task
+ async def terminate(self) -> None:
+ # Wait for the max duration of the load test
+ await asyncio.sleep(self.SETTINGS.max_duration)
+ ez.logger.warning("TIMEOUT -- terminating test.")
+ raise ez.NormalTermination
+
+
+### TEST CONFIGURATIONS
+
+@dataclasses.dataclass
+class ConfigSettings:
+ n_clients: int
+ settings: LoadTestSettings
+ source: LoadTestSource
+ sink: LoadTestSink
+
+Configuration = typing.Tuple[typing.Iterable[ez.Component], ez.NetworkDefinition]
+Configurator = typing.Callable[[ConfigSettings], Configuration]
+
+def fanout(config: ConfigSettings) -> Configuration:
+ """ one pub to many subs """
+ connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)]
+ subs = [LoadTestReceiver(config.settings) for _ in range(config.n_clients)]
+ for sub in subs:
+ connections.append((config.source.OUTPUT, sub.INPUT))
+
+ return subs, connections
+
+def fanin(config: ConfigSettings) -> Configuration:
+ """ many pubs to one sub """
+ connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)]
+ pubs = [LoadTestSource(config.settings) for _ in range(config.n_clients)]
+ expected_num_msgs = config.sink.SETTINGS.num_msgs * len(pubs)
+ config.sink.SETTINGS = replace(config.sink.SETTINGS, num_msgs = expected_num_msgs) # type: ignore
+ for pub in pubs:
+ connections.append((pub.OUTPUT, config.sink.INPUT))
+ return pubs, connections
+
+
+def relay(config: ConfigSettings) -> Configuration:
+ """ one pub to one sub through many relays """
+ connections: ez.NetworkDefinition = []
+
+ relays = [LoadTestRelay(config.settings) for _ in range(config.n_clients)]
+ if len(relays):
+ connections.append((config.source.OUTPUT, relays[0].INPUT))
+ for from_relay, to_relay in zip(relays[:-1], relays[1:]):
+ connections.append((from_relay.OUTPUT, to_relay.INPUT))
+ connections.append((relays[-1].OUTPUT, config.sink.INPUT))
+ else: connections.append((config.source.OUTPUT, config.sink.INPUT))
+
+ return relays, connections
+
+CONFIGS: typing.Mapping[str, Configurator] = {
+ c.__name__: c for c in [
+ fanin,
+ fanout,
+ relay
+ ]
+}
+
+class Communication(enum.StrEnum):
+ LOCAL = "local"
+ SHM = "shm"
+ SHM_SPREAD = "shm_spread"
+ TCP = "tcp"
+ TCP_SPREAD = "tcp_spread"
+
+def perform_test(
+ n_clients: int,
+ max_duration: float,
+ num_msgs: int,
+ msg_size: int,
+ buffers: int,
+ comms: Communication,
+ config: Configurator,
+ graph_address: Address
+) -> Metrics:
+
+ settings = LoadTestSettings(
+ dynamic_size = int(msg_size),
+ num_msgs = num_msgs,
+ max_duration = max_duration,
+ buffers = buffers,
+ force_tcp = (comms in (Communication.TCP, Communication.TCP_SPREAD)),
+ )
+
+ source = LoadTestSource(settings)
+ sink = LoadTestSink(settings)
+
+ components: typing.Mapping[str, ez.Component] = dict(
+ SINK = sink,
+ )
+
+ clients, connections = config(ConfigSettings(n_clients, settings, source, sink))
+
+ # The 'sink' MUST remain in this process for us to pull its state.
+ process_components: typing.Iterable[ez.Component] = []
+ if comms == Communication.LOCAL:
+ # Every component in the same process (this one)
+ components["SOURCE"] = source
+ for i, client in enumerate(clients):
+ components[f"CLIENT_{i+1}"] = client
+
+ else:
+
+ if comms in (Communication.SHM_SPREAD, Communication.TCP_SPREAD):
+ # Every component in its own process.
+ components["SOURCE"] = source
+ process_components.append(source)
+ for i, client in enumerate(clients):
+ components[f'CLIENT_{i+1}'] = client
+ process_components.append(client)
+
+ else:
+ # All clients and the source in ONE other process.
+ collect_comps: typing.Mapping[str, ez.Component] = dict()
+ collect_comps["SOURCE"] = source
+ for i, client in enumerate(clients):
+ collect_comps[f"CLIENT_{i+1}"] = client
+ proc_collection = collect(components = collect_comps)
+ components["PROC"] = proc_collection
+ process_components = [proc_collection]
+
+ with stable_perf():
+ ez.run(
+ components = components,
+ connections = connections,
+ process_components = process_components,
+ graph_address = graph_address
+ )
+
+ return calculate_metrics(sink)
+
+
+def calculate_metrics(sink: LoadTestSink) -> Metrics:
+
+ # Log some useful summary statistics
+ min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data)
+ max_timestamp = max(timestamp for timestamp, _, _ in sink.STATE.received_data)
+ latency = [
+ receive_timestamp - send_timestamp
+ for send_timestamp, receive_timestamp, _ in sink.STATE.received_data
+ ]
+ total_latency = abs(sum(latency))
+
+ counters = list(sorted(t[2] for t in sink.STATE.received_data))
+ dropped_samples = sum(
+ [max((x1 - x0) - 1, 0) for x1, x0 in zip(counters[1:], counters[:-1])]
+ )
+
+ rx_timestamps = np.array([rx_ts for _, rx_ts, _ in sink.STATE.received_data])
+ runtime = max_timestamp - min_timestamp
+ num_samples = len(sink.STATE.received_data)
+ samplerate_mean = num_samples / runtime
+ samplerate_median = 1.0 / float(np.median(np.diff(rx_timestamps)))
+ latency_mean = total_latency / num_samples
+ latency_median = list(sorted(latency))[len(latency) // 2]
+ total_data = num_samples * sink.SETTINGS.dynamic_size
+ data_rate = total_data / runtime
+
+ ez.logger.info(f"Samples received: {num_samples}")
+ ez.logger.info(f"Mean sample rate: {samplerate_mean} Hz")
+ ez.logger.info(f"Median sample rate: {samplerate_median} Hz")
+ ez.logger.info(f"Mean latency: {latency_mean} s")
+ ez.logger.info(f"Median latency: {latency_median} s")
+ ez.logger.info(f"Total latency: {total_latency} s")
+ ez.logger.info(f"Data rate: {data_rate * 1e-6} MB/s")
+
+ if dropped_samples:
+ ez.logger.error(
+ f"Dropped samples: {dropped_samples} ({dropped_samples / (dropped_samples + num_samples)}%)",
+ )
+
+ return Metrics(
+ num_msgs = num_samples,
+ sample_rate_mean = samplerate_mean,
+ sample_rate_median = samplerate_median,
+ latency_mean = latency_mean,
+ latency_median = latency_median,
+ latency_total = total_latency,
+ data_rate = data_rate
+ )
+
+
+@dataclasses.dataclass(unsafe_hash=True)
+class TestParameters:
+ msg_size: int
+ num_msgs: int
+ n_clients: int
+ config: str
+ comms: str
+ max_duration: float
+ num_buffers: int
+
+
+@dataclasses.dataclass
+class TestLogEntry:
+ params: TestParameters
+ results: Metrics
\ No newline at end of file
diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py
new file mode 100644
index 00000000..ce6c2071
--- /dev/null
+++ b/src/ezmsg/util/perf/run.py
@@ -0,0 +1,317 @@
+import os
+import sys
+import json
+import datetime
+import itertools
+import argparse
+import typing
+import random
+import time
+
+from datetime import datetime, timedelta
+from contextlib import contextmanager, redirect_stdout, redirect_stderr
+
+import ezmsg.core as ez
+from ezmsg.core.graphserver import GraphServer
+
+from ..messagecodec import MessageEncoder
+from .envinfo import TestEnvironmentInfo
+from .util import warmup
+from .impl import (
+ TestParameters,
+ TestLogEntry,
+ perform_test,
+ Communication,
+ CONFIGS,
+)
+
+DEFAULT_MSG_SIZES = [2**4, 2**20]
+DEFAULT_N_CLIENTS = [1, 16]
+DEFAULT_COMMS = [c for c in Communication]
+
+# --- Output Suppression Context Manager ---
+@contextmanager
+def suppress_output(verbose: bool = False):
+ """Context manager to redirect stdout and stderr to os.devnull"""
+ if verbose:
+ yield
+ else:
+ # Open the null device for writing
+ with open(os.devnull, 'w') as fnull:
+ # Redirect both stdout and stderr to the null device
+ with redirect_stderr(fnull):
+ with redirect_stdout(fnull):
+ yield
+
+CHECK_FOR_QUIT = lambda: False
+
+if sys.platform.startswith('win'):
+ import msvcrt
+ def _check_for_quit_win() -> bool:
+ """
+ Checks for the 'q' key press in a non-blocking way.
+ Returns True if 'q' is pressed (case-insensitive), False otherwise.
+ """
+ # Windows: Use msvcrt for non-blocking keyboard hit detection
+ if msvcrt.kbhit(): # type: ignore
+ # Read the key press (returns bytes)
+ key = msvcrt.getch() # type: ignore
+ try:
+ # Decode and check for 'q'
+ return key.decode().lower() == 'q'
+ except UnicodeDecodeError:
+ # Handle potential non-text key presses gracefully
+ return False
+ return False
+
+ CHECK_FOR_QUIT = _check_for_quit_win
+
+else:
+ import select
+ def _check_for_quit() -> bool:
+ """
+ Checks for the 'q' key press in a non-blocking way.
+ Returns True if 'q' is pressed (case-insensitive), False otherwise.
+ """
+ # Linux/macOS: Use select to check if stdin has data
+ # select.select(rlist, wlist, xlist, timeout)
+ # timeout=0 makes it non-blocking
+ if sys.stdin.isatty():
+ i, o, e = select.select([sys.stdin], [], [], 0) # type: ignore
+ if i:
+ # Read the buffered character
+ key = sys.stdin.read(1)
+ return key.lower() == 'q'
+ return False
+
+ CHECK_FOR_QUIT = _check_for_quit
+
+def get_datestamp() -> str:
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
+
+def perf_run(
+ max_duration: float,
+ num_msgs: int,
+ num_buffers: int,
+ iters: int,
+ repeats: int,
+ msg_sizes: typing.List[int] | None,
+ n_clients: typing.List[int] | None,
+ comms: typing.Iterable[str] | None,
+ configs: typing.Iterable[str] | None,
+ grid: bool,
+ warmup_dur: float,
+) -> None:
+
+ if n_clients is None:
+ n_clients = DEFAULT_N_CLIENTS
+ if any(c < 0 for c in n_clients):
+ ez.logger.error('All tests must have >=0 clients')
+ return
+
+ if msg_sizes is None:
+ msg_sizes = DEFAULT_MSG_SIZES
+ if any(s < 0 for s in msg_sizes):
+ ez.logger.error('All msg_sizes must be >=0 bytes')
+
+ if not grid and len(list(n_clients)) != len(list(msg_sizes)):
+ ez.logger.warning(
+ "Not performing a grid test of all combinations of n_clients and msg_sizes, but " + \
+ f"{len(n_clients)=} which is not equal to {len(msg_sizes)=}. "
+ )
+
+ try:
+ communications = DEFAULT_COMMS if comms is None else [Communication(c) for c in comms]
+ except ValueError:
+ ez.logger.error(f"Invalid test communications requested. Valid communications: {', '.join([c.value for c in Communication])}")
+ return
+
+ try:
+ configurators = list(CONFIGS.values()) if configs is None else [CONFIGS[c] for c in configs]
+ except ValueError:
+ ez.logger.error(f"Invalid test configuration requested. Valid configurations: {', '.join([c for c in CONFIGS])}")
+ return
+
+ subitr = itertools.product if grid else zip
+
+ test_list = [
+ (msg_size, clients, conf, comm)
+ for msg_size, clients in subitr(msg_sizes, n_clients)
+ for conf, comm in itertools.product(configurators, communications)
+ ] * iters
+
+ random.shuffle(test_list)
+
+ server = GraphServer()
+ server.start()
+
+ ez.logger.info(f"About to run {len(test_list)} tests (repeated {repeats} times) of {max_duration} sec (max) each.")
+ ez.logger.info(f"During each test, source will attempt to send {num_msgs} messages to the sink.")
+ ez.logger.info(f"Please try to avoid running other taxing software while this perf test runs.")
+ ez.logger.info(f"NOTE: Tests swallow interrupt. After warmup, use 'q' then [enter] to quit tests early.")
+
+ quitting = False
+
+ start_time = time.time()
+
+ try:
+ ez.logger.info(f"Warming up for {warmup_dur} seconds...")
+ warmup(warmup_dur)
+
+ with open(f'perf_{get_datestamp()}.txt', 'w') as out_f:
+ for _ in range(repeats):
+ out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n")
+
+ for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list):
+
+ if CHECK_FOR_QUIT():
+ ez.logger.info("Stopping tests early...")
+ quitting = True
+ break
+
+ ez.logger.info(
+ f"TEST {test_idx + 1}/{len(test_list)}: " \
+ f"{clients=}, {msg_size=}, conf={conf.__name__}, " \
+ f"comm={comm.value}"
+ )
+
+ output = TestLogEntry(
+ params = TestParameters(
+ msg_size = msg_size,
+ num_msgs = num_msgs,
+ n_clients = clients,
+ config = conf.__name__,
+ comms = comm.value,
+ max_duration = max_duration,
+ num_buffers = num_buffers
+ ),
+ results = perform_test(
+ n_clients = clients,
+ max_duration = max_duration,
+ num_msgs = num_msgs,
+ msg_size = msg_size,
+ buffers = num_buffers,
+ comms = comm,
+ config = conf,
+ graph_address = server.address
+ )
+ )
+
+ out_f.write(json.dumps(output, cls = MessageEncoder) + "\n")
+
+ if quitting:
+ break
+
+ finally:
+ server.stop()
+ d = datetime(1,1,1) + timedelta(seconds = time.time() - start_time)
+ dur_str = ':'.join([str(n) for n in [d.day - 1, d.hour, d.minute, d.second] if n != 0])
+ ez.logger.info(f"Tests concluded. Wallclock Runtime: {dur_str}s")
+
+
+
+
+def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None:
+
+ p_run = subparsers.add_parser("run", help="run performance test")
+
+ p_run.add_argument(
+ "--max-duration",
+ type=float,
+ default=5.0,
+ help="maximum individual test duration in seconds (default = 5.0)",
+ )
+
+ p_run.add_argument(
+ "--num-msgs",
+ type=int,
+ default=1000,
+ help = "number of messages to send per-test (default = 1000)"
+ )
+
+ # NOTE: We default num-buffers = 1 because this degenerate perf test scenario (blasting
+ # messages as fast as possible through the system) results in one of two scenerios:
+ # 1. A (few) messages is/are enqueued and dequeued before another message is posted
+ # 2. The buffer fills up before being FULLY emptied resulting in longer latency.
+ # (once a channel enters this condition, it tends to stay in this condition)
+ #
+ # This _indeterminate_ behavior results in bimodal distributions of runtimes that make
+ # A/B performance comparisons difficult. The perf test is not representative of the vast
+ # majority of production ezmsg systems where publishing is generally rate-limited.
+ #
+ # A flow-control algorithm could stabilize perf-test results with num_buffers > 1, but is
+ # generally implemented by enforcing delays on the publish side which simply degrades
+ # performance in the vast majority of ezmsg systems. - Griff
+ p_run.add_argument(
+ "--num-buffers",
+ type=int,
+ default=1,
+ help="shared memory buffers (default = 1)",
+ )
+
+ p_run.add_argument(
+ "--iters", "-i",
+ type = int,
+ default = 5,
+ help = "number of times to run each test (default = 5)"
+ )
+
+ p_run.add_argument(
+ "--repeats", "-r",
+ type = int,
+ default = 10,
+ help = "number of times to repeat the perf (default = 10)"
+ )
+
+ p_run.add_argument(
+ "--msg-sizes",
+ type = int,
+ default = None,
+ nargs = "*",
+ help = f"message sizes in bytes (default = {DEFAULT_MSG_SIZES})"
+ )
+
+ p_run.add_argument(
+ "--n-clients",
+ type = int,
+ default = None,
+ nargs = "*",
+ help = f"number of clients (default = {DEFAULT_N_CLIENTS})"
+ )
+
+ p_run.add_argument(
+ "--comms",
+ type = str,
+ default = None,
+ nargs = "*",
+ help = f"communication strategies to test (default = {[c.value for c in DEFAULT_COMMS]})"
+ )
+
+ p_run.add_argument(
+ "--configs",
+ type = str,
+ default = None,
+ nargs = "*",
+ help = f"configurations to test (default = {[c for c in CONFIGS]})"
+ )
+
+ p_run.add_argument(
+ "--warmup",
+ type = float,
+ default = 60.0,
+ help = "warmup CPU with busy task for some number of seconds (default = 60.0)"
+ )
+
+ p_run.set_defaults(_handler=lambda ns: perf_run(
+ max_duration = ns.max_duration,
+ num_msgs = ns.num_msgs,
+ num_buffers = ns.num_buffers,
+ iters = ns.iters,
+ repeats = ns.repeats,
+ msg_sizes = ns.msg_sizes,
+ n_clients = ns.n_clients,
+ comms = ns.comms,
+ configs = ns.configs,
+ grid = True,
+ warmup_dur = ns.warmup,
+ ))
\ No newline at end of file
diff --git a/src/ezmsg/util/perf/util.py b/src/ezmsg/util/perf/util.py
new file mode 100644
index 00000000..d22d1496
--- /dev/null
+++ b/src/ezmsg/util/perf/util.py
@@ -0,0 +1,288 @@
+import os
+import sys
+import gc
+import time
+import statistics as stats
+import contextlib
+import subprocess
+import platform
+from dataclasses import dataclass
+from typing import Iterable, List, Optional
+
+try:
+ import psutil # optional but helpful
+except Exception:
+ psutil = None
+
+_IS_WIN = os.name == "nt"
+_IS_MAC = sys.platform == "darwin"
+_IS_LINUX = sys.platform.startswith("linux")
+
+# ---------- Utilities ----------
+
+def _set_env_threads(single_thread: bool = True):
+ """
+ Normalize math/threading libs so they don't spawn surprise worker threads.
+ """
+ if single_thread:
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
+ os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
+ os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
+ # Keep PYTHONHASHSEED stable for deterministic dict/set iteration costs
+ os.environ.setdefault("PYTHONHASHSEED", "0")
+
+# ---------- Priority & Affinity ----------
+
+@contextlib.contextmanager
+def _process_priority():
+ """
+ Elevate process priority in a cross-platform best-effort way.
+ """
+ if psutil is None:
+ yield
+ return
+
+ p = psutil.Process()
+ orig_nice = None
+ if _IS_WIN:
+ try:
+ import ctypes, ctypes.wintypes as wt
+ kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
+ ABOVE_NORMAL_PRIORITY_CLASS = 0x00008000
+ HIGH_PRIORITY_CLASS = 0x00000080
+ # Try High, fall back to Above Normal
+ if not kernel32.SetPriorityClass(kernel32.GetCurrentProcess(), HIGH_PRIORITY_CLASS):
+ kernel32.SetPriorityClass(kernel32.GetCurrentProcess(), ABOVE_NORMAL_PRIORITY_CLASS)
+ except Exception:
+ pass
+ else:
+ try:
+ orig_nice = p.nice()
+ # Negative nice may need privileges; try smaller magnitude first
+ for nice_val in (-10, -5, 0):
+ try:
+ p.nice(nice_val)
+ break
+ except Exception:
+ continue
+ except Exception:
+ pass
+ try:
+ yield
+ finally:
+ # restore nice if we changed it
+ if psutil is not None and not _IS_WIN and orig_nice is not None:
+ try:
+ p.nice(orig_nice)
+ except Exception:
+ pass
+
+@contextlib.contextmanager
+def _cpu_affinity(prefer_isolation: bool = True):
+ """
+ Set CPU affinity to a small, stable set of CPUs (where supported).
+ macOS does not support affinity via psutil; we no-op there.
+ """
+ if psutil is None or _IS_MAC:
+ yield
+ return
+
+ p = psutil.Process()
+ original = None
+ try:
+ if hasattr(p, "cpu_affinity"):
+ original = p.cpu_affinity()
+ cpus = original
+ if prefer_isolation and len(cpus) > 2:
+ # Pick two middle CPUs to avoid 0 which often handles interrupts
+ mid = len(cpus) // 2
+ cpus = [cpus[mid-1], cpus[mid]]
+ p.cpu_affinity(cpus)
+ yield
+ finally:
+ try:
+ if original is not None and hasattr(p, "cpu_affinity"):
+ p.cpu_affinity(original)
+ except Exception:
+ pass
+
+# ---------- Platform-specific helpers ----------
+
+@contextlib.contextmanager
+def _mac_caffeinate():
+ """
+ Keep macOS awake during the run via a background caffeinate process.
+ """
+ if not _IS_MAC:
+ yield
+ return
+ proc = None
+ try:
+ proc = subprocess.Popen(["caffeinate", "-dimsu"])
+ except Exception:
+ proc = None
+ try:
+ yield
+ finally:
+ if proc is not None:
+ try:
+ proc.terminate()
+ except Exception:
+ pass
+
+@contextlib.contextmanager
+def _win_timer_resolution(ms: int = 1):
+ """
+ On Windows, request a finer system timer to stabilize sleeps and scheduling slices.
+ """
+ if not _IS_WIN:
+ yield
+ return
+ import ctypes
+ winmm = ctypes.WinDLL("winmm")
+ timeBeginPeriod = winmm.timeBeginPeriod
+ timeEndPeriod = winmm.timeEndPeriod
+ try:
+ timeBeginPeriod(ms)
+ except Exception:
+ pass
+ try:
+ yield
+ finally:
+ try:
+ timeEndPeriod(ms)
+ except Exception:
+ pass
+
+# ---------- Warm-up & GC ----------
+
+def warmup(seconds: float = 60.0, fn=None, *args, **kwargs):
+ """
+ Optional warm-up to reach steady clocks/caches.
+ If fn is provided, call it in a loop for the given time.
+ """
+ if seconds <= 0:
+ return
+ end = time.perf_counter()
+ target = end + seconds
+ if fn is None:
+ # Busy wait / sleep mix to heat up without heavy CPU
+ while time.perf_counter() < target:
+ x = 0
+ for _ in range(10000):
+ x += 1
+ time.sleep(0)
+ else:
+ while time.perf_counter() < target:
+ fn(*args, **kwargs)
+
+@contextlib.contextmanager
+def gc_pause():
+ """
+ Disable GC inside timing windows; re-enable and collect after.
+ """
+ was_enabled = gc.isenabled()
+ try:
+ gc.disable()
+ yield
+ finally:
+ if was_enabled:
+ gc.enable()
+ gc.collect()
+
+# ---------- Robust statistics ----------
+
+def median_of_means(samples: Iterable[float], k: int = 5) -> float:
+ """
+ Robust estimate: split samples into k buckets (round-robin), average each, take median of bucket means.
+ """
+ samples = list(samples)
+ if not samples:
+ return float("nan")
+ k = max(1, min(k, len(samples)))
+ buckets = [[] for _ in range(k)]
+ for i, v in enumerate(samples):
+ buckets[i % k].append(v)
+ means = [sum(b)/len(b) for b in buckets if b]
+ means.sort()
+ return means[len(means)//2]
+
+def coef_var(samples: Iterable[float]) -> float:
+ vals = list(samples)
+ if len(vals) < 2:
+ return 0.0
+ m = sum(vals)/len(vals)
+ if m == 0:
+ return 0.0
+ sd = stats.pstdev(vals)
+ return sd / m
+
+# ---------- Public context manager ----------
+
+@dataclass
+class PerfOptions:
+ single_thread_math: bool = True
+ prefer_isolated_cpus: bool = True
+ warmup_seconds: float = 0.0
+ adjust_priority: bool = True
+ tweak_timer_windows: bool = True
+ keep_mac_awake: bool = True
+
+@contextlib.contextmanager
+def stable_perf(opts: PerfOptions = PerfOptions()):
+ """
+ Wrap your perf runs with this context manager for a stabler environment.
+ """
+ _set_env_threads(opts.single_thread_math)
+
+ cm_stack = contextlib.ExitStack()
+ try:
+ if opts.adjust_priority:
+ cm_stack.enter_context(_process_priority())
+ if opts.tweak_timer_windows:
+ cm_stack.enter_context(_win_timer_resolution(1))
+ if opts.prefer_isolated_cpus:
+ cm_stack.enter_context(_cpu_affinity(True))
+ if opts.keep_mac_awake:
+ cm_stack.enter_context(_mac_caffeinate())
+
+ if opts.warmup_seconds > 0:
+ warmup(opts.warmup_seconds)
+
+ with gc_pause():
+ yield
+ finally:
+ cm_stack.close()
+
+# ---------- Example runners ----------
+
+def run_interleaved(configs: List[dict], run_fn, trials: int = 5, trial_seconds: float = 30.0, seed: int = 42):
+ """
+ Interleave scenarios to cancel slow drift. `run_fn(config, seconds, seed_offset)` should return a dict of metrics.
+ Returns a list of per-trial results per config.
+ """
+ import random
+ random.seed(seed)
+ results = [ [] for _ in configs ]
+ order = list(range(len(configs)))
+ # fixed order per pass; you'll re-use the same order for stability
+ for t in range(trials):
+ for idx in order:
+ res = run_fn(configs[idx], trial_seconds, seed + t)
+ results[idx].append(res)
+ return results
+
+def summarize_metric(trials: List[dict], key: str, mom_buckets: int = 5):
+ """
+ Extract a metric across trial dicts and summarize with median-of-means and CV.
+ """
+ vals = [float(tr[key]) for tr in trials if key in tr]
+ return {
+ "count": len(vals),
+ "mom": median_of_means(vals, mom_buckets),
+ "mean": sum(vals)/len(vals) if vals else float("nan"),
+ "p50": stats.median(vals) if vals else float("nan"),
+ "cv": coef_var(vals) if vals else float("nan"),
+ }
\ No newline at end of file
diff --git a/src/ezmsg/util/perf_test.py b/src/ezmsg/util/perf_test.py
deleted file mode 100644
index b7536b5e..00000000
--- a/src/ezmsg/util/perf_test.py
+++ /dev/null
@@ -1,220 +0,0 @@
-import asyncio
-import dataclasses
-import datetime
-import os
-import platform
-import time
-
-from typing import List, Tuple, AsyncGenerator
-
-import ezmsg.core as ez
-
-# We expect this test to generate LOTS of backpressure warnings
-# PERF_LOGLEVEL = os.environ.get("EZMSG_LOGLEVEL", "ERROR")
-# ez.logger.setLevel(PERF_LOGLEVEL)
-
-PLATFORM = {
- "Darwin": "mac",
- "Linux": "linux",
- "Windows": "win",
-}[platform.system()]
-SAMPLE_SUMMARY_DATASET_PREFIX = "sample_summary"
-COUNT_DATASET_NAME = "count"
-
-
-try:
- import numpy as np
-except ImportError:
- ez.logger.error("This test requires Numpy to run.")
- raise
-
-
-def get_datestamp() -> str:
- return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
-
-
-class LoadTestSettings(ez.Settings):
- duration: float = 30.0
- dynamic_size: int = 8
- buffers: int = 32
-
-
-@dataclasses.dataclass
-class LoadTestSample:
- _timestamp: float
- counter: int
- dynamic_data: np.ndarray
-
-
-class LoadTestPublisher(ez.Unit):
- OUTPUT = ez.OutputStream(LoadTestSample)
- SETTINGS = LoadTestSettings
-
- async def initialize(self) -> None:
- self.running = True
- self.counter = 0
- self.OUTPUT.num_buffers = self.SETTINGS.buffers
-
- @ez.publisher(OUTPUT)
- async def publish(self) -> AsyncGenerator:
- ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})")
- start_time = time.time()
- while self.running:
- current_time = time.time()
- if current_time - start_time >= self.SETTINGS.duration:
- break
-
- yield (
- self.OUTPUT,
- LoadTestSample(
- _timestamp=time.time(),
- counter=self.counter,
- dynamic_data=np.zeros(
- int(self.SETTINGS.dynamic_size // 8), dtype=np.float32
- ),
- ),
- )
- self.counter += 1
- ez.logger.info("Exiting publish")
- raise ez.Complete
-
- async def shutdown(self) -> None:
- self.running = False
- ez.logger.info(f"Samples sent: {self.counter}")
-
-
-class LoadTestSubscriberState(ez.State):
- # Tuples of sent timestamp, received timestamp, counter, dynamic size
- received_data: List[Tuple[float, float, int]] = dataclasses.field(
- default_factory=list
- )
- counter: int = -1
-
-
-class LoadTestSubscriber(ez.Unit):
- INPUT = ez.InputStream(LoadTestSample)
- SETTINGS = LoadTestSettings
- STATE = LoadTestSubscriberState
-
- @ez.subscriber(INPUT, zero_copy=True)
- async def receive(self, sample: LoadTestSample) -> None:
- if sample.counter != self.STATE.counter + 1:
- ez.logger.warning(
- f"{sample.counter - self.STATE.counter - 1} samples skipped!"
- )
- self.STATE.received_data.append(
- (sample._timestamp, time.time(), sample.counter)
- )
- self.STATE.counter = sample.counter
-
- @ez.task
- async def log_result(self) -> None:
- ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})")
-
- # Wait for the duration of the load test
- await asyncio.sleep(self.SETTINGS.duration)
- # logger.info(f"STATE = {self.STATE.received_data}")
-
- # Log some useful summary statistics
- min_timestamp = min(timestamp for timestamp, _, _ in self.STATE.received_data)
- max_timestamp = max(timestamp for timestamp, _, _ in self.STATE.received_data)
- total_latency = abs(
- sum(
- receive_timestamp - send_timestamp
- for send_timestamp, receive_timestamp, _ in self.STATE.received_data
- )
- )
-
- counters = list(sorted(t[2] for t in self.STATE.received_data))
- dropped_samples = sum(
- [(x1 - x0) - 1 for x1, x0 in zip(counters[1:], counters[:-1])]
- )
-
- num_samples = len(self.STATE.received_data)
- ez.logger.info(f"Samples received: {num_samples}")
- ez.logger.info(
- f"Sample rate: {num_samples / (max_timestamp - min_timestamp)} Hz"
- )
- ez.logger.info(f"Mean latency: {total_latency / num_samples} s")
- ez.logger.info(f"Total latency: {total_latency} s")
-
- total_data = num_samples * self.SETTINGS.dynamic_size
- ez.logger.info(
- f"Data rate: {total_data / (max_timestamp - min_timestamp) * 1e-6} MB/s"
- )
- ez.logger.info(
- f"Dropped samples: {dropped_samples} ({dropped_samples / (dropped_samples + num_samples)}%)",
- )
-
- raise ez.NormalTermination
-
-
-class LoadTest(ez.Collection):
- SETTINGS = LoadTestSettings
-
- PUBLISHER = LoadTestPublisher()
- SUBSCRIBER = LoadTestSubscriber()
-
- def configure(self) -> None:
- self.PUBLISHER.apply_settings(self.SETTINGS)
- self.SUBSCRIBER.apply_settings(self.SETTINGS)
-
- def network(self) -> ez.NetworkDefinition:
- return ((self.PUBLISHER.OUTPUT, self.SUBSCRIBER.INPUT),)
-
- def process_components(self):
- return (
- self.PUBLISHER,
- self.SUBSCRIBER,
- )
-
-
-def get_time() -> float:
- # time.perf_counter() isn't system-wide on Windows Python 3.6:
- # https://bugs.python.org/issue37205
- return time.time() if PLATFORM == "win" else time.perf_counter()
-
-
-def test_performance(duration, size, buffers) -> None:
- ez.logger.info(f"Running load test for dynamic size: {size} bytes")
- system = LoadTest(
- LoadTestSettings(dynamic_size=int(size), duration=duration, buffers=buffers)
- )
- ez.run(SYSTEM=system)
-
-
-def run_many_dynamic_sizes(duration, buffers) -> None:
- for exp in range(5, 22, 4):
- test_performance(duration, 2**exp, buffers)
-
-
-if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--many-dynamic-sizes",
- action="store_true",
- help="Run load test for many dynamic sizes",
- )
- parser.add_argument(
- "--duration",
- type=int,
- default=2,
- help="How long to run the load test (seconds)",
- )
- parser.add_argument(
- "--num-buffers", type=int, default=32, help="Shared memory buffers"
- )
-
- class Args:
- many_dynamic_sizes: bool
- duration: int
- num_buffers: int
-
- args = parser.parse_args(namespace=Args)
-
- if args.many_dynamic_sizes:
- run_many_dynamic_sizes(args.duration, args.num_buffers)
- else:
- test_performance(args.duration, 8, args.num_buffers)