From cf8e2f6df981fbe341f7d18c6a866e922bc68248 Mon Sep 17 00:00:00 2001 From: Bastien Orivel Date: Wed, 24 Dec 2025 03:09:23 +0100 Subject: [PATCH] Register the optimization schema with `register_strategy` Instead of hardcoding the schema, make it part of the registration of the strategy, this allows for people to add their own strategy without having to work around `OptimizationSchema` which is painful and brittle. The slight drawback from this is that we lose the documentation in the task schema about what is possible since it's now deferred to the strategies themselves, and that we now validate optimizations twice, first to check that it's a dict then to check that it matches the strategy. Fixes #368 --- src/taskgraph/optimize/base.py | 11 ++++++++++- src/taskgraph/optimize/strategies.py | 5 +++-- src/taskgraph/transforms/task.py | 5 ++--- src/taskgraph/util/schema.py | 10 ---------- test/test_optimize.py | 13 +++++++++++++ 5 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/taskgraph/optimize/base.py b/src/taskgraph/optimize/base.py index ba07ca276..0d333c624 100644 --- a/src/taskgraph/optimize/base.py +++ b/src/taskgraph/optimize/base.py @@ -22,13 +22,14 @@ from taskgraph.taskgraph import TaskGraph from taskgraph.util.parameterization import resolve_task_references, resolve_timestamps from taskgraph.util.python_path import import_sibling_modules +from taskgraph.util.schema import validate_schema from taskgraph.util.taskcluster import find_task_id_batched, status_task_batched logger = logging.getLogger("optimization") registry = {} -def register_strategy(name, args=(), kwargs=None): +def register_strategy(name, args=(), kwargs=None, schema=None): kwargs = kwargs or {} def wrap(cls): @@ -36,6 +37,7 @@ def wrap(cls): registry[name] = cls(*args, **kwargs) if not hasattr(registry[name], "description"): registry[name].description = name + registry[name].schema = schema return cls return wrap @@ -123,6 +125,13 @@ def optimizations(label): if task.optimization: opt_by, arg = list(task.optimization.items())[0] strategy = strategies[opt_by] + schema = getattr(strategy, "schema", None) + if schema: + validate_schema( + schema, + arg, + f"In task `{label}` optimization `{opt_by}`:", + ) if hasattr(strategy, "description"): opt_by += f" ({strategy.description})" return (opt_by, strategy, arg) diff --git a/src/taskgraph/optimize/strategies.py b/src/taskgraph/optimize/strategies.py index 8fed9e54a..7b461fe2a 100644 --- a/src/taskgraph/optimize/strategies.py +++ b/src/taskgraph/optimize/strategies.py @@ -5,12 +5,13 @@ from taskgraph.optimize.base import OptimizationStrategy, register_strategy from taskgraph.util.path import match as match_path +from taskgraph.util.schema import Schema from taskgraph.util.taskcluster import find_task_id, status_task logger = logging.getLogger("optimization") -@register_strategy("index-search") +@register_strategy("index-search", schema=Schema([str])) class IndexSearch(OptimizationStrategy): # A task with no dependencies remaining after optimization will be replaced # if artifacts exist for the corresponding index_paths. @@ -73,7 +74,7 @@ def should_replace_task(self, task, params, deadline, arg): return False -@register_strategy("skip-unless-changed") +@register_strategy("skip-unless-changed", schema=Schema([str])) class SkipUnlessChanged(OptimizationStrategy): def check(self, files_changed, patterns): for pattern in patterns: diff --git a/src/taskgraph/transforms/task.py b/src/taskgraph/transforms/task.py index 7c834dcc7..e6dc127cd 100644 --- a/src/taskgraph/transforms/task.py +++ b/src/taskgraph/transforms/task.py @@ -23,7 +23,6 @@ from taskgraph.util.hash import hash_path from taskgraph.util.keyed_by import evaluate_keyed_by from taskgraph.util.schema import ( - OptimizationSchema, Schema, optionally_keyed_by, resolve_keyed_by, @@ -340,10 +339,10 @@ def run_task_suffix(): description=dedent( """ Optimization to perform on this task during the optimization - phase. Defined in taskcluster/taskgraph/optimize.py. + phase. The schema for this value is specific to the given optimization. """.lstrip() ), - ): OptimizationSchema, + ): Any(None, dict), Required( "worker-type", description=dedent( diff --git a/src/taskgraph/util/schema.py b/src/taskgraph/util/schema.py index 3c5f4c955..f6d2ee9d3 100644 --- a/src/taskgraph/util/schema.py +++ b/src/taskgraph/util/schema.py @@ -230,16 +230,6 @@ def __getitem__(self, item): return self.schema[item] # type: ignore -OptimizationSchema = voluptuous.Any( - # always run this task (default) - None, - # search the index for the given index namespaces, and replace this task if found - # the search occurs in order, with the first match winning - {"index-search": [str]}, - # skip this task if none of the given file patterns match - {"skip-unless-changed": [str]}, -) - # shortcut for a string where task references are allowed taskref_or_string = voluptuous.Any( str, diff --git a/test/test_optimize.py b/test/test_optimize.py index bfc2e9709..07c2a6a9b 100644 --- a/test/test_optimize.py +++ b/test/test_optimize.py @@ -7,6 +7,7 @@ import pytest from pytest_taskgraph import make_graph, make_task +from voluptuous import Schema from taskgraph.graph import Graph from taskgraph.optimize import base as optimize_mod @@ -487,3 +488,15 @@ def test_register_strategy(mocker): func = register_strategy("foo", args=("one", "two"), kwargs={"n": 1}) func(m) m.assert_called_with("one", "two", n=1) + + +def test_register_strategy_with_schema(mocker, monkeypatch): + monkeypatch.setattr(optimize_mod, "registry", {}) + schema = Schema([str]) + + @register_strategy("bar", schema=schema) + class TestStrategy(OptimizationStrategy): + pass + + assert "bar" in optimize_mod.registry + assert optimize_mod.registry["bar"].schema is schema