diff --git a/pyproject.toml b/pyproject.toml index df1f469a..c0888f86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dev = [ {include-group = "lint"}, {include-group = "test"}, "pre-commit>=4.3.0", + "viztracer>=1.0.4", ] lint = [ "flake8>=7.3.0", @@ -34,7 +35,8 @@ test = [ "pytest>=8.4.1", "pytest-asyncio>=1.1.0", "pytest-cov>=6.2.1", - "xarray>=2023.1.0;python_version<'3.13'" + "xarray>=2023.1.0;python_version<'3.13'", + "psutil>=7.1.0", ] docs = [ "ezmsg-sigproc>=2.2.0", @@ -44,6 +46,7 @@ docs = [ [project.scripts] ezmsg = "ezmsg.core.command:cmdline" +ezmsg-perf = "ezmsg.util.perf.command:command" [project.optional-dependencies] axisarray = [ diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 50e1aa0c..9a28d9ee 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -210,7 +210,7 @@ async def _handle_publisher( self._incoming.put_nowait((id, msg_id)) - except (ConnectionResetError, BrokenPipeError): + except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError): logger.debug(f"connection fail: sub:{self.id} -> pub:{id}") finally: diff --git a/src/ezmsg/util/perf/__init__.py b/src/ezmsg/util/perf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py new file mode 100644 index 00000000..2df4bebb --- /dev/null +++ b/src/ezmsg/util/perf/analysis.py @@ -0,0 +1,454 @@ +import json +import typing +import dataclasses +import argparse +import html +import math +import webbrowser + +from pathlib import Path + +from ..messagecodec import MessageDecoder +from .envinfo import TestEnvironmentInfo, format_env_diff +from .run import get_datestamp +from .impl import ( + TestParameters, + Metrics, + TestLogEntry, +) + +import ezmsg.core as ez + +try: + import xarray as xr + import pandas as pd # xarray depends on pandas +except ImportError: + ez.logger.error('ezmsg perf analysis requires xarray') + raise + +try: + import numpy as np +except ImportError: + ez.logger.error('ezmsg perf analysis requires numpy') + raise + +TEST_DESCRIPTION = """ +Configurations (config): +- fanin: many publishers to one subscriber +- fanout: one publisher to many subscribers +- relay: one publisher to one subscriber through many relays + +Communication strategies (comms): +- local: all subs, relays, and pubs are in the SAME process +- shm / tcp: some clients move to a second process; comms via shared memory / TCP + * fanin: all publishers moved + * fanout: all subscribers moved + * relay: the publisher and all relay nodes moved +- shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively + +Variables: +- n_clients: pubs (fanin), subs (fanout), or relays (relay) +- msg_size: nominal message size (bytes) + +Metrics: +- sample_rate: messages/sec at the sink (higher = better) +- data_rate: bytes/sec at the sink (higher = better) +- latency_mean: average send -> receive latency in seconds (lower = better) +""" + + +def load_perf(perf: Path) -> xr.Dataset: + + all_results: typing.Dict[TestParameters, typing.Dict[int, typing.List[Metrics]]] = dict() + + run_idx = 0 + + with open(perf, 'r') as perf_f: + info: TestEnvironmentInfo = json.loads(next(perf_f), cls = MessageDecoder) + for line in perf_f: + obj = json.loads(line, cls = MessageDecoder) + if isinstance(obj, TestEnvironmentInfo): + run_idx += 1 + elif isinstance(obj, TestLogEntry): + runs = all_results.get(obj.params, dict()) + metrics = runs.get(run_idx, list()) + metrics.append(obj.results) + runs[run_idx] = metrics + all_results[obj.params] = runs + + n_clients_axis = list(sorted(set([p.n_clients for p in all_results.keys()]))) + msg_size_axis = list(sorted(set([p.msg_size for p in all_results.keys()]))) + comms_axis = list(sorted(set([p.comms for p in all_results.keys()]))) + config_axis = list(sorted(set([p.config for p in all_results.keys()]))) + + dims = ['n_clients', 'msg_size', 'comms', 'config'] + coords = { + 'n_clients': n_clients_axis, + 'msg_size': msg_size_axis, + 'comms': comms_axis, + 'config': config_axis + } + + data_vars = {} + for field in dataclasses.fields(Metrics): + m = np.zeros(( + len(n_clients_axis), + len(msg_size_axis), + len(comms_axis), + len(config_axis) + )) * np.nan + for p, a in all_results.items(): + # tests are run multiple times; get the median of means + m[ + n_clients_axis.index(p.n_clients), + msg_size_axis.index(p.msg_size), + comms_axis.index(p.comms), + config_axis.index(p.config) + ] = np.median([np.mean([getattr(v, field.name) for v in r]) for r in a.values()]) + data_vars[field.name] = xr.DataArray(m, dims = dims, coords = coords) + + dataset = xr.Dataset(data_vars, attrs = dict(info = info)) + return dataset + +def _escape(s: str) -> str: + return html.escape(str(s), quote=True) + +def _env_block(title: str, body: str) -> str: + return f""" +
+

{_escape(title)}

+
{_escape(body).strip()}
+
+ """ + +def _legend_block() -> str: + return """ +
+

Legend

+ +
+ """ + +def _base_css() -> str: + # Minimal, print-friendly CSS + color scales for cells. + return """ + + """ + +def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 10.0) -> str: + """ + Returns inline CSS background for a comparison % value. + value: e.g., 97.3, 104.8, etc. + For sample_rate/data_rate: improvement > 100 (good). + For latency_mean: improvement < 100 (good). + Noise band ±10% around 100 is neutral. + """ + if not (isinstance(value, (int, float)) and math.isfinite(value)): + return "" + + delta = value - 100.0 + # Determine direction: + is good for sample/data; - is good for latency + if 'rate' in metric: + # positive delta good, negative bad + magnitude = abs(delta) + sign_good = delta > 0 + elif 'latency' in metric: + # negative delta good (lower latency) + magnitude = abs(delta) + sign_good = delta < 0 + else: + return "" + + # Noise band: keep neutral + if magnitude <= noise_band_pct: + return "" + + # Scale 5%..50% across 0..1; clamp + scale = max(0.0, min(1.0, (magnitude - noise_band_pct) / 45.0)) + + # Choose hue and lightness; use HSL with gentle saturation + hue = "var(--green)" if sign_good else "var(--red)" + # opacity via alpha blend on lightness via HSLa + # Use saturation ~70%, lightness around 40–50% blended with table bg + alpha = 0.15 + 0.35 * scale # 0.15..0.50 + return f"background-color: hsla({hue}, 70%, 45%, {alpha});" + +def _format_number(x) -> str: + if isinstance(x, (int,)) and not isinstance(x, bool): + return f"{x:d}" + try: + xf = float(x) + except Exception: + return _escape(str(x)) + # Heuristic: for comparison percentages, 1 decimal is nice; for absolute, 3 decimals for latency. + return f"{xf:.3f}" + + +def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> None: + """ print perf test results and comparisons to the console """ + + output = '' + + perf = load_perf(perf_path) + info: TestEnvironmentInfo = perf.attrs['info'] + output += str(info) + '\n\n' + + relative = False + env_diff = None + if baseline_path is not None: + relative = True + output += "PERFORMANCE COMPARISON\n\n" + baseline = load_perf(baseline_path) + perf = (perf / baseline) * 100.0 + baseline_info: TestEnvironmentInfo = baseline.attrs['info'] + env_diff = format_env_diff(info.diff(baseline_info)) + output += env_diff + '\n\n' + + # These raw stats are still valuable to have, but are confusing + # when making relative comparisons + perf = perf.drop_vars(['latency_total', 'num_msgs']) + + perf = perf.stack(params = ['n_clients', 'msg_size']).dropna('params') + df = perf.squeeze().to_dataframe() + df = df.drop('n_clients', axis = 1) + df = df.drop('msg_size', axis = 1) + + for _, config_ds in perf.groupby('config'): + for _, comms_ds in config_ds.groupby('comms'): + output += str(comms_ds.squeeze().to_dataframe()) + '\n\n' + output += '\n' + + print(output) + + if html: + # Ensure expected columns exist + expected_cols = {"sample_rate_mean", "sample_rate_median", "data_rate", "latency_mean", "latency_median"} + missing = expected_cols - set(df.columns) + if missing: + raise ValueError(f"Missing expected columns in dataset: {missing}") + + # We'll render a table per (config, comms) group. + groups = df.reset_index().sort_values( + by=["config", "comms", "n_clients", "msg_size"] + ).groupby(["config", "comms"], sort=False) + + # Build HTML + parts: list[str] = [] + parts.append("") + parts.append("") + parts.append("ezmsg perf report") + parts.append(_base_css()) + parts.append("
") + + parts.append("
") + parts.append("

ezmsg Performance Report

") + sub = str(perf_path) + if baseline_path is not None: + sub += f' relative to {str(baseline_path)}' + parts.append(f"
{_escape(sub)}
") + parts.append("
") + + if info is not None: + parts.append(_env_block("Test Environment", str(info))) + + parts.append(_env_block("Test Details", TEST_DESCRIPTION)) + + if env_diff is not None: + # Show diffs using your helper + parts.append("
") + parts.append("

Environment Differences vs Baseline

") + parts.append(f"
{_escape(env_diff)}
") + parts.append("
") + parts.append(_legend_block()) + + # Render each group + for (config, comms), g in groups: + # Keep only expected columns in order + cols = ["n_clients", "msg_size", "sample_rate_mean", "sample_rate_median", "data_rate", "latency_mean", "latency_median"] + g = g[cols].copy() + + # String format some columns (msg_size with separators) + g["msg_size"] = g["msg_size"].map(lambda x: f"{int(x):,}" if pd.notna(x) else x) + + # Build table manually so we can inject inline cell styles easily + # (pandas Styler is great but produces bulky HTML; manual keeps it clean) + header = f""" + + + n_clients + msg_size {'' if relative else '(b)'} + sample_rate_mean {'' if relative else '(msgs/s)'} + sample_rate_median {'' if relative else '(msgs/s)'} + data_rate {'' if relative else '(MB/s)'} + latency_mean {'' if relative else '(us)'} + latency_median {'' if relative else '(us)'} + + + """ + body_rows: list[str] = [] + for _, row in g.iterrows(): + sr, srm, dr, lt, lm = row["sample_rate_mean"], row["sample_rate_median"], row["data_rate"], row["latency_mean"], row["latency_median"] + dr = dr if relative else dr / 2**20 + lt = lt if relative else lt * 1e6 + lm = lm if relative else lm * 1e6 + sr_style = _color_for_comparison(sr, "sample_rate_mean") if relative else "" + srm_style = _color_for_comparison(srm, "sample_rate_median") if relative else "" + dr_style = _color_for_comparison(dr, "data_rate") if relative else "" + lt_style = _color_for_comparison(lt, "latency_mean") if relative else "" + lm_style = _color_for_comparison(lm, "latency_median") if relative else "" + + body_rows.append( + "" + f"{_format_number(row['n_clients'])}" + f"{_escape(row['msg_size'])}" + f"{_format_number(sr)}" + f"{_format_number(srm)}" + f"{_format_number(dr)}" + f"{_format_number(lt)}" + f"{_format_number(lm)}" + "" + ) + table_html = f"{header}{''.join(body_rows)}
" + + parts.append( + f"

" + f"{_escape(config)}" + f"{_escape(comms)}" + f"

{table_html}
" + ) + + parts.append("
") + html_text = "".join(parts) + + out_path = Path(f'report_{get_datestamp()}.html') + out_path.write_text(html_text, encoding="utf-8") + webbrowser.open(out_path.resolve().as_uri()) + + +def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: + p_summary = subparsers.add_parser("summary", help = "summarize performance results") + p_summary.add_argument( + "perf", + type=Path, + help="perf test", + ) + p_summary.add_argument( + "--baseline", + "-b", + type=Path, + default=None, + help="baseline perf test for comparison", + ) + p_summary.add_argument( + "--html", + action = 'store_true', + help = "generate an html output file and render results in browser", + ) + + p_summary.set_defaults(_handler=lambda ns: summary( + perf_path = ns.perf, + baseline_path = ns.baseline, + html = ns.html + )) \ No newline at end of file diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py new file mode 100644 index 00000000..ab919999 --- /dev/null +++ b/src/ezmsg/util/perf/command.py @@ -0,0 +1,17 @@ +import argparse + +from .analysis import setup_summary_cmdline +from .run import setup_run_cmdline + +def command() -> None: + parser = argparse.ArgumentParser(description = 'ezmsg perf test utility') + subparsers = parser.add_subparsers(dest="command", required=True) + + setup_run_cmdline(subparsers) + setup_summary_cmdline(subparsers) + + ns = parser.parse_args() + ns._handler(ns) + +if __name__ == "__main__": + command() diff --git a/src/ezmsg/util/perf/envinfo.py b/src/ezmsg/util/perf/envinfo.py new file mode 100644 index 00000000..66fff18b --- /dev/null +++ b/src/ezmsg/util/perf/envinfo.py @@ -0,0 +1,83 @@ +import dataclasses +import datetime +import platform +import typing +import sys +import subprocess + +import ezmsg.core as ez + +try: + import numpy as np +except ImportError: + ez.logger.error("ezmsg perf requires numpy") + raise + +try: + import psutil +except ImportError: + ez.logger.error("ezmsg perf requires psutil") + raise + + +def _git_commit() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except: + return "unknown" + +def _git_branch() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except: + return "unknown" + +@dataclasses.dataclass +class TestEnvironmentInfo: + ezmsg_version: str = dataclasses.field(default_factory=lambda: ez.__version__) + numpy_version: str = dataclasses.field(default_factory=lambda: np.__version__) + python_version: str = dataclasses.field(default_factory=lambda: sys.version.replace("\n", " ")) + os: str = dataclasses.field(default_factory=lambda: platform.system()) + os_version: str = dataclasses.field(default_factory=lambda: platform.version()) + machine: str = dataclasses.field(default_factory=lambda: platform.machine()) + processor: str = dataclasses.field(default_factory=lambda: platform.processor()) + cpu_count_logical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=True)) + cpu_count_physical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=False)) + memory_gb: float = dataclasses.field(default_factory=lambda:round(psutil.virtual_memory().total / (1024**3), 2)) + start_time: str = dataclasses.field(default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds")) + git_commit: str = dataclasses.field(default_factory=_git_commit) + git_branch: str = dataclasses.field(default_factory=_git_branch) + + def __str__(self) -> str: + fields = dataclasses.asdict(self) + width = max(len(k) for k in fields) + lines = ["TestEnvironmentInfo:"] + for key, value in fields.items(): + lines.append(f" {key.ljust(width)} : {value}") + return "\n".join(lines) + + def diff(self, other: "TestEnvironmentInfo") -> typing.Dict[str, typing.Tuple[typing.Any, typing.Any]]: + """Return a structured diff: {field: (self_value, other_value)} for changed fields.""" + a = dataclasses.asdict(self) + b = dataclasses.asdict(other) + keys = set(a) | set(b) + return {k: (a.get(k), b.get(k)) for k in keys if a.get(k) != b.get(k)} + + +def format_env_diff(diffs: typing.Dict[str, typing.Tuple[typing.Any, typing.Any]]) -> str: + """Pretty-print the structured diff in the same aligned style.""" + if not diffs: + return "No differences." + width = max(len(k) for k in diffs) + lines = ["Differences in TestEnvironmentInfo:"] + for k in sorted(diffs): + left, right = diffs[k] + lines.append(f" {k.ljust(width)} : {left} != {right}") + return "\n".join(lines) + +def diff_envs(a: TestEnvironmentInfo, b: TestEnvironmentInfo) -> str: + return format_env_diff(a.diff(b)) \ No newline at end of file diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py new file mode 100644 index 00000000..51ae600b --- /dev/null +++ b/src/ezmsg/util/perf/impl.py @@ -0,0 +1,362 @@ +import asyncio +import dataclasses +import os +import time +import typing +import enum + +import ezmsg.core as ez + +from ezmsg.util.messages.util import replace +from ezmsg.core.netprotocol import Address + +try: + import numpy as np +except ImportError: + ez.logger.error("ezmsg perf requires numpy") + raise + +from .util import stable_perf + + +def collect( + components: typing.Optional[typing.Mapping[str, ez.Component]] = None, + network: ez.NetworkDefinition = (), + process_components: typing.Collection[ez.Component] | None = None, + **components_kwargs: ez.Component, +) -> ez.Collection: + """ collect a grouping of pre-configured components into a new "Collection" """ + from ezmsg.core.util import either_dict_or_kwargs + + components = either_dict_or_kwargs(components, components_kwargs, "collect") + if components is None: + raise ValueError("Must supply at least one component to run") + + out = ez.Collection() + for name, comp in components.items(): + comp._set_name(name) + out._components = components # FIXME: Component._components should be typehinted as a Mapping + out.network = lambda: network + out.process_components = lambda: (out,) if process_components is None else process_components + return out + + +@dataclasses.dataclass +class Metrics: + num_msgs: int + sample_rate_mean: float + sample_rate_median: float + latency_mean: float + latency_median: float + latency_total: float + data_rate: float + + +class LoadTestSettings(ez.Settings): + max_duration: float + num_msgs: int + dynamic_size: int + buffers: int + force_tcp: bool + + +@dataclasses.dataclass +class LoadTestSample: + _timestamp: float + counter: int + dynamic_data: np.ndarray + key: str + +class LoadTestSourceState(ez.State): + counter: int = 0 + +class LoadTestSource(ez.Unit): + OUTPUT = ez.OutputStream(LoadTestSample) + SETTINGS = LoadTestSettings + STATE = LoadTestSourceState + + async def initialize(self) -> None: + self.OUTPUT.num_buffers = self.SETTINGS.buffers + self.OUTPUT.force_tcp = self.SETTINGS.force_tcp + + @ez.publisher(OUTPUT) + async def publish(self) -> typing.AsyncGenerator: + ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") + start_time = time.time() + for _ in range(self.SETTINGS.num_msgs): + + current_time = time.time() + if current_time - start_time >= self.SETTINGS.max_duration: + break + + yield ( + self.OUTPUT, + LoadTestSample( + _timestamp=time.time(), + counter=self.STATE.counter, + dynamic_data=np.zeros( + int(self.SETTINGS.dynamic_size // 4), dtype=np.float32 + ), + key = self.name, + ), + ) + self.STATE.counter += 1 + + ez.logger.info("Exiting publish") + raise ez.Complete + + async def shutdown(self) -> None: + ez.logger.info(f"Samples sent: {self.STATE.counter}") + + + +class LoadTestRelay(ez.Unit): + INPUT = ez.InputStream(LoadTestSample) + OUTPUT = ez.OutputStream(LoadTestSample) + + @ez.subscriber(INPUT, zero_copy = True) + @ez.publisher(OUTPUT) + async def on_msg(self, msg: LoadTestSample) -> typing.AsyncGenerator: + yield self.OUTPUT, msg + + +class LoadTestReceiverState(ez.State): + # Tuples of sent timestamp, received timestamp, counter, dynamic size + received_data: typing.List[typing.Tuple[float, float, int]] = dataclasses.field( + default_factory=list + ) + counters: typing.Dict[str, int] = dataclasses.field(default_factory=dict) + + +class LoadTestReceiver(ez.Unit): + INPUT = ez.InputStream(LoadTestSample) + SETTINGS = LoadTestSettings + STATE = LoadTestReceiverState + + async def initialize(self) -> None: + ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") + + @ez.subscriber(INPUT, zero_copy=True) + async def receive(self, sample: LoadTestSample) -> None: + counter = self.STATE.counters.get(sample.key, -1) + if sample.counter != counter + 1: + ez.logger.warning( + f"{sample.counter - counter - 1} samples skipped!" + ) + self.STATE.received_data.append( + (sample._timestamp, time.time(), sample.counter) + ) + self.STATE.counters[sample.key] = sample.counter + + +class LoadTestSink(LoadTestReceiver): + + INPUT = ez.InputStream(LoadTestSample) + + @ez.subscriber(INPUT, zero_copy=True) + async def receive(self, sample: LoadTestSample) -> None: + await super().receive(sample) + if len(self.STATE.received_data) == self.SETTINGS.num_msgs: + raise ez.NormalTermination + + @ez.task + async def terminate(self) -> None: + # Wait for the max duration of the load test + await asyncio.sleep(self.SETTINGS.max_duration) + ez.logger.warning("TIMEOUT -- terminating test.") + raise ez.NormalTermination + + +### TEST CONFIGURATIONS + +@dataclasses.dataclass +class ConfigSettings: + n_clients: int + settings: LoadTestSettings + source: LoadTestSource + sink: LoadTestSink + +Configuration = typing.Tuple[typing.Iterable[ez.Component], ez.NetworkDefinition] +Configurator = typing.Callable[[ConfigSettings], Configuration] + +def fanout(config: ConfigSettings) -> Configuration: + """ one pub to many subs """ + connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] + subs = [LoadTestReceiver(config.settings) for _ in range(config.n_clients)] + for sub in subs: + connections.append((config.source.OUTPUT, sub.INPUT)) + + return subs, connections + +def fanin(config: ConfigSettings) -> Configuration: + """ many pubs to one sub """ + connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] + pubs = [LoadTestSource(config.settings) for _ in range(config.n_clients)] + expected_num_msgs = config.sink.SETTINGS.num_msgs * len(pubs) + config.sink.SETTINGS = replace(config.sink.SETTINGS, num_msgs = expected_num_msgs) # type: ignore + for pub in pubs: + connections.append((pub.OUTPUT, config.sink.INPUT)) + return pubs, connections + + +def relay(config: ConfigSettings) -> Configuration: + """ one pub to one sub through many relays """ + connections: ez.NetworkDefinition = [] + + relays = [LoadTestRelay(config.settings) for _ in range(config.n_clients)] + if len(relays): + connections.append((config.source.OUTPUT, relays[0].INPUT)) + for from_relay, to_relay in zip(relays[:-1], relays[1:]): + connections.append((from_relay.OUTPUT, to_relay.INPUT)) + connections.append((relays[-1].OUTPUT, config.sink.INPUT)) + else: connections.append((config.source.OUTPUT, config.sink.INPUT)) + + return relays, connections + +CONFIGS: typing.Mapping[str, Configurator] = { + c.__name__: c for c in [ + fanin, + fanout, + relay + ] +} + +class Communication(enum.StrEnum): + LOCAL = "local" + SHM = "shm" + SHM_SPREAD = "shm_spread" + TCP = "tcp" + TCP_SPREAD = "tcp_spread" + +def perform_test( + n_clients: int, + max_duration: float, + num_msgs: int, + msg_size: int, + buffers: int, + comms: Communication, + config: Configurator, + graph_address: Address +) -> Metrics: + + settings = LoadTestSettings( + dynamic_size = int(msg_size), + num_msgs = num_msgs, + max_duration = max_duration, + buffers = buffers, + force_tcp = (comms in (Communication.TCP, Communication.TCP_SPREAD)), + ) + + source = LoadTestSource(settings) + sink = LoadTestSink(settings) + + components: typing.Mapping[str, ez.Component] = dict( + SINK = sink, + ) + + clients, connections = config(ConfigSettings(n_clients, settings, source, sink)) + + # The 'sink' MUST remain in this process for us to pull its state. + process_components: typing.Iterable[ez.Component] = [] + if comms == Communication.LOCAL: + # Every component in the same process (this one) + components["SOURCE"] = source + for i, client in enumerate(clients): + components[f"CLIENT_{i+1}"] = client + + else: + + if comms in (Communication.SHM_SPREAD, Communication.TCP_SPREAD): + # Every component in its own process. + components["SOURCE"] = source + process_components.append(source) + for i, client in enumerate(clients): + components[f'CLIENT_{i+1}'] = client + process_components.append(client) + + else: + # All clients and the source in ONE other process. + collect_comps: typing.Mapping[str, ez.Component] = dict() + collect_comps["SOURCE"] = source + for i, client in enumerate(clients): + collect_comps[f"CLIENT_{i+1}"] = client + proc_collection = collect(components = collect_comps) + components["PROC"] = proc_collection + process_components = [proc_collection] + + with stable_perf(): + ez.run( + components = components, + connections = connections, + process_components = process_components, + graph_address = graph_address + ) + + return calculate_metrics(sink) + + +def calculate_metrics(sink: LoadTestSink) -> Metrics: + + # Log some useful summary statistics + min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data) + max_timestamp = max(timestamp for timestamp, _, _ in sink.STATE.received_data) + latency = [ + receive_timestamp - send_timestamp + for send_timestamp, receive_timestamp, _ in sink.STATE.received_data + ] + total_latency = abs(sum(latency)) + + counters = list(sorted(t[2] for t in sink.STATE.received_data)) + dropped_samples = sum( + [max((x1 - x0) - 1, 0) for x1, x0 in zip(counters[1:], counters[:-1])] + ) + + rx_timestamps = np.array([rx_ts for _, rx_ts, _ in sink.STATE.received_data]) + runtime = max_timestamp - min_timestamp + num_samples = len(sink.STATE.received_data) + samplerate_mean = num_samples / runtime + samplerate_median = 1.0 / float(np.median(np.diff(rx_timestamps))) + latency_mean = total_latency / num_samples + latency_median = list(sorted(latency))[len(latency) // 2] + total_data = num_samples * sink.SETTINGS.dynamic_size + data_rate = total_data / runtime + + ez.logger.info(f"Samples received: {num_samples}") + ez.logger.info(f"Mean sample rate: {samplerate_mean} Hz") + ez.logger.info(f"Median sample rate: {samplerate_median} Hz") + ez.logger.info(f"Mean latency: {latency_mean} s") + ez.logger.info(f"Median latency: {latency_median} s") + ez.logger.info(f"Total latency: {total_latency} s") + ez.logger.info(f"Data rate: {data_rate * 1e-6} MB/s") + + if dropped_samples: + ez.logger.error( + f"Dropped samples: {dropped_samples} ({dropped_samples / (dropped_samples + num_samples)}%)", + ) + + return Metrics( + num_msgs = num_samples, + sample_rate_mean = samplerate_mean, + sample_rate_median = samplerate_median, + latency_mean = latency_mean, + latency_median = latency_median, + latency_total = total_latency, + data_rate = data_rate + ) + + +@dataclasses.dataclass(unsafe_hash=True) +class TestParameters: + msg_size: int + num_msgs: int + n_clients: int + config: str + comms: str + max_duration: float + num_buffers: int + + +@dataclasses.dataclass +class TestLogEntry: + params: TestParameters + results: Metrics \ No newline at end of file diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py new file mode 100644 index 00000000..ce6c2071 --- /dev/null +++ b/src/ezmsg/util/perf/run.py @@ -0,0 +1,317 @@ +import os +import sys +import json +import datetime +import itertools +import argparse +import typing +import random +import time + +from datetime import datetime, timedelta +from contextlib import contextmanager, redirect_stdout, redirect_stderr + +import ezmsg.core as ez +from ezmsg.core.graphserver import GraphServer + +from ..messagecodec import MessageEncoder +from .envinfo import TestEnvironmentInfo +from .util import warmup +from .impl import ( + TestParameters, + TestLogEntry, + perform_test, + Communication, + CONFIGS, +) + +DEFAULT_MSG_SIZES = [2**4, 2**20] +DEFAULT_N_CLIENTS = [1, 16] +DEFAULT_COMMS = [c for c in Communication] + +# --- Output Suppression Context Manager --- +@contextmanager +def suppress_output(verbose: bool = False): + """Context manager to redirect stdout and stderr to os.devnull""" + if verbose: + yield + else: + # Open the null device for writing + with open(os.devnull, 'w') as fnull: + # Redirect both stdout and stderr to the null device + with redirect_stderr(fnull): + with redirect_stdout(fnull): + yield + +CHECK_FOR_QUIT = lambda: False + +if sys.platform.startswith('win'): + import msvcrt + def _check_for_quit_win() -> bool: + """ + Checks for the 'q' key press in a non-blocking way. + Returns True if 'q' is pressed (case-insensitive), False otherwise. + """ + # Windows: Use msvcrt for non-blocking keyboard hit detection + if msvcrt.kbhit(): # type: ignore + # Read the key press (returns bytes) + key = msvcrt.getch() # type: ignore + try: + # Decode and check for 'q' + return key.decode().lower() == 'q' + except UnicodeDecodeError: + # Handle potential non-text key presses gracefully + return False + return False + + CHECK_FOR_QUIT = _check_for_quit_win + +else: + import select + def _check_for_quit() -> bool: + """ + Checks for the 'q' key press in a non-blocking way. + Returns True if 'q' is pressed (case-insensitive), False otherwise. + """ + # Linux/macOS: Use select to check if stdin has data + # select.select(rlist, wlist, xlist, timeout) + # timeout=0 makes it non-blocking + if sys.stdin.isatty(): + i, o, e = select.select([sys.stdin], [], [], 0) # type: ignore + if i: + # Read the buffered character + key = sys.stdin.read(1) + return key.lower() == 'q' + return False + + CHECK_FOR_QUIT = _check_for_quit + +def get_datestamp() -> str: + return datetime.now().strftime("%Y%m%d_%H%M%S") + +def perf_run( + max_duration: float, + num_msgs: int, + num_buffers: int, + iters: int, + repeats: int, + msg_sizes: typing.List[int] | None, + n_clients: typing.List[int] | None, + comms: typing.Iterable[str] | None, + configs: typing.Iterable[str] | None, + grid: bool, + warmup_dur: float, +) -> None: + + if n_clients is None: + n_clients = DEFAULT_N_CLIENTS + if any(c < 0 for c in n_clients): + ez.logger.error('All tests must have >=0 clients') + return + + if msg_sizes is None: + msg_sizes = DEFAULT_MSG_SIZES + if any(s < 0 for s in msg_sizes): + ez.logger.error('All msg_sizes must be >=0 bytes') + + if not grid and len(list(n_clients)) != len(list(msg_sizes)): + ez.logger.warning( + "Not performing a grid test of all combinations of n_clients and msg_sizes, but " + \ + f"{len(n_clients)=} which is not equal to {len(msg_sizes)=}. " + ) + + try: + communications = DEFAULT_COMMS if comms is None else [Communication(c) for c in comms] + except ValueError: + ez.logger.error(f"Invalid test communications requested. Valid communications: {', '.join([c.value for c in Communication])}") + return + + try: + configurators = list(CONFIGS.values()) if configs is None else [CONFIGS[c] for c in configs] + except ValueError: + ez.logger.error(f"Invalid test configuration requested. Valid configurations: {', '.join([c for c in CONFIGS])}") + return + + subitr = itertools.product if grid else zip + + test_list = [ + (msg_size, clients, conf, comm) + for msg_size, clients in subitr(msg_sizes, n_clients) + for conf, comm in itertools.product(configurators, communications) + ] * iters + + random.shuffle(test_list) + + server = GraphServer() + server.start() + + ez.logger.info(f"About to run {len(test_list)} tests (repeated {repeats} times) of {max_duration} sec (max) each.") + ez.logger.info(f"During each test, source will attempt to send {num_msgs} messages to the sink.") + ez.logger.info(f"Please try to avoid running other taxing software while this perf test runs.") + ez.logger.info(f"NOTE: Tests swallow interrupt. After warmup, use 'q' then [enter] to quit tests early.") + + quitting = False + + start_time = time.time() + + try: + ez.logger.info(f"Warming up for {warmup_dur} seconds...") + warmup(warmup_dur) + + with open(f'perf_{get_datestamp()}.txt', 'w') as out_f: + for _ in range(repeats): + out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") + + for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list): + + if CHECK_FOR_QUIT(): + ez.logger.info("Stopping tests early...") + quitting = True + break + + ez.logger.info( + f"TEST {test_idx + 1}/{len(test_list)}: " \ + f"{clients=}, {msg_size=}, conf={conf.__name__}, " \ + f"comm={comm.value}" + ) + + output = TestLogEntry( + params = TestParameters( + msg_size = msg_size, + num_msgs = num_msgs, + n_clients = clients, + config = conf.__name__, + comms = comm.value, + max_duration = max_duration, + num_buffers = num_buffers + ), + results = perform_test( + n_clients = clients, + max_duration = max_duration, + num_msgs = num_msgs, + msg_size = msg_size, + buffers = num_buffers, + comms = comm, + config = conf, + graph_address = server.address + ) + ) + + out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") + + if quitting: + break + + finally: + server.stop() + d = datetime(1,1,1) + timedelta(seconds = time.time() - start_time) + dur_str = ':'.join([str(n) for n in [d.day - 1, d.hour, d.minute, d.second] if n != 0]) + ez.logger.info(f"Tests concluded. Wallclock Runtime: {dur_str}s") + + + + +def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: + + p_run = subparsers.add_parser("run", help="run performance test") + + p_run.add_argument( + "--max-duration", + type=float, + default=5.0, + help="maximum individual test duration in seconds (default = 5.0)", + ) + + p_run.add_argument( + "--num-msgs", + type=int, + default=1000, + help = "number of messages to send per-test (default = 1000)" + ) + + # NOTE: We default num-buffers = 1 because this degenerate perf test scenario (blasting + # messages as fast as possible through the system) results in one of two scenerios: + # 1. A (few) messages is/are enqueued and dequeued before another message is posted + # 2. The buffer fills up before being FULLY emptied resulting in longer latency. + # (once a channel enters this condition, it tends to stay in this condition) + # + # This _indeterminate_ behavior results in bimodal distributions of runtimes that make + # A/B performance comparisons difficult. The perf test is not representative of the vast + # majority of production ezmsg systems where publishing is generally rate-limited. + # + # A flow-control algorithm could stabilize perf-test results with num_buffers > 1, but is + # generally implemented by enforcing delays on the publish side which simply degrades + # performance in the vast majority of ezmsg systems. - Griff + p_run.add_argument( + "--num-buffers", + type=int, + default=1, + help="shared memory buffers (default = 1)", + ) + + p_run.add_argument( + "--iters", "-i", + type = int, + default = 5, + help = "number of times to run each test (default = 5)" + ) + + p_run.add_argument( + "--repeats", "-r", + type = int, + default = 10, + help = "number of times to repeat the perf (default = 10)" + ) + + p_run.add_argument( + "--msg-sizes", + type = int, + default = None, + nargs = "*", + help = f"message sizes in bytes (default = {DEFAULT_MSG_SIZES})" + ) + + p_run.add_argument( + "--n-clients", + type = int, + default = None, + nargs = "*", + help = f"number of clients (default = {DEFAULT_N_CLIENTS})" + ) + + p_run.add_argument( + "--comms", + type = str, + default = None, + nargs = "*", + help = f"communication strategies to test (default = {[c.value for c in DEFAULT_COMMS]})" + ) + + p_run.add_argument( + "--configs", + type = str, + default = None, + nargs = "*", + help = f"configurations to test (default = {[c for c in CONFIGS]})" + ) + + p_run.add_argument( + "--warmup", + type = float, + default = 60.0, + help = "warmup CPU with busy task for some number of seconds (default = 60.0)" + ) + + p_run.set_defaults(_handler=lambda ns: perf_run( + max_duration = ns.max_duration, + num_msgs = ns.num_msgs, + num_buffers = ns.num_buffers, + iters = ns.iters, + repeats = ns.repeats, + msg_sizes = ns.msg_sizes, + n_clients = ns.n_clients, + comms = ns.comms, + configs = ns.configs, + grid = True, + warmup_dur = ns.warmup, + )) \ No newline at end of file diff --git a/src/ezmsg/util/perf/util.py b/src/ezmsg/util/perf/util.py new file mode 100644 index 00000000..d22d1496 --- /dev/null +++ b/src/ezmsg/util/perf/util.py @@ -0,0 +1,288 @@ +import os +import sys +import gc +import time +import statistics as stats +import contextlib +import subprocess +import platform +from dataclasses import dataclass +from typing import Iterable, List, Optional + +try: + import psutil # optional but helpful +except Exception: + psutil = None + +_IS_WIN = os.name == "nt" +_IS_MAC = sys.platform == "darwin" +_IS_LINUX = sys.platform.startswith("linux") + +# ---------- Utilities ---------- + +def _set_env_threads(single_thread: bool = True): + """ + Normalize math/threading libs so they don't spawn surprise worker threads. + """ + if single_thread: + os.environ.setdefault("OMP_NUM_THREADS", "1") + os.environ.setdefault("MKL_NUM_THREADS", "1") + os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1") + os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") + os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") + # Keep PYTHONHASHSEED stable for deterministic dict/set iteration costs + os.environ.setdefault("PYTHONHASHSEED", "0") + +# ---------- Priority & Affinity ---------- + +@contextlib.contextmanager +def _process_priority(): + """ + Elevate process priority in a cross-platform best-effort way. + """ + if psutil is None: + yield + return + + p = psutil.Process() + orig_nice = None + if _IS_WIN: + try: + import ctypes, ctypes.wintypes as wt + kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) + ABOVE_NORMAL_PRIORITY_CLASS = 0x00008000 + HIGH_PRIORITY_CLASS = 0x00000080 + # Try High, fall back to Above Normal + if not kernel32.SetPriorityClass(kernel32.GetCurrentProcess(), HIGH_PRIORITY_CLASS): + kernel32.SetPriorityClass(kernel32.GetCurrentProcess(), ABOVE_NORMAL_PRIORITY_CLASS) + except Exception: + pass + else: + try: + orig_nice = p.nice() + # Negative nice may need privileges; try smaller magnitude first + for nice_val in (-10, -5, 0): + try: + p.nice(nice_val) + break + except Exception: + continue + except Exception: + pass + try: + yield + finally: + # restore nice if we changed it + if psutil is not None and not _IS_WIN and orig_nice is not None: + try: + p.nice(orig_nice) + except Exception: + pass + +@contextlib.contextmanager +def _cpu_affinity(prefer_isolation: bool = True): + """ + Set CPU affinity to a small, stable set of CPUs (where supported). + macOS does not support affinity via psutil; we no-op there. + """ + if psutil is None or _IS_MAC: + yield + return + + p = psutil.Process() + original = None + try: + if hasattr(p, "cpu_affinity"): + original = p.cpu_affinity() + cpus = original + if prefer_isolation and len(cpus) > 2: + # Pick two middle CPUs to avoid 0 which often handles interrupts + mid = len(cpus) // 2 + cpus = [cpus[mid-1], cpus[mid]] + p.cpu_affinity(cpus) + yield + finally: + try: + if original is not None and hasattr(p, "cpu_affinity"): + p.cpu_affinity(original) + except Exception: + pass + +# ---------- Platform-specific helpers ---------- + +@contextlib.contextmanager +def _mac_caffeinate(): + """ + Keep macOS awake during the run via a background caffeinate process. + """ + if not _IS_MAC: + yield + return + proc = None + try: + proc = subprocess.Popen(["caffeinate", "-dimsu"]) + except Exception: + proc = None + try: + yield + finally: + if proc is not None: + try: + proc.terminate() + except Exception: + pass + +@contextlib.contextmanager +def _win_timer_resolution(ms: int = 1): + """ + On Windows, request a finer system timer to stabilize sleeps and scheduling slices. + """ + if not _IS_WIN: + yield + return + import ctypes + winmm = ctypes.WinDLL("winmm") + timeBeginPeriod = winmm.timeBeginPeriod + timeEndPeriod = winmm.timeEndPeriod + try: + timeBeginPeriod(ms) + except Exception: + pass + try: + yield + finally: + try: + timeEndPeriod(ms) + except Exception: + pass + +# ---------- Warm-up & GC ---------- + +def warmup(seconds: float = 60.0, fn=None, *args, **kwargs): + """ + Optional warm-up to reach steady clocks/caches. + If fn is provided, call it in a loop for the given time. + """ + if seconds <= 0: + return + end = time.perf_counter() + target = end + seconds + if fn is None: + # Busy wait / sleep mix to heat up without heavy CPU + while time.perf_counter() < target: + x = 0 + for _ in range(10000): + x += 1 + time.sleep(0) + else: + while time.perf_counter() < target: + fn(*args, **kwargs) + +@contextlib.contextmanager +def gc_pause(): + """ + Disable GC inside timing windows; re-enable and collect after. + """ + was_enabled = gc.isenabled() + try: + gc.disable() + yield + finally: + if was_enabled: + gc.enable() + gc.collect() + +# ---------- Robust statistics ---------- + +def median_of_means(samples: Iterable[float], k: int = 5) -> float: + """ + Robust estimate: split samples into k buckets (round-robin), average each, take median of bucket means. + """ + samples = list(samples) + if not samples: + return float("nan") + k = max(1, min(k, len(samples))) + buckets = [[] for _ in range(k)] + for i, v in enumerate(samples): + buckets[i % k].append(v) + means = [sum(b)/len(b) for b in buckets if b] + means.sort() + return means[len(means)//2] + +def coef_var(samples: Iterable[float]) -> float: + vals = list(samples) + if len(vals) < 2: + return 0.0 + m = sum(vals)/len(vals) + if m == 0: + return 0.0 + sd = stats.pstdev(vals) + return sd / m + +# ---------- Public context manager ---------- + +@dataclass +class PerfOptions: + single_thread_math: bool = True + prefer_isolated_cpus: bool = True + warmup_seconds: float = 0.0 + adjust_priority: bool = True + tweak_timer_windows: bool = True + keep_mac_awake: bool = True + +@contextlib.contextmanager +def stable_perf(opts: PerfOptions = PerfOptions()): + """ + Wrap your perf runs with this context manager for a stabler environment. + """ + _set_env_threads(opts.single_thread_math) + + cm_stack = contextlib.ExitStack() + try: + if opts.adjust_priority: + cm_stack.enter_context(_process_priority()) + if opts.tweak_timer_windows: + cm_stack.enter_context(_win_timer_resolution(1)) + if opts.prefer_isolated_cpus: + cm_stack.enter_context(_cpu_affinity(True)) + if opts.keep_mac_awake: + cm_stack.enter_context(_mac_caffeinate()) + + if opts.warmup_seconds > 0: + warmup(opts.warmup_seconds) + + with gc_pause(): + yield + finally: + cm_stack.close() + +# ---------- Example runners ---------- + +def run_interleaved(configs: List[dict], run_fn, trials: int = 5, trial_seconds: float = 30.0, seed: int = 42): + """ + Interleave scenarios to cancel slow drift. `run_fn(config, seconds, seed_offset)` should return a dict of metrics. + Returns a list of per-trial results per config. + """ + import random + random.seed(seed) + results = [ [] for _ in configs ] + order = list(range(len(configs))) + # fixed order per pass; you'll re-use the same order for stability + for t in range(trials): + for idx in order: + res = run_fn(configs[idx], trial_seconds, seed + t) + results[idx].append(res) + return results + +def summarize_metric(trials: List[dict], key: str, mom_buckets: int = 5): + """ + Extract a metric across trial dicts and summarize with median-of-means and CV. + """ + vals = [float(tr[key]) for tr in trials if key in tr] + return { + "count": len(vals), + "mom": median_of_means(vals, mom_buckets), + "mean": sum(vals)/len(vals) if vals else float("nan"), + "p50": stats.median(vals) if vals else float("nan"), + "cv": coef_var(vals) if vals else float("nan"), + } \ No newline at end of file diff --git a/src/ezmsg/util/perf_test.py b/src/ezmsg/util/perf_test.py deleted file mode 100644 index b7536b5e..00000000 --- a/src/ezmsg/util/perf_test.py +++ /dev/null @@ -1,220 +0,0 @@ -import asyncio -import dataclasses -import datetime -import os -import platform -import time - -from typing import List, Tuple, AsyncGenerator - -import ezmsg.core as ez - -# We expect this test to generate LOTS of backpressure warnings -# PERF_LOGLEVEL = os.environ.get("EZMSG_LOGLEVEL", "ERROR") -# ez.logger.setLevel(PERF_LOGLEVEL) - -PLATFORM = { - "Darwin": "mac", - "Linux": "linux", - "Windows": "win", -}[platform.system()] -SAMPLE_SUMMARY_DATASET_PREFIX = "sample_summary" -COUNT_DATASET_NAME = "count" - - -try: - import numpy as np -except ImportError: - ez.logger.error("This test requires Numpy to run.") - raise - - -def get_datestamp() -> str: - return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - - -class LoadTestSettings(ez.Settings): - duration: float = 30.0 - dynamic_size: int = 8 - buffers: int = 32 - - -@dataclasses.dataclass -class LoadTestSample: - _timestamp: float - counter: int - dynamic_data: np.ndarray - - -class LoadTestPublisher(ez.Unit): - OUTPUT = ez.OutputStream(LoadTestSample) - SETTINGS = LoadTestSettings - - async def initialize(self) -> None: - self.running = True - self.counter = 0 - self.OUTPUT.num_buffers = self.SETTINGS.buffers - - @ez.publisher(OUTPUT) - async def publish(self) -> AsyncGenerator: - ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") - start_time = time.time() - while self.running: - current_time = time.time() - if current_time - start_time >= self.SETTINGS.duration: - break - - yield ( - self.OUTPUT, - LoadTestSample( - _timestamp=time.time(), - counter=self.counter, - dynamic_data=np.zeros( - int(self.SETTINGS.dynamic_size // 8), dtype=np.float32 - ), - ), - ) - self.counter += 1 - ez.logger.info("Exiting publish") - raise ez.Complete - - async def shutdown(self) -> None: - self.running = False - ez.logger.info(f"Samples sent: {self.counter}") - - -class LoadTestSubscriberState(ez.State): - # Tuples of sent timestamp, received timestamp, counter, dynamic size - received_data: List[Tuple[float, float, int]] = dataclasses.field( - default_factory=list - ) - counter: int = -1 - - -class LoadTestSubscriber(ez.Unit): - INPUT = ez.InputStream(LoadTestSample) - SETTINGS = LoadTestSettings - STATE = LoadTestSubscriberState - - @ez.subscriber(INPUT, zero_copy=True) - async def receive(self, sample: LoadTestSample) -> None: - if sample.counter != self.STATE.counter + 1: - ez.logger.warning( - f"{sample.counter - self.STATE.counter - 1} samples skipped!" - ) - self.STATE.received_data.append( - (sample._timestamp, time.time(), sample.counter) - ) - self.STATE.counter = sample.counter - - @ez.task - async def log_result(self) -> None: - ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") - - # Wait for the duration of the load test - await asyncio.sleep(self.SETTINGS.duration) - # logger.info(f"STATE = {self.STATE.received_data}") - - # Log some useful summary statistics - min_timestamp = min(timestamp for timestamp, _, _ in self.STATE.received_data) - max_timestamp = max(timestamp for timestamp, _, _ in self.STATE.received_data) - total_latency = abs( - sum( - receive_timestamp - send_timestamp - for send_timestamp, receive_timestamp, _ in self.STATE.received_data - ) - ) - - counters = list(sorted(t[2] for t in self.STATE.received_data)) - dropped_samples = sum( - [(x1 - x0) - 1 for x1, x0 in zip(counters[1:], counters[:-1])] - ) - - num_samples = len(self.STATE.received_data) - ez.logger.info(f"Samples received: {num_samples}") - ez.logger.info( - f"Sample rate: {num_samples / (max_timestamp - min_timestamp)} Hz" - ) - ez.logger.info(f"Mean latency: {total_latency / num_samples} s") - ez.logger.info(f"Total latency: {total_latency} s") - - total_data = num_samples * self.SETTINGS.dynamic_size - ez.logger.info( - f"Data rate: {total_data / (max_timestamp - min_timestamp) * 1e-6} MB/s" - ) - ez.logger.info( - f"Dropped samples: {dropped_samples} ({dropped_samples / (dropped_samples + num_samples)}%)", - ) - - raise ez.NormalTermination - - -class LoadTest(ez.Collection): - SETTINGS = LoadTestSettings - - PUBLISHER = LoadTestPublisher() - SUBSCRIBER = LoadTestSubscriber() - - def configure(self) -> None: - self.PUBLISHER.apply_settings(self.SETTINGS) - self.SUBSCRIBER.apply_settings(self.SETTINGS) - - def network(self) -> ez.NetworkDefinition: - return ((self.PUBLISHER.OUTPUT, self.SUBSCRIBER.INPUT),) - - def process_components(self): - return ( - self.PUBLISHER, - self.SUBSCRIBER, - ) - - -def get_time() -> float: - # time.perf_counter() isn't system-wide on Windows Python 3.6: - # https://bugs.python.org/issue37205 - return time.time() if PLATFORM == "win" else time.perf_counter() - - -def test_performance(duration, size, buffers) -> None: - ez.logger.info(f"Running load test for dynamic size: {size} bytes") - system = LoadTest( - LoadTestSettings(dynamic_size=int(size), duration=duration, buffers=buffers) - ) - ez.run(SYSTEM=system) - - -def run_many_dynamic_sizes(duration, buffers) -> None: - for exp in range(5, 22, 4): - test_performance(duration, 2**exp, buffers) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--many-dynamic-sizes", - action="store_true", - help="Run load test for many dynamic sizes", - ) - parser.add_argument( - "--duration", - type=int, - default=2, - help="How long to run the load test (seconds)", - ) - parser.add_argument( - "--num-buffers", type=int, default=32, help="Shared memory buffers" - ) - - class Args: - many_dynamic_sizes: bool - duration: int - num_buffers: int - - args = parser.parse_args(namespace=Args) - - if args.many_dynamic_sizes: - run_many_dynamic_sizes(args.duration, args.num_buffers) - else: - test_performance(args.duration, 8, args.num_buffers)