From baa4b535e9fc392d6e456cf548a2135a55b21218 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 17:55:40 +0100 Subject: [PATCH 01/17] Make mypy happy with pytensor/printing.py --- pytensor/printing.py | 64 ++++++++++++++++++---------------------- scripts/mypy-failing.txt | 1 - scripts/run_mypy.py | 4 +-- 3 files changed, 30 insertions(+), 39 deletions(-) diff --git a/pytensor/printing.py b/pytensor/printing.py index 2600d15459..eae2392609 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -1360,6 +1360,7 @@ def pydotprint( cond_highlight = None if cond_highlight is not None: + assert cond is not None def recursive_pass(x, ls): if x.owner is None: @@ -1370,20 +1371,22 @@ def recursive_pass(x, ls): ls += recursive_pass(inp, ls) return ls - left = set(recursive_pass(cond.inputs[1], [])) - right = set(recursive_pass(cond.inputs[2], [])) - middle = left.intersection(right) - left = left.difference(middle) - right = right.difference(middle) - middle = list(middle) - left = list(left) - right = list(right) + set_left = set(recursive_pass(cond.inputs[1], [])) + set_right = set(recursive_pass(cond.inputs[2], [])) - var_str = {} - var_id = {} - all_strings = set() + set_middle = set_left.intersection(set_right) + set_left = set_left.difference(set_middle) + set_right = set_right.difference(set_middle) - def var_name(var): + middle = list(set_middle) + left = list(set_left) + right = list(set_right) + + var_str: dict[Any, str] = {} + var_id: dict[Any, str] = {} + all_strings: set[str] = set() + + def var_name(var) -> tuple[str, str]: if var in var_str: return var_str[var], var_id[var] @@ -1391,9 +1394,9 @@ def var_name(var): if var_with_name_simple: varstr = var.name else: - varstr = "name=" + var.name + " " + str(var.type) + varstr = f"name={var.name} {var.type}" elif isinstance(var, Constant): - dstr = "val=" + str(np.asarray(var.data)) + dstr = f"val={np.asarray(var.data)}" if "\n" in dstr: dstr = dstr[: dstr.index("\n")] varstr = f"{dstr} {var.type}" @@ -1414,8 +1417,8 @@ def var_name(var): return varstr, var_id[var] - apply_name_cache = {} - apply_name_id = {} + apply_name_cache: dict[Any, str] = {} + apply_name_id: dict[Any, str] = {} def apply_name(node): if node in apply_name_cache: @@ -1432,7 +1435,7 @@ def apply_name(node): applystr = str(node.op).replace(":", "_") applystr += prof_str if (applystr in all_strings) or with_ids: - idx = " id=" + str(topo.index(node)) + idx = f" id={topo.index(node)}" if len(applystr) + len(idx) > max_label_size: applystr = applystr[: max_label_size - 3 - len(idx)] + idx + "..." else: @@ -1442,7 +1445,7 @@ def apply_name(node): idx = 1 while applystr in all_strings: idx += 1 - suffix = " id=" + str(idx) + suffix = f" id={idx}" applystr = applystr[: max_label_size - 3 - len(suffix)] + "..." + suffix all_strings.add(applystr) @@ -1626,10 +1629,10 @@ def apply_name(node): for idx, scan_op in scan_ops: # is there a chance that name is not defined? if hasattr(scan_op.op, "name"): - new_name = outfile.stem + "_" + scan_op.op.name + "_" + str(idx) + new_stem = f"{outfile.stem}_{scan_op.op.name}_{idx}" else: - new_name = outfile.stem + "_" + str(idx) - new_name = outfile.with_stem(new_name) + new_stem = f"{outfile.stem}_{idx}" + new_name = outfile.with_stem(new_stem) if hasattr(scan_op.op, "_fn"): to_print = scan_op.op.fn else: @@ -1752,7 +1755,7 @@ def min_informative_str( if id(obj) in _prev_obs: tag = _prev_obs[id(obj)] - return indent + "<" + tag + ">" + return f"{indent}<{tag}>" if _tag_generator is None: _tag_generator = _TagGenerator() @@ -1778,11 +1781,7 @@ def min_informative_str( else: name = str(obj) - prefix = cur_tag + ". " - - rval = indent + prefix + name - - return rval + return f"{indent}{cur_tag}. {name}" def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> str: @@ -1797,7 +1796,7 @@ def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> s if id(obj) in _prev_obs: tag = _prev_obs[id(obj)] - return "<" + tag + ">" + return f"<{tag}>" if _tag_generator is None: _tag_generator = _TagGenerator() @@ -1833,17 +1832,12 @@ def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> s if " at 0x" in name: raise AssertionError(name) - prefix = cur_tag + "=" - - rval = prefix + name - - return rval + return f"{cur_tag}={name}" def position_independent_str(obj) -> str: if isinstance(obj, Variable): - rval = "pytensor_var" - rval += "{type=" + str(obj.type) + "}" + rval = f"pytensor_var{{type={obj.type}}}" else: raise NotImplementedError() diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index ff73de2605..336845b5fd 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -7,7 +7,6 @@ pytensor/graph/rewriting/basic.py pytensor/ifelse.py pytensor/link/numba/dispatch/elemwise.py pytensor/link/numba/dispatch/scan.py -pytensor/printing.py pytensor/raise_op.py pytensor/tensor/basic.py pytensor/tensor/blas_c.py diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 0b9529cc18..54179853a6 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -101,9 +101,7 @@ def check_no_unexpected_results(mypy_df: pd.DataFrame, show_expected: bool): print("!!!!!!!!!") print(f"{len(unexpected_passing)} files unexpectedly passed the type checks:") print("\n".join(sorted(map(str, unexpected_passing)))) - print( - "This is good news! Go to scripts/run_mypy.py and remove them from the `FAILING` list." - ) + print("This is good news! Remove them from scripts/mypy-failing.txt.") if all_files.issubset(passing): print("WOW! All files are passing the mypy type checks!") print("scripts\\run_mypy.py may no longer be needed.") From 6c1f7b0327db6aca2cedde52b83776068780a1c4 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 18:03:21 +0100 Subject: [PATCH 02/17] Make mypy happy with pytensor/tensor/blas_c.py --- pytensor/tensor/blas_c.py | 2 +- scripts/mypy-failing.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py index 83cd87796a..f0c8f4995a 100644 --- a/pytensor/tensor/blas_c.py +++ b/pytensor/tensor/blas_c.py @@ -670,4 +670,4 @@ def must_initialize_y_gemv(): return must_initialize_y_gemv._force_init_beta -must_initialize_y_gemv._force_init_beta = None +must_initialize_y_gemv._force_init_beta = None # type: ignore[attr-defined] diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 336845b5fd..447385be54 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -9,7 +9,6 @@ pytensor/link/numba/dispatch/elemwise.py pytensor/link/numba/dispatch/scan.py pytensor/raise_op.py pytensor/tensor/basic.py -pytensor/tensor/blas_c.py pytensor/tensor/blas_headers.py pytensor/tensor/elemwise.py pytensor/tensor/extra_ops.py From 5d29fe3569045e297b9c8e192ced8749904ecf94 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 18:23:16 +0100 Subject: [PATCH 03/17] Make mypy happy with pytensor/compile/debugmode.py --- pytensor/compile/debugmode.py | 8 ++++---- scripts/mypy-failing.txt | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index 3b35cc2cd5..84caf8c9f9 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -29,7 +29,8 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Variable from pytensor.graph.destroyhandler import DestroyHandler -from pytensor.graph.features import AlreadyThere, BadOptimization +from pytensor.graph.features import AlreadyThere +from pytensor.graph.features import BadOptimization as _BadOptimization from pytensor.graph.fg import Output from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.traversal import io_toposort @@ -144,7 +145,7 @@ def str_diagnostic(self): return ret -class BadOptimization(DebugModeError, BadOptimization): +class BadOptimization(DebugModeError, _BadOptimization): pass @@ -2244,8 +2245,7 @@ class DebugMode(Mode): """ - check_preallocated_output = config.DebugMode__check_preallocated_output - check_preallocated_output = check_preallocated_output.split(":") + check_preallocated_output = config.DebugMode__check_preallocated_output.split(":") """ List of strings representing ways to pre-allocate output memory in tests. Valid values are: "previous" (previously-returned memory), diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 447385be54..045d2faaff 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -1,5 +1,4 @@ pytensor/compile/builders.py -pytensor/compile/debugmode.py pytensor/compile/function/pfunc.py pytensor/compile/function/types.py pytensor/compile/mode.py From 07d5d9aaffb3b53edb14897328589ff1e30ea88d Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 18:41:10 +0100 Subject: [PATCH 04/17] Make mypy happy with pytensor/compile/function/pfunc.py --- pytensor/compile/function/pfunc.py | 18 +++++++++--------- scripts/mypy-failing.txt | 1 - 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index b04c41fe6d..48bc699675 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -163,7 +163,7 @@ def rebuild_collect_shared( # This function implements similar functionality as graph.clone # and it should be merged with that - clone_d = {} + clone_d: dict = {} update_d = {} update_expr = [] # list of shared inputs that are used as inputs of the graph @@ -300,32 +300,32 @@ def clone_inputs(i): update_expr.append((store_into, update_val)) # Elements of "outputs" are here cloned to "cloned_outputs" + cloned_outputs: list[Variable] | Variable | Out | list[Out] if isinstance(outputs, list): - cloned_outputs = [] + cloned_outputs_list = [] for v in outputs: if isinstance(v, Variable): cloned_v = clone_v_get_shared_updates(v, copy_inputs_over) - cloned_outputs.append(cloned_v) + cloned_outputs_list.append(cloned_v) elif isinstance(v, Out): - cloned_v = clone_v_get_shared_updates(v.variable, copy_inputs_over) - cloned_outputs.append(Out(cloned_v, borrow=v.borrow)) + cloned_o = clone_v_get_shared_updates(v.variable, copy_inputs_over) + cloned_outputs_list.append(Out(cloned_o, borrow=v.borrow)) else: raise TypeError( "Outputs must be pytensor Variable or " "Out instances. Received " + str(v) + " of type " + str(type(v)) ) # computed_list.append(cloned_v) + cloned_outputs = cloned_outputs_list else: if isinstance(outputs, Variable): cloned_v = clone_v_get_shared_updates(outputs, copy_inputs_over) cloned_outputs = cloned_v # computed_list.append(cloned_v) elif isinstance(outputs, Out): - cloned_v = clone_v_get_shared_updates(outputs.variable, copy_inputs_over) - cloned_outputs = Out(cloned_v, borrow=outputs.borrow) + cloned_o = clone_v_get_shared_updates(outputs.variable, copy_inputs_over) + cloned_outputs = Out(cloned_o, borrow=outputs.borrow) # computed_list.append(cloned_v) - elif outputs is None: - cloned_outputs = [] # TODO: get Function.__call__ to return None else: raise TypeError( "output must be an PyTensor Variable or Out instance (or list of them)", diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 045d2faaff..aee737bd83 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -1,5 +1,4 @@ pytensor/compile/builders.py -pytensor/compile/function/pfunc.py pytensor/compile/function/types.py pytensor/compile/mode.py pytensor/graph/rewriting/basic.py From 13bc90257ba2b81ba94b8c80a8f39e2d80876aa1 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 18:50:32 +0100 Subject: [PATCH 05/17] Make mypy happy with pytensor/tensor/blas_headers.py --- pytensor/compile/mode.py | 3 ++- pytensor/tensor/blas_headers.py | 6 +++--- scripts/mypy-failing.txt | 2 -- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 5a5e0c9cdc..ee02518bf1 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -306,7 +306,7 @@ def __init__( self, linker: str | Linker | None = None, optimizer: str | RewriteDatabaseQuery = "default", - db: RewriteDatabase = None, + db: RewriteDatabase | None = None, ): if linker is None: linker = config.linker @@ -317,6 +317,7 @@ def __init__( self.__setstate__((linker, optimizer)) + self.optdb: RewriteDatabase if db is None: global optdb self.optdb = optdb diff --git a/pytensor/tensor/blas_headers.py b/pytensor/tensor/blas_headers.py index 5d49b70ec4..43409a29d3 100644 --- a/pytensor/tensor/blas_headers.py +++ b/pytensor/tensor/blas_headers.py @@ -150,9 +150,9 @@ def detect_macos_sdot_bug(): return detect_macos_sdot_bug.present -detect_macos_sdot_bug.tested = False -detect_macos_sdot_bug.present = False -detect_macos_sdot_bug.fix_works = False +detect_macos_sdot_bug.tested = False # type: ignore[attr-defined] +detect_macos_sdot_bug.present = False # type: ignore[attr-defined] +detect_macos_sdot_bug.fix_works = False # type: ignore[attr-defined] def cblas_header_text(): diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index aee737bd83..94588833f2 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -1,13 +1,11 @@ pytensor/compile/builders.py pytensor/compile/function/types.py -pytensor/compile/mode.py pytensor/graph/rewriting/basic.py pytensor/ifelse.py pytensor/link/numba/dispatch/elemwise.py pytensor/link/numba/dispatch/scan.py pytensor/raise_op.py pytensor/tensor/basic.py -pytensor/tensor/blas_headers.py pytensor/tensor/elemwise.py pytensor/tensor/extra_ops.py pytensor/tensor/math.py From 5209ce3ba11eb5861bf0e3df93ac725e8850627f Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 19:04:15 +0100 Subject: [PATCH 06/17] Improve typing in pytensor/tensor/random/op.py --- pytensor/tensor/random/op.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 02f1840521..b7bd39125c 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -62,11 +62,11 @@ class RandomVariable(RNGConsumerOp): def __init__( self, - name=None, + name: str | None = None, ndim_supp=None, ndims_params=None, - dtype: str | None = None, - inplace=None, + dtype: str | np.dtype | None = None, + inplace: bool | None = None, signature: str | None = None, ): """Create a random variable `Op`. @@ -115,7 +115,7 @@ def __init__( if self.signature is not None: # Assume a single output. Several methods need to be updated to handle multiple outputs. self.inputs_sig, [self.output_sig] = _parse_gufunc_signature(self.signature) - self.ndims_params = [len(input_sig) for input_sig in self.inputs_sig] + self.ndims_params = tuple([len(input_sig) for input_sig in self.inputs_sig]) self.ndim_supp = len(self.output_sig) else: if ( @@ -238,7 +238,7 @@ def _infer_shape( from pytensor.tensor.extra_ops import broadcast_shape_iter - supp_shape: tuple[Any] + supp_shape: tuple[Any, ...] if self.ndim_supp == 0: supp_shape = () else: @@ -406,19 +406,19 @@ def make_node(self, rng, size, *dist_params): def batch_ndim(self, node: Apply) -> int: return cast(int, node.default_output().type.ndim - self.ndim_supp) - def rng_param(self, node) -> Variable: + def rng_param(self, node: Apply) -> Variable: """Return the node input corresponding to the rng""" return node.inputs[0] - def size_param(self, node) -> Variable: + def size_param(self, node: Apply) -> Variable: """Return the node input corresponding to the size""" return node.inputs[1] - def dist_params(self, node) -> Sequence[Variable]: + def dist_params(self, node: Apply) -> Sequence[Variable]: """Return the node inpust corresponding to dist params""" return node.inputs[2:] - def perform(self, node, inputs, outputs): + def perform(self, node: Apply, inputs, outputs): rng, size, *args = inputs # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. From 3e764901ce279cdec84e836dd1b12d0066771252 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 19:04:45 +0100 Subject: [PATCH 07/17] Improve typing in pytensor/compile/builders.py --- pytensor/compile/builders.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 055497a8ff..7f78abba52 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -1,11 +1,11 @@ """Define new Ops from existing Ops""" import warnings -from collections.abc import Callable, Sequence +from collections.abc import Callable from copy import copy from functools import partial from itertools import chain -from typing import Union, cast +from typing import Union from pytensor.compile.function import function from pytensor.compile.function.pfunc import rebuild_collect_shared @@ -88,12 +88,12 @@ def local_traverse(out): def construct_nominal_fgraph( - inputs: Sequence[Variable], outputs: Sequence[Variable] + inputs: list[Variable], outputs: list[Variable] ) -> tuple[ FunctionGraph, - Sequence[Variable], - dict[Variable, Variable], - dict[Variable, Variable], + list[SharedVariable], + dict[SharedVariable, Variable], + list[Variable], ]: """Construct an inner-`FunctionGraph` with ordered nominal inputs.""" implicit_shared_inputs = [] @@ -119,7 +119,7 @@ def construct_nominal_fgraph( ) new = rebuild_collect_shared( - cast(Sequence[Variable], outputs), + outputs, inputs=inputs + implicit_shared_inputs, replace=replacements, copy_inputs_over=False, @@ -401,7 +401,7 @@ def __init__( self.output_types = [out.type for out in outputs] for override in (lop_overrides, grad_overrides, rop_overrides): - if override == "default": + if override == "default": # type: ignore[comparison-overlap] raise ValueError( "'default' is no longer a valid value for overrides. Use None instead." ) @@ -702,7 +702,7 @@ def _build_and_cache_rop_op(self): # Return a wrapper that combines connected and disconnected output gradients def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]: connected_output_grads = iter(rop_op(*inputs, **kwargs)) - all_output_grads = [] + all_output_grads: list[Variable | None] = [] for out_grad in output_grads: if isinstance(out_grad.type, DisconnectedType): # R_Op does not have DisconnectedType yet, None should be used instead From be0eea7aacec2e83e40e499a6607fc120fb782e4 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 19:05:53 +0100 Subject: [PATCH 08/17] Improve typing in pytensor/utils.py --- pytensor/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pytensor/utils.py b/pytensor/utils.py index a86b293973..ee2197c929 100644 --- a/pytensor/utils.py +++ b/pytensor/utils.py @@ -9,6 +9,7 @@ from collections.abc import Iterable, Sequence from functools import partial from pathlib import Path +from typing import TypeVar import numpy as np @@ -57,6 +58,9 @@ NDARRAY_C_VERSION = np._core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined] +T = TypeVar("T") + + def __call_excepthooks(type, value, trace): """ This function is meant to replace excepthook and do some @@ -205,7 +209,7 @@ def hash_from_code(msg: str | bytes) -> str: return f"m{hashlib.sha256(msg).hexdigest()}" -def uniq(seq: Sequence) -> list: +def uniq(seq: Sequence[T]) -> list[T]: """ Do not use set, this must always return the same value at the same index. If we just exchange other values, but keep the same pattern of duplication, @@ -217,7 +221,7 @@ def uniq(seq: Sequence) -> list: return [x for i, x in enumerate(seq) if seq.index(x) == i] -def difference(seq1: Iterable, seq2: Iterable): +def difference(seq1: Iterable[T], seq2: Iterable[T]) -> list[T]: r""" Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``. @@ -236,7 +240,7 @@ def difference(seq1: Iterable, seq2: Iterable): return [x for x in seq1 if x not in seq2] -def to_return_values(values): +def to_return_values(values: Sequence[T]) -> T | Sequence[T]: if len(values) == 1: return values[0] else: From 26f4157b2254f4ca24d801953f3ceda33ce5eb4c Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 19:09:59 +0100 Subject: [PATCH 09/17] Make the names of arguments of overriding methods consistent --- doc/extending/ctype.rst | 6 +++--- doc/extending/type.rst | 2 +- pytensor/link/c/type.py | 2 +- pytensor/sparse/type.py | 21 +++++++++------------ pytensor/tensor/type.py | 6 +++--- pytensor/tensor/type_other.py | 12 ++++++------ pytensor/tensor/variable.py | 10 +++++----- pytensor/typed_list/type.py | 12 ++++++------ pytensor/xtensor/type.py | 7 +++++-- tests/graph/test_destroyhandler.py | 2 +- tests/graph/test_op.py | 8 ++++---- tests/graph/utils.py | 4 ++-- tests/link/numba/test_basic.py | 2 +- tests/link/test_link.py | 2 +- tests/tensor/test_merge.py | 2 +- tests/tensor/test_optimize.py | 6 +++--- 16 files changed, 52 insertions(+), 52 deletions(-) diff --git a/doc/extending/ctype.rst b/doc/extending/ctype.rst index 7f0f28003b..92d1e4e9e1 100644 --- a/doc/extending/ctype.rst +++ b/doc/extending/ctype.rst @@ -463,10 +463,10 @@ Final version class Double(Type): - def filter(self, x, strict=False, allow_downcast=None): - if strict and not isinstance(x, float): + def filter(self, data, strict=False, allow_downcast=None): + if strict and not isinstance(data, float): raise TypeError('Expected a float!') - return float(x) + return float(data) def values_eq_approx(self, x, y, tolerance=1e-4): return abs(x - y) / (x + y) < tolerance diff --git a/doc/extending/type.rst b/doc/extending/type.rst index 5f0c723c3f..c326e8df8f 100644 --- a/doc/extending/type.rst +++ b/doc/extending/type.rst @@ -417,7 +417,7 @@ required methods of the interface, except ``filter``. class DoubleType(Type): - def filter(self, x, strict=False, allow_downcast=None): + def filter(self, data, strict=False, allow_downcast=None): # See code above. ... diff --git a/pytensor/link/c/type.py b/pytensor/link/c/type.py index 84715eebb6..39f3cf5dde 100644 --- a/pytensor/link/c/type.py +++ b/pytensor/link/c/type.py @@ -73,7 +73,7 @@ class Generic(CType, Singleton): def filter(self, data, strict=False, allow_downcast=None): return data - def is_valid_value(self, a): + def is_valid_value(self, data, strict: bool = True) -> bool: return True def c_declare(self, name, sub, check_input=True): diff --git a/pytensor/sparse/type.py b/pytensor/sparse/type.py index bbc8a9fda1..a3fb703192 100644 --- a/pytensor/sparse/type.py +++ b/pytensor/sparse/type.py @@ -98,31 +98,28 @@ def clone( shape = self.shape return type(self)(format, dtype, shape=shape, **kwargs) - def filter(self, value, strict=False, allow_downcast=None): - if isinstance(value, Variable): + def filter(self, data, strict: bool = False, allow_downcast=None): + if isinstance(data, Variable): raise TypeError( "Expected an array-like object, but found a Variable: " "maybe you are trying to call a function on a (possibly " "shared) variable instead of a numeric array?" ) - if ( - isinstance(value, self.format_cls[self.format]) - and value.dtype == self.dtype - ): - return value + if isinstance(data, self.format_cls[self.format]) and data.dtype == self.dtype: + return data if strict: raise TypeError( - f"{value} is not sparse, or not the right dtype (is {value.dtype}, " + f"{data} is not sparse, or not the right dtype (is {data.dtype}, " f"expected {self.dtype})" ) # The input format could be converted here if allow_downcast: - sp = self.format_cls[self.format](value, dtype=self.dtype) + sp = self.format_cls[self.format](data, dtype=self.dtype) else: - data = self.format_cls[self.format](value) + data = self.format_cls[self.format](data) up_dtype = ps.upcast(self.dtype, data.dtype) if up_dtype != self.dtype: raise TypeError(f"Expected {self.dtype} dtype but got {data.dtype}") @@ -209,8 +206,8 @@ def values_eq(self, a, b): and abs(a - b).sum() == 0.0 ) - def is_valid_value(self, a): - return scipy.sparse.issparse(a) and (a.format == self.format) + def is_valid_value(self, data, strict: bool = True): + return scipy.sparse.issparse(data) and (data.format == self.format) def get_shape_info(self, obj): obj = self.filter(obj) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 5ae92006e2..e096806f7e 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -1,7 +1,7 @@ import logging import warnings from collections.abc import Iterable -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional import numpy as np import numpy.typing as npt @@ -630,8 +630,8 @@ def c_code_cache_version(self): class DenseTypeMeta(MetaType): - def __instancecheck__(self, o): - if type(o) is TensorType or isinstance(o, DenseTypeMeta): + def __instancecheck__(self, instance: Any) -> bool: + if type(instance) is TensorType or isinstance(instance, DenseTypeMeta): return True return False diff --git a/pytensor/tensor/type_other.py b/pytensor/tensor/type_other.py index 61545df370..16c4b0fb41 100644 --- a/pytensor/tensor/type_other.py +++ b/pytensor/tensor/type_other.py @@ -54,9 +54,9 @@ class SliceType(Type[slice]): def clone(self, **kwargs): return type(self)() - def filter(self, x, strict=False, allow_downcast=None): - if isinstance(x, slice): - return x + def filter(self, data, strict=False, allow_downcast=None): + if isinstance(data, slice): + return data else: raise TypeError("Expected a slice!") @@ -123,9 +123,9 @@ class NoneTypeT(Generic): """ - def filter(self, x, strict=False, allow_downcast=None): - if x is None: - return x + def filter(self, data, strict=False, allow_downcast=None): + if data is None: + return data else: raise TypeError("Expected None!") diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 31e08fd39b..8a4fee7186 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -3,7 +3,7 @@ import warnings from collections.abc import Iterable from numbers import Number -from typing import TypeVar +from typing import Any, TypeVar import numpy as np @@ -1117,8 +1117,8 @@ def __deepcopy__(self, memo): class DenseVariableMeta(MetaType): - def __instancecheck__(self, o): - if type(o) is TensorVariable or isinstance(o, DenseVariableMeta): + def __instancecheck__(self, instance: Any) -> bool: + if type(instance) is TensorVariable or isinstance(instance, DenseVariableMeta): return True return False @@ -1132,8 +1132,8 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta): class DenseConstantMeta(MetaType): - def __instancecheck__(self, o): - if type(o) is TensorConstant or isinstance(o, DenseConstantMeta): + def __instancecheck__(self, instance: Any) -> bool: + if type(instance) is TensorConstant or isinstance(instance, DenseConstantMeta): return True return False diff --git a/pytensor/typed_list/type.py b/pytensor/typed_list/type.py index bd10a501b6..7f4b7667e6 100644 --- a/pytensor/typed_list/type.py +++ b/pytensor/typed_list/type.py @@ -25,12 +25,12 @@ def __init__(self, ttype, depth=0): else: self.ttype = TypedListType(ttype, depth - 1) - def filter(self, x, strict=False, allow_downcast=None): + def filter(self, data, strict: bool = False, allow_downcast=None): """ Parameters ---------- - x + data Value to filter. strict If true, only native python list will be accepted. @@ -39,13 +39,13 @@ def filter(self, x, strict=False, allow_downcast=None): """ if strict: - if not isinstance(x, list): + if not isinstance(data, list): raise TypeError("Expected a python list") else: - x = [self.ttype.filter(y) for y in x] + data = [self.ttype.filter(y) for y in data] - if all(self.ttype.is_valid_value(y) for y in x): - return x + if all(self.ttype.is_valid_value(y) for y in data): + return data else: raise TypeError(f"Expected all elements to be {self.ttype}") diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index db6a66036d..8e8b5cd709 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -89,10 +89,13 @@ def clone( shape = self.shape return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs) - def filter(self, value, strict=False, allow_downcast=None): + def filter(self, data, strict: bool = False, allow_downcast=None): # XTensorType behaves like TensorType at runtime, so we filter the same way. return TensorType.filter( - self, value, strict=strict, allow_downcast=allow_downcast + typing.cast(TensorType, self), + data, + strict=strict, + allow_downcast=allow_downcast, ) @staticmethod diff --git a/tests/graph/test_destroyhandler.py b/tests/graph/test_destroyhandler.py index 761787e54c..7bb9ca8285 100644 --- a/tests/graph/test_destroyhandler.py +++ b/tests/graph/test_destroyhandler.py @@ -37,7 +37,7 @@ def as_variable(x): class MyType(Type): - def filter(self, data): + def filter(self, data, strict=False, allow_downcast=None): return data def __eq__(self, other): diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index d0d8b6c5fb..7dfb8f2434 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -33,14 +33,14 @@ def __str__(self): def __repr__(self): return str(self.thingy) - def filter(self, x, strict=False, allow_downcast=None): + def filter(self, data, strict=False, allow_downcast=None): # Dummy filter: we want this type to represent strings that # start with `self.thingy`. - if not isinstance(x, str): + if not isinstance(data, str): raise TypeError("Invalid type") - if not x.startswith(self.thingy): + if not data.startswith(self.thingy): raise ValueError("Invalid value") - return x + return data # Added to make those tests pass in DebugMode @staticmethod diff --git a/tests/graph/utils.py b/tests/graph/utils.py index 2e14fc79a4..74fe9235f9 100644 --- a/tests/graph/utils.py +++ b/tests/graph/utils.py @@ -14,7 +14,7 @@ def is_variable(x): class MyType(Type): - def filter(self, data): + def filter(self, data, strict=False, allow_downcast=None): return data def __eq__(self, other): @@ -28,7 +28,7 @@ def __repr__(self): class MyType2(Type): - def filter(self, data): + def filter(self, data, strict=False, allow_downcast=None): return data def __eq__(self, other): diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 5dad9eb8e3..a273bd53db 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -41,7 +41,7 @@ class MyType(Type): - def filter(self, data): + def filter(self, data, strict=False, allow_downcast=None): return data def __eq__(self, other): diff --git a/tests/link/test_link.py b/tests/link/test_link.py index 410329aa4c..dafaaace53 100644 --- a/tests/link/test_link.py +++ b/tests/link/test_link.py @@ -61,7 +61,7 @@ def as_variable(x): class TDouble(Type): - def filter(self, data): + def filter(self, data, strict=False, allow_downcast=None): return float(data) diff --git a/tests/tensor/test_merge.py b/tests/tensor/test_merge.py index 561b5d7c89..7b098f21ac 100644 --- a/tests/tensor/test_merge.py +++ b/tests/tensor/test_merge.py @@ -15,7 +15,7 @@ def is_variable(x): class MyType(Type): - def filter(self, data): + def filter(self, data, strict=False, allow_downcast=None): return data def __eq__(self, other): diff --git a/tests/tensor/test_optimize.py b/tests/tensor/test_optimize.py index b211381e30..6b1691d147 100644 --- a/tests/tensor/test_optimize.py +++ b/tests/tensor/test_optimize.py @@ -289,9 +289,9 @@ def test_optimize_grad_disconnected_numerical_inp(optimize_op): @pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar)) def test_optimize_grad_disconnected_non_numerical_inp(optimize_op): class StrType(Type): - def filter(self, x, **kwargs): - if isinstance(x, str): - return x + def filter(self, data, strict=False, allow_downcast=None): + if isinstance(data, str): + return data raise TypeError class SmileOrFrown(Op): From 4f56ee08ed4d5b942f3ca18f835ba5f0a0d91f14 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 19:10:31 +0100 Subject: [PATCH 10/17] Various typing improvements --- pytensor/breakpoint.py | 20 +++++++++++--------- pytensor/configdefaults.py | 2 +- pytensor/tensor/rewriting/basic.py | 11 +++++------ pytensor/xtensor/vectorization.py | 8 ++++---- 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/pytensor/breakpoint.py b/pytensor/breakpoint.py index 3d59b5c24c..cb9255d589 100644 --- a/pytensor/breakpoint.py +++ b/pytensor/breakpoint.py @@ -4,6 +4,7 @@ from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import Op from pytensor.tensor.basic import as_tensor_variable +from pytensor.tensor.type import TensorType class PdbBreakpoint(Op): @@ -50,23 +51,26 @@ class PdbBreakpoint(Op): # as the individual error values breakpointOp = PdbBreakpoint("MSE too high") condition = pt.gt(mse.sum(), 100) - mse, monitored_input, monitored_target = breakpointOp(condition, mse, - input, target) + mse, monitored_input, monitored_target = breakpointOp( + condition, mse, input, target + ) # Compile the pytensor function fct = pytensor.function([input, target], mse) # Use the function - print fct([10, 0], [10, 5]) # Will NOT activate the breakpoint - print fct([0, 0], [10, 5]) # Will activate the breakpoint + print(fct([10, 0], [10, 5])) # Will NOT activate the breakpoint + print(fct([0, 0], [10, 5])) # Will activate the breakpoint """ __props__ = ("name",) - def __init__(self, name): + def __init__(self, name: str): self.name = name + self.view_map = {} + self.inp_types: list[TensorType] = [] def make_node(self, condition, *monitored_vars): # Ensure that condition is an PyTensor tensor @@ -83,13 +87,11 @@ def make_node(self, condition, *monitored_vars): # (view_map and var_types) in that instance and then apply it on the # inputs. new_op = PdbBreakpoint(name=self.name) - new_op.view_map = {} - new_op.inp_types = [] - for i in range(len(monitored_vars)): + for i, var in enumerate(monitored_vars): # Every output i is a view of the input i+1 because of the input # condition. new_op.view_map[i] = [i + 1] - new_op.inp_types.append(monitored_vars[i].type) + new_op.inp_types.append(var.type) # Build the Apply node inputs = [condition, *monitored_vars] diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index 0847ca29f6..2350489249 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -101,7 +101,7 @@ def _good_seem_param(seed): return True try: int(seed) - except Exception: + except ValueError: return False return True diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index e309c9f485..57a6988d86 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -34,7 +34,6 @@ from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, NodeRewriter, - Rewriter, copy_stack_trace, dfs_rewriter, in2out, @@ -158,7 +157,7 @@ def register_useless( ): if isinstance(node_rewriter, str): - def register(inner_rewriter: RewriteDatabase | Rewriter): + def register(inner_rewriter: RewriteDatabase | NodeRewriter): return register_useless(inner_rewriter, node_rewriter, *tags, **kwargs) return register @@ -176,7 +175,7 @@ def register_canonicalize( ): if isinstance(node_rewriter, str): - def register(inner_rewriter: RewriteDatabase | Rewriter): + def register(inner_rewriter: RewriteDatabase | NodeRewriter): return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs) return register @@ -193,7 +192,7 @@ def register_stabilize( ): if isinstance(node_rewriter, str): - def register(inner_rewriter: RewriteDatabase | Rewriter): + def register(inner_rewriter: RewriteDatabase | NodeRewriter): return register_stabilize(inner_rewriter, node_rewriter, *tags, **kwargs) return register @@ -210,7 +209,7 @@ def register_specialize( ): if isinstance(node_rewriter, str): - def register(inner_rewriter: RewriteDatabase | Rewriter): + def register(inner_rewriter: RewriteDatabase | NodeRewriter): return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs) return register @@ -227,7 +226,7 @@ def register_uncanonicalize( ): if isinstance(node_rewriter, str): - def register(inner_rewriter: RewriteDatabase | Rewriter): + def register(inner_rewriter: RewriteDatabase | NodeRewriter): return register_uncanonicalize( inner_rewriter, node_rewriter, *tags, **kwargs ) diff --git a/pytensor/xtensor/vectorization.py b/pytensor/xtensor/vectorization.py index a6cbb2b5c3..9b44f4e0f1 100644 --- a/pytensor/xtensor/vectorization.py +++ b/pytensor/xtensor/vectorization.py @@ -106,9 +106,9 @@ def make_node(self, *inputs): dummy_core_inputs = [] for inp, core_inp_dims in zip(inputs, core_inputs_dims): try: - core_static_shape = [ + core_static_shape = tuple( inp.type.shape[inp.type.dims.index(d)] for d in core_inp_dims - ] + ) except IndexError: raise ValueError( f"At least one core dim={core_inp_dims} missing from input {inp} with dims={inp.type.dims}" @@ -251,9 +251,9 @@ def make_node(self, rng, *extra_dim_lengths_and_params): dummy_core_inputs = [] for param, core_param_dims in zip(params, param_core_dims): try: - core_static_shape = [ + core_static_shape = tuple( param.type.shape[param.type.dims.index(d)] for d in core_param_dims - ] + ) except ValueError: raise ValueError( f"At least one core dim={core_param_dims} missing from input {param} with dims={param.type.dims}" From 515022407b50ad3a3c376e6f5ea94a6e74830881 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 20:42:14 +0100 Subject: [PATCH 11/17] Replace tuple[T] with tuple[T, ...] when appropriate --- pytensor/compile/function/types.py | 4 ++-- pytensor/link/numba/dispatch/blockwise.py | 2 +- pytensor/tensor/subtensor.py | 4 ++-- pytensor/tensor/type.py | 2 +- pytensor/xtensor/reduction.py | 2 +- pytensor/xtensor/shape.py | 2 +- tests/link/jax/test_slinalg.py | 4 ++-- tests/link/numba/test_slinalg.py | 4 ++-- tests/tensor/test_slinalg.py | 20 ++++++++++---------- tests/tensor/test_subtensor.py | 2 +- 10 files changed, 23 insertions(+), 23 deletions(-) diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index d77f11d84d..7b59b59a0c 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -1637,7 +1637,7 @@ def __init__( if any(self.refeed): warnings.warn("Inputs with default values are deprecated.", FutureWarning) - def create(self, input_storage=None, storage_map=None): + def create(self, input_storage=None, storage_map=None) -> Function: """ Create a function. @@ -1730,7 +1730,7 @@ def create(self, input_storage=None, storage_map=None): import_time = pytensor.link.c.cmodule.import_time - start_import_time self.profile.import_time += import_time - fn = self.function_builder( + fn: Function = self.function_builder( _fn, _i, _o, diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index e0b086e89c..ab1ba97080 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -30,7 +30,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:]) core_node = blockwise_op._create_dummy_core_node( - cast(tuple[TensorVariable], node.inputs[:nin]), + cast(tuple[TensorVariable, ...], node.inputs[:nin]), propagate_unbatched_core_inputs=True, ) core_op_fn, core_op_key = numba_funcify_and_cache_key( diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 1e21e67726..37010c58a2 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -3112,7 +3112,7 @@ def slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]: def flip( - arr: TensorVariable, axis: int | tuple[int] | TensorVariable | None = None + arr: TensorVariable, axis: int | tuple[int, ...] | TensorVariable | None = None ) -> TensorVariable: """ Reverse the order of elements in an tensor along the given axis. @@ -3122,7 +3122,7 @@ def flip( arr: TensorVariable Input tensor. - axis: int | tuple[int] | TensorVariable, optional + axis: int | tuple[int, ...] | TensorVariable, optional Axis or axes along which to flip over. The default is to flip over all of the axes of the input tensor. Returns diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index e096806f7e..3a166e717a 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -880,7 +880,7 @@ def vector( name: str | None = None, *, dtype: Optional["DTypeLike"] = None, - shape: tuple[ST] | None = (None,), + shape: tuple[ST, ...] | None = (None,), ) -> "TensorVariable": """Return a symbolic vector variable. diff --git a/pytensor/xtensor/reduction.py b/pytensor/xtensor/reduction.py index 1379ae4fc9..fdd6469995 100644 --- a/pytensor/xtensor/reduction.py +++ b/pytensor/xtensor/reduction.py @@ -52,7 +52,7 @@ def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]: return (dim,) elif dim is None or dim is Ellipsis: x = as_xtensor(x) - return typing.cast(tuple[str], x.type.dims) + return typing.cast(tuple[str, ...], x.type.dims) return dim diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 0a7960228f..225ec5f126 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -251,7 +251,7 @@ def transpose( # No-op transpose return x - return Transpose(dims=typing.cast(tuple[str], dim))(x) + return Transpose(dims=typing.cast(tuple[str, ...], dim))(x) class Concat(XOp): diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 49490994b1..4aa5af84ae 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -271,7 +271,7 @@ def test_jax_eigvalsh(lower): @pytest.mark.parametrize("method", ["direct", "bilinear"]) @pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"]) def test_jax_solve_discrete_lyapunov( - method: Literal["direct", "bilinear"], shape: tuple[int] + method: Literal["direct", "bilinear"], shape: tuple[int, ...] ): A = pt.tensor(name="A", shape=shape) B = pt.tensor(name="B", shape=shape) @@ -297,7 +297,7 @@ def test_jax_solve_discrete_lyapunov( ) @pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"]) @pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"]) -def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int]): +def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int, ...]): rng = np.random.default_rng() A = pt.tensor( "A", diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index a97465e68a..dba14465fd 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -115,7 +115,7 @@ class TestSolves: @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos", "tridiagonal"], ids=str) def test_solve( self, - b_shape: tuple[int], + b_shape: tuple[int, ...], assume_a: Literal["gen", "sym", "pos"], lower: bool, overwrite_a: bool, @@ -236,7 +236,7 @@ def A_func(x): @pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) def test_solve_triangular( self, - b_shape: tuple[int], + b_shape: tuple[int, ...], lower: bool, transposed: bool, unit_diagonal: bool, diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index c2cf51f634..e447926532 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -316,7 +316,7 @@ def test_infer_shape(self, b_shape): "assume_a, lower, transposed", solve_test_cases, ids=solve_test_ids ) def test_solve_correctness( - self, b_size: tuple[int], assume_a: str, lower: bool, transposed: bool + self, b_size: tuple[int, ...], assume_a: str, lower: bool, transposed: bool ): rng = np.random.default_rng(utt.fetch_seed()) A = pt.tensor("A", shape=(5, 5)) @@ -370,7 +370,7 @@ def test_solve_correctness( config.floatX == "float32", reason="Gradients not numerically stable in float32" ) def test_solve_gradient( - self, b_size: tuple[int], assume_a: str, lower: bool, transposed: bool + self, b_size: tuple[int, ...], assume_a: str, lower: bool, transposed: bool ): rng = np.random.default_rng(utt.fetch_seed()) @@ -447,7 +447,7 @@ def test_infer_shape(self, b_shape): @pytest.mark.parametrize("lower", [True, False]) @pytest.mark.parametrize("trans", [0, 1, 2]) @pytest.mark.parametrize("unit_diagonal", [True, False]) - def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal): + def test_correctness(self, b_shape: tuple[int, ...], lower, trans, unit_diagonal): rng = np.random.default_rng(utt.fetch_seed()) A = pt.tensor("A", shape=(5, 5)) b = pt.tensor("b", shape=b_shape) @@ -650,7 +650,7 @@ def test_solve_dtype(self): @pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"]) @pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"]) def test_lu_decomposition( - permute_l: bool, p_indices: bool, complex: bool, shape: tuple[int] + permute_l: bool, p_indices: bool, complex: bool, shape: tuple[int, ...] ): dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}" @@ -749,7 +749,7 @@ def factor_and_solve(A, b, sum=False, **lu_kwargs): @pytest.mark.parametrize("b_shape", [(5,), (5, 5)], ids=["b_vec", "b_matrix"]) @pytest.mark.parametrize("trans", [True, False], ids=["x_T", "x"]) - def test_lu_solve(self, b_shape: tuple[int], trans): + def test_lu_solve(self, b_shape: tuple[int, ...], trans): rng = np.random.default_rng(utt.fetch_seed()) A = pt.tensor("A", shape=(5, 5)) b = pt.tensor("b", shape=b_shape) @@ -785,7 +785,7 @@ def T(x): @pytest.mark.parametrize("b_shape", [(5,), (5, 5)], ids=["b_vec", "b_matrix"]) @pytest.mark.parametrize("trans", [True, False], ids=["x_T", "x"]) - def test_lu_solve_gradient(self, b_shape: tuple[int], trans: bool): + def test_lu_solve_gradient(self, b_shape: tuple[int, ...], trans: bool): rng = np.random.default_rng(utt.fetch_seed()) A_val = rng.normal(size=(5, 5)).astype(config.floatX) @@ -925,7 +925,7 @@ def recover_Q(A, X, continuous=True): @pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"]) @pytest.mark.parametrize("method", ["direct", "bilinear"]) def test_solve_discrete_lyapunov( - use_complex, shape: tuple[int], method: Literal["direct", "bilinear"] + use_complex, shape: tuple[int, ...], method: Literal["direct", "bilinear"] ): rng = np.random.default_rng(utt.fetch_seed()) dtype = config.floatX @@ -962,7 +962,7 @@ def test_solve_discrete_lyapunov( @pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"]) @pytest.mark.parametrize("method", ["direct", "bilinear"]) def test_solve_discrete_lyapunov_gradient( - use_complex, shape: tuple[int], method: Literal["direct", "bilinear"] + use_complex, shape: tuple[int, ...], method: Literal["direct", "bilinear"] ): if config.floatX == "float32": pytest.skip(reason="Not enough precision in float32 to get a good gradient") @@ -982,7 +982,7 @@ def test_solve_discrete_lyapunov_gradient( @pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"]) @pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"]) -def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool): +def test_solve_continuous_lyapunov(shape: tuple[int, ...], use_complex: bool): dtype = config.floatX if use_complex and dtype == "float32": pytest.skip( @@ -1023,7 +1023,7 @@ def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool): @pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"]) @pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"]) -def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex): +def test_solve_continuous_lyapunov_grad(shape: tuple[int, ...], use_complex): if config.floatX == "float32": pytest.skip(reason="Not enough precision in float32 to get a good gradient") if use_complex: diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 6f79694e25..47bebb72a8 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -3145,7 +3145,7 @@ def test_slice_at_axis(): @pytest.mark.parametrize( "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] ) -def test_flip(size: tuple[int]): +def test_flip(size: tuple[int, ...]): from itertools import combinations ATOL = RTOL = 1e-8 if config.floatX == "float64" else 1e-4 From 0cb3898d32a2a8ef3e0b6baafbec78d5cab8f4ef Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 21:31:19 +0100 Subject: [PATCH 12/17] Use existing divmod --- pytensor/xtensor/type.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 8e8b5cd709..620eca4956 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -14,7 +14,7 @@ as_tensor_variable, specify_shape, ) -from pytensor.tensor.math import variadic_mul +from pytensor.tensor.math import divmod, variadic_mul try: @@ -318,7 +318,7 @@ def __mod__(self, other): return px.math.mod(self, other) def __divmod__(self, other): - return px.math.divmod(self, other) + return divmod(self, other) def __truediv__(self, other): return px.math.true_div(self, other) @@ -345,7 +345,7 @@ def __rmod__(self, other): return px.math.mod(other, self) def __rdivmod__(self, other): - return px.math.divmod(other, self) + return divmod(other, self) def __rpow__(self, other): return px.math.pow(other, self) From 5b9e06b4ad5c15a509225a839555ed38d6e6ac7f Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 22:03:19 +0100 Subject: [PATCH 13/17] Standardize Op.perform arguments --- pytensor/compile/builders.py | 4 +- pytensor/compile/ops.py | 18 ++-- pytensor/graph/fg.py | 2 +- pytensor/link/jax/ops.py | 8 +- pytensor/raise_op.py | 4 +- pytensor/sparse/basic.py | 134 ++++++++++++------------- pytensor/sparse/math.py | 72 ++++++------- pytensor/sparse/rewriting.py | 8 +- pytensor/tensor/basic.py | 56 +++++------ pytensor/tensor/blas.py | 28 +++--- pytensor/tensor/conv/abstract_conv.py | 18 ++-- pytensor/tensor/elemwise.py | 16 +-- pytensor/tensor/extra_ops.py | 20 ++-- pytensor/tensor/math.py | 6 +- pytensor/tensor/nlinalg.py | 54 +++++----- pytensor/tensor/optimize.py | 24 ++--- pytensor/tensor/random/op.py | 6 +- pytensor/tensor/shape.py | 24 ++--- pytensor/tensor/signal/conv.py | 10 +- pytensor/tensor/slinalg.py | 82 +++++++-------- pytensor/tensor/special.py | 12 +-- pytensor/tensor/subtensor.py | 16 +-- pytensor/tensor/type_other.py | 6 +- pytensor/typed_list/basic.py | 46 ++++----- pytensor/xtensor/basic.py | 2 +- tests/compile/test_debugmode.py | 60 +++++------ tests/graph/rewriting/test_kanren.py | 6 +- tests/graph/rewriting/test_unify.py | 10 +- tests/graph/test_compute_test_value.py | 8 +- tests/graph/test_op.py | 8 +- tests/graph/utils.py | 4 +- tests/link/c/test_basic.py | 8 +- tests/link/c/test_cmodule.py | 4 +- tests/link/c/test_op.py | 4 +- tests/link/c/test_type.py | 4 +- tests/link/jax/test_basic.py | 4 +- tests/link/jax/test_scan.py | 4 +- tests/link/numba/test_basic.py | 28 +++--- tests/link/test_link.py | 4 +- tests/link/test_vm.py | 8 +- tests/scalar/test_basic.py | 4 +- tests/scan/test_basic.py | 2 +- tests/sparse/test_basic.py | 4 +- tests/tensor/rewriting/test_shape.py | 12 +-- tests/tensor/test_blockwise.py | 18 ++-- tests/tensor/test_elemwise.py | 2 +- tests/tensor/test_slinalg.py | 4 +- tests/test_rop.py | 6 +- 48 files changed, 447 insertions(+), 445 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 7f78abba52..5b2df99d20 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -875,8 +875,8 @@ def clone(self): res.fgraph = res.fgraph.clone() return res - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): variables = self.fn(*inputs) # zip strict not specified because we are in a hot loop - for output, variable in zip(outputs, variables): + for output, variable in zip(output_storage, variables): output[0] = variable diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 51398cd7d8..1d3ac709a4 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -44,8 +44,8 @@ class TypeCastingOp(COp): __props__: tuple = () _f16_ok: bool = True - def perform(self, node, inputs, outputs_storage): - outputs_storage[0][0] = inputs[0] + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] def __str__(self): return f"{self.__class__.__name__}" @@ -160,15 +160,15 @@ def __init__(self): def make_node(self, x): return Apply(self, [x], [x.type()]) - def perform(self, node, args, outs): - if hasattr(args[0], "copy"): + def perform(self, node, inputs, output_storage): + if hasattr(inputs[0], "copy"): # when args[0] is a an ndarray of 0 dimensions, # this return a numpy.dtype and not an ndarray # So when the args have a copy attribute we use it # as this don't have this problem - outs[0][0] = args[0].copy() + output_storage[0][0] = inputs[0].copy() else: - outs[0][0] = copy.deepcopy(args[0]) + output_storage[0][0] = copy.deepcopy(inputs[0]) def c_code_cache_version(self): version = [] @@ -253,13 +253,13 @@ def __hash__(self): def __str__(self): return f"FromFunctionOp{{{self.__fn.__name__}}}" - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): outs = self.__fn(*inputs) if not isinstance(outs, list | tuple): outs = (outs,) - assert len(outs) == len(outputs) + assert len(outs) == len(output_storage) for i in range(len(outs)): - outputs[i][0] = outs[i] + output_storage[i][0] = outs[i] def __reduce__(self): mod = self.__fn.__module__ diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index 2728d50a50..64336e9e32 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -40,7 +40,7 @@ def __init__(self, idx): def make_node(self, inp): return Apply(self, [inp], []) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): raise RuntimeError("Output Ops should never be evaluated") def __str__(self): diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 38978eef75..d4adbead6c 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -111,18 +111,18 @@ def make_node(self, *inputs: Variable) -> Apply: outputs = [output_type() for output_type in self.output_types] return Apply(self, filtered_inputs, outputs) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): """Execute the JAX function and store results in output storage.""" results = self.jitted_func(*inputs) if not isinstance(results, tuple): raise TypeError("JAX function must return a tuple of outputs.") - if len(results) != len(outputs): + if len(results) != len(output_storage): raise ValueError( f"JAX function returned {len(results)} outputs, but " - f"{len(outputs)} were expected." + f"{len(output_storage)} were expected." ) for output_container, result, out_type in zip( - outputs, results, self.output_types + output_storage, results, self.output_types ): output_container[0] = np.array(result, dtype=out_type.dtype) diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index e23078b8ae..fd0c424d17 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -81,8 +81,8 @@ def make_node(self, value: Variable, *conds: Variable): [value.type()], ) - def perform(self, node, inputs, outputs): - (out,) = outputs + def perform(self, node, inputs, output_storage): + (out,) = output_storage val, *conds = inputs out[0] = val if not all(conds): diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 60ac79f149..e47f81206b 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -300,15 +300,15 @@ def make_node(self, csm): data = TensorType(dtype=csm.type.dtype, shape=(None,))() return Apply(self, [csm], [data, ivector(), ivector(), ivector()]) - def perform(self, node, inputs, out): + def perform(self, node, inputs, output_storage): (csm,) = inputs - out[0][0] = csm.data + output_storage[0][0] = csm.data if str(csm.data.dtype) == "int32": - out[0][0] = np.asarray(out[0][0], dtype="int32") + output_storage[0][0] = np.asarray(output_storage[0][0], dtype="int32") # backport - out[1][0] = np.asarray(csm.indices, dtype="int32") - out[2][0] = np.asarray(csm.indptr, dtype="int32") - out[3][0] = np.asarray(csm.shape, dtype="int32") + output_storage[1][0] = np.asarray(csm.indices, dtype="int32") + output_storage[2][0] = np.asarray(csm.indptr, dtype="int32") + output_storage[3][0] = np.asarray(csm.shape, dtype="int32") def grad(self, inputs, g): # g[1:] is all integers, so their Jacobian in this op @@ -440,10 +440,10 @@ def make_node(self, data, indices, indptr, shape): [SparseTensorType(dtype=data.type.dtype, format=self.format)()], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): # for efficiency, if remap does nothing, then do not apply it (data, indices, indptr, _shape) = inputs - (out,) = outputs + (out,) = output_storage if len(_shape) != 2: raise ValueError("Shape should be an array of length 2") @@ -543,7 +543,7 @@ def make_node( [gout_data], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): ( x_data, x_indices, @@ -554,7 +554,7 @@ def perform(self, node, inputs, outputs): g_indptr, _g_shape, ) = inputs - (g_out,) = outputs + (g_out,) = output_storage if len(x_indptr) - 1 == x_shape[0]: sp_dim = x_shape[1] else: @@ -595,9 +595,9 @@ def make_node(self, x): self, [x], [SparseTensorType(dtype=self.out_type, format=x.format)()] ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (out,) = outputs + (out,) = output_storage assert _is_sparse(x) out[0] = x.astype(self.out_type) @@ -701,9 +701,9 @@ def make_node(self, x): [TensorType(dtype=x.type.dtype, shape=(None, None))()], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (out,) = outputs + (out,) = output_storage if _is_dense(x): warn( "You just called DenseFromSparse on a dense matrix.", @@ -783,9 +783,9 @@ def make_node(self, x): self, [x], [SparseTensorType(dtype=x.type.dtype, format=self.format)()] ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (out,) = outputs + (out,) = output_storage out[0] = SparseTensorType.format_cls[self.format](x) def grad(self, inputs, gout): @@ -834,10 +834,10 @@ def make_node(self, x, index): return Apply(self, [x, ind], [x.type()]) - def perform(self, node, inp, outputs): - (out,) = outputs - x = inp[0] - indices = inp[1] + def perform(self, node, inputs, output_storage): + (out,) = output_storage + x = inputs[0] + indices = inputs[1] assert _is_sparse(x) out[0] = x[indices] @@ -877,11 +877,11 @@ def make_node(self, x, index, gz): return Apply(self, [x, ind, gz], [x.type()]) - def perform(self, node, inp, outputs): - (out,) = outputs - x = inp[0] - indices = inp[1] - gz = inp[2] + def perform(self, node, inputs, output_storage): + (out,) = output_storage + x = inputs[0] + indices = inputs[1] + gz = inputs[2] if x.format in ["csr"]: y = scipy.sparse.csr_matrix((x.shape[0], x.shape[1])) @@ -922,11 +922,11 @@ def make_node(self, x, ind1, ind2): return Apply(self, [x, ind1, ind2], [vector()]) - def perform(self, node, inp, outputs): - (out,) = outputs - x = inp[0] - ind1 = inp[1] - ind2 = inp[2] + def perform(self, node, inputs, output_storage): + (out,) = output_storage + x = inputs[0] + ind1 = inputs[1] + ind2 = inputs[2] # SciPy returns the corresponding elements as a `matrix`-type instance, # which isn't what we want, so we convert it into an `ndarray` out[0] = np.asarray(x[ind1, ind2]).flatten() @@ -964,12 +964,12 @@ def make_node(self, x, ind1, ind2, gz): return Apply(self, [x, ind1, ind2, gz], [x.type()]) - def perform(self, node, inp, outputs): - (out,) = outputs - x = inp[0] - ind1 = inp[1] - ind2 = inp[2] - gz = inp[3] + def perform(self, node, inputs, output_storage): + (out,) = output_storage + x = inputs[0] + ind1 = inputs[1] + ind2 = inputs[2] + gz = inputs[3] if x.format in ["csr"]: y = scipy.sparse.csr_matrix((x.shape[0], x.shape[1])) @@ -1104,9 +1104,9 @@ def make_node(self, x, index): return Apply(self, input_op, [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, start1, stop1, step1, start2, stop2, step2) = inputs - (out,) = outputs + (out,) = output_storage assert _is_sparse(x) out[0] = x[start1:stop1:step1, start2:stop2:step2] @@ -1165,9 +1165,9 @@ def make_node(self, x, index): return Apply(self, input_op, [scalar(dtype=x.dtype)]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, ind1, ind2) = inputs - (out,) = outputs + (out,) = output_storage assert _is_sparse(x) out[0] = np.asarray(x[ind1, ind2], x.dtype) @@ -1216,9 +1216,9 @@ def make_node(self, x): ], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (out,) = outputs + (out,) = output_storage assert _is_sparse(x) out[0] = x.transpose() @@ -1262,9 +1262,9 @@ def make_node(self, x): assert x.format in ("csr", "csc") return Apply(self, [x], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (out,) = outputs + (out,) = output_storage assert _is_sparse(x) out[0] = -x @@ -1302,9 +1302,9 @@ def make_node(self, x, s): raise ValueError("x was not a csc matrix") return Apply(self, [x, s], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, s) = inputs - (z,) = outputs + (z,) = output_storage _M, N = x.shape assert x.format == "csc" assert s.shape == (N,) @@ -1349,9 +1349,9 @@ def make_node(self, x, s): assert x.format in ("csr", "csc") return Apply(self, [x, s], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, s) = inputs - (z,) = outputs + (z,) = output_storage M, N = x.shape assert x.format == "csc" assert s.shape == (M,) @@ -1460,9 +1460,9 @@ def make_node(self, x): assert x.format in ("csr", "csc") return Apply(self, [x], [tensor(dtype=x.dtype, shape=(None,))]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (z,) = outputs + (z,) = output_storage N, M = x.shape if N != M: raise ValueError("Diag only apply on square matrix") @@ -1506,8 +1506,8 @@ def make_node(self, diag): return Apply(self, [diag], [SparseTensorType(dtype=diag.dtype, format="csc")()]) - def perform(self, node, inputs, outputs): - (z,) = outputs + def perform(self, node, inputs, output_storage): + (z,) = output_storage diag = inputs[0] N = len(diag) @@ -1562,9 +1562,9 @@ def make_node(self, x): assert x.format in ("csr", "csc") return Apply(self, [x], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (z,) = outputs + (z,) = output_storage if self.inplace: z[0] = x.sort_indices() else: @@ -1644,11 +1644,11 @@ def __str__(self): class HStack(Stack): - def perform(self, node, block, outputs): - (out,) = outputs - for b in block: + def perform(self, node, inputs, output_storage): + (out,) = output_storage + for b in inputs: assert _is_sparse(b) - out[0] = scipy.sparse.hstack(block, format=self.format, dtype=self.dtype) + out[0] = scipy.sparse.hstack(inputs, format=self.format, dtype=self.dtype) # Some version of scipy (at least 0.14.0.dev-c4314b0) # Do not cast to the wanted dtype. if out[0].dtype != self.dtype: @@ -1717,11 +1717,11 @@ def hstack(blocks, format=None, dtype=None): class VStack(Stack): - def perform(self, node, block, outputs): - (out,) = outputs - for b in block: + def perform(self, node, inputs, output_storage): + (out,) = output_storage + for b in inputs: assert _is_sparse(b) - out[0] = scipy.sparse.vstack(block, format=self.format, dtype=self.dtype) + out[0] = scipy.sparse.vstack(inputs, format=self.format, dtype=self.dtype) # Some version of scipy (at least 0.14.0.dev-c4314b0) # Do not cast to the wanted dtype. if out[0].dtype != self.dtype: @@ -1824,9 +1824,9 @@ def make_node(self, x): assert x.format in ("csr", "csc") return Apply(self, [x], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (z,) = outputs + (z,) = output_storage if self.inplace: c = x else: @@ -1905,9 +1905,9 @@ def make_node(self, x, values, ilist): # take `x_.shape` as input and not `x`. return Apply(self, [x_.shape, values_, ilist_], [csc_matrix(dtype=x.dtype)]) - def perform(self, node, inp, out_): - out_shape, values, ilist = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + out_shape, values, ilist = inputs + (out,) = output_storage rows, cols = values.shape assert rows == len(ilist) indptr = np.arange(cols + 1) * rows diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index 972de80d89..e84e048f18 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -326,9 +326,9 @@ def make_node(self, x): z = TensorType(dtype=x.dtype, shape=out_shape)() return Apply(self, [x], [z]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (z,) = outputs + (z,) = output_storage if self.axis is None: z[0] = np.asarray(x.sum()) else: @@ -431,9 +431,9 @@ def make_node(self, x, y): [psb.SparseTensorType(dtype=out_dtype, format=x.type.format)()], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, y) = inputs - (out,) = outputs + (out,) = output_storage assert psb._is_sparse(x) and psb._is_sparse(y) assert x.shape == y.shape out[0] = x + y @@ -491,9 +491,9 @@ def make_node(self, x, y): [psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, y) = inputs - (out,) = outputs + (out,) = output_storage assert psb._is_sparse(x) and psb._is_sparse(y) assert x.shape == y.shape assert x.data.shape == y.data.shape @@ -532,9 +532,9 @@ def make_node(self, x, y): [TensorType(dtype=out_dtype, shape=y.type.shape)()], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, y) = inputs - (out,) = outputs + (out,) = output_storage assert psb._is_dense(y) # The asarray is needed as in some case, this return a @@ -594,9 +594,9 @@ def make_node(self, x, y): [psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, y) = inputs - (out,) = outputs + (out,) = output_storage assert psb._is_sparse(x) and not psb._is_sparse(y) assert x.shape[1] == y.shape[0] out[0] = x.__class__(x + (x.toarray() != 0) * y) @@ -723,9 +723,9 @@ def make_node(self, x, y): [psb.SparseTensorType(dtype=out_dtype, format=x.type.format)()], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, y) = inputs - (out,) = outputs + (out,) = output_storage assert psb._is_sparse(x) and psb._is_sparse(y) assert len(x.shape) == 2 assert y.shape == x.shape @@ -766,9 +766,9 @@ def make_node(self, x, y): out = psb.SparseTensorType(dtype=dtype, format=x.type.format)() return Apply(self, [x, y], [out]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, y) = inputs - (out,) = outputs + (out,) = output_storage assert psb._is_sparse(x) and psb._is_dense(y) if len(y.shape) == 0: out_dtype = node.outputs[0].dtype @@ -874,9 +874,9 @@ def make_node(self, x, y): [psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, y) = inputs - (out,) = outputs + (out,) = output_storage assert psb._is_sparse(x) and not psb._is_sparse(y) assert x.shape[1] == y.shape[0] out[0] = x.__class__(x.toarray() * y) @@ -1002,9 +1002,9 @@ def make_node(self, x, y): self, [x, y], [psb.SparseTensorType(dtype="uint8", format=x.type.format)()] ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, y) = inputs - (out,) = outputs + (out,) = output_storage assert psb._is_sparse(x) and psb._is_sparse(y) assert x.shape == y.shape out[0] = self.comparison(x, y).astype("uint8") @@ -1044,9 +1044,9 @@ def make_node(self, x, y): out = TensorType(dtype="uint8", shape=(None, None))() return Apply(self, [x, y], [out]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, y) = inputs - (out,) = outputs + (out,) = output_storage assert psb._is_sparse(x) assert x.shape == y.shape assert psb._is_dense(y) @@ -1255,15 +1255,15 @@ def make_node(self, x, y): outputs = [psb.SparseTensorType(dtype=x.type.dtype, format=myformat)()] return Apply(self, inputs, outputs) - def perform(self, node, inp, out_): + def perform(self, node, inputs, output_storage): # TODO # -Verify that output is sufficiently sparse, # and raise a warning if it is not. # -Also determine that we are storing the # output in the best storage format? - x, y = inp - (out,) = out_ + x, y = inputs + (out,) = output_storage rval = x.dot(y) if not scipy_sparse.issparse(rval): rval = getattr(scipy_sparse, x.format + "_matrix")(rval) @@ -1388,9 +1388,9 @@ def make_node(self, a, b): ], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (a, b) = inputs - (out,) = outputs + (out,) = output_storage if a.shape[1] != b.shape[0]: raise ValueError( "shape mismatch in StructuredDot.perform", (a.shape, b.shape) @@ -1509,9 +1509,9 @@ def make_node(self, a_indices, a_indptr, b, g_ab): [tensor(dtype=g_ab.dtype, shape=(None,))], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (a_indices, a_indptr, b, g_ab) = inputs - (out,) = outputs + (out,) = output_storage g_a_data = np.zeros(a_indices.shape, dtype=g_ab.dtype) for j in range(len(a_indptr) - 1): ind0 = a_indptr[j] @@ -1641,9 +1641,9 @@ def make_node(self, a_indices, a_indptr, b, g_ab): self, [a_indices, a_indptr, b, g_ab], [tensor(dtype=b.dtype, shape=(None,))] ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (a_indices, a_indptr, b, g_ab) = inputs - (out,) = outputs + (out,) = output_storage g_a_data = np.zeros(a_indices.shape, dtype=g_ab.dtype) for i in range(len(a_indptr) - 1): # loop over rows ind0 = a_indptr[i] @@ -1825,9 +1825,9 @@ def make_node(self, x, y, p): return Apply(self, [x, y, p], [p.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, y, p) = inputs - (out,) = outputs + (out,) = output_storage if psb._is_sparse(x): raise TypeError(x) @@ -1924,9 +1924,9 @@ def make_node(self, x, y): return Apply(self, [x, y], [tensor(dtype=dtype_out, shape=shape_out)]) - def perform(self, node, inputs, out): + def perform(self, node, inputs, output_storage): x, y = inputs - out = out[0] + output_storage = output_storage[0] x_is_sparse = psb._is_sparse(x) y_is_sparse = psb._is_sparse(y) @@ -1938,7 +1938,7 @@ def perform(self, node, inputs, out): if x_is_sparse and y_is_sparse: rval = rval.toarray() - out[0] = np.asarray(rval, dtype=node.outputs[0].dtype) + output_storage[0] = np.asarray(rval, dtype=node.outputs[0].dtype) def grad(self, inputs, gout): (x, y) = inputs @@ -2060,9 +2060,9 @@ def make_node(self, alpha, x, y, z): [tensor(dtype=dtype_out, shape=(None, None))], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (alpha, x, y, z) = inputs - (out,) = outputs + (out,) = output_storage x_is_sparse = psb._is_sparse(x) y_is_sparse = psb._is_sparse(y) diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index d992635298..daed09181a 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -272,9 +272,9 @@ def make_node(self, a_val, a_ind, a_ptr, a_nrows, b): ) return r - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (a_val, a_ind, a_ptr, a_nrows, b) = inputs - (out,) = outputs + (out,) = output_storage a = scipy.sparse.csc_matrix( (a_val, a_ind, a_ptr), (a_nrows, b.shape[0]), copy=False ) @@ -471,9 +471,9 @@ def make_node(self, a_val, a_ind, a_ptr, b): ) return r - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (a_val, a_ind, a_ptr, b) = inputs - (out,) = outputs + (out,) = output_storage a = scipy.sparse.csr_matrix( (a_val, a_ind, a_ptr), (len(a_ptr) - 1, b.shape[0]), copy=True ) # use view_map before setting this to False diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 4fb4a33a5d..82f87e6202 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -621,9 +621,9 @@ def make_node(self, s): return Apply(self, [s], [tensor(dtype=s.type.dtype, shape=())]) - def perform(self, node, inp, out_): - (s,) = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + (s,) = inputs + (out,) = output_storage out[0] = np.asarray(s) def infer_shape(self, fgraph, node, in_shapes): @@ -976,12 +976,12 @@ def make_node(self, a): output = [TensorType(dtype="int64", shape=(None,))() for i in range(a.ndim)] return Apply(self, [a], output) - def perform(self, node, inp, out_): - a = inp[0] + def perform(self, node, inputs, output_storage): + a = inputs[0] result_tuple = np.nonzero(a) for i, res in enumerate(result_tuple): - out_[i][0] = res.astype("int64") + output_storage[i][0] = res.astype("int64") def grad(self, inp, grads): return [grad_undefined(self, 0, inp[0])] @@ -1108,9 +1108,9 @@ def make_node(self, N, M, k): [TensorType(dtype=self.dtype, shape=(None, None))()], ) - def perform(self, node, inp, out_): - N, M, k = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + N, M, k = inputs + (out,) = output_storage out[0] = np.tri(N, M, k, dtype=self.dtype) def infer_shape(self, fgraph, node, in_shapes): @@ -1394,9 +1394,9 @@ def make_node(self, n, m, k): [TensorType(dtype=self.dtype, shape=static_shape)()], ) - def perform(self, node, inp, out_): - n, m, k = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + n, m, k = inputs + (out,) = output_storage out[0] = np.eye(n, m, k, dtype=self.dtype) def infer_shape(self, fgraph, node, in_shapes): @@ -1649,8 +1649,8 @@ def value_is_scalar_zero(x: TensorVariable) -> bool: and (x.unique_value == 0) ) - def perform(self, node, inputs, out_): - (out,) = out_ + def perform(self, node, inputs, output_storage): + (out,) = output_storage v = inputs[0] sh = tuple(int(i) for i in inputs[1:]) self._check_runtime_broadcast(node, v, sh) @@ -1912,8 +1912,8 @@ def make_node(self, *inputs): otype = TensorType(dtype, shape=(len(inputs),)) return Apply(self, inputs, [otype()]) - def perform(self, node, inputs, out_): - (out,) = out_ + def perform(self, node, inputs, output_storage): + (out,) = output_storage # not calling pytensor._asarray as optimization if (out[0] is None) or (out[0].size != len(inputs)): out[0] = np.asarray(inputs, dtype=node.outputs[0].dtype) @@ -2072,9 +2072,9 @@ def make_node(self, x, default): raise TypeError("Both arguments must have compatible types") return Apply(self, [x, default], [default.type()]) - def perform(self, node, inp, out_): - x, default = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + x, default = inputs + (out,) = output_storage if x is None: # why copy? PyTensor can't yet understand out[0] being a view of # either x or y, so we can be a view of x, but only a copy of y. @@ -2257,7 +2257,7 @@ def make_node(self, x, axis, splits): return Apply(self, inputs, outputs) - def perform(self, node, inputs, outputs_storage): + def perform(self, node, inputs, output_storage): x, axis, splits = inputs if len(splits) != self.len_splits: @@ -2270,7 +2270,7 @@ def perform(self, node, inputs, outputs_storage): raise ValueError("Split sizes cannot be negative") split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis) - for out_storage, out in zip(outputs_storage, split_outs, strict=False): + for out_storage, out in zip(output_storage, split_outs, strict=False): out_storage[0] = out def infer_shape(self, fgraph, node, in_shapes): @@ -3633,9 +3633,9 @@ def _rec_perform(self, node, x, y, inverse, out, curdim): else: raise ValueError(f"Dimension mismatch: {xs0}, {ys0}") - def perform(self, node, inp, out): - x, y = inp - (outs,) = out + def perform(self, node, inputs, output_storage): + x, y = inputs + (outs,) = output_storage x_s = x.shape y_s = y.shape assert len(x_s) == len(y_s) @@ -4305,8 +4305,8 @@ def make_node(self, a, choices): o = TensorType(choice.dtype, shape=static_out_shape) return Apply(self, [a, choice], [o()]) - def perform(self, node, inputs, outputs): - (z,) = outputs + def perform(self, node, inputs, output_storage): + (z,) = output_storage a = inputs[0] choice = inputs[1] # TODO reuse out? @@ -4352,8 +4352,8 @@ def debug_perform(self, node, inputs, out_): self.perform(node, inputs, out_) out_[0][0].fill(-123456789) - def perform(self, node, inputs, out_): - (out,) = out_ + def perform(self, node, inputs, output_storage): + (out,) = output_storage sh = tuple(int(i) for i in inputs) if out[0] is None or out[0].shape != sh: out[0] = np.empty(sh, dtype=self.dtype) diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 47b6af80d1..569942354b 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -198,7 +198,7 @@ def make_node(self, y, alpha, A, x, beta): return Apply(self, inputs, [y.type()]) - def perform(self, node, inputs, out_storage): + def perform(self, node, inputs, output_storage): from scipy.linalg.blas import get_blas_funcs y, alpha, A, x, beta = inputs @@ -232,7 +232,7 @@ def perform(self, node, inputs, out_storage): # trans flag don't seam to cause slowdown. # out_storage[0][0] = gemv(alpha, A, x, beta, y, # overwrite_y=self.inplace) - out_storage[0][0] = gemv( + output_storage[0][0] = gemv( alpha, A.T, x, beta, y, overwrite_y=self.inplace, trans=True ) else: @@ -244,7 +244,7 @@ def perform(self, node, inputs, out_storage): out += beta * y else: out += y - out_storage[0][0] = np.asarray(out, dtype=y.dtype) + output_storage[0][0] = np.asarray(out, dtype=y.dtype) def infer_shape(self, fgraph, node, input_shapes): return [input_shapes[0]] @@ -909,9 +909,9 @@ def make_node(self, *inputs): output = z.type() return Apply(self, inputs, [output]) - def perform(self, node, inp, out): - z, a, x, y, b = inp - (zout,) = out + def perform(self, node, inputs, output_storage): + z, a, x, y, b = inputs + (zout,) = output_storage assert a.shape == () assert b.shape == () if not self.inplace: @@ -1241,9 +1241,9 @@ def make_node(self, x, y, a): outputs = [tensor(dtype=x.type.dtype, shape=sz)] return Apply(self, [x, y, a], outputs) - def perform(self, node, inp, out): - x, y, scalar = inp - (z,) = out + def perform(self, node, inputs, output_storage): + x, y, scalar = inputs + (z,) = output_storage try: z[0] = np.asarray(scalar * np.dot(x, y)) except ValueError as e: @@ -1354,14 +1354,14 @@ def extract_static_dim(dim_x, dim_y): out = tensor(dtype=dtype, shape=out_shape) return Apply(self, [x, y], [out]) - def perform(self, node, inp, out): - x, y = inp - (z,) = out + def perform(self, node, inputs, output_storage): + x, y = inputs + (z,) = output_storage if x.shape[0] != y.shape[0]: raise TypeError( - f"Inputs [{', '.join(map(str, inp))}] must have the" - f" same size in axis 0, but have sizes [{', '.join(str(i.shape[0]) for i in inp)}]." + f"Inputs [{', '.join(map(str, inputs))}] must have the" + f" same size in axis 0, but have sizes [{', '.join(str(i.shape[0]) for i in inputs)}]." ) z[0] = np.matmul(x, y) diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index 9adb6354b2..41324ce103 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -2502,8 +2502,8 @@ def make_node(self, img, kern): output = img.type.clone(shape=out_shape)() return Apply(self, [img, kern], [output]) - def perform(self, node, inp, out_): - img, kern = inp + def perform(self, node, inputs, output_storage): + img, kern = inputs img = np.asarray(img) kern = np.asarray(kern) @@ -2515,7 +2515,7 @@ def perform(self, node, inp, out_): raise NotImplementedError( f"Unshared convolution not implemented for {int(self.convdim)}D" ) - (o,) = out_ + (o,) = output_storage mode = self.border_mode pad = border_mode_to_pad(mode, self.convdim, dil_kernshp) @@ -2839,12 +2839,12 @@ def make_node(self, img, topgrad, shape, add_assert_shape=True): output = img.type.clone(shape=out_shape)() return Apply(self, [img, topgrad, shape], [output]) - def perform(self, node, inp, out_): - img, topgrad, shape = inp + def perform(self, node, inputs, output_storage): + img, topgrad, shape = inputs img = np.asarray(img) topgrad = np.asarray(topgrad) - (o,) = out_ + (o,) = output_storage if self.unshared and self.convdim != 2: raise NotImplementedError( @@ -3208,11 +3208,11 @@ def make_node(self, kern, topgrad, shape, add_assert_shape=True): output = kern.type.clone(shape=out_shape)() return Apply(self, [kern, topgrad, shape], [output]) - def perform(self, node, inp, out_): - kern, topgrad, shape = inp + def perform(self, node, inputs, output_storage): + kern, topgrad, shape = inputs kern = np.asarray(kern) topgrad = np.asarray(topgrad) - (o,) = out_ + (o,) = output_storage if self.unshared and self.convdim != 2: raise NotImplementedError( diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 1616666a63..eabd9ca03e 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -224,8 +224,8 @@ def __str__(self): return f"Transpose{{axes={self.shuffle}}}" return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" - def perform(self, node, inp, out): - (res,) = inp + def perform(self, node, inputs, output_storage): + (res,) = inputs # This C-like impl is very slow in Python compared to transpose+reshape # new_order = self._new_order @@ -243,7 +243,7 @@ def perform(self, node, inp, out): new_shape = list(res.shape[: len(self.shuffle)]) for augm in self.augment: new_shape.insert(augm, 1) - out[0][0] = res.reshape(new_shape) + output_storage[0][0] = res.reshape(new_shape) def infer_shape(self, fgraph, node, shapes): (ishp,) = shapes @@ -1396,9 +1396,9 @@ def __str__(self): else: return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}" - def perform(self, node, inp, out): - (input,) = inp - (output,) = out + def perform(self, node, inputs, output_storage): + (input,) = inputs + (output,) = output_storage axis = self.axis out_dtype = node.outputs[0].type.dtype @@ -1412,9 +1412,9 @@ def perform(self, node, inp, out): input = np.array(input, dtype=acc_dtype) - out = self.ufunc.reduce(input, axis=axis, dtype=acc_dtype) + output_storage = self.ufunc.reduce(input, axis=axis, dtype=acc_dtype) - output[0] = np.asarray(out, dtype=out_dtype) + output[0] = np.asarray(output_storage, dtype=out_dtype) def infer_shape(self, fgraph, node, shapes): (ishape,) = shapes diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 0c6e59d72f..ae35794a80 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -845,9 +845,9 @@ def make_node(self, M): raise TypeError(f"{self.__class__.__name__} only works on integer input") return Apply(self, [M], [dvector()]) - def perform(self, node, inputs, out_): + def perform(self, node, inputs, output_storage): M = inputs[0] - (out,) = out_ + (out,) = output_storage out[0] = np.bartlett(M) def infer_shape(self, fgraph, node, in_shapes): @@ -1314,17 +1314,17 @@ def make_node(self, indices, dims): def infer_shape(self, fgraph, node, input_shapes): return [input_shapes[0]] * len(node.outputs) - def perform(self, node, inp, out): - indices, dims = inp + def perform(self, node, inputs, output_storage): + indices, dims = inputs res = np.unravel_index(indices, dims, order=self.order) - assert len(res) == len(out) - for i in range(len(out)): + assert len(res) == len(output_storage) + for i in range(len(output_storage)): ret = np.asarray(res[i], node.outputs[0].dtype) if ret.base is not None: # NumPy will return a view when it can. # But we don't want that. ret = ret.copy() - out[i][0] = ret + output_storage[i][0] = ret def unravel_index(indices, dims, order="C"): @@ -1391,10 +1391,10 @@ def make_node(self, *inp): def infer_shape(self, fgraph, node, input_shapes): return [input_shapes[0]] - def perform(self, node, inp, out): - *multi_index, dims = inp + def perform(self, node, inputs, output_storage): + *multi_index, dims = inputs res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order) - out[0][0] = np.asarray(res, "int64") + output_storage[0][0] = np.asarray(res, "int64") def ravel_multi_index(multi_index, dims, mode="raise", order="C"): diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index fa424e4679..6d9bb8e082 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -184,10 +184,10 @@ def prepare_node(self, node, storage_map, compute_map, impl): "You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format." ) - def perform(self, node, inp, outs): - (x,) = inp + def perform(self, node, inputs, output_storage): + (x,) = inputs axes = self.axis - (max_idx,) = outs + (max_idx,) = output_storage if axes is None: axes = tuple(range(x.ndim)) # Numpy does not support multiple axes for argmax diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 1d0d564abd..a7d939f40b 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -44,9 +44,9 @@ def make_node(self, x): out_dtype = x.dtype return Apply(self, [x], [matrix(shape=x.type.shape, dtype=out_dtype)]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (z,) = outputs + (z,) = output_storage z[0] = np.linalg.pinv(x, hermitian=self.hermitian) def L_op(self, inputs, outputs, g_outputs): @@ -128,9 +128,9 @@ def make_node(self, x): out_dtype = x.dtype return Apply(self, [x], [matrix(shape=x.type.shape, dtype=out_dtype)]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (z,) = outputs + (z,) = output_storage z[0] = np.linalg.inv(x) def grad(self, inputs, g_outputs): @@ -232,9 +232,9 @@ def make_node(self, x): o = scalar(dtype=out_dtype) return Apply(self, [x], [o]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (z,) = outputs + (z,) = output_storage try: z[0] = np.asarray(np.linalg.det(x)) except Exception as e: @@ -275,9 +275,9 @@ def make_node(self, x): det = scalar(dtype=out_dtype) return Apply(self, [x], [sign, det]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (sign, det) = outputs + (sign, det) = output_storage try: sign[0], det[0] = (np.array(z) for z in np.linalg.slogdet(x)) except Exception as e: @@ -347,7 +347,7 @@ def make_node(self, x): return Apply(self, [x], [w, v]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs dtype = np.promote_types(x.dtype, np.complex64) @@ -355,8 +355,8 @@ def perform(self, node, inputs, outputs): # If the imaginary part of the eigenvalues is zero, numpy automatically casts them to real. We require # a statically known return dtype, so we have to cast back to complex to avoid dtype mismatch. - outputs[0][0] = w.astype(dtype, copy=False) - outputs[1][0] = v.astype(dtype, copy=False) + output_storage[0][0] = w.astype(dtype, copy=False) + output_storage[1][0] = v.astype(dtype, copy=False) def infer_shape(self, fgraph, node, shapes): (x_shapes,) = shapes @@ -416,9 +416,9 @@ def make_node(self, x): v = matrix(dtype=w_dtype) return Apply(self, [x], [w, v]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (w, v) = outputs + (w, v) = output_storage w[0], v[0] = np.linalg.eigh(x, self.UPLO) def L_op(self, inputs, outputs, output_grads): @@ -490,7 +490,7 @@ def make_node(self, x, w, v, gw, gv): out = matrix(dtype=out_dtype) return Apply(self, [x, w, v, gw, gv], [out]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): """ Implements the "reverse-mode" gradient for the eigensystem of a square matrix. @@ -521,7 +521,7 @@ def G(n): # Make sure we return the right dtype even if NumPy performed # upcasting in self.tri0. - outputs[0][0] = np.asarray(out, dtype=node.outputs[0].dtype) + output_storage[0][0] = np.asarray(out, dtype=node.outputs[0].dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -581,14 +581,14 @@ def make_node(self, x): else: return Apply(self, [x], [s]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs assert x.ndim == 2, "The input of svd function should be a matrix." if self.compute_uv: - u, s, vt = outputs + u, s, vt = output_storage u[0], s[0], vt[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv) else: - (s,) = outputs + (s,) = output_storage s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv) def infer_shape(self, fgraph, node, shapes): @@ -760,12 +760,12 @@ def make_node(self, x, y, rcond): ], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): zz = np.linalg.lstsq(inputs[0], inputs[1], inputs[2]) - outputs[0][0] = zz[0] - outputs[1][0] = zz[1] - outputs[2][0] = np.asarray(zz[2]) - outputs[3][0] = zz[3] + output_storage[0][0] = zz[0] + output_storage[1][0] = zz[1] + output_storage[2][0] = np.asarray(zz[2]) + output_storage[3][0] = zz[3] lstsq = Lstsq() @@ -1029,9 +1029,9 @@ def make_node(self, a): out = a.type() return Apply(self, [a], [out]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (a,) = inputs - (x,) = outputs + (x,) = output_storage x[0] = np.linalg.tensorinv(a, self.ind) def infer_shape(self, fgraph, node, shapes): @@ -1090,12 +1090,12 @@ def make_node(self, a, b): x = matrix(dtype=out_dtype) return Apply(self, [a, b], [x]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): ( a, b, ) = inputs - (x,) = outputs + (x,) = output_storage x[0] = np.linalg.tensorsolve(a, b, self.axes) diff --git a/pytensor/tensor/optimize.py b/pytensor/tensor/optimize.py index 7653d01b54..daf6a8f88c 100644 --- a/pytensor/tensor/optimize.py +++ b/pytensor/tensor/optimize.py @@ -393,7 +393,7 @@ def __init__( def __str__(self): return f"{self.__class__.__name__}(method={self.method})" - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): global optimize if optimize is None: import scipy.optimize as optimize @@ -412,8 +412,8 @@ def perform(self, node, inputs, outputs): **self.optimizer_kwargs, ) - outputs[0][0] = np.array(res.x, dtype=x0.dtype) - outputs[1][0] = np.bool_(res.success) + output_storage[0][0] = np.array(res.x, dtype=x0.dtype) + output_storage[1][0] = np.bool_(res.success) def L_op(self, inputs, outputs, output_grads): # TODO: Handle disconnected inputs @@ -554,7 +554,7 @@ def __str__(self): ) return f"{self.__class__.__name__}({str_args})" - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): global optimize if optimize is None: import scipy.optimize as optimize @@ -574,8 +574,8 @@ def perform(self, node, inputs, outputs): f.clear_cache() - outputs[0][0] = res.x.reshape(x0.shape).astype(x0.dtype) - outputs[1][0] = np.bool_(res.success) + output_storage[0][0] = res.x.reshape(x0.shape).astype(x0.dtype) + output_storage[1][0] = np.bool_(res.success) def L_op(self, inputs, outputs, output_grads): # TODO: Handle disconnected inputs @@ -732,7 +732,7 @@ def __str__(self): ) return f"{self.__class__.__name__}({str_args})" - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): global optimize if optimize is None: import scipy.optimize as optimize @@ -753,8 +753,8 @@ def perform(self, node, inputs, outputs): **self.optimizer_kwargs, ) - outputs[0][0] = np.array(res.root) - outputs[1][0] = np.bool_(res.converged) + output_storage[0][0] = np.array(res.root) + output_storage[1][0] = np.bool_(res.converged) def L_op(self, inputs, outputs, output_grads): # TODO: Handle disconnected inputs @@ -922,7 +922,7 @@ def build_fn(self): self._fn_wrapped = LRUCache1(fn) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): global optimize if optimize is None: import scipy.optimize as optimize @@ -944,8 +944,8 @@ def perform(self, node, inputs, outputs): # There's a reshape here to cover the case where variables is a scalar. Scipy will still return a # (1, 1) matrix in in this case, which causes errors downstream (since pytensor expects a scalar). - outputs[0][0] = res.x.reshape(variables.shape).astype(variables.dtype) - outputs[1][0] = np.bool_(res.success) + output_storage[0][0] = res.x.reshape(variables.shape).astype(variables.dtype) + output_storage[1][0] = np.bool_(res.success) def L_op(self, inputs, outputs, output_grads): # TODO: Handle disconnected inputs diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index b7bd39125c..6afa5a29b2 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -418,15 +418,15 @@ def dist_params(self, node: Apply) -> Sequence[Variable]: """Return the node inpust corresponding to dist params""" return node.inputs[2:] - def perform(self, node: Apply, inputs, outputs): + def perform(self, node: Apply, inputs, output_storage): rng, size, *args = inputs # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. if not self.inplace: rng = custom_rng_deepcopy(rng) - outputs[0][0] = rng - outputs[1][0] = np.asarray( + output_storage[0][0] = rng + output_storage[1][0] = np.asarray( self.rng_fn(rng, *args, None if size is None else tuple(size)), dtype=self.dtype, ) diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index c6fefc6b63..8316ae2f28 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -81,9 +81,9 @@ def make_node(self, x): return Apply(self, [x], [out_var]) - def perform(self, node, inp, out_): - (x,) = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + (x,) = inputs + (out,) = output_storage out[0] = np.asarray(np.shape(x), dtype="int64") def infer_shape(self, fgraph, node, in_shapes): @@ -257,9 +257,9 @@ def make_node(self, x): raise TypeError(f"{x} has too few dimensions for Shape_i") return Apply(self, [x], [pytensor.tensor.type.lscalar()]) - def perform(self, node, inp, out_): - (x,) = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + (x,) = inputs + (out,) = output_storage if out[0] is None: out[0] = np.asarray(np.shape(x)[self.i], dtype="int64") else: @@ -442,9 +442,9 @@ def make_node(self, x, *shape): return Apply(self, [x, *shape], [out_var]) - def perform(self, node, inp, out_): - x, *shape = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + x, *shape = inputs + (out,) = output_storage ndim = len(shape) if x.ndim != ndim: raise AssertionError( @@ -709,9 +709,9 @@ def make_node(self, x, shp): return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)]) - def perform(self, node, inp, out_): - x, shp = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + x, shp = inputs + (out,) = output_storage if len(shp) != self.ndim: raise ValueError( "Shape argument to Reshape has incorrect" diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index e061019cef..1ffbffa137 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -117,11 +117,13 @@ class Convolve1d(AbstractConvolveNd, COp): # type: ignore[misc] __props__ = () ndim = 1 - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): # We use numpy_convolve as that's what scipy would use if method="direct" was passed. # And mode != "same", which this Op doesn't cover anyway. in1, in2, full_mode = inputs - outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid") + output_storage[0][0] = numpy_convolve( + in1, in2, mode="full" if full_mode else "valid" + ) def c_code_cache_version(self): return (2,) @@ -252,7 +254,7 @@ class Convolve2d(AbstractConvolveNd, Op): # type: ignore[misc] def __init__(self, method: Literal["direct", "fft", "auto"] = "auto"): self.method = method - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): in1, in2, full_mode = inputs if isinstance(full_mode, np.bool): @@ -260,7 +262,7 @@ def perform(self, node, inputs, outputs): # Conditional, because numba will produce a bool, not np.bool_ full_mode = full_mode.item() mode = "full" if full_mode else "valid" - outputs[0][0] = scipy_convolve(in1, in2, mode=mode, method=self.method) + output_storage[0][0] = scipy_convolve(in1, in2, mode=mode, method=self.method) def convolve2d( diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 5ce5e8da12..eb232a9bff 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -66,9 +66,9 @@ def make_node(self, x): dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): [x] = inputs - [out] = outputs + [out] = output_storage (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,)) @@ -278,7 +278,7 @@ def __init__( destroy_map[0] = [1] self.destroy_map = destroy_map - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): raise NotImplementedError( "SolveBase should be subclassed with an perform method" ) @@ -516,7 +516,7 @@ def make_node(self, x): P = tensor(shape=x.type.shape, dtype=P_dtype) return Apply(self, inputs=[x], outputs=[P, L, U]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): [A] = inputs out = scipy_linalg.lu( @@ -527,12 +527,12 @@ def perform(self, node, inputs, outputs): p_indices=self.p_indices, ) - outputs[0][0] = out[0] - outputs[1][0] = out[1] + output_storage[0][0] = out[0] + output_storage[1][0] = out[1] if not self.permute_l: # In all cases except permute_l, there are three returns - outputs[2][0] = out[2] + output_storage[2][0] = out[2] def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": if 0 in allowed_inplace_inputs: @@ -661,7 +661,7 @@ def make_node(self, pivots): permutations = pivots.type.clone(dtype="int64")() return Apply(self, [pivots], [permutations]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): [pivots] = inputs p_inv = np.arange(len(pivots), dtype="int64") @@ -669,9 +669,9 @@ def perform(self, node, inputs, outputs): p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i] if self.inverse: - outputs[0][0] = p_inv + output_storage[0][0] = p_inv else: - outputs[0][0] = np.argsort(p_inv) + output_storage[0][0] = np.argsort(p_inv) def pivot_to_permutation(p: TensorLike, inverse=False): @@ -714,13 +714,13 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": else: return self - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): A = inputs[0] # Quick return for empty arrays if A.size == 0: - outputs[0][0] = np.empty_like(A) - outputs[1][0] = np.array([], dtype=np.int32) + output_storage[0][0] = np.empty_like(A) + output_storage[1][0] = np.array([], dtype=np.int32) return if self.check_finite and not np.isfinite(A).all(): @@ -739,8 +739,8 @@ def perform(self, node, inputs, outputs): stacklevel=2, ) - outputs[0][0] = LU - outputs[1][0] = p + output_storage[0][0] = LU + output_storage[1][0] = p def L_op(self, inputs, outputs, output_gradients): [A] = inputs @@ -902,7 +902,7 @@ def __init__(self, *, unit_diagonal=False, **kwargs): super().__init__(**kwargs) self.unit_diagonal = unit_diagonal - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): A, b = inputs if self.check_finite and not (np.isfinite(A).all() and np.isfinite(b).all()): @@ -918,7 +918,7 @@ def perform(self, node, inputs, outputs): # Quick return for empty arrays if b.size == 0: - outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype) + output_storage[0][0] = np.empty_like(b, dtype=trtrs.dtype) return if A.flags["F_CONTIGUOUS"]: @@ -948,7 +948,7 @@ def perform(self, node, inputs, outputs): elif info < 0: raise ValueError(f"illegal value in {-info}-th argument of internal trtrs") - outputs[0][0] = x + output_storage[0][0] = x def L_op(self, inputs, outputs, output_gradients): res = super().L_op(inputs, outputs, output_gradients) @@ -1071,9 +1071,9 @@ def __init__(self, *, assume_a="gen", **kwargs): super().__init__(**kwargs) self.assume_a = assume_a - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): a, b = inputs - outputs[0][0] = scipy_linalg.solve( + output_storage[0][0] = scipy_linalg.solve( a=a, b=b, lower=self.lower, @@ -1232,8 +1232,8 @@ def make_node(self, a, b): w = vector(dtype=out_dtype) return Apply(self, [a, b], [w]) - def perform(self, node, inputs, outputs): - (w,) = outputs + def perform(self, node, inputs, output_storage): + (w,) = output_storage if len(inputs) == 2: w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower) else: @@ -1289,7 +1289,7 @@ def make_node(self, a, b, gw): out2 = matrix(dtype=out_dtype) return Apply(self, [a, b, gw], [out1, out2]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (a, b, gw) = inputs w, v = scipy_linalg.eigh(a, b, lower=self.lower) gA = v.dot(np.diag(gw).dot(v.T)) @@ -1298,8 +1298,8 @@ def perform(self, node, inputs, outputs): # See EighGrad comments for an explanation of these lines out1 = self.tri0(gA) + self.tri1(gA).T out2 = self.tri0(gB) + self.tri1(gB).T - outputs[0][0] = np.asarray(out1, dtype=node.outputs[0].dtype) - outputs[1][0] = np.asarray(out2, dtype=node.outputs[1].dtype) + output_storage[0][0] = np.asarray(out1, dtype=node.outputs[0].dtype) + output_storage[1][0] = np.asarray(out2, dtype=node.outputs[1].dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0], shapes[1]] @@ -1325,9 +1325,9 @@ def make_node(self, A): return Apply(self, [A], [expm]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (A,) = inputs - (expm,) = outputs + (expm,) = output_storage expm[0] = scipy_linalg.expm(A) def L_op(self, inputs, outputs, output_grads): @@ -1911,7 +1911,7 @@ def _call_and_get_lwork(self, fn, *args, lwork, **kwargs): return fn(*args, lwork=lwork, **kwargs) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs M, N = x.shape @@ -1933,25 +1933,25 @@ def perform(self, node, inputs, outputs): R = np.triu(qr[:N, :]) if self.mode == "r" and self.pivoting: - outputs[0][0] = R - outputs[1][0] = jpvt + output_storage[0][0] = R + output_storage[1][0] = jpvt return elif self.mode == "r": - outputs[0][0] = R + output_storage[0][0] = R return elif self.mode == "raw" and self.pivoting: - outputs[0][0] = qr - outputs[1][0] = tau - outputs[2][0] = R - outputs[3][0] = jpvt + output_storage[0][0] = qr + output_storage[1][0] = tau + output_storage[2][0] = R + output_storage[3][0] = jpvt return elif self.mode == "raw": - outputs[0][0] = qr - outputs[1][0] = tau - outputs[2][0] = R + output_storage[0][0] = qr + output_storage[1][0] = tau + output_storage[2][0] = R return (gor_un_gqr,) = get_lapack_funcs(("orgqr",), (qr,)) @@ -1974,11 +1974,11 @@ def perform(self, node, inputs, outputs): gor_un_gqr, qqr, tau, lwork=-1, overwrite_a=1 ) - outputs[0][0] = Q - outputs[1][0] = R + output_storage[0][0] = Q + output_storage[1][0] = R if self.pivoting: - outputs[2][0] = jpvt + output_storage[2][0] = jpvt def L_op(self, inputs, outputs, output_grads): """ diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index b845e69b37..f0fa8a1cf7 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -37,8 +37,8 @@ def make_node(self, dy, sm): return Apply(self, [dy, sm], [sm.type()]) - def perform(self, node, input_storage, output_storage): - dy, sm = input_storage + def perform(self, node, inputs, output_storage): + dy, sm = inputs dy_times_sm = dy * sm dx = dy_times_sm - np.sum(dy_times_sm, axis=self.axis, keepdims=True) * sm @@ -268,8 +268,8 @@ def make_node(self, x): return Apply(self, [x], [x.type()]) - def perform(self, node, input_storage, output_storage): - (x,) = input_storage + def perform(self, node, inputs, output_storage): + (x,) = inputs (z,) = output_storage z[0] = scipy.special.softmax(x, axis=self.axis) @@ -523,8 +523,8 @@ def make_node(self, x): return Apply(self, [x], [x.type()]) - def perform(self, node, input_storage, output_storage): - (x,) = input_storage + def perform(self, node, inputs, output_storage): + (x,) = inputs (z,) = output_storage z[0] = scipy.special.log_softmax(x, axis=self.axis) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 37010c58a2..ee7ec8bfb4 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -921,8 +921,8 @@ def make_node(self, x, *inputs): [tensor(dtype=x.type.dtype, shape=out_shape)], ) - def perform(self, node, inputs, out_): - (out,) = out_ + def perform(self, node, inputs, output_storage): + (out,) = output_storage x = inputs[0] cdata = get_idx_list(inputs, self.idx_list) @@ -2112,8 +2112,8 @@ def make_node(self, x, ilist): out_shape = (ilist_.type.shape[0], *x_.type.shape[1:]) return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype, shape=out_shape)()]) - def perform(self, node, inp, output_storage): - x, i = inp + def perform(self, node, inputs, output_storage): + x, i = inputs # Numpy take is always slower when out is provided # https://github.com/numpy/numpy/issues/28636 @@ -2742,8 +2742,8 @@ def is_bool_index(idx): assert node.outputs[0].ndim == len(res_shape) return [res_shape] - def perform(self, node, inputs, out_): - (out,) = out_ + def perform(self, node, inputs, output_storage): + (out,) = output_storage check_advanced_indexing_dimensions(inputs[0], inputs[1:]) rval = inputs[0].__getitem__(tuple(inputs[1:])) # When there are no arrays, we are not actually doing advanced @@ -2880,12 +2880,12 @@ def make_node(self, x, y, *inputs): [x.type()], ) - def perform(self, node, inputs, out_): + def perform(self, node, inputs, output_storage): x, y, *indices = inputs check_advanced_indexing_dimensions(x, indices) - (out,) = out_ + (out,) = output_storage if not self.inplace: out[0] = x.copy() else: diff --git a/pytensor/tensor/type_other.py b/pytensor/tensor/type_other.py index 16c4b0fb41..7e59687ed6 100644 --- a/pytensor/tensor/type_other.py +++ b/pytensor/tensor/type_other.py @@ -39,9 +39,9 @@ def make_node(self, slc, stop=None, step=None): inp = [slc, stop, step] return Apply(self, list(map(as_int_none_variable, inp)), [slicetype()]) - def perform(self, node, inp, out_): - (out,) = out_ - out[0] = slice(*inp) + def perform(self, node, inputs, output_storage): + (out,) = output_storage + out[0] = slice(*inputs) def grad(self, inputs, grads): return [DisconnectedType()() for i in inputs] diff --git a/pytensor/typed_list/basic.py b/pytensor/typed_list/basic.py index 188581b2c9..59d99bb76c 100644 --- a/pytensor/typed_list/basic.py +++ b/pytensor/typed_list/basic.py @@ -92,9 +92,9 @@ def make_node(self, x, index): else: raise TypeError("Expected scalar or slice as index.") - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, index) = inputs - (out,) = outputs + (out,) = output_storage if not isinstance(index, slice): index = int(index) out[0] = x[index] @@ -153,9 +153,9 @@ def make_node(self, x, toAppend): assert x.ttype == toAppend.type, (x.ttype, toAppend.type) return Apply(self, [x, toAppend], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, toAppend) = inputs - (out,) = outputs + (out,) = output_storage if not self.inplace: out[0] = list(x) else: @@ -232,9 +232,9 @@ def make_node(self, x, toAppend): assert toAppend.type.is_super(x.type) return Apply(self, [x, toAppend], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, toAppend) = inputs - (out,) = outputs + (out,) = output_storage if not self.inplace: out[0] = list(x) else: @@ -321,9 +321,9 @@ def make_node(self, x, index, toInsert): assert isinstance(index, TensorVariable) and index.ndim == 0 return Apply(self, [x, index, toInsert], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, index, toInsert) = inputs - (out,) = outputs + (out,) = output_storage if not self.inplace: out[0] = list(x) else: @@ -397,9 +397,9 @@ def make_node(self, x, toRemove): assert x.ttype == toRemove.type return Apply(self, [x, toRemove], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x, toRemove) = inputs - (out,) = outputs + (out,) = output_storage if not self.inplace: out[0] = list(x) else: @@ -453,12 +453,12 @@ def make_node(self, x): assert isinstance(x.type, TypedListType) return Apply(self, [x], [x.type()]) - def perform(self, node, inp, outputs): - (out,) = outputs + def perform(self, node, inputs, output_storage): + (out,) = output_storage if not self.inplace: - out[0] = list(inp[0]) + out[0] = list(inputs[0]) else: - out[0] = inp[0] + out[0] = inputs[0] out[0].reverse() def __str__(self): @@ -514,14 +514,14 @@ def make_node(self, x, elem): assert x.ttype == elem.type return Apply(self, [x, elem], [lscalar()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): """ Inelegant workaround for ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() being thrown when trying to remove a matrix from a matrices list """ (x, elem) = inputs - (out,) = outputs + (out,) = output_storage for y in range(len(x)): if node.inputs[0].ttype.values_eq(x[y], elem): out[0] = np.asarray(y, dtype="int64") @@ -543,14 +543,14 @@ def make_node(self, x, elem): assert x.ttype == elem.type return Apply(self, [x, elem], [lscalar()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): """ Inelegant workaround for ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() being thrown when trying to remove a matrix from a matrices list """ (x, elem) = inputs - (out,) = outputs + (out,) = output_storage out[0] = 0 for y in range(len(x)): if node.inputs[0].ttype.values_eq(x[y], elem): @@ -589,9 +589,9 @@ def make_node(self, x): assert isinstance(x.type, TypedListType) return Apply(self, [x], [lscalar()]) - def perform(self, node, x, outputs): - (out,) = outputs - out[0] = np.asarray(len(x[0]), "int64") + def perform(self, node, inputs, output_storage): + (out,) = output_storage + out[0] = np.asarray(len(inputs[0]), "int64") def __str__(self): return self.__class__.__name__ @@ -638,8 +638,8 @@ def make_node(self, a): return Apply(self, a2, [tl]) - def perform(self, node, inputs, outputs): - (out,) = outputs + def perform(self, node, inputs, output_storage): + (out,) = output_storage # We need to make sure that we don't get a view on our inputs out[0] = [_lessbroken_deepcopy(inp) for inp in inputs] diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 5c1f700b9f..df64fcd9db 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -9,7 +9,7 @@ class XOp(Op): """A base class for XOps that shouldn't be materialized""" - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): raise NotImplementedError( f"xtensor operation {self} must be lowered to equivalent tensor operations" ) diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index 95d0074fc3..340411eb0d 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -48,9 +48,9 @@ def make_node(self, a, b): r = Apply(self, [a, b], [a.type()]) return r - def perform(self, node, inp, out_): - a, b = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + a, b = inputs + (out,) = output_storage z = a + b # ERROR TO ADD THIS CRAPPY OFFSET if self.py_offset: @@ -364,9 +364,9 @@ def make_node(self, a, b): c = a.type() return Apply(self, [a, b], [c]) - def perform(self, node, inp, out): - a, b = inp - (c,) = out + def perform(self, node, inputs, output_storage): + a, b = inputs + (c,) = output_storage c[0] = a c[0] += b @@ -394,9 +394,9 @@ def make_node(self, a, b): c = b.type() return Apply(self, [a, b], [c]) - def perform(self, node, inp, out): - _a, b = inp - (c,) = out + def perform(self, node, inputs, output_storage): + _a, b = inputs + (c,) = output_storage c[0] = b class BadAddSlice(Op): @@ -404,9 +404,9 @@ def make_node(self, a, b): c = b.type() return Apply(self, [a, b], [c]) - def perform(self, node, inp, out): - _a, b = inp - (c,) = out + def perform(self, node, inputs, output_storage): + _a, b = inputs + (c,) = output_storage c[0] = b[1:3] def test_badviewmap_ref(self): @@ -452,9 +452,9 @@ def make_node(self, a, b): d = a.type() return Apply(self, [a, b], [c, d]) - def perform(self, node, inp, out): - a, _b = inp - c, d = out + def perform(self, node, inputs, output_storage): + a, _b = inputs + c, d = output_storage c[0] = a d[0] = a[1:] @@ -476,9 +476,9 @@ def make_node(self, a, b): d = a.type() return Apply(self, [a, b], [c, d]) - def perform(self, node, inp, out): - a, _b = inp - c, d = out + def perform(self, node, inputs, output_storage): + a, _b = inputs + c, d = output_storage r = a * 2 c[0] = r d[0] = r[1:] @@ -502,9 +502,9 @@ def make_node(self, a, b): d = a.type() return Apply(self, [a, b], [c, d]) - def perform(self, node, inp, out): - a, _b = inp - c, d = out + def perform(self, node, inputs, output_storage): + a, _b = inputs + c, d = output_storage r = a * 1 c[0] = r d[0] = r[1:] @@ -527,9 +527,9 @@ def make_node(self, a, b): d = a.type() return Apply(self, [a, b], [c, d]) - def perform(self, node, inp, out): - a, _b = inp - c, d = out + def perform(self, node, inputs, output_storage): + a, _b = inputs + c, d = output_storage r = a * 1 c[0] = r[:-1] d[0] = r[1:] @@ -618,10 +618,10 @@ def make_node(self, a, b): r = Apply(self, [a, b], [a.type()]) return r - def perform(self, node, inp, out_): + def perform(self, node, inputs, output_storage): # print 'executing python perform' - a, b = inp - (out,) = out_ + a, b = inputs + (out,) = output_storage z = a + b # print 'out[0] was:', out[0] out[0] = z @@ -714,9 +714,9 @@ def make_node(self, v): out_c_type = type_class(dtype=v.dtype, shape=(None, 1)) return Apply(self, [v], [out_r_type(), out_c_type()]) - def perform(self, node, inp, out): - (v,) = inp - r, c = out + def perform(self, node, inputs, output_storage): + (v,) = inputs + r, c = output_storage lv = v.shape[0] if (r[0] is None) or (r[0].shape != (1, lv)): r[0] = np.empty((1, lv), dtype=node.outputs[0].type.dtype) diff --git a/tests/graph/rewriting/test_kanren.py b/tests/graph/rewriting/test_kanren.py index 7cb66a4ba0..db26f0735e 100644 --- a/tests/graph/rewriting/test_kanren.py +++ b/tests/graph/rewriting/test_kanren.py @@ -101,9 +101,9 @@ def make_node(self, *inputs): outputs = [MyType()(), MyType()()] return Apply(self, list(inputs), outputs) - def perform(self, node, inputs, outputs): - outputs[0] = np.array(inputs[0]) - outputs[1] = np.array(inputs[0]) + def perform(self, node, inputs, output_storage): + output_storage[0] = np.array(inputs[0]) + output_storage[1] = np.array(inputs[0]) x = MyVariable("x") y = MyVariable("y") diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index 5ce8d04105..6c15aaa52e 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -29,7 +29,7 @@ def __init__(self, a): def make_node(self, *inputs): return Apply(self, list(inputs), [pt.vector()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): raise NotImplementedError() @@ -40,7 +40,7 @@ def __init__(self, a): def make_node(self, *inputs): return Apply(self, list(inputs), [pt.vector()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): raise NotImplementedError() @@ -136,9 +136,9 @@ def make_node(self, *inputs): outputs = [MyType()(), MyType()()] return Apply(self, list(inputs), outputs) - def perform(self, node, inputs, outputs): - outputs[0] = np.array(inputs[0]) - outputs[1] = np.array(inputs[0]) + def perform(self, node, inputs, output_storage): + output_storage[0] = np.array(inputs[0]) + output_storage[1] = np.array(inputs[0]) x_pt = pt.vector("x") op1_np = MyMultiOutOp() diff --git a/tests/graph/test_compute_test_value.py b/tests/graph/test_compute_test_value.py index c535cc97cc..ef403dace0 100644 --- a/tests/graph/test_compute_test_value.py +++ b/tests/graph/test_compute_test_value.py @@ -62,8 +62,8 @@ def __init__(self, inplace): def make_node(self, input): return Apply(self, [input], [input.type()]) - def perform(self, node, inputs, outputs): - outputs[0][0] = inputs[0] + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] test_input = SomeType()() orig_object = object() @@ -282,9 +282,9 @@ def make_node(self, input): output = input.type() return Apply(self, [input], [output]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (input,) = inputs - (output,) = outputs + (output,) = output_storage output[0] = input + 1 i = ps.int32("i") diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index 7dfb8f2434..282aa7e27f 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -118,9 +118,9 @@ class DoubleOp(Op): itypes = [dmatrix] otypes = [dmatrix] - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): inp = inputs[0] - output = outputs[0] + output = output_storage[0] output[0] = inp * 2 x_input = dmatrix("x_input") @@ -213,7 +213,7 @@ class TestOp(pytensor.graph.op.Op): itypes = [dvector, dvector, dvector] otypes = [dvector] - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): pass msg = r"^Invalid input types for Op.*" @@ -255,7 +255,7 @@ def make_node(self, input): outputs = [input.type()] return Apply(self, inputs, outputs) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): raise NotImplementedError() if multi_output: diff --git a/tests/graph/utils.py b/tests/graph/utils.py index 74fe9235f9..d3dc23a113 100644 --- a/tests/graph/utils.py +++ b/tests/graph/utils.py @@ -70,8 +70,8 @@ def make_node(self, *inputs): outputs = [MyType()() for i in range(self.n_outs)] return Apply(self, inputs, outputs) - def perform(self, node, inputs, outputs): - outputs[0] = np.array(inputs, dtype=object) + def perform(self, node, inputs, output_storage): + output_storage[0] = np.array(inputs, dtype=object) def __str__(self): return self.name diff --git a/tests/link/c/test_basic.py b/tests/link/c/test_basic.py index 33fa19afd7..da80018a22 100644 --- a/tests/link/c/test_basic.py +++ b/tests/link/c/test_basic.py @@ -106,8 +106,8 @@ def make_node(self, *inputs): def __str__(self): return self.name - def perform(self, node, inputs, out_): - (out,) = out_ + def perform(self, node, inputs, output_storage): + (out,) = output_storage out[0] = self.impl(*inputs) def c_code_cache_version(self): @@ -603,8 +603,8 @@ def make_node(self, *inputs): def __str__(self): return self.name - def perform(self, node, inputs, out_): - (out,) = out_ + def perform(self, node, inputs, output_storage): + (out,) = output_storage out[0] = sum(*inputs) def c_code_cache_version(self): diff --git a/tests/link/c/test_cmodule.py b/tests/link/c/test_cmodule.py index 5ebcd19d57..51efd5d410 100644 --- a/tests/link/c/test_cmodule.py +++ b/tests/link/c/test_cmodule.py @@ -56,8 +56,8 @@ def make_node(self, *inputs): outputs = [vector()] return Apply(self, inputs, outputs) - def perform(self, node, inputs, out_): - (out,) = out_ + def perform(self, node, inputs, output_storage): + (out,) = output_storage out[0] = inputs[0][0] + 1 def c_code(self, node, name, inp, out, sub): diff --git a/tests/link/c/test_op.py b/tests/link/c/test_op.py index 4cf6058a78..d4aed3c26c 100644 --- a/tests/link/c/test_op.py +++ b/tests/link/c/test_op.py @@ -123,9 +123,9 @@ def make_node(self, input): output = input.type() return Apply(self, [input], [output]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (input,) = inputs - (output,) = outputs + (output,) = output_storage output[0] = input + 1 i = ps.int32("i") diff --git a/tests/link/c/test_type.py b/tests/link/c/test_type.py index 26f1a07f53..dc24a03a8d 100644 --- a/tests/link/c/test_type.py +++ b/tests/link/c/test_type.py @@ -112,10 +112,10 @@ def get_params(self, node): def make_node(self, a, b): return Apply(self, [ps.as_scalar(a), ps.as_scalar(b)], [ps.float64()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): op = self.params_type.filter(self.get_params(node)) a, b = inputs - (o,) = outputs + (o,) = output_storage if op == self.params_type.ADD: o[0] = a + b elif op == self.params_type.SUB: diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index 4a6eee1890..5f86fb92a2 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -109,9 +109,9 @@ def __init__(self): def make_node(self, *args): return Apply(self, list(args), [x.type() for x in args]) - def perform(self, inputs, outputs): + def perform(self, inputs, output_storage): for i, inp in enumerate(inputs): - outputs[i][0] = inp[0] + output_storage[i][0] = inp[0] @jax_funcify.register(TestOp) def jax_funcify_TestOp(op, **kwargs): diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index 01a12914fd..753b9c8f9b 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -357,8 +357,8 @@ def make_node(self, x): x = pt.as_tensor_variable(x) return Apply(self, [x], [pt.tensor(shape=(None,) * x.type.ndim)]) - def perform(self, node, inputs, outputs): - outputs[0][0] = inputs[0] + 1 + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] + 1 @jax_funcify.register(IncWithoutStaticShape) def _(op, **kwargs): diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index a273bd53db..8150ddde1e 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -60,9 +60,9 @@ class MySingleOut(Op): def make_node(self, a, b): return Apply(self, [a, b], [a.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): res = (inputs[0] + inputs[1]).astype(inputs[0][0].dtype) - outputs[0][0] = res + output_storage[0][0] = res class ScalarMyMultiOut(ScalarOp): @@ -80,10 +80,10 @@ def make_node(self, a, b): b = as_scalar(b) return Apply(self, [a, b], [a.type(), b.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): res1, res2 = self.impl(inputs[0], inputs[1]) - outputs[0][0] = res1 - outputs[1][0] = res2 + output_storage[0][0] = res1 + output_storage[1][0] = res2 scalar_my_multi_out = Elemwise(ScalarMyMultiOut()) @@ -105,10 +105,10 @@ def impl(a, b): def make_node(self, a, b): return Apply(self, [a, b], [a.type(), b.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): res1, res2 = self.impl(inputs[0], inputs[1]) - outputs[0][0] = res1 - outputs[1][0] = res2 + output_storage[0][0] = res1 + output_storage[1][0] = res2 my_multi_out = Elemwise(MyMultiOut()) @@ -599,18 +599,18 @@ class BaseOp(Op): otypes = [pt.dscalar] class FuncifiedOp(BaseOp): - def perform(self, node, inputs, outputs): - outputs[0][0] = inputs[0] + 1 + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] + 1 class FuncifiedAndCachedOp(BaseOp): - def perform(self, node, inputs, outputs): - outputs[0][0] = inputs[0] * 2 + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] * 2 class FuncifiedAndDefaultCachedOp(BaseOp): __props__ = () - def perform(self, node, inputs, outputs): - outputs[0][0] = inputs[0] - 3 + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] - 3 @numba_basic.numba_funcify.register(FuncifiedOp) def _(op, node, **kwargs): diff --git a/tests/link/test_link.py b/tests/link/test_link.py index dafaaace53..bad4ad2183 100644 --- a/tests/link/test_link.py +++ b/tests/link/test_link.py @@ -93,8 +93,8 @@ def make_node(self, *inputs): def __str__(self): return self.name - def perform(self, node, inputs, out_): - (out,) = out_ + def perform(self, node, inputs, output_storage): + (out,) = output_storage out[0] = self.impl(*inputs) diff --git a/tests/link/test_vm.py b/tests/link/test_vm.py index 517fbe940b..f9bc52d0e2 100644 --- a/tests/link/test_vm.py +++ b/tests/link/test_vm.py @@ -22,7 +22,7 @@ class SomeOp(Op): - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): pass def make_node(self, x): @@ -298,10 +298,10 @@ def __init__(self): def make_node(self, x): return Apply(self, [x], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): assert self.nb_run == 0 self.nb_run += 1 - outputs[0][0] = inputs[0].copy() + output_storage[0][0] = inputs[0].copy() def test_vm_gc(): @@ -406,7 +406,7 @@ def test_VMLinker_make_vm_no_cvm(): def test_VMLinker_exception(): class BadOp(Op): - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): pass def make_node(self, x): diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 0f4b0db607..37622ae103 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -232,8 +232,8 @@ class MultiOutOp(ScalarOp): def make_node(self, x): return Apply(self, [x], [x.type(), x.type()]) - def perform(self, node, inputs, outputs): - outputs[1][0] = outputs[0][0] = inputs[0] + def perform(self, node, inputs, output_storage): + output_storage[1][0] = output_storage[0][0] = inputs[0] def c_code(self, *args): return "dummy" diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index a8847b3cf6..1fc5b15aff 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -2341,7 +2341,7 @@ class MyOp(Op): def make_node(self, input): return Apply(self, [input], [vector()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): raise Exception("blah") # def c_code(self, node, name, inputs, outputs, sub): diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 3117932fc1..01d7a05bd1 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -297,9 +297,9 @@ def make_node(self, x): x = as_sparse_variable(x) return Apply(self, [x], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (out,) = outputs + (out,) = output_storage assert _is_sparse(x) out[0] = -x diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index 43df9ffd23..8c60884629 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -157,9 +157,9 @@ def make_node(self, x): x = as_tensor_variable(x) return Apply(self, [x], [x.type()]) - def perform(self, node, inp, out_): - (x,) = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + (x,) = inputs + (out,) = output_storage out[0] = x.copy() # def infer_shape(self, fgraph, node, (xshp,)): @@ -174,9 +174,9 @@ def make_node(self, x): x = as_tensor_variable(x) return Apply(self, [x], [x.type()]) - def perform(self, node, inp, out_): - (x,) = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + (x,) = inputs + (out,) = output_storage out[0] = x.copy() def infer_shape(self, fgraph, node, xshp_): diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 1af02dfb54..2f1fcc9b69 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -59,13 +59,13 @@ class NodeDependentPerformOp(Op): def make_node(self, x): return Apply(self, [x], [x.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): [x] = inputs if node.inputs[0].type.dtype.startswith("float"): y = x + 1 else: y = x - 1 - outputs[0][0] = y + output_storage[0][0] = y blockwise_op = Blockwise(core_op=NodeDependentPerformOp(), signature="()->()") x = tensor("x", shape=(3,), dtype="float32") @@ -306,9 +306,9 @@ def make_node(self, a, b): d = tensor(shape=(None,)) return Apply(self, [a_identity, b_identity], [c, d]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): a, b = inputs - c, d = outputs + c, d = output_storage c[0] = np.arange(a.size + b.size, dtype=config.floatX) d[0] = np.arange(a.sum() + b.sum(), dtype=config.floatX) @@ -694,7 +694,7 @@ class MixedDtypeCoreOp(Op): itypes = [scalar().type] otypes = [scalar().type, scalar(dtype=int).type] - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): raise NotImplementedError() def L_op(self, inputs, outputs, output_gradients): @@ -774,7 +774,7 @@ def inplace_on_inputs(self, allowed_inplace_inputs): def make_node(self, x, y, z): return Apply(self, [x, y, z], [x.type(), y.type(), z.type()]) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): [x, y, z] = inputs if 0 not in self.inplace: x = x.copy() @@ -782,9 +782,9 @@ def perform(self, node, inputs, outputs): y = y.copy() if 2 not in self.inplace: z = z.copy() - outputs[0][0] = x - outputs[1][0] = y - outputs[2][0] = z + output_storage[0][0] = x + output_storage[1][0] = y + output_storage[2][0] = z core_op = CoreOp(inplace=()) blockwise_op = Blockwise(core_op, signature="(),(),()->(),(),()") diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index b8a9ee4e0d..d70747ed3a 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -1108,7 +1108,7 @@ def make_node(self, *inputs): outputs = [float_op(), int_op()] return Apply(self, inputs, outputs) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): raise NotImplementedError() def L_op(self, inputs, outputs, output_gradients): diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index e447926532..776f2ad815 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -203,9 +203,9 @@ def test_eigvalsh_grad(): class TestSolveBase: class SolveTest(SolveBase): - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): A, b = inputs - outputs[0][0] = scipy.linalg.solve(A, b) + output_storage[0][0] = scipy.linalg.solve(A, b) @pytest.mark.parametrize( "A_func, b_func, error_message", diff --git a/tests/test_rop.py b/tests/test_rop.py index 2e7d4691bb..50cd27c87f 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -45,9 +45,9 @@ class BreakRop(Op): def make_node(self, x): return Apply(self, [x], [x.type()]) - def perform(self, node, inp, out_): - (x,) = inp - (out,) = out_ + def perform(self, node, inputs, output_storage): + (x,) = inputs + (out,) = output_storage out[0] = x def grad(self, inp, grads): From 666da8d3b8e97036396f19a4507e92471b453482 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 22:14:37 +0100 Subject: [PATCH 14/17] Standardize Op.L_op arguments --- pytensor/scalar/basic.py | 148 ++++++++++++++++----------------- pytensor/scalar/math.py | 40 ++++----- pytensor/scan/op.py | 84 ++++++++++--------- pytensor/tensor/basic.py | 14 ++-- pytensor/tensor/blockwise.py | 6 +- pytensor/tensor/elemwise.py | 6 +- pytensor/tensor/extra_ops.py | 4 +- pytensor/tensor/math.py | 20 ++--- pytensor/tensor/nlinalg.py | 4 +- pytensor/tensor/slinalg.py | 16 ++-- pytensor/tensor/special.py | 6 +- pytensor/xtensor/basic.py | 12 +-- tests/tensor/test_blockwise.py | 4 +- tests/tensor/test_elemwise.py | 4 +- tests/tensor/test_optimize.py | 8 +- 15 files changed, 189 insertions(+), 187 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index af0b0b7173..d74924b749 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1275,8 +1275,8 @@ def impl(self, *inputs): def grad(self, inputs, output_gradients): raise MethodNotDefined("grad", type(self), self.__class__.__name__) - def L_op(self, inputs, outputs, output_gradients): - return self.grad(inputs, output_gradients) + def L_op(self, inputs, outputs, output_grads): + return self.grad(inputs, output_grads) def __eq__(self, other): return type(self) is type(other) and getattr( @@ -1409,7 +1409,7 @@ def __hash__(self): def output_types(self, *input_dtypes): return [bool] if getattr(self, "bool", False) else [int8] - def L_op(self, inputs, outputs, output_gradients): + def L_op(self, inputs, outputs, output_grads): x, y = inputs assert outputs[0].type == bool return [ @@ -1445,7 +1445,7 @@ def __hash__(self): def output_types(self, *input_dtypes): return [bool] if getattr(self, "bool", False) else [int8] - def L_op(self, inputs, outputs, output_gradients): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs assert outputs[0].type == bool return [x.zeros_like(dtype=config.floatX)] @@ -1644,9 +1644,9 @@ def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs return f"{z} = {cond} ? {ift} : {iff};" - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (cond, ift, iff) = inputs - (gz,) = gout + (gz,) = output_grads first_part = switch(cond, gz, 0.0) second_part = switch(cond, 0.0, gz) @@ -1806,9 +1806,9 @@ def c_code(self, node, name, inputs, outputs, sub): # Test for both y>x and x>=y to detect NaN return f'{z} = (({y})>({x})? ({y}): (({x})>=({y})? ({x}): nan("")));' - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: # max is currently defined for complex_types, # but the gradient for complex is not. @@ -1852,9 +1852,9 @@ def c_code(self, node, name, inputs, outputs, sub): raise NotImplementedError() return f'{z} = (({y})<({x})? ({y}): (({x})<=({y})? ({x}): nan("")));' - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: # min is currently defined for complex_types, # but the gradient for complex is not. @@ -1900,8 +1900,8 @@ def c_code(self, node, name, inputs, outputs, sub): else: return z + " = " + op.join(inputs) + ";" - def L_op(self, inputs, outputs, gout): - (gz,) = gout + def L_op(self, inputs, outputs, output_grads): + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -1995,9 +1995,9 @@ def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs return f"{z} = {x} - {y};" - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -2282,9 +2282,9 @@ def c_code(self, node, name, inputs, outputs, sub): """ ) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads if outputs[0].type in discrete_types: # The gradient does not flow in if the output is discrete return [ @@ -2310,9 +2310,9 @@ def c_code(self, node, name, inputs, outputs, sub): raise NotImplementedError("type not supported", type) return f"{z} = pow({x}, {y});" - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() @@ -2400,9 +2400,9 @@ def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs return f"{z} = {x} < {min} ? {min} : {x} > {max} ? {max} : {x};" - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x, mn, mx) = inputs - (gz,) = gout + (gz,) = output_grads assert gz.type not in complex_types gx = ((x >= mn) & (x <= mx)) * gz gmn = (x < mn) * gz @@ -2589,9 +2589,9 @@ def make_node(self, x): def impl(self, x): return np.abs(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if outputs[0].type in discrete_types: if x.type in discrete_types: return [x.zeros_like(dtype=config.floatX)] @@ -2874,9 +2874,9 @@ class Neg(UnaryScalarOp): def impl(self, x): return -x - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if outputs[0].type in discrete_types: if x.type in discrete_types: return [x.zeros_like(dtype=config.floatX)] @@ -2911,9 +2911,9 @@ class Reciprocal(UnaryScalarOp): def impl(self, x): return np.float32(1.0) / x - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -2953,9 +2953,9 @@ def impl(self, x): return np.log(x, dtype=np.float32) return np.log(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -2999,9 +2999,9 @@ def impl(self, x): return np.log2(x, dtype=np.float32) return np.log2(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3042,9 +3042,9 @@ def impl(self, x): return np.log10(x, dtype=np.float32) return np.log10(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3083,9 +3083,9 @@ def impl(self, x): return np.log1p(x, dtype=np.float32) return np.log1p(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3121,9 +3121,9 @@ def impl(self, x): return np.exp(x, dtype=np.float32) return np.exp(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3157,9 +3157,9 @@ def impl(self, x): return np.exp2(x, dtype=np.float32) return np.exp2(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3193,9 +3193,9 @@ def impl(self, x): return np.expm1(x, dtype=np.float32) return np.expm1(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3227,9 +3227,9 @@ class Sqr(UnaryScalarOp): def impl(self, x): return x * x - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3260,9 +3260,9 @@ def impl(self, x): return np.sqrt(x, dtype=np.float32) return np.sqrt(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3296,9 +3296,9 @@ def impl(self, x): return np.deg2rad(x, dtype=np.float32) return np.deg2rad(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3331,9 +3331,9 @@ def impl(self, x): return np.rad2deg(x, dtype=np.float32) return np.rad2deg(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3368,9 +3368,9 @@ def impl(self, x): return np.cos(x, dtype=np.float32) return np.cos(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3404,9 +3404,9 @@ def impl(self, x): return np.arccos(x, dtype=np.float32) return np.arccos(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3442,9 +3442,9 @@ def impl(self, x): return np.sin(x, dtype=np.float32) return np.sin(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3478,9 +3478,9 @@ def impl(self, x): return np.arcsin(x, dtype=np.float32) return np.arcsin(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3514,9 +3514,9 @@ def impl(self, x): return np.tan(x, dtype=np.float32) return np.tan(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3550,9 +3550,9 @@ def impl(self, x): return np.arctan(x, dtype=np.float32) return np.arctan(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3588,9 +3588,9 @@ def impl(self, y, x): return np.arctan2(y, x, dtype=np.float32) return np.arctan2(y, x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (y, x) = inputs - (gz,) = gout + (gz,) = output_grads if gz.type in complex_types: raise NotImplementedError() else: @@ -3637,9 +3637,9 @@ def impl(self, x): return np.cosh(x, dtype=np.float32) return np.cosh(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3673,9 +3673,9 @@ def impl(self, x): return np.arccosh(x, dtype=np.float32) return np.arccosh(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3714,9 +3714,9 @@ def impl(self, x): return np.sinh(x, dtype=np.float32) return np.sinh(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3750,9 +3750,9 @@ def impl(self, x): return np.arcsinh(x, dtype=np.float32) return np.arcsinh(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3792,9 +3792,9 @@ def impl(self, x): return np.tanh(x, dtype=np.float32) return np.tanh(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -3828,9 +3828,9 @@ def impl(self, x): return np.arctanh(x, dtype=np.float32) return np.arctanh(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index b8beaf2c27..e223c16e03 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -54,9 +54,9 @@ class Erf(UnaryScalarOp): def impl(self, x): return special.erf(x) - def L_op(self, inputs, outputs, grads): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = grads + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -88,9 +88,9 @@ class Erfc(UnaryScalarOp): def impl(self, x): return special.erfc(x) - def L_op(self, inputs, outputs, grads): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = grads + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -137,9 +137,9 @@ class Erfcx(UnaryScalarOp): def impl(self, x): return special.erfcx(x) - def L_op(self, inputs, outputs, grads): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = grads + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -193,9 +193,9 @@ class Erfinv(UnaryScalarOp): def impl(self, x): return special.erfinv(x) - def L_op(self, inputs, outputs, grads): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = grads + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -228,9 +228,9 @@ class Erfcinv(UnaryScalarOp): def impl(self, x): return special.erfcinv(x) - def L_op(self, inputs, outputs, grads): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = grads + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -288,9 +288,9 @@ class Gamma(UnaryScalarOp): def impl(self, x): return special.gamma(x) - def L_op(self, inputs, outputs, gout): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -323,9 +323,9 @@ class GammaLn(UnaryScalarOp): def impl(self, x): return special.gammaln(x) - def L_op(self, inputs, outputs, grads): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = grads + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -363,9 +363,9 @@ class Psi(UnaryScalarOp): def impl(self, x): return special.psi(x) - def L_op(self, inputs, outputs, grads): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gz,) = grads + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() if outputs[0].type in discrete_types: @@ -460,9 +460,9 @@ class TriGamma(UnaryScalarOp): def impl(self, x): return special.polygamma(1, x) - def L_op(self, inputs, outputs, outputs_gradients): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (g_out,) = outputs_gradients + (g_out,) = output_grads if x in complex_types: raise NotImplementedError("gradient not implemented for complex types") return [g_out * polygamma(2, x)] @@ -559,9 +559,9 @@ def output_types_preference(n_type, x_type): def impl(self, n, x): return special.polygamma(n, x) - def L_op(self, inputs, outputs, output_gradients): + def L_op(self, inputs, outputs, output_grads): (n, x) = inputs - (g_out,) = output_gradients + (g_out,) = output_grads if x in complex_types: raise NotImplementedError("gradient not implemented for complex types") return [ diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 6efeddc8bb..b00c9ee083 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -2429,9 +2429,9 @@ def connection_pattern(self, node): node.tag.connection_pattern = connection_pattern return connection_pattern - def L_op(self, inputs, outs, dC_douts): - if not isinstance(outs, list | tuple): - outs = [outs] + def L_op(self, inputs, outputs, output_grads): + if not isinstance(outputs, list | tuple): + outputs = [outputs] # `grad_step` equals the number of steps the original scan node has # done (if the original scan is a while loop than this number is the # length of the output sequence) @@ -2440,17 +2440,18 @@ def L_op(self, inputs, outs, dC_douts): # then a mit_sot info = self.info if info.n_nit_sot > 0: - grad_steps = self.outer_nitsot_outs(outs)[0].shape[0] + grad_steps = self.outer_nitsot_outs(outputs)[0].shape[0] elif info.n_sit_sot > 0: - grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1 + grad_steps = self.outer_sitsot_outs(outputs)[0].shape[0] - 1 elif info.n_mit_sot > 0: grad_steps = ( - self.outer_mitsot_outs(outs)[0].shape[0] + self.mintaps[info.n_mit_mot] + self.outer_mitsot_outs(outputs)[0].shape[0] + + self.mintaps[info.n_mit_mot] ) else: grad_steps = inputs[0] if info.as_while: - n_steps = outs[0].shape[0] + n_steps = outputs[0].shape[0] # Restrict the number of grad steps according to # self.truncate_gradient @@ -2473,7 +2474,7 @@ def L_op(self, inputs, outs, dC_douts): + self.inner_sitsot_outs(self_outputs) + self.inner_nitsot_outs(self_outputs) ) - scan_node = outs[0].owner + scan_node = outputs[0].owner connection_pattern = self.connection_pattern(scan_node) def get_inp_idx(iidx): @@ -2603,8 +2604,10 @@ def compute_all_gradients(known_grads): info.n_seqs + pos ] - if not isinstance(dC_douts[outer_oidx].type, DisconnectedType): - dtypes.append(dC_douts[outer_oidx].dtype) + if not isinstance( + output_grads[outer_oidx].type, DisconnectedType + ): + dtypes.append(output_grads[outer_oidx].dtype) if dtypes: new_dtype = pytensor.scalar.upcast(*dtypes) else: @@ -2614,10 +2617,10 @@ def compute_all_gradients(known_grads): # nit-sot outputs # If not disconnected assume the output gradient type is a valid type for the input gradient if isinstance( - dC_douts[idx - n_extra_mit_mot_outs].type, DisconnectedType + output_grads[idx - n_extra_mit_mot_outs].type, DisconnectedType ): continue - dC_dXt = safe_new(dC_douts[idx - n_extra_mit_mot_outs][0]) + dC_dXt = safe_new(output_grads[idx - n_extra_mit_mot_outs][0]) dC_dXts.append(dC_dXt) # Handle cases where the very same variable may be used as different outputs @@ -2626,7 +2629,7 @@ def compute_all_gradients(known_grads): dc_dxts_idx = 0 for i in range(len(diff_outputs)): if not (i < idx_nitsot_out_start or i >= idx_nitsot_out_end) and isinstance( - dC_douts[i - n_extra_mit_mot_outs].type, DisconnectedType + output_grads[i - n_extra_mit_mot_outs].type, DisconnectedType ): # Special case where we don't have a dC_dXt for disconnected nitsot outputs continue @@ -2706,15 +2709,15 @@ def compute_all_gradients(known_grads): outmaxtap = np.max(info.mit_mot_out_slices[: info.n_mit_mot][idx]) else: outmaxtap = 0 - seq = outs[idx] + seq = outputs[idx] for k in taps: if outmaxtap - k != 0: nw_seq = seq[k - mintap : -(outmaxtap - k)][::-1] else: nw_seq = seq[k - mintap :][::-1] outer_inp_seqs.append(nw_seq) - outer_inp_seqs += [x[:-1][::-1] for x in self.outer_sitsot_outs(outs)] - for x in self.outer_nitsot_outs(dC_douts): + outer_inp_seqs += [x[:-1][::-1] for x in self.outer_sitsot_outs(outputs)] + for x in self.outer_nitsot_outs(output_grads): if not isinstance(x.type, DisconnectedType): if info.as_while: # equivalent to x[:n_steps][::-1] @@ -2733,15 +2736,15 @@ def compute_all_gradients(known_grads): else: n = inputs[0].tag.test_value for taps, x in zip( - info.mit_sot_in_slices, self.outer_mitsot_outs(outs), strict=True + info.mit_sot_in_slices, self.outer_mitsot_outs(outputs), strict=True ): mintap = np.min(taps) if hasattr(x[::-1][:mintap], "test_value"): assert x[::-1][:mintap].tag.test_value.shape[0] == n - for x in self.outer_sitsot_outs(outs): + for x in self.outer_sitsot_outs(outputs): if hasattr(x[::-1][:-1].tag, "test_value"): assert x[::-1][:-1].tag.test_value.shape[0] == n - for x in self.outer_nitsot_outs(outs): + for x in self.outer_nitsot_outs(outputs): if hasattr(x[::-1].tag, "test_value"): if info.as_while: assert x[n_steps - 1 :: -1].tag.test_value.shape[0] == n @@ -2750,11 +2753,11 @@ def compute_all_gradients(known_grads): outer_inp_seqs += [ x[::-1][: np.min(taps)] for taps, x in zip( - info.mit_sot_in_slices, self.outer_mitsot_outs(outs), strict=True + info.mit_sot_in_slices, self.outer_mitsot_outs(outputs), strict=True ) ] - outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)] - outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] + outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outputs)] + outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outputs)] # Restrict the length of the outer sequences to the number of grad # steps @@ -2779,11 +2782,11 @@ def compute_all_gradients(known_grads): n_mitmot_inps = 0 for idx, taps in enumerate(info.mit_mot_in_slices): - if isinstance(dC_douts[idx].type, DisconnectedType): - out = outs[idx] + if isinstance(output_grads[idx].type, DisconnectedType): + out = outputs[idx] outer_inp_mitmot.append(pt.zeros_like(out)) else: - outer_inp_mitmot.append(dC_douts[idx][::-1]) + outer_inp_mitmot.append(output_grads[idx][::-1]) mitmot_inp_taps.append([]) mitmot_out_taps.append([]) undefined_msg = None @@ -2856,10 +2859,10 @@ def compute_all_gradients(known_grads): offset = info.n_mit_mot for idx, taps in enumerate(info.mit_sot_in_slices): - if isinstance(dC_douts[idx + offset].type, DisconnectedType): - outer_inp_mitmot.append(outs[idx + offset].zeros_like()) + if isinstance(output_grads[idx + offset].type, DisconnectedType): + outer_inp_mitmot.append(outputs[idx + offset].zeros_like()) else: - outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) + outer_inp_mitmot.append(output_grads[idx + offset][::-1]) mitmot_inp_taps.append([]) mitmot_out_taps.append([]) inner_inp_mitmot.append(dC_dXts[out_pos]) @@ -2910,20 +2913,20 @@ def compute_all_gradients(known_grads): for idx in range(info.n_sit_sot): mitmot_inp_taps.append([0, 1]) mitmot_out_taps.append([1]) - if not isinstance(dC_douts[idx + offset].type, DisconnectedType): - outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) + if not isinstance(output_grads[idx + offset].type, DisconnectedType): + outer_inp_mitmot.append(output_grads[idx + offset][::-1]) else: if isinstance(dC_dinps_t[ins_pos].type, NullType): # Cannot use dC_dinps_t[ins_pos].dtype, so we use # floatX instead, as it is a dummy value that will not # be used anyway. outer_inp_mitmot.append( - pt.zeros(outs[idx + offset].shape, dtype=config.floatX) + pt.zeros(outputs[idx + offset].shape, dtype=config.floatX) ) else: outer_inp_mitmot.append( pt.zeros( - outs[idx + offset].shape, dtype=dC_dinps_t[ins_pos].dtype + outputs[idx + offset].shape, dtype=dC_dinps_t[ins_pos].dtype ) ) @@ -3071,14 +3074,14 @@ def compute_all_gradients(known_grads): name=f"grad_of_{self.name}" if self.name else None, allow_gc=self.allow_gc, ) - outputs = local_op(*outer_inputs, return_list=True) + outs = local_op(*outer_inputs, return_list=True) # Re-order the gradients correctly gradients = [DisconnectedType()()] offset = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + n_sitsot_outs for p, (x, t) in enumerate( zip( - outputs[offset : offset + info.n_seqs], + outs[offset : offset + info.n_seqs], type_outs[offset : offset + info.n_seqs], strict=True, ) @@ -3110,7 +3113,7 @@ def compute_all_gradients(known_grads): gradients.append(NullType(t)()) end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot - for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end], strict=True)): + for p, (x, t) in enumerate(zip(outs[:end], type_outs[:end], strict=True)): if t == "connected": # If the forward scan is in as_while mode, we need to pad # the gradients, so that they match the size of the input @@ -3141,11 +3144,10 @@ def compute_all_gradients(known_grads): gradients.append(NullType(t)()) start = len(gradients) - node = outs[0].owner for idx in range(info.n_untraced_sit_sot_outs): disconnected = True - connected_flags = self.connection_pattern(node)[idx + start] - for dC_dout, connected in zip(dC_douts, connected_flags, strict=True): + connected_flags = self.connection_pattern(scan_node)[idx + start] + for dC_dout, connected in zip(output_grads, connected_flags, strict=True): if not isinstance(dC_dout.type, DisconnectedType) and connected: disconnected = False if disconnected: @@ -3162,7 +3164,7 @@ def compute_all_gradients(known_grads): end = begin + n_sitsot_outs for p, (x, t) in enumerate( - zip(outputs[begin:end], type_outs[begin:end], strict=True) + zip(outs[begin:end], type_outs[begin:end], strict=True) ): if t == "connected": gradients.append(x[-1]) @@ -3189,9 +3191,9 @@ def compute_all_gradients(known_grads): # because through the recurrence they can become nonzero for idx in range(len(gradients)): disconnected = True - for kdx in range(len(node.outputs)): + for kdx in range(len(scan_node.outputs)): if connection_pattern[idx][kdx] and not isinstance( - dC_douts[kdx].type, DisconnectedType + output_grads[kdx].type, DisconnectedType ): disconnected = False if disconnected: diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 82f87e6202..5e71c6813e 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2285,12 +2285,12 @@ def infer_shape(self, fgraph, node, in_shapes): out_shapes.append(temp) return out_shapes - def L_op(self, inputs, outputs, g_outputs): + def L_op(self, inputs, outputs, output_grads): """Join the gradients along the axis that was used to split x.""" _x, axis, n = inputs # If all the output gradients are disconnected, then so are the inputs - if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs): + if builtins.all(isinstance(g.type, DisconnectedType) for g in output_grads): return [ DisconnectedType()(), grad_undefined(self, 1, axis), @@ -2298,7 +2298,7 @@ def L_op(self, inputs, outputs, g_outputs): ] # Else, we have to make them zeros before joining them new_g_outputs = [] - for o, g in zip(outputs, g_outputs, strict=True): + for o, g in zip(outputs, output_grads, strict=True): if isinstance(g.type, DisconnectedType): new_g_outputs.append(o.zeros_like()) else: @@ -2692,11 +2692,11 @@ def R_op(self, inputs, eval_points): return [None] return self.make_node(inputs[0], *eval_points[1:]).outputs - def L_op(self, inputs, outputs, grads): + def L_op(self, inputs, outputs, output_grads): """The gradient wrt a join op is a `Split`, used to partition the gradient along the `axis` which was used for joining. """ - [gz] = grads + [gz] = output_grads [out] = outputs axis, *tensors = inputs @@ -3356,9 +3356,9 @@ def c_code_cache_version(self): def connection_pattern(self, node): return [[True], [False], [True]] - def L_op(self, inputs, outputs, grads): + def L_op(self, inputs, outputs, output_grads): start, _stop, step = inputs - (gz,) = grads + (gz,) = output_grads # `start` and `step` affect the output values # but the outputs are integers so there's # no gradient through them. diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index fea9e768d3..035684083e 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -409,7 +409,7 @@ def connection_pattern(self, node): return [[True for _ in node.outputs] for _ in node.inputs] - def L_op(self, inputs, outputs, output_gradients): + def L_op(self, inputs, outputs, output_grads): batch_ndim = self.batch_ndim(outputs[0].owner) # Obtain core_op gradients @@ -429,7 +429,7 @@ def L_op(self, inputs, outputs, output_gradients): if isinstance(output_grad.type, NullType | DisconnectedType) else core_output.type() for output_grad, core_output in zip( - output_gradients, core_outputs, strict=True + output_grads, core_outputs, strict=True ) ] @@ -444,7 +444,7 @@ def L_op(self, inputs, outputs, output_gradients): replace=dict( zip( core_inputs + core_outputs + core_output_gradients, - inputs + outputs + output_gradients, + inputs + outputs + output_grads, strict=True, ) ), diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index eabd9ca03e..b495b3aa1c 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -509,11 +509,11 @@ def connection_pattern(self, node): return [[True for output in node.outputs] for ipt in node.inputs] - def L_op(self, inputs, outs, ograds): + def L_op(self, inputs, outputs, output_grads): from pytensor.tensor.math import sum as pt_sum # Compute grad with respect to broadcasted input - rval = self._bgrad(inputs, outs, ograds) + rval = self._bgrad(inputs, outputs, output_grads) # sum out the broadcasted dimensions for i, ipt in enumerate(inputs): @@ -526,7 +526,7 @@ def L_op(self, inputs, outs, ograds): to_sum = [ j for j, in_s in enumerate(ipt.type.shape) - if in_s == 1 and outs[0].type.shape[j] != 1 + if in_s == 1 and outputs[0].type.shape[j] != 1 ] if to_sum: diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index ae35794a80..97150e3774 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -322,9 +322,9 @@ def perform(self, node, inputs, output_storage): else: z[0] = np.cumprod(x, axis=self.axis) - def L_op(self, inputs, outputs, output_gradients): + def L_op(self, inputs, outputs, output_grads): (x,) = inputs - (gi,) = output_gradients + (gi,) = output_grads reverse_slicing = [slice(None, None, None)] * gi.ndim reverse_slicing[self.axis] = slice(None, None, -1) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 6d9bb8e082..ba45f834d1 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -408,7 +408,7 @@ def clone(self, **kwargs): axis = kwargs.get("axis", self.axis) return type(self)(axis=axis) - def L_op(self, inputs, outputs, grads): + def L_op(self, inputs, outputs, output_grads): # The strict sense mathematical gradient of the maximum function is # not calculated here for it is not defined at every point where some # coordinates are identical. However, since the latter set has null @@ -424,7 +424,7 @@ def L_op(self, inputs, outputs, grads): # does it automatically [x] = inputs [out] = outputs - [g_out] = grads + [g_out] = output_grads axis = tuple(range(x.ndim)) if self.axis is None else self.axis out_pad = expand_dims(out, axis) @@ -3454,13 +3454,13 @@ def __init__(self, axis=None, dtype=None, acc_dtype=None): upcast_discrete_output=True, ) - def L_op(self, inp, out, grads): - (x,) = inp + def L_op(self, inputs, outputs, output_grads): + (x,) = inputs - if out[0].dtype not in continuous_dtypes: + if outputs[0].dtype not in continuous_dtypes: return [x.zeros_like(dtype=config.floatX)] - (gz,) = grads + (gz,) = output_grads gz = as_tensor_variable(gz) axis = self.axis if axis is None: @@ -3545,7 +3545,7 @@ def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=Fals ) self.no_zeros_in_input = no_zeros_in_input - def L_op(self, inp, out, grads): + def L_op(self, inputs, outputs, output_grads): """ The grad of this Op could be very easy, if it is was not for the case where zeros are present in a given "group" (ie. elements reduced @@ -3591,10 +3591,10 @@ def L_op(self, inp, out, grads): based on the result of this count. """ - (prod_in,) = inp - (gz,) = grads + (prod_in,) = inputs + (gz,) = output_grads - if out[0].dtype in discrete_dtypes or self.acc_dtype in discrete_dtypes: + if outputs[0].dtype in discrete_dtypes or self.acc_dtype in discrete_dtypes: # There is an int conversion in the way return [prod_in.zeros_like(dtype=config.floatX)] diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index a7d939f40b..640cacd70b 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -49,7 +49,7 @@ def perform(self, node, inputs, output_storage): (z,) = output_storage z[0] = np.linalg.pinv(x, hermitian=self.hermitian) - def L_op(self, inputs, outputs, g_outputs): + def L_op(self, inputs, outputs, output_grads): r"""The gradient function should return .. math:: V\frac{\partial X^+}{\partial X}, @@ -63,7 +63,7 @@ def L_op(self, inputs, outputs, g_outputs): """ (x,) = inputs (z,) = outputs - (gz,) = g_outputs + (gz,) = output_grads x_dot_z = ptm.dot(x, z) z_dot_x = ptm.dot(z, x) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index eb232a9bff..f91caf14b3 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -119,7 +119,7 @@ def perform(self, node, inputs, output_storage): # Transpose result if input was transposed out[0] = c.T if c_contiguous_input else c - def L_op(self, inputs, outputs, gradients): + def L_op(self, inputs, outputs, output_grads): """ Cholesky decomposition reverse-mode gradient update. @@ -132,7 +132,7 @@ def L_op(self, inputs, outputs, gradients): """ - dz = gradients[0] + dz = output_grads[0] chol_x = outputs[0] # Replace the cholesky decomposition with 1 if there are nans @@ -309,7 +309,7 @@ def infer_shape(self, fgraph, node, shapes): cols = Bshape[1] return [(rows, cols)] - def L_op(self, inputs, outputs, output_gradients): + def L_op(self, inputs, outputs, output_grads): r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`. Symbolic expression for updates taken from [#]_. @@ -327,7 +327,7 @@ def L_op(self, inputs, outputs, output_gradients): # C is a scalar representing the entire graph # `output_gradients` is (dC/dc,) # We need to return (dC/d[inv(A)], dC/db) - c_bar = output_gradients[0] + c_bar = output_grads[0] props_dict = self._props_dict() props_dict["lower"] = not self.lower @@ -742,9 +742,9 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = LU output_storage[1][0] = p - def L_op(self, inputs, outputs, output_gradients): + def L_op(self, inputs, outputs, output_grads): [A] = inputs - LU_bar, _ = output_gradients + LU_bar, _ = output_grads LU, p_indices = outputs eye = ptb.identity_like(A) @@ -950,8 +950,8 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = x - def L_op(self, inputs, outputs, output_gradients): - res = super().L_op(inputs, outputs, output_gradients) + def L_op(self, inputs, outputs, output_grads): + res = super().L_op(inputs, outputs, output_grads) if self.lower: res[0] = ptb.tril(res[0]) diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index f0fa8a1cf7..4b874e726e 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -273,9 +273,9 @@ def perform(self, node, inputs, output_storage): (z,) = output_storage z[0] = scipy.special.softmax(x, axis=self.axis) - def L_op(self, inp, outputs, grads): - (_x,) = inp - (g_sm,) = grads + def L_op(self, inputs, outputs, output_grads): + (_x,) = inputs + (g_sm,) = output_grads return [SoftmaxGrad(axis=self.axis)(g_sm, outputs[0])] def R_op(self, inputs, eval_points): diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index df64fcd9db..9425eb642a 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -31,9 +31,9 @@ def make_node(self, x): output = TensorType(x.type.dtype, shape=x.type.shape)() return Apply(self, [x], [output]) - def L_op(self, inputs, outs, g_outs): + def L_op(self, inputs, outputs, output_grads): [x] = inputs - [g_out] = g_outs + [g_out] = output_grads return [xtensor_from_tensor(g_out, dims=x.type.dims)] @@ -53,8 +53,8 @@ def make_node(self, x): output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape) return Apply(self, [x], [output]) - def L_op(self, inputs, outs, g_outs): - [g_out] = g_outs + def L_op(self, inputs, outputs, output_grads): + [g_out] = output_grads return [tensor_from_xtensor(g_out)] @@ -74,9 +74,9 @@ def make_node(self, x): output = x.type.clone(dims=self.new_dims)() return Apply(self, [x], [output]) - def L_op(self, inputs, outs, g_outs): + def L_op(self, inputs, outputs, output_grads): [x] = inputs - [g_out] = g_outs + [g_out] = output_grads return [rename(g_out, dims=x.type.dims)] diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 2f1fcc9b69..976b9501d8 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -697,8 +697,8 @@ class MixedDtypeCoreOp(Op): def perform(self, node, inputs, output_storage): raise NotImplementedError() - def L_op(self, inputs, outputs, output_gradients): - return [ones_like(inputs[0]) * output_gradients[0]] + def L_op(self, inputs, outputs, output_grads): + return [ones_like(inputs[0]) * output_grads[0]] op = Blockwise(MixedDtypeCoreOp()) x = vector("x") diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index d70747ed3a..16e3ad95f7 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -1111,8 +1111,8 @@ def make_node(self, *inputs): def perform(self, node, inputs, output_storage): raise NotImplementedError() - def L_op(self, inputs, outputs, output_gradients): - return [inputs[0].ones_like() * output_gradients[0]] + def L_op(self, inputs, outputs, output_grads): + return [inputs[0].ones_like() * output_grads[0]] op = Elemwise(MixedDtypeScalarOp()) x = vector("x") diff --git a/tests/tensor/test_optimize.py b/tests/tensor/test_optimize.py index 6b1691d147..bb5c581f45 100644 --- a/tests/tensor/test_optimize.py +++ b/tests/tensor/test_optimize.py @@ -239,8 +239,8 @@ def perform(self, node, inputs, output_storage): assert x.ndim == 0 output_storage[0][0] = x - def L_op(self, inputs, outputs, out_grads): - return out_grads + def L_op(self, inputs, outputs, output_grads): + return output_grads x = scalar("x") x_check = AssertScalar()(x) @@ -313,9 +313,9 @@ def connection_pattern(self, node): # Gradient connected only to first input return [[True], [False]] - def L_op(self, inputs, outputs, output_gradients): + def L_op(self, inputs, outputs, output_grads): [_x, str_emoji] = inputs - [g] = output_gradients + [g] = output_grads return [ self(g, str_emoji), disconnected_type(), From b59c24ecb71dd68b172e3cddba5d4fa680de3779 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 22:33:02 +0100 Subject: [PATCH 15/17] Standardize Op.grad arguments --- pytensor/breakpoint.py | 4 +- pytensor/compile/ops.py | 4 +- pytensor/gradient.py | 20 +++--- pytensor/ifelse.py | 16 ++--- pytensor/link/jax/ops.py | 6 +- pytensor/printing.py | 4 +- pytensor/raise_op.py | 4 +- pytensor/scalar/basic.py | 72 ++++++++++---------- pytensor/scalar/math.py | 84 ++++++++++++------------ pytensor/sparse/basic.py | 74 ++++++++++----------- pytensor/sparse/math.py | 48 +++++++------- pytensor/tensor/basic.py | 46 ++++++------- pytensor/tensor/blas.py | 6 +- pytensor/tensor/conv/abstract_conv.py | 36 +++++----- pytensor/tensor/elemwise.py | 6 +- pytensor/tensor/extra_ops.py | 22 +++---- pytensor/tensor/fourier.py | 4 +- pytensor/tensor/math.py | 22 +++---- pytensor/tensor/nlinalg.py | 8 +-- pytensor/tensor/random/op.py | 2 +- pytensor/tensor/shape.py | 18 ++--- pytensor/tensor/slinalg.py | 8 +-- pytensor/tensor/special.py | 14 ++-- pytensor/tensor/subtensor.py | 28 ++++---- pytensor/tensor/type_other.py | 2 +- pytensor/tensor/xlogx.py | 8 +-- tests/sparse/test_basic.py | 4 +- tests/tensor/conv/c_conv3d_corr3d_ref.py | 22 +++---- tests/tensor/conv/c_conv_corr_ref.py | 22 +++---- tests/tensor/test_elemwise.py | 4 +- tests/test_gradient.py | 30 ++++----- tests/test_rop.py | 4 +- 32 files changed, 327 insertions(+), 325 deletions(-) diff --git a/pytensor/breakpoint.py b/pytensor/breakpoint.py index cb9255d589..e1ca00273a 100644 --- a/pytensor/breakpoint.py +++ b/pytensor/breakpoint.py @@ -143,8 +143,8 @@ def perform(self, node, inputs, output_storage): for i in range(len(output_storage)): output_storage[i][0] = inputs[i + 1] - def grad(self, inputs, output_gradients): - return [DisconnectedType()(), *output_gradients] + def grad(self, inputs, output_grads): + return [DisconnectedType()(), *output_grads] def infer_shape(self, fgraph, inputs, input_shapes): # Return the shape of every input but the condition (first input) diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 1d3ac709a4..9994699b08 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -92,8 +92,8 @@ def make_node(self, x): def infer_shape(self, fgraph, node, input_shapes): return input_shapes - def grad(self, args, g_outs): - return g_outs + def grad(self, inputs, output_grads): + return output_grads view_op = ViewOp() diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 8f6530fad1..d835e030cd 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -2316,8 +2316,8 @@ def _is_zero(x): class ZeroGrad(ViewOp): - def grad(self, args, g_outs): - return [g_out.zeros_like() for g_out in g_outs] + def grad(self, inputs, output_grads): + return [g_out.zeros_like() for g_out in output_grads] def R_op(self, inputs, eval_points): if eval_points[0] is None: @@ -2354,8 +2354,8 @@ def zero_grad(x): class UndefinedGrad(ViewOp): - def grad(self, args, g_outs): - return [grad_undefined(self, i, arg) for i, arg in enumerate(args)] + def grad(self, inputs, output_grads): + return [grad_undefined(self, i, arg) for i, arg in enumerate(inputs)] def R_op(self, inputs, eval_points): return [None] @@ -2392,8 +2392,8 @@ def undefined_grad(x): class DisconnectedGrad(ViewOp): - def grad(self, args, g_outs): - return [disconnected_type() for g_out in g_outs] + def grad(self, inputs, output_grads): + return [disconnected_type() for g_out in output_grads] def R_op(self, inputs, eval_points): return [None] @@ -2447,10 +2447,10 @@ def __init__(self, clip_lower_bound, clip_upper_bound): if not self.clip_upper_bound >= self.clip_lower_bound: raise ValueError("`clip_upper_bound` should be >= `clip_lower_bound`") - def grad(self, args, g_outs): + def grad(self, inputs, output_grads): return [ pytensor.tensor.clip(g_out, self.clip_lower_bound, self.clip_upper_bound) - for g_out in g_outs + for g_out in output_grads ] @@ -2490,8 +2490,8 @@ class GradScale(ViewOp): def __init__(self, multiplier): self.multiplier = multiplier - def grad(self, args, g_outs): - return [self.multiplier * g_out for g_out in g_outs] + def grad(self, inputs, output_grads): + return [self.multiplier * g_out for g_out in output_grads] def grad_scale(x, multiplier): diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index f8e033a431..cbb78bd81c 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -239,10 +239,10 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any): def R_op(self, inputs, eval_points): return self(inputs[0], *eval_points[1:], return_list=True) - def grad(self, ins, grads): - condition = ins[0] - inputs_true_branch = ins[1:][: self.n_outs] - inputs_false_branch = ins[1:][self.n_outs :] + def grad(self, inputs, output_grads): + condition = inputs[0] + inputs_true_branch = inputs[1:][: self.n_outs] + inputs_false_branch = inputs[1:][self.n_outs :] if self.name is not None: nw_name_t = self.name + "_grad_t" @@ -260,19 +260,19 @@ def grad(self, ins, grads): # dtypes. inputs_true_grad = ( [condition] - + grads + + output_grads + [ - pt.basic.zeros_like(t, dtype=grads[i].dtype) + pt.basic.zeros_like(t, dtype=output_grads[i].dtype) for i, t in enumerate(inputs_true_branch) ] ) inputs_false_grad = ( [condition] + [ - pt.basic.zeros_like(f, dtype=grads[i].dtype) + pt.basic.zeros_like(f, dtype=output_grads[i].dtype) for i, f in enumerate(inputs_false_branch) ] - + grads + + output_grads ) # `condition` does affect the elements of the output so it is connected. diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index d4adbead6c..318bb17efa 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -135,14 +135,14 @@ def perform_jax(self, *inputs): return outputs[0] return outputs - def grad(self, inputs, output_gradients): + def grad(self, inputs, output_grads): """Compute gradients using JAX's vector-Jacobian product (VJP).""" import jax # Find indices of outputs that need gradients connected_output_indices = [ i - for i, output_grad in enumerate(output_gradients) + for i, output_grad in enumerate(output_grads) if not isinstance(output_grad.type, DisconnectedType) ] @@ -190,7 +190,7 @@ def restricted_function(*input_values): ) return vjp_op( - *[*inputs, *[output_gradients[i] for i in connected_output_indices]], + *[*inputs, *[output_grads[i] for i in connected_output_indices]], return_list=True, ) diff --git a/pytensor/printing.py b/pytensor/printing.py index eae2392609..e814ad4d66 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -804,8 +804,8 @@ def perform(self, node, inputs, output_storage): xout[0] = xin self.global_fn(self, xin) - def grad(self, input, output_gradients): - return output_gradients + def grad(self, inputs, output_grads): + return output_grads def R_op(self, inputs, eval_points): return list(eval_points) diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index fd0c424d17..97e001b42b 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -88,8 +88,8 @@ def perform(self, node, inputs, output_storage): if not all(conds): raise self.exc_type(self.msg) - def grad(self, input, output_gradients): - return output_gradients + [DisconnectedType()()] * (len(input) - 1) + def grad(self, inputs, output_grads): + return output_grads + [DisconnectedType()()] * (len(inputs) - 1) def connection_pattern(self, node): return [[1]] + [[0]] * (len(node.inputs) - 1) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index d74924b749..4f96d48aff 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1272,7 +1272,7 @@ def perform(self, node, inputs, output_storage): def impl(self, *inputs): raise MethodNotDefined("impl", type(self), self.__class__.__name__) - def grad(self, inputs, output_gradients): + def grad(self, inputs, output_grads): raise MethodNotDefined("grad", type(self), self.__class__.__name__) def L_op(self, inputs, outputs, output_grads): @@ -1683,7 +1683,7 @@ def output_types(self, *input_types): ) return upcast_out(*input_types[0]) - def grad(self, inputs, output_gradients): + def grad(self, inputs, output_grads): return [inputs[0].zeros_like(dtype=config.floatX)] @@ -1701,7 +1701,7 @@ def output_types(self, *input_types): ) return upcast_out(*input_types[0]) - def grad(self, inputs, output_gradients): + def grad(self, inputs, output_grads): a, b = inputs return [ a.zeros_like(dtype=config.floatX), @@ -1942,8 +1942,8 @@ def c_code(self, node, name, inputs, outputs, sub): else: return z + " = " + op.join(inputs) + ";" - def grad(self, inputs, gout): - (gz,) = gout + def grad(self, inputs, output_grads): + (gz,) = output_grads retval = [] # The following 3 lines verify that gz is complex when the @@ -2045,9 +2045,9 @@ def c_code(self, node, name, inputs, outputs, sub): return f"{z} = ((double){x}) / {y};" return f"{z} = {x} / {y};" - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: raise NotImplementedError() @@ -2166,7 +2166,7 @@ def c_code(self, node, name, inputs, outputs, sub): def c_code_cache_version(self): return (6,) - def grad(self, inputs, g_output): + def grad(self, inputs, output_grads): return [inp.zeros_like(dtype=config.floatX) for inp in inputs] @@ -2440,9 +2440,9 @@ def connection_pattern(self, node): return [[False], [True]] - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (_x, y) = inputs - (gz,) = gout + (gz,) = output_grads if y.type in continuous_types: # x is disconnected because the elements of x are not used return DisconnectedType()(), gz @@ -2466,9 +2466,9 @@ def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs return f"{z} = {x};" - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in continuous_types: return (gz,) else: @@ -2505,9 +2505,9 @@ def c_code(self, node, name, inputs, outputs, sub): return f"{z} = ({x}) ? 1 : 0;" return f"{z} = ({node.outputs[0].type.dtype_specs()[1]}){x};" - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if self.o_type in continuous_types: return [gz] else: @@ -2636,9 +2636,9 @@ def impl(self, x): # casting to output type is handled by filter return np.sign(x) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (_gz,) = gout + (_gz,) = output_grads rval = x.zeros_like() if rval.type in discrete_types: @@ -2677,9 +2677,9 @@ class Ceil(UnaryScalarOp): def impl(self, x): return np.ceil(x) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (_gz,) = gout + (_gz,) = output_grads rval = x.zeros_like() if rval.type in discrete_types: @@ -2703,9 +2703,9 @@ class Floor(UnaryScalarOp): def impl(self, x): return np.floor(x) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (_gz,) = gout + (_gz,) = output_grads rval = x.zeros_like() if rval.type in discrete_types: @@ -2729,9 +2729,9 @@ class Trunc(UnaryScalarOp): def impl(self, x): return np.trunc(x) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (_gz,) = gout + (_gz,) = output_grads return [x.zeros_like(dtype=config.floatX)] def c_code(self, node, name, inputs, outputs, sub): @@ -2757,9 +2757,9 @@ class RoundHalfToEven(UnaryScalarOp): def impl(self, x): return np.round(x) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (_gz,) = gout + (_gz,) = output_grads rval = x.zeros_like() if rval.type in discrete_types: @@ -2843,9 +2843,9 @@ class RoundHalfAwayFromZero(UnaryScalarOp): def impl(self, x): return round_half_away_from_zero_vec(x) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (_gz,) = gout + (_gz,) = output_grads rval = x.zeros_like() if rval.type in discrete_types: @@ -3865,9 +3865,9 @@ class Real(UnaryScalarOp): def impl(self, x): return np.real(x) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (_x,) = inputs - (gz,) = gout + (gz,) = output_grads return [complex(gz, 0)] def c_code(self, *args, **kwargs): @@ -3883,9 +3883,9 @@ class Imag(UnaryScalarOp): def impl(self, x): return np.imag(x) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.type in complex_types: return [complex(0, gz)] elif x.type in float_types: @@ -3906,7 +3906,7 @@ class Angle(UnaryScalarOp): def impl(self, x): return np.angle(x) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): # y = x.imag # r = sqrt(y**2 + x.real**2) # g = y/r @@ -3918,7 +3918,7 @@ def grad(self, inputs, gout): # theta = -numpy.arcsin(g)+numpy.pi (c,) = inputs - (gtheta,) = gout + (gtheta,) = output_grads x = real(c) y = imag(c) r = _abs(c) @@ -3957,9 +3957,9 @@ def output_types_preference(x, y): def impl(self, x, y): return builtins.complex(x, y) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads return [cast(real(gz), x.type.dtype), cast(imag(gz), y.type.dtype)] def c_code(self, *args, **kwargs): @@ -4002,9 +4002,9 @@ def impl(self, r, theta): else: return np.complex128(builtins.complex(x, y)) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (r, theta) = inputs - (gz,) = gout + (gz,) = output_grads gr = gz * complex_from_polar(1, theta) gtheta = gz * complex_from_polar(r, -theta) return [gr, gtheta] diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index e223c16e03..317ea59b8d 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -263,9 +263,9 @@ class Owens_t(BinaryScalarOp): def impl(self, h, a): return special.owens_t(h, a) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): (h, a) = inputs - (gz,) = grads + (gz,) = output_grads return [ gz * (-1) @@ -586,9 +586,9 @@ class GammaInc(BinaryScalarOp): def impl(self, k, x): return special.gammainc(k, x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): (k, x) = inputs - (gz,) = grads + (gz,) = output_grads return [ gz * gammainc_grad(k, x), gz * exp(-x + (k - 1) * log(x) - gammaln(k)), @@ -633,9 +633,9 @@ class GammaIncC(BinaryScalarOp): def impl(self, k, x): return special.gammaincc(k, x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): (k, x) = inputs - (gz,) = grads + (gz,) = output_grads return [ gz * gammaincc_grad(k, x), gz * -exp(-x + (k - 1) * log(x) - gammaln(k)), @@ -680,9 +680,9 @@ class GammaIncInv(BinaryScalarOp): def impl(self, k, x): return special.gammaincinv(k, x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): (k, x) = inputs - (gz,) = grads + (gz,) = output_grads return [ grad_not_implemented(self, 0, k), gz * exp(gammaincinv(k, x)) * gamma(k) * (gammaincinv(k, x) ** (1 - k)), @@ -705,9 +705,9 @@ class GammaIncCInv(BinaryScalarOp): def impl(self, k, x): return special.gammainccinv(k, x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): (k, x) = inputs - (gz,) = grads + (gz,) = output_grads return [ grad_not_implemented(self, 0, k), gz * -exp(gammainccinv(k, x)) * gamma(k) * (gammainccinv(k, x) ** (1 - k)), @@ -1004,9 +1004,9 @@ class Jv(BinaryScalarOp): def impl(self, v, x): return special.jv(v, x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): v, x = inputs - (gz,) = grads + (gz,) = output_grads return [ grad_not_implemented(self, 0, v), gz * (jv(v - 1, x) - jv(v + 1, x)) / 2.0, @@ -1029,9 +1029,9 @@ class J1(UnaryScalarOp): def impl(self, x): return special.j1(x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = grads + (gz,) = output_grads return [gz * (j0(x) - jv(2, x)) / 2.0] def c_code(self, node, name, inp, out, sub): @@ -1056,9 +1056,9 @@ class J0(UnaryScalarOp): def impl(self, x): return special.j0(x) - def grad(self, inp, grads): - (x,) = inp - (gz,) = grads + def grad(self, inputs, output_grads): + (x,) = inputs + (gz,) = output_grads return [gz * -1 * j1(x)] def c_code(self, node, name, inp, out, sub): @@ -1083,9 +1083,9 @@ class Iv(BinaryScalarOp): def impl(self, v, x): return special.iv(v, x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): v, x = inputs - (gz,) = grads + (gz,) = output_grads return [ grad_not_implemented(self, 0, v), gz * (iv(v - 1, x) + iv(v + 1, x)) / 2.0, @@ -1108,9 +1108,9 @@ class I1(UnaryScalarOp): def impl(self, x): return special.i1(x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = grads + (gz,) = output_grads return [gz * (i0(x) + iv(2, x)) / 2.0] def c_code(self, *args, **kwargs): @@ -1130,9 +1130,9 @@ class I0(UnaryScalarOp): def impl(self, x): return special.i0(x) - def grad(self, inp, grads): - (x,) = inp - (gz,) = grads + def grad(self, inputs, output_grads): + (x,) = inputs + (gz,) = output_grads return [gz * i1(x)] def c_code(self, *args, **kwargs): @@ -1152,9 +1152,9 @@ class Ive(BinaryScalarOp): def impl(self, v, x): return special.ive(v, x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): v, x = inputs - (gz,) = grads + (gz,) = output_grads return [ grad_not_implemented(self, 0, v), gz @@ -1207,9 +1207,9 @@ class Sigmoid(UnaryScalarOp): def impl(self, x): return special.expit(x) - def grad(self, inp, grads): - (x,) = inp - (gz,) = grads + def grad(self, inputs, output_grads): + (x,) = inputs + (gz,) = output_grads y = sigmoid(x) rval = gz * y * (1.0 - y) @@ -1275,9 +1275,9 @@ def impl(self, x): else: return x - def grad(self, inp, grads): - (x,) = inp - (gz,) = grads + def grad(self, inputs, output_grads): + (x,) = inputs + (gz,) = output_grads return [gz * sigmoid(x)] def c_code(self, node, name, inp, out, sub): @@ -1343,9 +1343,9 @@ def impl(self, x): else: return np.log(-np.expm1(x)) - def grad(self, inp, grads): - (x,) = inp - (gz,) = grads + def grad(self, inputs, output_grads): + (x,) = inputs + (gz,) = output_grads res = true_div(-1.0, expm1(-x)) # Correct gradient at 0.0 to be -inf res = switch(isinf(res), -np.inf, res) @@ -1378,9 +1378,9 @@ class BetaInc(ScalarOp): def impl(self, a, b, x): return special.betainc(a, b, x) - def grad(self, inp, grads): - a, b, x = inp - (gz,) = grads + def grad(self, inputs, output_grads): + a, b, x = inputs + (gz,) = output_grads return [ gz * betainc_grad(a, b, x, True), @@ -1636,9 +1636,9 @@ class BetaIncInv(ScalarOp): def impl(self, a, b, x): return special.betaincinv(a, b, x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): (a, b, x) = inputs - (gz,) = grads + (gz,) = output_grads return [ grad_not_implemented(self, 0, a), grad_not_implemented(self, 0, b), @@ -1675,9 +1675,9 @@ class Hyp2F1(ScalarOp): def impl(self, a, b, c, z): return special.hyp2f1(a, b, c, z) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): a, b, c, z = inputs - (gz,) = grads + (gz,) = output_grads grad_a, grad_b, grad_c = hyp2f1_grad(a, b, c, z, wrt=[0, 1, 2]) return [ gz * grad_a, diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index e47f81206b..02c2df729a 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -310,7 +310,7 @@ def perform(self, node, inputs, output_storage): output_storage[2][0] = np.asarray(csm.indptr, dtype="int32") output_storage[3][0] = np.asarray(csm.shape, dtype="int32") - def grad(self, inputs, g): + def grad(self, inputs, output_grads): # g[1:] is all integers, so their Jacobian in this op # is 0. We thus don't need to worry about what their values # are. @@ -320,11 +320,11 @@ def grad(self, inputs, g): # g[1:] is connected, or this grad method wouldn't have been # called, so we should report zeros (csm,) = inputs - if isinstance(g[0].type, DisconnectedType): + if isinstance(output_grads[0].type, DisconnectedType): return [csm.zeros_like()] _data, indices, indptr, _shape = csm_properties(csm) - return [CSM(csm.format)(g[0], indices, indptr, _shape)] + return [CSM(csm.format)(output_grads[0], indices, indptr, _shape)] # don't make this a function or it breaks some optimizations below @@ -470,9 +470,9 @@ def perform(self, node, inputs, output_storage): def connection_pattern(self, node): return [[True], [False], [False], [False]] - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x_data, x_indices, x_indptr, x_shape) = inputs - (g_out,) = gout + (g_out,) = output_grads g_data, g_indices, g_indptr, g_shape = csm_properties(g_out) # unpack the data vector and wrap it as a 1d TensorType g_data = csm_grad()( @@ -601,8 +601,8 @@ def perform(self, node, inputs, output_storage): assert _is_sparse(x) out[0] = x.astype(self.out_type) - def grad(self, inputs, outputs_gradients): - gz = outputs_gradients[0] + def grad(self, inputs, output_grads): + gz = output_grads[0] if gz.dtype in complex_dtypes: raise NotImplementedError("grad not implemented for complex types") @@ -713,9 +713,9 @@ def perform(self, node, inputs, output_storage): out[0] = x.toarray() assert _is_dense(out[0]) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if self.sparse_grad: left = sp_ones_like(x) right = gz @@ -788,9 +788,9 @@ def perform(self, node, inputs, output_storage): (out,) = output_storage out[0] = SparseTensorType.format_cls[self.format](x) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads gx = dense_from_sparse(gz) gx = specify_broadcastable( gx, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b) @@ -841,9 +841,9 @@ def perform(self, node, inputs, output_storage): assert _is_sparse(x) out[0] = x[indices] - def grad(self, inputs, g_outputs): + def grad(self, inputs, output_grads): x, indices = inputs - (gout,) = g_outputs + (gout,) = output_grads return [ get_item_list_grad(x, indices, gout), grad_undefined(self, 1, indices, "No gradient for this input"), @@ -931,9 +931,9 @@ def perform(self, node, inputs, output_storage): # which isn't what we want, so we convert it into an `ndarray` out[0] = np.asarray(x[ind1, ind2]).flatten() - def grad(self, inputs, g_outputs): + def grad(self, inputs, output_grads): x, ind1, ind2 = inputs - (gout,) = g_outputs + (gout,) = output_grads return [ get_item_2lists_grad(x, ind1, ind2, gout), grad_undefined(self, 1, ind1, "No gradient for this input"), @@ -1222,9 +1222,9 @@ def perform(self, node, inputs, output_storage): assert _is_sparse(x) out[0] = x.transpose() - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads assert _is_sparse_variable(x) and _is_sparse_variable(gz) return (transpose(gz),) @@ -1268,9 +1268,9 @@ def perform(self, node, inputs, output_storage): assert _is_sparse(x) out[0] = -x - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads assert _is_sparse_variable(x) and _is_sparse_variable(gz) return (-gz,) @@ -1316,11 +1316,11 @@ def perform(self, node, inputs, output_storage): z[0] = y - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): from pytensor.sparse.math import sp_sum (x, s) = inputs - (gz,) = gout + (gz,) = output_grads return [col_scale(gz, s), sp_sum(x * gz, axis=0)] def infer_shape(self, fgraph, node, ins_shapes): @@ -1367,11 +1367,11 @@ def perform(self, node, inputs, output_storage): z[0] = scipy.sparse.csc_matrix((y_data, indices, indptr), (M, N)) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): from pytensor.sparse.math import sp_sum (x, s) = inputs - (gz,) = gout + (gz,) = output_grads return [row_scale(gz, s), sp_sum(x * gz, axis=1)] def infer_shape(self, fgraph, node, ins_shapes): @@ -1468,9 +1468,9 @@ def perform(self, node, inputs, output_storage): raise ValueError("Diag only apply on square matrix") z[0] = x.diagonal() - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (_x,) = inputs - (gz,) = gout + (gz,) = output_grads return [square_diagonal(gz)] def infer_shape(self, fgraph, nodes, shapes): @@ -1518,8 +1518,8 @@ def perform(self, node, inputs, output_storage): z[0] = scipy.sparse.csc_matrix(tup, copy=True) - def grad(self, inputs, gout): - (gz,) = gout + def grad(self, inputs, output_grads): + (gz,) = output_grads return [diag(gz)] def infer_shape(self, fgraph, nodes, shapes): @@ -1570,8 +1570,8 @@ def perform(self, node, inputs, output_storage): else: z[0] = x.sorted_indices() - def grad(self, inputs, output_grad): - return [output_grad[0]] + def grad(self, inputs, output_grads): + return [output_grads[0]] def infer_shape(self, fgraph, node, i0_shapes): return i0_shapes @@ -1654,8 +1654,8 @@ def perform(self, node, inputs, output_storage): if out[0].dtype != self.dtype: out[0] = out[0].astype(self.dtype) - def grad(self, inputs, gout): - (gz,) = gout + def grad(self, inputs, output_grads): + (gz,) = output_grads is_continuous = [ (inputs[i].dtype in tensor_continuous_dtypes) for i in range(len(inputs)) ] @@ -1727,8 +1727,8 @@ def perform(self, node, inputs, output_storage): if out[0].dtype != self.dtype: out[0] = out[0].astype(self.dtype) - def grad(self, inputs, gout): - (gz,) = gout + def grad(self, inputs, output_grads): + (gz,) = output_grads is_continuous = [ (inputs[i].dtype in tensor_continuous_dtypes) for i in range(len(inputs)) ] @@ -1834,9 +1834,9 @@ def perform(self, node, inputs, output_storage): c.eliminate_zeros() z[0] = c - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (_x,) = inputs - (gz,) = gout + (gz,) = output_grads return [gz] def infer_shape(self, fgraph, node, i0_shapes): @@ -1932,8 +1932,8 @@ def connection_pattern(self, node): rval = [[True], [True], [False]] return rval - def grad(self, inputs, grads): - (g_output,) = grads + def grad(self, inputs, output_grads): + (g_output,) = output_grads _x, _y = inputs[:2] idx_list = inputs[2:] diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index e84e048f18..1c0b4dcf36 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -334,9 +334,9 @@ def perform(self, node, inputs, output_storage): else: z[0] = np.asarray(x.sum(self.axis)).ravel() - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads if x.dtype not in psb.continuous_dtypes: return [x.zeros_like(dtype=config.floatX)] if self.structured: @@ -438,9 +438,9 @@ def perform(self, node, inputs, output_storage): assert x.shape == y.shape out[0] = x + y - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads assert psb._is_sparse_variable(x) and psb._is_sparse_variable(y) assert psb._is_sparse_variable(gz) return gz, gz @@ -500,8 +500,8 @@ def perform(self, node, inputs, output_storage): out[0] = x.copy() out[0].data += y.data - def grad(self, inputs, gout): - (gz,) = gout + def grad(self, inputs, output_grads): + (gz,) = output_grads is_continuous = [(i.dtype in psb.continuous_dtypes) for i in inputs] derivative = {True: gz, False: None} return [derivative[b] for b in is_continuous] @@ -541,9 +541,9 @@ def perform(self, node, inputs, output_storage): # numpy.matrixlib.defmatrix.matrix object and not an ndarray. out[0] = np.asarray(x + y, dtype=node.outputs[0].type.dtype) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads assert psb._is_sparse_variable(x) and psb._is_dense_variable(y) assert psb._is_dense_variable(gz) return psb.sp_ones_like(x) * gz, gz @@ -601,9 +601,9 @@ def perform(self, node, inputs, output_storage): assert x.shape[1] == y.shape[0] out[0] = x.__class__(x + (x.toarray() != 0) * y) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads assert psb._is_sparse_variable(x) and not psb._is_sparse_variable(y) assert psb._is_sparse_variable(gz) return gz, sp_sum(gz, axis=0, sparse_grad=True) @@ -733,9 +733,9 @@ def perform(self, node, inputs, output_storage): # x * y calls dot... out[0] = x.multiply(y) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads return y * gz, x * gz def infer_shape(self, fgraph, node, shapes): @@ -822,9 +822,9 @@ def perform(self, node, inputs, output_storage): ) out[0] = type(x)(x.toarray() * y) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads assert psb._is_sparse_variable(x) and psb._is_dense_variable(y) assert psb._is_sparse_variable(gz) return y * gz, psb.dense_from_sparse(x * gz) @@ -881,9 +881,9 @@ def perform(self, node, inputs, output_storage): assert x.shape[1] == y.shape[0] out[0] = x.__class__(x.toarray() * y) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads assert psb._is_sparse_variable(x) and psb._is_dense_variable(y) assert psb._is_sparse_variable(gz) @@ -1292,9 +1292,9 @@ def perform(self, node, inputs, output_storage): ) out[0] = rval - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads assert psb._is_sparse_variable(gz) assert psb._is_sparse_variable(x) @@ -1422,12 +1422,12 @@ def perform(self, node, inputs, output_storage): # _asarray function documentation. out[0] = np.asarray(variable, str(variable.dtype)) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): # a is sparse, b is dense, g_out is dense # ga = g_out x b.T # gb = a.T x g_out (a, b) = inputs - (g_out,) = gout + (g_out,) = output_grads return [structured_dot_grad(a, b, g_out), structured_dot(a.T, g_out)] def infer_shape(self, fgraph, node, shapes): @@ -1839,9 +1839,9 @@ def perform(self, node, inputs, output_storage): out[0] = p.__class__(p.multiply(np.dot(x, y.T))) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y, p) = inputs - (gz,) = gout + (gz,) = output_grads rval = [dot(p * gz, y), dot((p * gz).T, x), grad_not_implemented(self, 2, p)] return rval @@ -1940,9 +1940,9 @@ def perform(self, node, inputs, output_storage): output_storage[0] = np.asarray(rval, dtype=node.outputs[0].dtype) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, y) = inputs - (gz,) = gout + (gz,) = output_grads assert psb._is_sparse_variable(x) or psb._is_sparse_variable(y) rval = [] diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 5e71c6813e..1d477d44e5 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -629,9 +629,9 @@ def perform(self, node, inputs, output_storage): def infer_shape(self, fgraph, node, in_shapes): return [()] - def grad(self, inp, grads): - (s,) = inp - (dt,) = grads + def grad(self, inputs, output_grads): + (s,) = inputs + (dt,) = output_grads if s.type.dtype in float_dtypes: assert dt.type.dtype in float_dtypes return [scalar_from_tensor(dt)] @@ -690,9 +690,9 @@ def perform(self, node, inputs, output_storage): def infer_shape(self, fgraph, node, in_shapes): return [()] - def grad(self, inp, grads): - (_s,) = inp - (dt,) = grads + def grad(self, inputs, output_grads): + (_s,) = inputs + (dt,) = output_grads return [tensor_from_scalar(dt)] def R_op(self, inputs, eval_points): @@ -983,8 +983,8 @@ def perform(self, node, inputs, output_storage): for i, res in enumerate(result_tuple): output_storage[i][0] = res.astype("int64") - def grad(self, inp, grads): - return [grad_undefined(self, 0, inp[0])] + def grad(self, inputs, output_grads): + return [grad_undefined(self, 0, inputs[0])] _nonzero = Nonzero() @@ -1117,8 +1117,8 @@ def infer_shape(self, fgraph, node, in_shapes): out_shape = [node.inputs[0], node.inputs[1]] return [out_shape] - def grad(self, inp, grads): - return [grad_undefined(self, i, inp[i]) for i in range(3)] + def grad(self, inputs, output_grads): + return [grad_undefined(self, i, inputs[i]) for i in range(3)] def tri(N, M=None, k=0, dtype=None): @@ -1403,8 +1403,8 @@ def infer_shape(self, fgraph, node, in_shapes): out_shape = [node.inputs[0], node.inputs[1]] return [out_shape] - def grad(self, inp, grads): - return [grad_undefined(self, i, inp[i]) for i in range(3)] + def grad(self, inputs, output_grads): + return [grad_undefined(self, i, inputs[i]) for i in range(3)] @staticmethod def is_offset_zero(node) -> bool: @@ -1736,9 +1736,9 @@ def connection_pattern(self, node): return rval - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): x = inputs[0] - gz = grads[0] + gz = output_grads[0] n_axes_to_sum = gz.ndim - x.ndim # The number of dimensions added axis = list(range(n_axes_to_sum)) @@ -1749,7 +1749,7 @@ def grad(self, inputs, grads): zip( inputs[0].type.shape, # We need the dimensions corresponding to x - grads[0].type.shape[-inputs[0].ndim :], + output_grads[0].type.shape[-inputs[0].ndim :], strict=False, ) ): @@ -1955,12 +1955,12 @@ def c_code(self, node, name, inp, out_, props): def infer_shape(self, fgraph, node, ishapes): return [(len(ishapes),)] - def grad(self, inputs, output_gradients): + def grad(self, inputs, output_grads): # If the output is of an integer dtype, no gradient shall pass if self.dtype in discrete_dtypes: return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs] - grads = [output_gradients[0][i] for i in range(len(inputs))] + grads = [output_grads[0][i] for i in range(len(inputs))] return grads def R_op(self, inputs, eval_points): @@ -3668,11 +3668,11 @@ def infer_shape(self, fgraph, node, in_shapes): out_shape = [maximum(sx, sy) for sx, sy in zip(shp_x, shp_y, strict=True)] return [out_shape] - def grad(self, inp, grads): + def grad(self, inputs, output_grads): from pytensor.tensor.math import Sum - x, y = inp - (gz,) = grads + x, y = inputs + (gz,) = output_grads # First, compute the gradient wrt the broadcasted x. # If 'inverse' is False (0), apply the inverse of y on gz. # Else, apply y on gz. @@ -3886,12 +3886,12 @@ def c_code(self, node, nodename, input_names, output_names, sub): def c_code_cache_version(self): return (0,) - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): # Avoid circular import from pytensor.tensor.subtensor import set_subtensor (x,) = inputs - (gz,) = gout + (gz,) = output_grads axis1, axis2, offset = self.axis1, self.axis2, self.offset @@ -4403,7 +4403,7 @@ def do_constant_folding(self, fgraph, node): def connection_pattern(self, node): return [[False] for i in node.inputs] - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): return [DisconnectedType()() for i in inputs] def R_op(self, inputs, eval_points): diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 569942354b..8184ee8730 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -1597,9 +1597,9 @@ def c_code_cache_version(self): return (6, blas_header_version()) - def grad(self, inp, grads): - x, y = inp - (gz,) = grads + def grad(self, inputs, output_grads): + x, y = inputs + (gz,) = output_grads xgrad = _batched_dot(gz, y.dimshuffle(0, 2, 1)) ygrad = _batched_dot(x.dimshuffle(0, 2, 1), gz) diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index 41324ce103..aacd577cc0 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -2680,9 +2680,9 @@ def __init__( unshared=unshared, ) - def grad(self, inp, grads): - bottom, weights = inp - (top,) = grads + def grad(self, inputs, output_grads): + bottom, weights = inputs + (top,) = output_grads # Don't add the assert again, as it was already added in the forward. d_bottom = AbstractConv2d_gradInputs( self.imshp, @@ -2740,9 +2740,9 @@ def __init__( num_groups=num_groups, ) - def grad(self, inp, grads): - bottom, weights = inp - (top,) = grads + def grad(self, inputs, output_grads): + bottom, weights = inputs + (top,) = output_grads d_bottom = AbstractConv3d_gradInputs( self.imshp, self.kshp, @@ -3037,9 +3037,9 @@ def __init__( unshared=unshared, ) - def grad(self, inp, grads): - bottom, top = inp[:2] - (weights,) = grads + def grad(self, inputs, output_grads): + bottom, top = inputs[:2] + (weights,) = output_grads d_bottom = AbstractConv2d_gradInputs( self.imshp, self.kshp, @@ -3098,9 +3098,9 @@ def __init__( num_groups=num_groups, ) - def grad(self, inp, grads): - bottom, top = inp[:2] - (weights,) = grads + def grad(self, inputs, output_grads): + bottom, top = inputs[:2] + (weights,) = output_grads d_bottom = AbstractConv3d_gradInputs( self.imshp, self.kshp, @@ -3419,9 +3419,9 @@ def __init__( unshared=unshared, ) - def grad(self, inp, grads): - weights, top = inp[:2] - (bottom,) = grads + def grad(self, inputs, output_grads): + weights, top = inputs[:2] + (bottom,) = output_grads d_weights = AbstractConv2d_gradWeights( self.imshp, self.kshp, @@ -3480,9 +3480,9 @@ def __init__( num_groups=num_groups, ) - def grad(self, inp, grads): - weights, top = inp[:2] - (bottom,) = grads + def grad(self, inputs, output_grads): + weights, top = inputs[:2] + (bottom,) = output_grads d_weights = AbstractConv3d_gradWeights( self.imshp, self.kshp, diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index b495b3aa1c..678561bbdd 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -260,9 +260,9 @@ def R_op(self, inputs, eval_points): return [None] return self(*eval_points, return_list=True) - def grad(self, inp, grads): - (x,) = inp - (gz,) = grads + def grad(self, inputs, output_grads): + (x,) = inputs + (gz,) = output_grads grad_order = ["x"] * x.type.ndim for i, v in enumerate(self.new_order): if v != "x": diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 97150e3774..2cfd9b21f7 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -76,8 +76,8 @@ def perform(self, node, inputs, output_storage): assert x.flags["C_CONTIGUOUS"] y[0] = x - def grad(self, inputs, dout): - return [ptb.as_tensor_variable(dout[0])] + def grad(self, inputs, output_grads): + return [ptb.as_tensor_variable(output_grads[0])] def c_code(self, node, name, inames, onames, sub): (x,) = inames @@ -210,7 +210,7 @@ def c_code(self, node, name, inames, onames, sub): def c_code_cache_version(self): return (2,) - def grad(self, inputs, output_gradients): + def grad(self, inputs, output_grads): num_ins = len(inputs) if num_ins == 3: x, v, _sorter = inputs @@ -701,9 +701,9 @@ def perform(self, node, inputs, output_storage): def connection_pattern(self, node): return [[True], [False]] - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x, repeats) = inputs - (gz,) = gout + (gz,) = output_grads axis = self.axis # Use IncSubtensor to sum the gradients that belong to the repeated entries of x @@ -932,15 +932,15 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = a - def grad(self, inp, cost_grad): + def grad(self, inputs, output_grads): """ Notes ----- The gradient is currently implemented for matrices only. """ - a, _val = inp - grad = cost_grad[0] + a, _val = inputs + grad = output_grads[0] if a.dtype.startswith("complex"): return [None, None] elif a.ndim > 2: @@ -1059,14 +1059,14 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = a - def grad(self, inp, cost_grad): + def grad(self, inputs, output_grads): """ Notes ----- The gradient is currently implemented for matrices only. """ - a, _val, offset = inp - grad = cost_grad[0] + a, _val, offset = inputs + grad = output_grads[0] height, width = grad.shape if a.dtype.startswith("complex"): diff --git a/pytensor/tensor/fourier.py b/pytensor/tensor/fourier.py index 033d46222c..a3d0caf71c 100644 --- a/pytensor/tensor/fourier.py +++ b/pytensor/tensor/fourier.py @@ -127,7 +127,7 @@ def perform(self, node, inputs, output_storage): axis = inputs[2] output_storage[0][0] = np.fft.fft(a, n=int(n), axis=axis.item()) - def grad(self, inputs, cost_grad): + def grad(self, inputs, output_grads): """ In defining the gradient, the Finite Fourier Transform is viewed as a complex-differentiable function of a complex variable @@ -135,7 +135,7 @@ def grad(self, inputs, cost_grad): a = inputs[0] n = inputs[1] axis = inputs[2] - grad = cost_grad[0] + grad = output_grads[0] if not isinstance(axis, TensorConstant): raise NotImplementedError( f"{self.__class__.__name__}: gradient is currently implemented" diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index ba45f834d1..cdf09706fd 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -266,8 +266,8 @@ def infer_shape(self, fgraph, node, shapes): def R_op(self, inputs, eval_points): raise ValueError("Argmax is non-diifferentiable") - def grad(self, inp, grads): - (x,) = inp + def grad(self, inputs, output_grads): + (x,) = inputs return [x.zeros_like()] @@ -3025,9 +3025,9 @@ def make_node(self, x, y): def perform(self, node, inputs, output_storage): output_storage[0][0] = np.matmul(*inputs) - def grad(self, inp, grads): - x, y = inp - (gz,) = grads + def grad(self, inputs, output_grads): + x, y = inputs + (gz,) = output_grads xgrad = self(gz, y.T) ygrad = self(x.T, gz) @@ -3394,8 +3394,8 @@ def make_node(self, input): ret = super().make_node(input) return ret - def grad(self, inp, grads): - (x,) = inp + def grad(self, inputs, output_grads): + (x,) = inputs return [x.zeros_like(config.floatX)] def clone(self, **kwargs): @@ -3424,8 +3424,8 @@ def make_node(self, input): ret = super().make_node(input) return ret - def grad(self, inp, grads): - (x,) = inp + def grad(self, inputs, output_grads): + (x,) = inputs return [x.zeros_like(config.floatX)] def clone(self, **kwargs): @@ -3761,10 +3761,10 @@ def __init__(self, axis=None, dtype=None, acc_dtype=None): upcast_discrete_output=True, ) - def grad(self, inp, grads): + def grad(self, inputs, output_grads): from pytensor.gradient import grad_not_implemented - (a,) = inp + (a,) = inputs a_grad = grad_not_implemented( self, 0, diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 640cacd70b..e180984dda 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -133,7 +133,7 @@ def perform(self, node, inputs, output_storage): (z,) = output_storage z[0] = np.linalg.inv(x) - def grad(self, inputs, g_outputs): + def grad(self, inputs, output_grads): r"""The gradient function should return .. math:: V\frac{\partial X^{-1}}{\partial X}, @@ -148,7 +148,7 @@ def grad(self, inputs, g_outputs): """ (x,) = inputs xi = self(x) - (gz,) = g_outputs + (gz,) = output_grads # ptm.dot(gz.T,xi) return [-matrix_dot(xi, gz.T, xi).T] @@ -240,8 +240,8 @@ def perform(self, node, inputs, output_storage): except Exception as e: raise ValueError("Failed to compute determinant", x) from e - def grad(self, inputs, g_outputs): - (gz,) = g_outputs + def grad(self, inputs, output_grads): + (gz,) = output_grads (x,) = inputs return [gz * self(x) * matrix_inverse(x).T] diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 6afa5a29b2..18387b6a3c 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -431,7 +431,7 @@ def perform(self, node: Apply, inputs, output_storage): dtype=self.dtype, ) - def grad(self, inputs, outputs): + def grad(self, inputs, output_grads): return [ pytensor.gradient.grad_undefined( self, k, inp, "No gradient defined for random variables" diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 8316ae2f28..71de47496e 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -97,7 +97,7 @@ def connection_pattern(self, node): # part of the graph return [[False]] - def grad(self, inp, grads): + def grad(self, inputs, output_grads): # the grad returns the gradient with respect to the # elements of a tensor variable # the elements of the tensor variable do not participate @@ -313,12 +313,12 @@ def connection_pattern(self, node): # part of the graph return [[False]] - def grad(self, inp, grads): + def grad(self, inputs, output_grads): return [ pytensor.gradient.grad_not_implemented( op=self, x_pos=0, - x=inp[0], + x=inputs[0], comment="No gradient for the shape of a matrix is implemented.", ) ] @@ -471,9 +471,9 @@ def infer_shape(self, fgraph, node, shapes): def connection_pattern(self, node): return [[True], *[[False]] * len(node.inputs[1:])] - def grad(self, inp, grads): - _x, *shape = inp - (gz,) = grads + def grad(self, inputs, output_grads): + _x, *shape = inputs + (gz,) = output_grads return [specify_shape(gz, shape)] + [ pytensor.gradient.DisconnectedType()() for _ in range(len(shape)) ] @@ -722,9 +722,9 @@ def perform(self, node, inputs, output_storage): def connection_pattern(self, node): return [[True], [False]] - def grad(self, inp, grads): - x, _shp = inp - (g_out,) = grads + def grad(self, inputs, output_grads): + x, _shp = inputs + (g_out,) = output_grads return [reshape(g_out, shape(x), ndim=x.ndim), DisconnectedType()()] def R_op(self, inputs, eval_points): diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index f91caf14b3..97dd648cf0 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1239,9 +1239,9 @@ def perform(self, node, inputs, output_storage): else: w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower) - def grad(self, inputs, g_outputs): + def grad(self, inputs, output_grads): a, b = inputs - (gw,) = g_outputs + (gw,) = output_grads return EigvalshGrad(self.lower)(a, b, gw) def infer_shape(self, fgraph, node, shapes): @@ -1673,7 +1673,7 @@ def __init__(self, n_inputs): raise ValueError("n_inputs must be greater than 0") self.n_inputs = n_inputs - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): shapes = pt.stack([i.shape for i in inputs]) index_end = shapes.cumsum(0) index_begin = index_end - shapes @@ -1684,7 +1684,7 @@ def grad(self, inputs, gout): ) for i in range(len(inputs)) ] - return [gout[0][slc] for slc in slices] + return [output_grads[0][slc] for slc in slices] def infer_shape(self, fgraph, nodes, shapes): first, second = zip(*shapes, strict=True) diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index 4b874e726e..bbb64747d3 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -44,9 +44,9 @@ def perform(self, node, inputs, output_storage): dx = dy_times_sm - np.sum(dy_times_sm, axis=self.axis, keepdims=True) * sm output_storage[0][0] = dx - def grad(self, inp, grads): - dy, sm = inp - (g,) = grads + def grad(self, inputs, output_grads): + dy, sm = inputs + (g,) = output_grads tmp = g + neg(sum(g * sm, axis=self.axis, keepdims=True)) g_dy = tmp * sm @@ -528,10 +528,12 @@ def perform(self, node, inputs, output_storage): (z,) = output_storage z[0] = scipy.special.log_softmax(x, axis=self.axis) - def grad(self, inp, grads): - (x,) = inp + def grad(self, inputs, output_grads): + (x,) = inputs sm = Softmax(axis=self.axis)(x) - return [grads[0] - sum(grads[0], axis=self.axis, keepdims=True) * sm] + return [ + output_grads[0] - sum(output_grads[0], axis=self.axis, keepdims=True) * sm + ] def R_op(self, inputs, eval_points): # I think the Jacobian is symmetric so the R_op diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index ee7ec8bfb4..5f13b96f84 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -975,8 +975,8 @@ def _is_constant(const, x): assert len(outshp) == node.outputs[0].ndim return [outshp] - def grad(self, inputs, grads): - (gz,) = grads + def grad(self, inputs, output_grads): + (gz,) = output_grads x = inputs[0] rest = inputs[1:] if x.dtype in discrete_dtypes: @@ -1998,8 +1998,8 @@ def connection_pattern(self, node): return rval - def grad(self, inputs, grads): - (g_output,) = grads + def grad(self, inputs, output_grads): + (g_output,) = output_grads x, y = inputs[:2] idx_list = inputs[2:] @@ -2124,9 +2124,9 @@ def connection_pattern(self, node): return rval - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): x, ilist = inputs - (gz,) = grads + (gz,) = output_grads assert len(inputs) == 2 if self.sparse_grad: if x.type.ndim != 2: @@ -2499,8 +2499,8 @@ def connection_pattern(self, node): rval = [[True], [True], [False]] return rval - def grad(self, inputs, grads): - (g_output,) = grads + def grad(self, inputs, output_grads): + (g_output,) = output_grads x, y, idx_list = inputs if x.dtype in discrete_dtypes: # The output dtype is the same as x @@ -2760,8 +2760,8 @@ def connection_pattern(self, node): return rval - def grad(self, inputs, grads): - (gz,) = grads + def grad(self, inputs, output_grads): + (gz,) = output_grads x = inputs[0] if x.dtype in discrete_dtypes: # The output dtype is the same as x @@ -2911,10 +2911,10 @@ def R_op(self, inputs, eval_points): return [None] return self.make_node(eval_points[0], eval_points[1], *inputs[2:]).outputs - def grad(self, inpt, output_gradients): - x, y = inpt[:2] - idxs = inpt[2:] - (outgrad,) = output_gradients + def grad(self, inputs, output_grads): + x, y = inputs[:2] + idxs = inputs[2:] + (outgrad,) = output_grads if x.dtype in discrete_dtypes: # The output dtype is the same as x gx = x.zeros_like(dtype=config.floatX) diff --git a/pytensor/tensor/type_other.py b/pytensor/tensor/type_other.py index 7e59687ed6..b43a0f6c9e 100644 --- a/pytensor/tensor/type_other.py +++ b/pytensor/tensor/type_other.py @@ -43,7 +43,7 @@ def perform(self, node, inputs, output_storage): (out,) = output_storage out[0] = slice(*inputs) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): return [DisconnectedType()() for i in inputs] diff --git a/pytensor/tensor/xlogx.py b/pytensor/tensor/xlogx.py index 896996de58..429d38e26f 100644 --- a/pytensor/tensor/xlogx.py +++ b/pytensor/tensor/xlogx.py @@ -15,9 +15,9 @@ def impl(self, x): return 0.0 return x * np.log(x) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = grads + (gz,) = output_grads return [gz * (1 + ps.log(x))] def c_code(self, node, name, inputs, outputs, sub): @@ -46,9 +46,9 @@ def impl(self, x, y): return 0.0 return x * np.log(y) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): x, y = inputs - (gz,) = grads + (gz,) = output_grads return [gz * ps.log(y), gz * x / y] def c_code(self, node, name, inputs, outputs, sub): diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 01d7a05bd1..1e349be3dd 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -303,9 +303,9 @@ def perform(self, node, inputs, output_storage): assert _is_sparse(x) out[0] = -x - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (x,) = inputs - (gz,) = gout + (gz,) = output_grads assert _is_sparse_variable(x) and _is_sparse_variable(gz) if self.structured: return (sp_ones_like(x) * dense_from_sparse(gz),) diff --git a/tests/tensor/conv/c_conv3d_corr3d_ref.py b/tests/tensor/conv/c_conv3d_corr3d_ref.py index 306442e6e9..b0c26c60cc 100644 --- a/tests/tensor/conv/c_conv3d_corr3d_ref.py +++ b/tests/tensor/conv/c_conv3d_corr3d_ref.py @@ -636,9 +636,9 @@ def c_code(self, node, nodename, inp, out_, sub): (top,) = out_ return super().c_code_helper(bottom, weights, top, sub) - def grad(self, inp, grads): - bottom, weights = inp - (top,) = grads + def grad(self, inputs, output_grads): + bottom, weights = inputs + (top,) = output_grads d_bottom = Corr3dMMGradInputs( self.border_mode, self.subsample, @@ -750,9 +750,9 @@ def c_code(self, node, nodename, inp, out_, sub): (weights,) = out_ return super().c_code_helper(bottom, weights, top, sub, height, width, depth) - def grad(self, inp, grads): - bottom, top = inp[:2] - (weights,) = grads + def grad(self, inputs, output_grads): + bottom, top = inputs[:2] + (weights,) = output_grads d_bottom = Corr3dMMGradInputs( self.border_mode, self.subsample, @@ -766,7 +766,7 @@ def grad(self, inp, grads): num_groups=self.num_groups, )(bottom, weights) d_height_width_depth = ( - (pytensor.gradient.DisconnectedType()(),) * 3 if len(inp) == 5 else () + (pytensor.gradient.DisconnectedType()(),) * 3 if len(inputs) == 5 else () ) return (d_bottom, d_top, *d_height_width_depth) @@ -890,9 +890,9 @@ def c_code(self, node, nodename, inp, out_, sub): (bottom,) = out_ return super().c_code_helper(bottom, weights, top, sub, height, width, depth) - def grad(self, inp, grads): - weights, top = inp[:2] - (bottom,) = grads + def grad(self, inputs, output_grads): + weights, top = inputs[:2] + (bottom,) = output_grads d_weights = Corr3dMMGradWeights( self.border_mode, self.subsample, @@ -906,7 +906,7 @@ def grad(self, inp, grads): num_groups=self.num_groups, )(bottom, weights) d_height_width_depth = ( - (pytensor.gradient.DisconnectedType()(),) * 3 if len(inp) == 5 else () + (pytensor.gradient.DisconnectedType()(),) * 3 if len(inputs) == 5 else () ) return (d_weights, d_top, *d_height_width_depth) diff --git a/tests/tensor/conv/c_conv_corr_ref.py b/tests/tensor/conv/c_conv_corr_ref.py index 6ab212e1e4..160d503569 100644 --- a/tests/tensor/conv/c_conv_corr_ref.py +++ b/tests/tensor/conv/c_conv_corr_ref.py @@ -695,9 +695,9 @@ def c_code(self, node, nodename, inp, out_, sub): (top,) = out_ return super().c_code_helper(bottom, weights, top, sub) - def grad(self, inp, grads): - bottom, weights = inp - (top,) = grads + def grad(self, inputs, output_grads): + bottom, weights = inputs + (top,) = output_grads d_bottom = CorrMM_gradInputs( self.border_mode, self.subsample, @@ -821,9 +821,9 @@ def c_code(self, node, nodename, inp, out_, sub): (weights,) = out_ return super().c_code_helper(bottom, weights, top, sub, height, width) - def grad(self, inp, grads): - bottom, top = inp[:2] - (weights,) = grads + def grad(self, inputs, output_grads): + bottom, top = inputs[:2] + (weights,) = output_grads d_bottom = CorrMM_gradInputs( self.border_mode, self.subsample, @@ -839,7 +839,7 @@ def grad(self, inp, grads): self.unshared, )(bottom, weights) d_height_width = ( - (pytensor.gradient.DisconnectedType()(),) * 2 if len(inp) == 4 else () + (pytensor.gradient.DisconnectedType()(),) * 2 if len(inputs) == 4 else () ) return (d_bottom, d_top, *d_height_width) @@ -953,9 +953,9 @@ def c_code(self, node, nodename, inp, out_, sub): (bottom,) = out_ return super().c_code_helper(bottom, weights, top, sub, height, width) - def grad(self, inp, grads): - weights, top = inp[:2] - (bottom,) = grads + def grad(self, inputs, output_grads): + weights, top = inputs[:2] + (bottom,) = output_grads d_weights = CorrMM_gradWeights( self.border_mode, self.subsample, @@ -971,7 +971,7 @@ def grad(self, inp, grads): self.unshared, )(bottom, weights) d_height_width = ( - (pytensor.gradient.DisconnectedType()(),) * 2 if len(inp) == 4 else () + (pytensor.gradient.DisconnectedType()(),) * 2 if len(inputs) == 4 else () ) return (d_weights, d_top, *d_height_width) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 16e3ad95f7..d40fcd9551 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -1005,9 +1005,9 @@ def __init__(self): def impl(self, n, x): return x * n - def grad(self, inputs, gout): + def grad(self, inputs, output_grads): (n, _x) = inputs - (gz,) = gout + (gz,) = output_grads dy_dx = n return [pytensor.gradient.grad_not_implemented(self, 0, n), gz * dy_dx] diff --git a/tests/test_gradient.py b/tests/test_gradient.py index a79746da6d..02ca0c0a9a 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -87,9 +87,9 @@ def make_node(self): outputs = [vector()] return Apply(self, inputs, outputs) - def grad(self, inp, grads): - (_x,) = inp - (_gz,) = grads + def grad(self, inputs, output_grads): + (_x,) = inputs + (_gz,) = output_grads def perform(self, *args, **kwargs): raise NotImplementedError() @@ -108,7 +108,7 @@ def make_node(self, *inputs): outputs = [vector()] return Apply(self, inputs, outputs) - def grad(self, inputs, grads): + def grad(self, inputs, output_grads): return [inputs[0].zeros_like()] def perform(self, *args, **kwargs): @@ -134,7 +134,7 @@ def make_node(self): outputs = [matrix()] return Apply(self, inputs, outputs) - def grad(self, inp, grads): + def grad(self, inputs, output_grads): return (gval,) def perform(self, *args, **kwargs): @@ -156,9 +156,9 @@ def make_node(self): outputs = [scalar(), scalar()] return Apply(self, inputs, outputs) - def grad(self, inp, grads): - (_x,) = inp - _gz1, _gz2 = grads + def grad(self, inputs, output_grads): + (_x,) = inputs + _gz1, _gz2 = output_grads return (gval,) def perform(self, *args, **kwargs): @@ -181,9 +181,9 @@ def make_node(self): outputs = [matrix()] return Apply(self, inputs, outputs) - def grad(self, inp, grads): - _x0, _x1 = inp - (_gz,) = grads + def grad(self, inputs, output_grads): + _x0, _x1 = inputs + (_gz,) = output_grads return (gval0, gval1) def perform(self, *args, **kwargs): @@ -207,7 +207,7 @@ def make_node(self): outputs = [matrix(), matrix()] return Apply(self, inputs, outputs) - def grad(self, inp, grads): + def grad(self, inputs, output_grads): return gval0, gval1 def perform(self, *args, **kwargs): @@ -230,9 +230,9 @@ def make_node(self): outputs = [scalar("b"), scalar("d")] return Apply(self, inputs, outputs) - def grad(self, inp, grads): - _x0, _x1 = inp - _gz0, _gz1 = grads + def grad(self, inputs, output_grads): + _x0, _x1 = inputs + _gz0, _gz1 = output_grads return self.gval0, self.gval1 def perform(self, *args, **kwargs): diff --git a/tests/test_rop.py b/tests/test_rop.py index 50cd27c87f..953cce5b45 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -50,8 +50,8 @@ def perform(self, node, inputs, output_storage): (out,) = output_storage out[0] = x - def grad(self, inp, grads): - return [grad_undefined(self, 0, inp[0])] + def grad(self, inputs, output_grads): + return [grad_undefined(self, 0, inputs[0])] def R_op(self, inputs, eval_points): return [None] From c17f657c2408ac042e99bea6fb8dc2f4ba63113b Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 22:46:11 +0100 Subject: [PATCH 16/17] Standardize Op.c_code arguments --- pytensor/compile/ops.py | 12 +-- pytensor/link/c/op.py | 6 +- pytensor/raise_op.py | 10 +-- pytensor/scalar/basic.py | 10 +-- pytensor/scalar/loop.py | 10 +-- pytensor/scalar/math.py | 100 +++++++++++------------ pytensor/tensor/basic.py | 38 ++++----- pytensor/tensor/blas.py | 24 +++--- pytensor/tensor/blas_c.py | 12 +-- pytensor/tensor/elemwise.py | 8 +- pytensor/tensor/extra_ops.py | 20 ++--- pytensor/tensor/math.py | 12 +-- pytensor/tensor/shape.py | 18 ++-- pytensor/tensor/special.py | 18 ++-- pytensor/tensor/subtensor.py | 12 +-- pytensor/typed_list/basic.py | 36 ++++---- tests/compile/test_debugmode.py | 18 ++-- tests/link/c/test_basic.py | 42 +++++----- tests/link/c/test_cmodule.py | 14 ++-- tests/link/c/test_op.py | 4 +- tests/link/c/test_type.py | 16 ++-- tests/tensor/conv/c_conv3d_corr3d_ref.py | 22 ++--- tests/tensor/conv/c_conv_corr_ref.py | 22 ++--- 23 files changed, 242 insertions(+), 242 deletions(-) diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 9994699b08..4059c4dead 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -50,9 +50,9 @@ def perform(self, node, inputs, output_storage): def __str__(self): return f"{self.__class__.__name__}" - def c_code(self, node, nodename, inp, out, sub): - (iname,) = inp - (oname,) = out + def c_code(self, node, name, inputs, outputs, sub): + (iname,) = inputs + (oname,) = outputs fail = sub["fail"] itype = node.inputs[0].type.__class__ @@ -192,9 +192,9 @@ def c_code_cache_version(self): version.append(1) return tuple(version) - def c_code(self, node, name, inames, onames, sub): - (iname,) = inames - (oname,) = onames + def c_code(self, node, name, inputs, outputs, sub): + (iname,) = inputs + (oname,) = outputs fail = sub["fail"] itype = node.inputs[0].type.__class__ diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index 2a0170f98d..46444cb280 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -571,7 +571,7 @@ def c_init_code_struct(self, node, name, sub): else: return super().c_init_code_struct(node, name, sub) - def c_code(self, node, name, inp, out, sub): + def c_code(self, node, name, inputs, outputs, sub): if self.func_name is not None: assert "code" not in self.code_sections @@ -587,7 +587,7 @@ def c_code(self, node, name, inp, out, sub): return f""" {define_macros} {{ - if ({self.func_name}({self.format_c_function_args(inp, out)}{params}) != 0) {{ + if ({self.func_name}({self.format_c_function_args(inputs, outputs)}{params}) != 0) {{ {sub["fail"]} }} }} @@ -599,7 +599,7 @@ def c_code(self, node, name, inp, out, sub): def_macros, undef_macros = self.get_c_macros(node, name) def_sub, undef_sub = get_sub_macros(sub) - def_io, undef_io = get_io_macros(inp, out) + def_io, undef_io = get_io_macros(inputs, outputs) return ( f"{def_macros}\n{def_sub}\n{def_io}\n{op_code}" diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index 97e001b42b..f36160f2e7 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -94,15 +94,15 @@ def grad(self, inputs, output_grads): def connection_pattern(self, node): return [[1]] + [[0]] * (len(node.inputs) - 1) - def c_code(self, node, name, inames, onames, props): + def c_code(self, node, name, inputs, outputs, sub): if not isinstance(node.inputs[0].type, DenseTensorType | ScalarType): raise NotImplementedError( f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}" ) - value_name, *cond_names = inames - out_name = onames[0] - fail_code = props["fail"] - param_struct_name = props["params"] + value_name, *cond_names = inputs + out_name = outputs[0] + fail_code = sub["fail"] + param_struct_name = sub["params"] msg = self.msg.replace('"', '\\"').replace("\n", "\\n") all_conds = " && ".join(cond_names) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 4f96d48aff..3cf87ec3b1 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1294,7 +1294,7 @@ def __str__(self): def c_code_cache_version(self): return (4,) - def c_code_contiguous(self, node, name, inp, out, sub): + def c_code_contiguous(self, node, name, inputs, outputs, sub): """ This function is called by Elemwise when all inputs and outputs are c_contiguous. This allows to use the SIMD version of this op. @@ -4406,15 +4406,15 @@ def c_code_template(self): return self._c_code - def c_code(self, node, nodename, inames, onames, sub): + def c_code(self, node, name, inputs, outputs, sub): d = dict( chain( - zip((f"i{i}" for i in range(len(inames))), inames, strict=True), - zip((f"o{i}" for i in range(len(onames))), onames, strict=True), + zip((f"i{i}" for i in range(len(inputs))), inputs, strict=True), + zip((f"o{i}" for i in range(len(outputs))), outputs, strict=True), ), **sub, ) - d["nodename"] = nodename + d["nodename"] = name if "id" not in sub: # The use of a dummy id is safe as the code is in a separate block. # It won't generate conflicting variable name. diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index de2dfb4f30..f780ba5f16 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -312,15 +312,15 @@ def c_code_template(self): return self._c_code - def c_code(self, node, nodename, inames, onames, sub): + def c_code(self, node, name, inputs, outputs, sub): d = dict( chain( - zip((f"i{i}" for i in range(len(inames))), inames, strict=True), - zip((f"o{i}" for i in range(len(onames))), onames, strict=True), + zip((f"i{i}" for i in range(len(inputs))), inputs, strict=True), + zip((f"o{i}" for i in range(len(outputs))), outputs, strict=True), ), **sub, ) - d["nodename"] = nodename + d["nodename"] = name if "id" not in sub: # The use of a dummy id is safe as the code is in a separate block. # It won't generate conflicting variable name. @@ -328,7 +328,7 @@ def c_code(self, node, nodename, inames, onames, sub): # When called inside Elemwise we don't have access to the dtype # via the usual `f"dtype_{inames[i]}"` variable - d["n_steps"] = inames[0] + d["n_steps"] = inputs[0] d["n_steps_dtype"] = "npy_" + node.inputs[0].dtype res = self.c_code_template % d diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 317ea59b8d..9ff1a79bc5 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -70,9 +70,9 @@ def L_op(self, inputs, outputs, output_grads): ) return (gz * cst * exp(-x * x),) - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs if node.inputs[0].type in complex_types: raise NotImplementedError("type not supported", type) cast = node.outputs[0].type.dtype_specs()[1] @@ -104,9 +104,9 @@ def L_op(self, inputs, outputs, output_grads): ) return (-gz * cst * exp(-x * x),) - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs if node.inputs[0].type in complex_types: raise NotImplementedError("type not supported", type) cast = node.outputs[0].type.dtype_specs()[1] @@ -162,9 +162,9 @@ def c_support_code(self, **kwargs): # Using Faddeeva.cc source file from: http://ab-initio.mit.edu/wiki/index.php/Faddeeva_Package return (C_CODE_PATH / "Faddeeva.cc").read_text(encoding="utf-8") - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs if node.inputs[0].type in float_types: dtype = "npy_" + node.outputs[0].dtype @@ -209,7 +209,7 @@ def L_op(self, inputs, outputs, output_grads): ) return (gz * cst * exp(erfinv(x) ** 2),) - def c_code(self, node, name, inp, out, sub): + def c_code(self, node, name, inputs, outputs, sub): # TODO: erfinv() is not provided by the C standard library # x, = inp # z, = out @@ -244,7 +244,7 @@ def L_op(self, inputs, outputs, output_grads): ) return (-gz * cst * exp(erfcinv(x) ** 2),) - def c_code(self, node, name, inp, out, sub): + def c_code(self, node, name, inputs, outputs, sub): # TODO: erfcinv() is not provided by the C standard library # x, = inp # z, = out @@ -336,9 +336,9 @@ def L_op(self, inputs, outputs, output_grads): return [gz * psi(x)] - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs # no c code for complex # [u]int* will be casted to float64 before computation if node.inputs[0].type in complex_types: @@ -439,9 +439,9 @@ def c_support_code(self, **kwargs): #endif """ - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs if node.inputs[0].type in float_types: dtype = "npy_" + node.outputs[0].dtype return f"{z} = ({dtype}) _psi({x});" @@ -523,9 +523,9 @@ def c_support_code(self, **kwargs): #endif """ - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs if node.inputs[0].type in float_types: return f"""{z} = _tri_gamma({x});""" @@ -597,9 +597,9 @@ def grad(self, inputs, output_grads): def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") - def c_code(self, node, name, inp, out, sub): - k, x = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + k, x = inputs + (z,) = outputs if node.inputs[0].type in float_types: dtype = "npy_" + node.outputs[0].dtype return f"""{z} = @@ -644,9 +644,9 @@ def grad(self, inputs, output_grads): def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") - def c_code(self, node, name, inp, out, sub): - k, x = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + k, x = inputs + (z,) = outputs if node.inputs[0].type in float_types: dtype = "npy_" + node.outputs[0].dtype return f"""{z} = @@ -943,9 +943,9 @@ def impl(self, k, x): def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") - def c_code(self, node, name, inp, out, sub): - k, x = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + k, x = inputs + (z,) = outputs if node.inputs[0].type in float_types: dtype = "npy_" + node.outputs[0].dtype return f"""{z} = @@ -975,9 +975,9 @@ def impl(self, k, x): def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") - def c_code(self, node, name, inp, out, sub): - k, x = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + k, x = inputs + (z,) = outputs if node.inputs[0].type in float_types: dtype = "npy_" + node.outputs[0].dtype return f"""{z} = @@ -1034,9 +1034,9 @@ def grad(self, inputs, output_grads): (gz,) = output_grads return [gz * (j0(x) - jv(2, x)) / 2.0] - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs if node.inputs[0].type in float_types: return f"""{z} = j1({x});""" @@ -1061,9 +1061,9 @@ def grad(self, inputs, output_grads): (gz,) = output_grads return [gz * -1 * j1(x)] - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs if node.inputs[0].type in float_types: return f"""{z} = j0({x});""" @@ -1217,9 +1217,9 @@ def grad(self, inputs, output_grads): return [rval] - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs if node.inputs[0].type in float_types: if node.inputs[0].type == float64: @@ -1280,9 +1280,9 @@ def grad(self, inputs, output_grads): (gz,) = output_grads return [gz * sigmoid(x)] - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs # We use the same limits for all precisions, which may be suboptimal. The reference # paper only looked at double precision if node.inputs[0].type in float_types: @@ -1351,9 +1351,9 @@ def grad(self, inputs, output_grads): res = switch(isinf(res), -np.inf, res) return [gz * res] - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs if node.inputs[0].type in float_types: if node.inputs[0].type == float64: @@ -1396,9 +1396,9 @@ def grad(self, inputs, output_grads): def c_support_code(self, **kwargs): return (C_CODE_PATH / "incbet.c").read_text(encoding="utf-8") - def c_code(self, node, name, inp, out, sub): - (a, b, x) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (a, b, x) = inputs + (z,) = outputs if ( node.inputs[0].type in float_types and node.inputs[1].type in float_types diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 1d477d44e5..9f1183da1c 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1665,9 +1665,9 @@ def perform(self, node, inputs, output_storage): # reuse the allocated memory. out[0][...] = v # broadcast v to fill us up - def c_code(self, node, name, inp, out, sub): - vv = inp[0] - (zz,) = out + def c_code(self, node, name, inputs, outputs, sub): + vv = inputs[0] + (zz,) = outputs fail = sub["fail"] v_static_shape = node.inputs[0].type.shape @@ -1675,7 +1675,7 @@ def c_code(self, node, name, inp, out, sub): v_ndim = len(v_static_shape) o_ndim = len(o_static_shape) is_zero = self.value_is_scalar_zero(node.inputs[0]) - assert o_ndim == len(inp[1:]) + assert o_ndim == len(inputs[1:]) # Declare variables code = f""" @@ -1684,7 +1684,7 @@ def c_code(self, node, name, inp, out, sub): """ # Initialize shape - for i, shp_i in enumerate(inp[1:]): + for i, shp_i in enumerate(inputs[1:]): code += f""" shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0]; """ @@ -1924,19 +1924,19 @@ def perform(self, node, inputs, output_storage): def c_code_cache_version(self): return (2,) - def c_code(self, node, name, inp, out_, props): - (out,) = out_ + def c_code(self, node, name, inputs, outputs, sub): + (out,) = outputs # Shouldn't use PyArray_TYPE(inp[0]) for the dtype # when len(inp) == 0 (we need to support this case. # So there will be (1 * nb_dtype) + ((nb len(inp) - 1 )) # different c code with the following algo - out_shape = len(inp) + out_shape = len(inputs) out_num = np.dtype(node.outputs[0].dtype).num # don't use dtype_%(out)s as when check_input=False, it isn't defined. out_dtype = node.outputs[0].type.dtype_specs()[1] - if len(inp) > 0: + if len(inputs) > 0: assert self.dtype == node.inputs[0].dtype - out_num = f"PyArray_TYPE({inp[0]})" + out_num = f"PyArray_TYPE({inputs[0]})" ret = f""" npy_intp dims[1]; @@ -1946,7 +1946,7 @@ def c_code(self, node, name, inp, out_, props): {out} = (PyArrayObject*)PyArray_EMPTY(1, dims, {out_num}, 0); }} """ - for idx, i in enumerate(inp): + for idx, i in enumerate(inputs): ret += f""" *(({out_dtype} *)PyArray_GETPTR1({out}, {idx})) = *(({out_dtype} *) PyArray_DATA({i})); """ @@ -3334,9 +3334,9 @@ def perform(self, node, inputs, output_storage): start.item(), stop.item(), step.item(), dtype=self.dtype ) - def c_code(self, node, nodename, input_names, output_names, sub): - [start_name, stop_name, step_name] = input_names - [out_name] = output_names + def c_code(self, node, name, inputs, outputs, sub): + [start_name, stop_name, step_name] = inputs + [out_name] = outputs typenum = np.dtype(self.dtype).num return f""" double start = ((dtype_{start_name}*)PyArray_DATA({start_name}))[0]; @@ -3858,9 +3858,9 @@ def perform(self, node, inputs, output_storage): out = out.copy() output_storage[0][0] = out - def c_code(self, node, nodename, input_names, output_names, sub): - [x_name] = input_names - [out_name] = output_names + def c_code(self, node, name, inputs, outputs, sub): + [x_name] = inputs + [out_name] = outputs return f""" Py_XDECREF({out_name}); @@ -4358,8 +4358,8 @@ def perform(self, node, inputs, output_storage): if out[0] is None or out[0].shape != sh: out[0] = np.empty(sh, dtype=self.dtype) - def c_code(self, node, name, inputs, out_, sub): - (out,) = out_ + def c_code(self, node, name, inputs, outputs, sub): + (out,) = outputs fail = sub["fail"] shps = inputs nd = len(shps) diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 8184ee8730..4858fc11d4 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -1096,9 +1096,9 @@ def infer_shape(self, fgraph, node, input_shapes): #undef REAL """ - def c_code(self, node, name, inp, out, sub): - _z, _a, _x, _y, _b = inp - (_zout,) = out + def c_code(self, node, name, inputs, outputs, sub): + _z, _a, _x, _y, _b = inputs + (_zout,) = outputs if node.inputs[0].type.dtype.startswith("complex"): raise MethodNotDefined(f"{self.__class__.__name__}.c_code") full_code = self.build_gemm_call() % dict(locals(), **sub) @@ -1185,9 +1185,9 @@ def infer_shape(self, fgraph, node, input_shapes): double b = 0.0; """ - def c_code(self, node, name, inp, out, sub): # DEBUG - _x, _y = inp - (_zout,) = out + def c_code(self, node, name, inputs, outputs, sub): # DEBUG + _x, _y = inputs + (_zout,) = outputs if node.inputs[0].type.dtype.startswith("complex"): raise MethodNotDefined(f"{self.__class__.__name__}.c_code") if len(self.c_libraries()) <= 0: @@ -1283,9 +1283,9 @@ def infer_shape(self, fgraph, node, input_shapes): double b = 0.0; """ - def c_code(self, node, name, inp, out, sub): - _x, _y, _a = inp - (_zout,) = out + def c_code(self, node, name, inputs, outputs, sub): + _x, _y, _a = inputs + (_zout,) = outputs if node.inputs[0].type.dtype.startswith("complex"): raise MethodNotDefined(f"{self.__class__.__name__}.c_code") if len(self.c_libraries()) <= 0: @@ -1462,13 +1462,13 @@ def c_lib_dirs(self, **kwargs): def c_header_dirs(self, **kwargs): return ldflags(libs=False, include_dir=True) - def c_code(self, node, name, inp, out, sub): + def c_code(self, node, name, inputs, outputs, sub): # Can only compile if linked to blas libraries if len(self.c_libraries()) <= 0: raise NotImplementedError() - _x, _y = inp - (_z,) = out + _x, _y = inputs + (_z,) = outputs fail = sub["fail"] # generate contiguity condition diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py index f0c8f4995a..84086819b5 100644 --- a/pytensor/tensor/blas_c.py +++ b/pytensor/tensor/blas_c.py @@ -317,9 +317,9 @@ class CGer(BaseBLAS, Ger): destructive=bool_t, ) - def c_code(self, node, name, inp, out, sub): - A, a, x, y = inp - (Z,) = out + def c_code(self, node, name, inputs, outputs, sub): + A, a, x, y = inputs + (Z,) = outputs code = ger_c_code(A, a, x, y, Z, fail=sub["fail"], params=sub["params"]) return code @@ -590,9 +590,9 @@ class CGemv(BaseBLAS, Gemv): def __init__(self, inplace): super().__init__(inplace) - def c_code(self, node, name, inp, out, sub): - y, alpha, A, x, beta = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + y, alpha, A, x, beta = inputs + (z,) = outputs code = gemv_c_code( y, A, diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 678561bbdd..853f78efa5 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1072,7 +1072,7 @@ def _c_all(self, node, nodename, inames, onames, sub): """ return decl, checks, alloc, loop, "" - def c_code(self, node, nodename, inames, onames, sub): + def c_code(self, node, name, inputs, outputs, sub): if ( any(i.dtype == "float16" for i in node.inputs) or any(o.dtype == "float16" for o in node.outputs) @@ -1082,7 +1082,7 @@ def c_code(self, node, nodename, inames, onames, sub): ): # Disable C code for float16 vars raise NotImplementedError() - code = "\n".join(self._c_all(node, nodename, inames, onames, sub)) + code = "\n".join(self._c_all(node, name, inputs, outputs, sub)) return code def c_headers(self, **kwargs): @@ -1583,8 +1583,8 @@ def _c_all(self, node, name, input_names, output_names, sub): return setup, alloc, loop, cast - def c_code(self, node, name, inames, onames, sub): - code = "\n".join(self._c_all(node, name, inames, onames, sub)) + def c_code(self, node, name, inputs, outputs, sub): + code = "\n".join(self._c_all(node, name, inputs, outputs, sub)) return code def c_headers(self, **kwargs): diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 2cfd9b21f7..104e9514e6 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -79,9 +79,9 @@ def perform(self, node, inputs, output_storage): def grad(self, inputs, output_grads): return [ptb.as_tensor_variable(output_grads[0])] - def c_code(self, node, name, inames, onames, sub): - (x,) = inames - (y,) = onames + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (y,) = outputs code = f""" if (!PyArray_CHKFLAGS({x}, NPY_ARRAY_C_CONTIGUOUS)){{ // check to see if output is contiguous first @@ -183,15 +183,15 @@ def c_init_code_struct(self, node, name, sub): Py_DECREF(tmp_{name}); """ - def c_code(self, node, name, inames, onames, sub): + def c_code(self, node, name, inputs, outputs, sub): sorter = None if len(node.inputs) == 3: - x, v, sorter = inames + x, v, sorter = inputs else: - x, v = inames + x, v = inputs if not sorter: sorter = "NULL" - (z,) = onames + (z,) = outputs fail = sub["fail"] return f""" @@ -344,9 +344,9 @@ def L_op(self, inputs, outputs, output_grads): def infer_shape(self, fgraph, node, shapes): return shapes - def c_code(self, node, name, inames, onames, sub): - (x,) = inames - (z,) = onames + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs fail = sub["fail"] params = sub["params"] diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index cdf09706fd..4d4086502c 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -204,9 +204,9 @@ def perform(self, node, inputs, output_storage): max_idx[0] = np.asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (argmax,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (argmax,) = outputs fail = sub["fail"] params = sub["params"] if self.axis is None: @@ -3739,9 +3739,9 @@ def impl(self, x, y): return x return x * y - def c_code(self, node, name, inp, out, sub): - x, y = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + x, y = inputs + (z,) = outputs return f"{z} = (({x} == 0) ? ({y}) : (({y} == 0) ? ({x}) : (({y})*({x}))) );" def c_code_cache_version(self): diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 71de47496e..61263a7908 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -108,9 +108,9 @@ def grad(self, inputs, output_grads): def R_op(self, inputs, eval_points): return [None] - def c_code(self, node, name, inames, onames, sub): - (iname,) = inames - (oname,) = onames + def c_code(self, node, name, inputs, outputs, sub): + (iname,) = inputs + (oname,) = outputs fail = sub["fail"] itype = node.inputs[0].type.__class__ @@ -287,9 +287,9 @@ def c_code_cache_version(self): return tuple(version) - def c_code(self, node, name, inames, onames, sub): - (iname,) = inames - (oname,) = onames + def c_code(self, node, name, inputs, outputs, sub): + (iname,) = inputs + (oname,) = outputs fail = sub["fail"] # i is then 'params->i', not just 'params'. i = sub["params"] + "->i" @@ -484,14 +484,14 @@ def R_op(self, inputs, eval_points): return [None] return self.make_node(eval_points[0], *inputs[1:]).outputs - def c_code(self, node, name, i_names, o_names, sub): + def c_code(self, node, name, inputs, outputs, sub): if not isinstance(node.inputs[0].type, DenseTensorType): raise NotImplementedError( f"Specify_shape c_code not implemented for input type {node.inputs[0].type}" ) - x_name, *shape_names = i_names - (o_name,) = o_names + x_name, *shape_names = inputs + (o_name,) = outputs fail = sub["fail"] code = dedent( diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index bbb64747d3..c1c6803839 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -62,9 +62,9 @@ def infer_shape(self, fgraph, node, shape): def c_code_cache_version(self): return (6,) - def c_code(self, node, name, inp, out, sub): - dy, sm = inp - (dx,) = out + def c_code(self, node, name, inputs, outputs, sub): + dy, sm = inputs + (dx,) = outputs axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] @@ -291,9 +291,9 @@ def infer_shape(self, fgraph, node, shape): def c_headers(self, **kwargs): return [""] - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (sm,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (sm,) = outputs axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] # dtype = node.inputs[0].type.dtype_specs()[1] @@ -548,9 +548,9 @@ def infer_shape(self, fgraph, node, shape): def c_headers(self, **kwargs): return [""] - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (sm,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (sm,) = outputs axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 5f13b96f84..f825926f11 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2156,7 +2156,7 @@ def infer_shape(self, fgraph, node, ishapes): x, ilist = ishapes return [ilist + x[1:]] - def c_code(self, node, name, input_names, output_names, sub): + def c_code(self, node, name, inputs, outputs, sub): if self.__class__ is not AdvancedSubtensor1: raise MethodNotDefined( "c_code defined for AdvancedSubtensor1, not for child class", @@ -2169,8 +2169,8 @@ def c_code(self, node, name, input_names, output_names, sub): # We can know ahead of time that all indices are valid, so we can use a faster mode mode = "NPY_WRAP" # This seems to be faster than NPY_CLIP - a_name, i_name = input_names[0], input_names[1] - output_name = output_names[0] + a_name, i_name = inputs[0], inputs[1] + output_name = outputs[0] fail = sub["fail"] if mode == "NPY_RAISE": # numpy_take always makes an intermediate copy if NPY_RAISE which is slower than just allocating a new buffer @@ -2330,9 +2330,9 @@ def copy_of_x(self, x): return f"""(PyArrayObject*)PyArray_FromAny(py_{x}, NULL, 0, 0, NPY_ARRAY_ENSURECOPY, NULL)""" - def c_code(self, node, name, input_names, output_names, sub): - x, y, idx = input_names - [out] = output_names + def c_code(self, node, name, inputs, outputs, sub): + x, y, idx = inputs + [out] = outputs copy_of_x = self.copy_of_x(x) params = sub["params"] fail = sub["fail"] diff --git a/pytensor/typed_list/basic.py b/pytensor/typed_list/basic.py index 59d99bb76c..7c8d7770df 100644 --- a/pytensor/typed_list/basic.py +++ b/pytensor/typed_list/basic.py @@ -102,9 +102,9 @@ def perform(self, node, inputs, output_storage): def __str__(self): return self.__class__.__name__ - def c_code(self, node, name, inp, out, sub): - x_name, index = inp[0], inp[1] - output_name = out[0] + def c_code(self, node, name, inputs, outputs, sub): + x_name, index = inputs[0], inputs[1] + output_name = outputs[0] fail = sub["fail"] return f""" {output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index}))); @@ -167,10 +167,10 @@ def perform(self, node, inputs, output_storage): def __str__(self): return self.__class__.__name__ - def c_code(self, node, name, inp, out, sub): + def c_code(self, node, name, inputs, outputs, sub): raise NotImplementedError("DISABLED AS WE NEED TO UPDATE IT TO COPY toAppend()") - x_name, toAppend = inp[0], inp[1] - output_name = out[0] + x_name, toAppend = inputs[0], inputs[1] + output_name = outputs[0] fail = sub["fail"] if not self.inplace: init = f""" @@ -247,10 +247,10 @@ def perform(self, node, inputs, output_storage): def __str__(self): return self.__class__.__name__ - def c_code(self, node, name, inp, out, sub): + def c_code(self, node, name, inputs, outputs, sub): raise NotImplementedError("DISABLED AS WE NEED TO UPDATE IT TO COPY toAppend()") - x_name, toAppend = inp[0], inp[1] - output_name = out[0] + x_name, toAppend = inputs[0], inputs[1] + output_name = outputs[0] fail = sub["fail"] if not self.inplace: init = f""" @@ -335,10 +335,10 @@ def perform(self, node, inputs, output_storage): def __str__(self): return self.__class__.__name__ - def c_code(self, node, name, inp, out, sub): + def c_code(self, node, name, inputs, outputs, sub): raise NotImplementedError("DISABLED AS WE NEED TO UPDATE IT TO COPY toAppend()") - x_name, index, toInsert = inp[0], inp[1], inp[2] - output_name = out[0] + x_name, index, toInsert = inputs[0], inputs[1], inputs[2] + output_name = outputs[0] fail = sub["fail"] if not self.inplace: init = f""" @@ -464,9 +464,9 @@ def perform(self, node, inputs, output_storage): def __str__(self): return self.__class__.__name__ - def c_code(self, node, name, inp, out, sub): - x_name = inp[0] - output_name = out[0] + def c_code(self, node, name, inputs, outputs, sub): + x_name = inputs[0] + output_name = outputs[0] fail = sub["fail"] if not self.inplace: init = f""" @@ -596,9 +596,9 @@ def perform(self, node, inputs, output_storage): def __str__(self): return self.__class__.__name__ - def c_code(self, node, name, inp, out, sub): - x_name = inp[0] - output_name = out[0] + def c_code(self, node, name, inputs, outputs, sub): + x_name = inputs[0] + output_name = outputs[0] return f""" if(!{output_name}) {output_name}=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0); diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index 340411eb0d..563eb151d4 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -61,9 +61,9 @@ def perform(self, node, inputs, output_storage): def c_code_cache_version(self): return (1,) - def c_code(self, node, name, inp, out, sub): - a, b = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + a, b = inputs + (z,) = outputs fail = sub["fail"] return f""" if (PyArray_NDIM({a}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); {fail};}} @@ -148,9 +148,9 @@ def dontuse_perform(self, node, inp, out_): def c_code_cache_version(self): return (2,) - def c_code(self, node, name, inp, out, sub): - (a,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (a,) = inputs + (z,) = outputs if "inplace" in self.behaviour: z_code = f""" {{Py_XDECREF({z});}} @@ -629,9 +629,9 @@ def perform(self, node, inputs, output_storage): def c_code_cache_version(self): return (1,) - def c_code(self, node, name, inp, out, sub): - a, b = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + a, b = inputs + (z,) = outputs debug = 0 fail = sub["fail"] return f""" diff --git a/tests/link/c/test_basic.py b/tests/link/c/test_basic.py index da80018a22..7f17cb2139 100644 --- a/tests/link/c/test_basic.py +++ b/tests/link/c/test_basic.py @@ -125,9 +125,9 @@ def __init__(self): class Add(Binary): - def c_code(self, node, name, inp, out, sub): - x, y = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + x, y = inputs + (z,) = outputs return f"{z} = {x} + {y};" def impl(self, x, y): @@ -138,9 +138,9 @@ def impl(self, x, y): class Sub(Binary): - def c_code(self, node, name, inp, out, sub): - x, y = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + x, y = inputs + (z,) = outputs return f"{z} = {x} - {y};" def impl(self, x, y): @@ -151,9 +151,9 @@ def impl(self, x, y): class BadSub(Binary): - def c_code(self, node, name, inp, out, sub): - x, y = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + x, y = inputs + (z,) = outputs return f"{z} = {x} - {y};" def impl(self, x, y): @@ -164,9 +164,9 @@ def impl(self, x, y): class Mul(Binary): - def c_code(self, node, name, inp, out, sub): - x, y = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + x, y = inputs + (z,) = outputs return f"{z} = {x} * {y};" def impl(self, x, y): @@ -177,9 +177,9 @@ def impl(self, x, y): class Div(Binary): - def c_code(self, node, name, inp, out, sub): - x, y = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + x, y = inputs + (z,) = outputs return f"{z} = {x} / {y};" def impl(self, x, y): @@ -499,9 +499,9 @@ def test_duallinker_mismatch(): class AddFail(Binary): - def c_code(self, node, name, inp, out, sub): - x, y = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + x, y = inputs + (z,) = outputs fail = sub["fail"] return f"""{z} = {x} + {y}; PyErr_SetString(PyExc_RuntimeError, "failing here"); @@ -610,9 +610,9 @@ def perform(self, node, inputs, output_storage): def c_code_cache_version(self): return (1,) - def c_code(self, node, name, inp, out, sub): - x, y = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + x, y = inputs + (z,) = outputs return f"{z} = {x} + {y};" x = tdouble("x") diff --git a/tests/link/c/test_cmodule.py b/tests/link/c/test_cmodule.py index 51efd5d410..f193af732c 100644 --- a/tests/link/c/test_cmodule.py +++ b/tests/link/c/test_cmodule.py @@ -35,10 +35,10 @@ class MyOp(DeepCopyOp): def c_code_cache_version(self): return () - def c_code(self, node, name, inames, onames, sub): + def c_code(self, node, name, inputs, outputs, sub): MyOp.nb_called += 1 - (iname,) = inames - (oname,) = onames + (iname,) = inputs + (oname,) = outputs fail = sub["fail"] itype = node.inputs[0].type.__class__ if itype in self.c_code_and_version: @@ -46,7 +46,7 @@ def c_code(self, node, name, inames, onames, sub): rand = np.random.random() return f'printf("{rand}\\n");{code % locals()}' # Else, no C code - return super(DeepCopyOp, self).c_code(node, name, inames, onames, sub) + return super(DeepCopyOp, self).c_code(node, name, inputs, outputs, sub) class MyAdd(COp): @@ -60,9 +60,9 @@ def perform(self, node, inputs, output_storage): (out,) = output_storage out[0] = inputs[0][0] + 1 - def c_code(self, node, name, inp, out, sub): - (x,) = inp - (z,) = out + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs return f"{z} = {x} + 1;" diff --git a/tests/link/c/test_op.py b/tests/link/c/test_op.py index d4aed3c26c..b30b688688 100644 --- a/tests/link/c/test_op.py +++ b/tests/link/c/test_op.py @@ -78,9 +78,9 @@ def c_support_code_struct(self, node, name): def c_init_code_struct(self, node, name, sub): return f"counter{name} = 0;" - def c_code(self, node, name, input_names, outputs_names, sub): + def c_code(self, node, name, inputs, outputs, sub): return f""" -{outputs_names[0]} = counter{name}; +{outputs[0]} = counter{name}; counter{name}++; """ diff --git a/tests/link/c/test_type.py b/tests/link/c/test_type.py index dc24a03a8d..66259c5eff 100644 --- a/tests/link/c/test_type.py +++ b/tests/link/c/test_type.py @@ -24,11 +24,11 @@ def c_support_code(self, **kwargs): } """ - def c_code(self, node, name, inps, outs, sub): + def c_code(self, node, name, inputs, outputs, sub): return f""" -Py_XDECREF({outs[0]}); -{outs[0]} = (void *){inps[0]}; -Py_INCREF({inps[0]}); +Py_XDECREF({outputs[0]}); +{outputs[0]} = (void *){inputs[0]}; +Py_INCREF({inputs[0]}); """ # FIXME: should it not be outs[0]? @@ -52,11 +52,11 @@ def c_support_code(self, **kwargs): } """ - def c_code(self, node, name, inps, outs, sub): + def c_code(self, node, name, inputs, outputs, sub): return f""" -Py_XDECREF({outs[0]}); -{outs[0]} = (PyArrayObject *){inps[0]}; -Py_INCREF({outs[0]}); +Py_XDECREF({outputs[0]}); +{outputs[0]} = (PyArrayObject *){inputs[0]}; +Py_INCREF({outputs[0]}); """ def c_code_cache_version(self): diff --git a/tests/tensor/conv/c_conv3d_corr3d_ref.py b/tests/tensor/conv/c_conv3d_corr3d_ref.py index b0c26c60cc..067be760bd 100644 --- a/tests/tensor/conv/c_conv3d_corr3d_ref.py +++ b/tests/tensor/conv/c_conv3d_corr3d_ref.py @@ -631,9 +631,9 @@ def infer_shape(self, fgraph, node, input_shape): ) return [res] - def c_code(self, node, nodename, inp, out_, sub): - bottom, weights = inp - (top,) = out_ + def c_code(self, node, name, inputs, outputs, sub): + bottom, weights = inputs + (top,) = outputs return super().c_code_helper(bottom, weights, top, sub) def grad(self, inputs, output_grads): @@ -744,10 +744,10 @@ def infer_shape(self, fgraph, node, input_shape): kD = imshp[2] + 2 * padD - (topshp[2] - 1) * dD return [(nkern, ssize, kH, kW, kD)] - def c_code(self, node, nodename, inp, out_, sub): - bottom, top = inp[:2] - height, width, depth = inp[2:] or (None, None, None) - (weights,) = out_ + def c_code(self, node, name, inputs, outputs, sub): + bottom, top = inputs[:2] + height, width, depth = inputs[2:] or (None, None, None) + (weights,) = outputs return super().c_code_helper(bottom, weights, top, sub, height, width, depth) def grad(self, inputs, output_grads): @@ -884,10 +884,10 @@ def infer_shape(self, fgraph, node, input_shape): out_shp = (out_shp0, out_shp1, out_shp2) return [(bsize, ssize, *out_shp)] - def c_code(self, node, nodename, inp, out_, sub): - weights, top = inp[:2] - height, width, depth = inp[2:] or (None, None, None) - (bottom,) = out_ + def c_code(self, node, name, inputs, outputs, sub): + weights, top = inputs[:2] + height, width, depth = inputs[2:] or (None, None, None) + (bottom,) = outputs return super().c_code_helper(bottom, weights, top, sub, height, width, depth) def grad(self, inputs, output_grads): diff --git a/tests/tensor/conv/c_conv_corr_ref.py b/tests/tensor/conv/c_conv_corr_ref.py index 160d503569..4a96251095 100644 --- a/tests/tensor/conv/c_conv_corr_ref.py +++ b/tests/tensor/conv/c_conv_corr_ref.py @@ -690,9 +690,9 @@ def infer_shape(self, fgraph, node, input_shape): ) return [res] - def c_code(self, node, nodename, inp, out_, sub): - bottom, weights = inp - (top,) = out_ + def c_code(self, node, name, inputs, outputs, sub): + bottom, weights = inputs + (top,) = outputs return super().c_code_helper(bottom, weights, top, sub) def grad(self, inputs, output_grads): @@ -815,10 +815,10 @@ def infer_shape(self, fgraph, node, input_shape): else: return [(nkern, ssize, kH, kW)] - def c_code(self, node, nodename, inp, out_, sub): - bottom, top = inp[:2] - height, width = inp[2:] or (None, None) - (weights,) = out_ + def c_code(self, node, name, inputs, outputs, sub): + bottom, top = inputs[:2] + height, width = inputs[2:] or (None, None) + (weights,) = outputs return super().c_code_helper(bottom, weights, top, sub, height, width) def grad(self, inputs, output_grads): @@ -947,10 +947,10 @@ def infer_shape(self, fgraph, node, input_shape): out_shp = (out_shp0, out_shp1) return [(bsize, ssize, *out_shp)] - def c_code(self, node, nodename, inp, out_, sub): - weights, top = inp[:2] - height, width = inp[2:] or (None, None) - (bottom,) = out_ + def c_code(self, node, name, inputs, outputs, sub): + weights, top = inputs[:2] + height, width = inputs[2:] or (None, None) + (bottom,) = outputs return super().c_code_helper(bottom, weights, top, sub, height, width) def grad(self, inputs, output_grads): From dee236c59e81a62d4ceaf516034285fb2db37bbb Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 17 Dec 2025 23:24:48 +0100 Subject: [PATCH 17/17] Standardize various method overrides --- pytensor/graph/destroyhandler.py | 94 +++++++++++------------ pytensor/graph/features.py | 22 +++--- pytensor/graph/rewriting/basic.py | 18 ++--- pytensor/graph/rewriting/db.py | 8 +- pytensor/graph/rewriting/kanren.py | 2 +- pytensor/misc/ordered_set.py | 4 +- pytensor/printing.py | 96 ++++++++++++------------ pytensor/scalar/basic.py | 14 ++-- pytensor/tensor/basic.py | 8 +- pytensor/tensor/elemwise.py | 14 ++-- pytensor/tensor/math.py | 89 +++++++++++++++++----- pytensor/tensor/rewriting/math.py | 2 +- pytensor/tensor/rewriting/shape.py | 26 +++---- pytensor/tensor/subtensor.py | 14 ++-- tests/graph/rewriting/test_basic.py | 2 +- tests/graph/test_types.py | 4 +- tests/tensor/conv/c_conv3d_corr3d_ref.py | 2 +- tests/tensor/conv/c_conv_corr_ref.py | 2 +- tests/tensor/rewriting/test_elemwise.py | 4 +- tests/test_ifelse.py | 4 +- 20 files changed, 239 insertions(+), 190 deletions(-) diff --git a/pytensor/graph/destroyhandler.py b/pytensor/graph/destroyhandler.py index 1fe59f2c6d..764d6e693b 100644 --- a/pytensor/graph/destroyhandler.py +++ b/pytensor/graph/destroyhandler.py @@ -521,72 +521,72 @@ def fast_destroy(self, fgraph, app, reason): # assert len(v) <= 1 # assert len(d) <= 1 - def on_import(self, fgraph, app, reason): + def on_import(self, fgraph, node, reason): """ Add Apply instance to set which must be computed. """ - if app in self.debug_all_apps: + if node in self.debug_all_apps: raise ProtocolError("double import") - self.debug_all_apps.add(app) + self.debug_all_apps.add(node) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # If it's a destructive op, add it to our watch list - dmap = app.op.destroy_map - vmap = app.op.view_map + dmap = node.op.destroy_map + vmap = node.op.view_map if dmap: - self.destroyers.add(app) + self.destroyers.add(node) if self.algo == "fast": - self.fast_destroy(fgraph, app, reason) + self.fast_destroy(fgraph, node, reason) # add this symbol to the forward and backward maps for o_idx, i_idx_list in vmap.items(): if len(i_idx_list) > 1: raise NotImplementedError( - "destroying this output invalidates multiple inputs", (app.op) + "destroying this output invalidates multiple inputs", (node.op) ) - o = app.outputs[o_idx] - i = app.inputs[i_idx_list[0]] + o = node.outputs[o_idx] + i = node.inputs[i_idx_list[0]] self.view_i[o] = i self.view_o.setdefault(i, OrderedSet()).add(o) # update self.clients - for i, input in enumerate(app.inputs): - self.clients.setdefault(input, {}).setdefault(app, 0) - self.clients[input][app] += 1 + for i, input in enumerate(node.inputs): + self.clients.setdefault(input, {}).setdefault(node, 0) + self.clients[input][node] += 1 - for i, output in enumerate(app.outputs): + for i, output in enumerate(node.outputs): self.clients.setdefault(output, {}) self.stale_droot = True - def on_prune(self, fgraph, app, reason): + def on_prune(self, fgraph, node, reason): """ Remove Apply instance from set which must be computed. """ - if app not in self.debug_all_apps: + if node not in self.debug_all_apps: raise ProtocolError("prune without import") - self.debug_all_apps.remove(app) + self.debug_all_apps.remove(node) # UPDATE self.clients - for input in set(app.inputs): - del self.clients[input][app] + for input in set(node.inputs): + del self.clients[input][node] - if app.op.destroy_map: - self.destroyers.remove(app) + if node.op.destroy_map: + self.destroyers.remove(node) # Note: leaving empty client dictionaries in the struct. # Why? It's a pain to remove them. I think they aren't doing any harm, they will be # deleted on_detach(). # UPDATE self.view_i, self.view_o - for o_idx, i_idx_list in app.op.view_map.items(): + for o_idx, i_idx_list in node.op.view_map.items(): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() - o = app.outputs[o_idx] - i = app.inputs[i_idx_list[0]] + o = node.outputs[o_idx] + i = node.inputs[i_idx_list[0]] del self.view_i[o] @@ -595,53 +595,53 @@ def on_prune(self, fgraph, app, reason): del self.view_o[i] self.stale_droot = True - if app in self.fail_validate: - del self.fail_validate[app] + if node in self.fail_validate: + del self.fail_validate[node] - def on_change_input(self, fgraph, app, i, old_r, new_r, reason): + def on_change_input(self, fgraph, node, i, var, new_var, reason=None): """ - app.inputs[i] changed from old_r to new_r. + node.inputs[i] changed from var to new_var. """ - if isinstance(app.op, Output): - # app == 'output' is special key that means FunctionGraph is redefining which nodes are being + if isinstance(node.op, Output): + # node == 'output' is special key that means FunctionGraph is redefining which nodes are being # considered 'outputs' of the graph. pass else: - if app not in self.debug_all_apps: + if node not in self.debug_all_apps: raise ProtocolError("change without import") # UPDATE self.clients - self.clients[old_r][app] -= 1 - if self.clients[old_r][app] == 0: - del self.clients[old_r][app] + self.clients[var][node] -= 1 + if self.clients[var][node] == 0: + del self.clients[var][node] - self.clients.setdefault(new_r, {}).setdefault(app, 0) - self.clients[new_r][app] += 1 + self.clients.setdefault(new_var, {}).setdefault(node, 0) + self.clients[new_var][node] += 1 # UPDATE self.view_i, self.view_o - for o_idx, i_idx_list in app.op.view_map.items(): + for o_idx, i_idx_list in node.op.view_map.items(): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() i_idx = i_idx_list[0] - output = app.outputs[o_idx] + output = node.outputs[o_idx] if i_idx == i: - if app.inputs[i_idx] is not new_r: + if node.inputs[i_idx] is not new_var: raise ProtocolError("wrong new_r on change") - self.view_i[output] = new_r + self.view_i[output] = new_var - self.view_o[old_r].remove(output) - if not self.view_o[old_r]: - del self.view_o[old_r] + self.view_o[var].remove(output) + if not self.view_o[var]: + del self.view_o[var] - self.view_o.setdefault(new_r, OrderedSet()).add(output) + self.view_o.setdefault(new_var, OrderedSet()).add(output) if self.algo == "fast": - if app in self.fail_validate: - del self.fail_validate[app] - self.fast_destroy(fgraph, app, reason) + if node in self.fail_validate: + del self.fail_validate[node] + self.fast_destroy(fgraph, node, reason) self.stale_droot = True def validate(self, fgraph): diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index f03b58dfcb..74ea7f8971 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -416,11 +416,11 @@ def on_detach(self, fgraph): del fgraph.revert del self.history[fgraph] - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input(self, fgraph, node, i, var, new_var, reason=None): if self.history[fgraph] is None: return h = self.history[fgraph] - h.append(LambdaExtract(fgraph, node, i, r, reason)) + h.append(LambdaExtract(fgraph, node, i, var, reason)) def revert(self, fgraph, checkpoint): """ @@ -544,9 +544,9 @@ def on_attach(self, fgraph): raise ValueError("Full History already attached to another fgraph") self.fg = fgraph - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): - self.bw.append(LambdaExtract(fgraph, node, i, r, reason)) - self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason)) + def on_change_input(self, fgraph, node, i, var, new_var, reason=None): + self.bw.append(LambdaExtract(fgraph, node, i, var, reason)) + self.fw.append(LambdaExtract(fgraph, node, i, new_var, reason)) self.pointer += 1 if self.callback: self.callback() @@ -832,15 +832,15 @@ class PreserveVariableAttributes(Feature): This preserve some variables attributes and tag during optimization. """ - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input(self, fgraph, node, i, var, new_var, reason=None): # Don't change the name of constants - if r.owner and r.name is not None and new_r.name is None: - new_r.name = r.name + if var.owner and var.name is not None and new_var.name is None: + new_var.name = var.name if ( - getattr(r.tag, "nan_guard_mode_check", False) - and getattr(new_r.tag, "nan_guard_mode_check", False) is False + getattr(var.tag, "nan_guard_mode_check", False) + and getattr(new_var.tag, "nan_guard_mode_check", False) is False ): - new_r.tag.nan_guard_mode_check = r.tag.nan_guard_mode_check + new_var.tag.nan_guard_mode_check = var.tag.nan_guard_mode_check class NoOutputFromInplace(Feature): diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 20d0fe52f0..c4a701fdc2 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -550,15 +550,15 @@ def on_attach(self, fgraph): def clone(self): return type(self)() - def on_change_input(self, fgraph, node, i, r, new_r, reason): + def on_change_input(self, fgraph, node, i, var, new_var, reason=None): if node in self.nodes_seen: # If inputs to a node change, it's not guaranteed that the node is # distinct from the other nodes in `self.nodes_seen`. self.nodes_seen.discard(node) self.process_node(fgraph, node) - if isinstance(new_r, AtomicVariable): - self.process_atomic(fgraph, new_r) + if isinstance(new_var, AtomicVariable): + self.process_atomic(fgraph, new_var) def on_import(self, fgraph, node, reason): for c in node.inputs: @@ -973,7 +973,7 @@ def __init__(self, fn, tracks=None, requirements=()): ) self.requirements = requirements - def transform(self, fgraph, node, enforce_tracks: bool = True): + def transform(self, fgraph, node, enforce_tracks: bool = True, *args, **kwargs): if enforce_tracks and self._tracks: node_op = node.op if not ( @@ -1230,7 +1230,7 @@ def tracks(self): t.extend(at) return t - def transform(self, fgraph, node, enforce_tracks=False): + def transform(self, fgraph, node, enforce_tracks=False, *args, **kwargs): if len(self.rewrites) == 0: return @@ -1385,7 +1385,7 @@ def __init__(self, op1, op2, transfer_tags=True): def tracks(self): return [self.op1] - def transform(self, fgraph, node, enforce_tracks=True): + def transform(self, fgraph, node, enforce_tracks=True, *args, **kwargs): if enforce_tracks and (node.op != self.op1): return False repl = self.op2.make_node(*node.inputs) @@ -1713,9 +1713,9 @@ def on_prune(self, fgraph, node, reason): if self.pruner: self.pruner(node) - def on_change_input(self, fgraph, node, i, r, new_r, reason): + def on_change_input(self, fgraph, node, i, var, new_var, reason=None): if self.chin: - self.chin(node, i, r, new_r, reason) + self.chin(node, i, var, new_var, reason) def on_detach(self, fgraph): # To allow pickling this object @@ -2160,7 +2160,7 @@ def on_import(self, fgraph, node, reason): self.nb_imported += 1 self.changed = True - def on_change_input(self, fgraph, node, i, r, new_r, reason): + def on_change_input(self, fgraph, node, i, var, new_var, reason=None): self.changed = True def reset(self): diff --git a/pytensor/graph/rewriting/db.py b/pytensor/graph/rewriting/db.py index b5f0de04aa..b4bf8a5de9 100644 --- a/pytensor/graph/rewriting/db.py +++ b/pytensor/graph/rewriting/db.py @@ -396,10 +396,10 @@ def __init__( self.__position__ = {} self.failure_callback = failure_callback - def register(self, name, obj, *tags, **kwargs): + def register(self, name, rewriter, *tags, **kwargs): position = kwargs.pop("position", "last") - super().register(name, obj, *tags, **kwargs) + super().register(name, rewriter, *tags, **kwargs) if position == "last": if len(self.__position__) == 0: @@ -497,8 +497,8 @@ def __init__( self.node_rewriter = node_rewriter self.__name__: str = "" - def register(self, name, obj, *tags, position="last", **kwargs): - super().register(name, obj, *tags, position=position, **kwargs) + def register(self, name, rewriter, *tags, position="last", **kwargs): + super().register(name, rewriter, *tags, position=position, **kwargs) def query(self, *tags, **kwtags): rewrites = list(super().query(*tags, **kwtags)) diff --git a/pytensor/graph/rewriting/kanren.py b/pytensor/graph/rewriting/kanren.py index 8b45d85da8..9a58850137 100644 --- a/pytensor/graph/rewriting/kanren.py +++ b/pytensor/graph/rewriting/kanren.py @@ -74,7 +74,7 @@ def results_filter( self.node_filter = node_filter super().__init__() - def transform(self, fgraph, node, enforce_tracks: bool = True): + def transform(self, fgraph, node, enforce_tracks: bool = True, *args, **kwargs): if self.node_filter(node) is False: return False diff --git a/pytensor/misc/ordered_set.py b/pytensor/misc/ordered_set.py index a33cc53c32..a0c0378e86 100644 --- a/pytensor/misc/ordered_set.py +++ b/pytensor/misc/ordered_set.py @@ -11,8 +11,8 @@ def __init__(self, iterable: Iterable | None = None) -> None: else: self.values = dict.fromkeys(iterable) - def __contains__(self, value) -> bool: - return value in self.values + def __contains__(self, x) -> bool: + return x in self.values def __iter__(self) -> Iterator: yield from self.values diff --git a/pytensor/printing.py b/pytensor/printing.py index e814ad4d66..f2555cc6ca 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -861,11 +861,11 @@ def __init__(self, operator, precedence, assoc="left"): self.assoc = assoc assert self.assoc in VALID_ASSOC - def process(self, output, pstate): - if output in pstate.memo: - return pstate.memo[output] + def process(self, var, pstate): + if var in pstate.memo: + return pstate.memo[var] pprinter = pstate.pprinter - node = output.owner + node = var.owner if node is None: raise TypeError( f"operator {self.operator} cannot represent a variable that is " @@ -903,7 +903,7 @@ def process(self, output, pstate): r = f"({s})" else: r = s - pstate.memo[output] = r + pstate.memo[var] = r return r @@ -916,17 +916,17 @@ def __init__(self, *patterns): else: self.patterns.append((pattern[0], pattern[1:])) - def process(self, output, pstate): - if output in pstate.memo: - return pstate.memo[output] + def process(self, var, pstate): + if var in pstate.memo: + return pstate.memo[var] pprinter = pstate.pprinter - node = output.owner + node = var.owner if node is None: raise TypeError( f"Patterns {self.patterns} cannot represent a variable that is " "not the result of an operation" ) - idx = node.outputs.index(output) + idx = node.outputs.index(var) pattern, precedences = self.patterns[idx] precedences += (1000,) * (len(node.inputs) - len(precedences)) @@ -942,7 +942,7 @@ def pp_process(input, new_precedence): ) } r = pattern % d - pstate.memo[output] = r + pstate.memo[var] = r return r @@ -963,17 +963,17 @@ def __init__(self, names: list[str], keywords: list[str] | None = None): self.keywords = keywords - def process(self, output, pstate): - if output in pstate.memo: - return pstate.memo[output] + def process(self, var, pstate): + if var in pstate.memo: + return pstate.memo[var] pprinter = pstate.pprinter - node = output.owner + node = var.owner if node is None: raise TypeError( f"function {self.names} cannot represent a variable that is " "not the result of an operation" ) - idx = node.outputs.index(output) + idx = node.outputs.index(var) name = self.names[idx] with set_precedence(pstate): inputs_str = ", ".join( @@ -988,16 +988,16 @@ def process(self, output, pstate): r = f"{name}({inputs_str}{keywords_str})" - pstate.memo[output] = r + pstate.memo[var] = r return r class IgnorePrinter(Printer): - def process(self, output, pstate): - if output in pstate.memo: - return pstate.memo[output] + def process(self, var, pstate): + if var in pstate.memo: + return pstate.memo[var] pprinter = pstate.pprinter - node = output.owner + node = var.owner if node is None: raise TypeError( f"function {self.function} cannot represent a variable that is" @@ -1005,19 +1005,19 @@ def process(self, output, pstate): ) input = node.inputs[0] r = f"{pprinter.process(input, pstate)}" - pstate.memo[output] = r + pstate.memo[var] = r return r class LeafPrinter(Printer): - def process(self, output, pstate): - if output in pstate.memo: - return pstate.memo[output] - if output.name in greek: - r = greek[output.name] + def process(self, var, pstate): + if var in pstate.memo: + return pstate.memo[var] + if var.name in greek: + r = greek[var.name] else: - r = str(output) - pstate.memo[output] = r + r = str(var) + pstate.memo[var] = r return r @@ -1025,11 +1025,11 @@ def process(self, output, pstate): class ConstantPrinter(Printer): - def process(self, output, pstate): - if output in pstate.memo: - return pstate.memo[output] - r = str(output.data) - pstate.memo[output] = r + def process(self, var, pstate): + if var in pstate.memo: + return pstate.memo[var] + r = str(var.data) + pstate.memo[var] = r return r @@ -1037,18 +1037,18 @@ def process(self, output, pstate): class DefaultPrinter(Printer): - def process(self, output, pstate): - if output in pstate.memo: - return pstate.memo[output] + def process(self, var, pstate): + if var in pstate.memo: + return pstate.memo[var] pprinter = pstate.pprinter - node = output.owner + node = var.owner if node is None: - return leaf_printer.process(output, pstate) + return leaf_printer.process(var, pstate) with set_precedence(pstate): args = ", ".join(pprinter.process(input, pstate) for input in node.inputs) r = f"{node.op}({args})" - pstate.memo[output] = r + pstate.memo[var] = r return r @@ -1066,19 +1066,19 @@ def assign(self, condition: Op | type | Callable, printer: Printer): else: self.printers.insert(0, (condition, printer)) - def process(self, r: Variable, pstate: PrinterState | None = None) -> str: + def process(self, var: Variable, pstate: PrinterState | None = None) -> str: if pstate is None: pstate = PrinterState(pprinter=self) elif isinstance(pstate, dict): pstate = PrinterState(pprinter=self, **pstate) - if getattr(r, "owner", None) is not None: - if r.owner.op in self.printers_dict: - return self.printers_dict[r.owner.op].process(r, pstate) - if type(r.owner.op) in self.printers_dict: - return self.printers_dict[type(r.owner.op)].process(r, pstate) + if getattr(var, "owner", None) is not None: + if var.owner.op in self.printers_dict: + return self.printers_dict[var.owner.op].process(var, pstate) + if type(var.owner.op) in self.printers_dict: + return self.printers_dict[type(var.owner.op)].process(var, pstate) for condition, printer in self.printers: - if condition(pstate, r): - return printer.process(r, pstate) + if condition(pstate, var): + return printer.process(var, pstate) return "" def clone(self): diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 3cf87ec3b1..f3c3a65ddd 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1234,9 +1234,9 @@ def make_node(self, *inputs): ) return Apply(self, inputs, outputs) - def output_types(self, types): + def output_types(self, input_types): if self.output_types_preference is not None: - variables = self.output_types_preference(*types) + variables = self.output_types_preference(*input_types) if not isinstance(variables, list | tuple) or any( not isinstance(x, CType) for x in variables ): @@ -1661,8 +1661,8 @@ def L_op(self, inputs, outputs, output_grads): return (condition_grad, first_part, second_part) - def output_types(self, types): - (_cond_t, ift_t, iff_t) = types + def output_types(self, input_types): + (_cond_t, ift_t, iff_t) = input_types return upcast_out(ift_t, iff_t) @@ -2018,11 +2018,11 @@ def L_op(self, inputs, outputs, output_grads): class TrueDiv(BinaryScalarOp): nfunc_spec = ("true_divide", 2, 1) - def output_types(self, types): - if all(t in discrete_types for t in types): + def output_types(self, input_types): + if all(t in discrete_types for t in input_types): return [get_scalar_type(config.floatX)] else: - return super().output_types(types) + return super().output_types(input_types) def impl(self, x, y): x = np.asarray(x) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 9f1183da1c..1007396ec2 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1973,12 +1973,12 @@ def R_op(self, inputs, eval_points): class MakeVectorPrinter(Printer): - def process(self, r, pstate): - if r.owner is None: + def process(self, var, pstate): + if var.owner is None: raise TypeError("Can only print make_vector.") - elif isinstance(r.owner.op, MakeVector): + elif isinstance(var.owner.op, MakeVector): with set_precedence(pstate): - s = [pstate.pprinter.process(inp) for inp in r.owner.inputs] + s = [pstate.pprinter.process(inp) for inp in var.owner.inputs] return f"[{', '.join(s)}]" else: raise TypeError("Can only print make_vector.") diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 853f78efa5..6d49c503d0 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -285,12 +285,12 @@ def __p(self, new_order, pstate, r): return f"{pstate.pprinter.process(r)}.T" return f"DimShuffle{{{', '.join(str(o) for o in new_order)}}}({pstate.pprinter.process(r)})" - def process(self, r, pstate): - if r.owner is None: + def process(self, var, pstate): + if var.owner is None: raise TypeError("Can only print DimShuffle.") - elif isinstance(r.owner.op, DimShuffle): - ord = r.owner.op.new_order - return self.__p(ord, pstate, r.owner.inputs[0]) + elif isinstance(var.owner.op, DimShuffle): + ord = var.owner.op.new_order + return self.__p(ord, pstate, var.owner.inputs[0]) else: raise TypeError("Can only print DimShuffle.") @@ -1094,8 +1094,8 @@ def c_header_dirs(self, **kwargs): def c_support_code(self, **kwargs): return self.scalar_op.c_support_code(**kwargs) - def c_support_code_apply(self, node, nodename): - support_code = self.scalar_op.c_support_code_apply(node, nodename + "_scalar_") + def c_support_code_apply(self, node, name): + support_code = self.scalar_op.c_support_code_apply(node, name + "_scalar_") return support_code def c_code_cache_version_apply(self, node): diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 4d4086502c..6f22724460 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -404,8 +404,15 @@ class Max(NonZeroDimsCAReduce): def __init__(self, axis): super().__init__(ps.maximum, axis) - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) + def clone( + self, + axis=None, + dtype=None, + acc_dtype=None, + upcast_discrete_output=None, + **kwargs, + ): + axis = axis or self.axis return type(self)(axis=axis) def L_op(self, inputs, outputs, output_grads): @@ -462,8 +469,15 @@ class Min(NonZeroDimsCAReduce): def __init__(self, axis): super().__init__(ps.minimum, axis) - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) + def clone( + self, + axis=None, + dtype=None, + acc_dtype=None, + upcast_discrete_output=None, + **kwargs, + ): + axis = axis or self.axis return type(self)(axis=axis) @@ -3398,8 +3412,15 @@ def grad(self, inputs, output_grads): (x,) = inputs return [x.zeros_like(config.floatX)] - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) + def clone( + self, + axis=None, + dtype=None, + acc_dtype=None, + upcast_discrete_output=None, + **kwargs, + ): + axis = axis or self.axis return type(self)(axis=axis) @@ -3428,8 +3449,15 @@ def grad(self, inputs, output_grads): (x,) = inputs return [x.zeros_like(config.floatX)] - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) + def clone( + self, + axis=None, + dtype=None, + acc_dtype=None, + upcast_discrete_output=None, + **kwargs, + ): + axis = axis or self.axis return type(self)(axis=axis) @@ -3485,10 +3513,17 @@ def R_op(self, inputs, eval_points): return [None] return self(*eval_points, return_list=True) - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) - dtype = kwargs.get("dtype", self.dtype) - acc_dtype = kwargs.get("acc_dtype", self.acc_dtype) + def clone( + self, + axis=None, + dtype=None, + acc_dtype=None, + upcast_discrete_output=None, + **kwargs, + ): + axis = axis or self.axis + dtype = dtype or self.dtype + acc_dtype = acc_dtype or self.acc_dtype return type(self)(axis=axis, dtype=dtype, acc_dtype=acc_dtype) @@ -3666,10 +3701,17 @@ def L_op(self, inputs, outputs, output_grads): def c_code_cache_version(self): return (1,) - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) - dtype = kwargs.get("dtype", self.dtype) - acc_dtype = kwargs.get("acc_dtype", self.acc_dtype) + def clone( + self, + axis=None, + dtype=None, + acc_dtype=None, + upcast_discrete_output=None, + **kwargs, + ): + axis = axis or self.axis + dtype = dtype or self.dtype + acc_dtype = acc_dtype or self.acc_dtype no_zeros_in_input = kwargs.get("no_zeros_in_input", self.no_zeros_in_input) return type(self)( axis=axis, @@ -3775,10 +3817,17 @@ def grad(self, inputs, output_grads): ) return [a_grad] - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) - dtype = kwargs.get("dtype", self.dtype) - acc_dtype = kwargs.get("acc_dtype", self.acc_dtype) + def clone( + self, + axis=None, + dtype=None, + acc_dtype=None, + upcast_discrete_output=None, + **kwargs, + ): + axis = axis or self.axis + dtype = dtype or self.dtype + acc_dtype = acc_dtype or self.acc_dtype return type(self)(axis=axis, dtype=dtype, acc_dtype=acc_dtype) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index e9c5651c8a..9cbf296c81 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1320,7 +1320,7 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): return ct + num, denum - def transform(self, fgraph, node, enforce_tracks=True): + def transform(self, fgraph, node, enforce_tracks=True, *args, **kwargs): op = node.op if enforce_tracks and (op not in {self.main, self.inverse, self.reciprocal}): return False diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index af953c79fd..a392fb20f2 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -588,17 +588,17 @@ def on_import(self, fgraph, node, reason): for r, s in zip(node.outputs, o_shapes, strict=True): self.set_shape(r, s) - def on_change_input(self, fgraph, node, i, r, new_r, reason): - if new_r not in self.shape_of: + def on_change_input(self, fgraph, node, i, var, new_var, reason=None): + if new_var not in self.shape_of: # It happen that the fgraph didn't called on_import for some # new_r. This happen when new_r don't have an # owner(i.e. it is a constant or an input of the graph) # update_shape suppose that r and new_r are in shape_of. - self.init_r(new_r) + self.init_r(new_var) # This tells us that r and new_r must have the same shape if # we didn't know that the shapes are related, now we do. - self.update_shape(new_r, r) + self.update_shape(new_var, var) # change_input happens in two cases: # 1) we are trying to get rid of r, or @@ -608,10 +608,10 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason): # replace the shape_i of r with the shape of new_r. Say that # r is *scheduled*. # At that point, node is no longer a client of r, but of new_r - for shpnode, idx in fgraph.clients[r] + [(node, i)]: + for shpnode, idx in fgraph.clients[var] + [(node, i)]: if isinstance(shpnode.op, Shape_i): idx = shpnode.op.i - repl = self.shape_of[new_r][idx] + repl = self.shape_of[new_var][idx] if repl.owner is shpnode: # This mean the replacement shape object is # exactly the same as the current shape object. So @@ -631,20 +631,20 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason): if shpnode.outputs[0] in ancestors([repl]): raise InconsistencyError( "This substitution would insert a cycle in the graph:" - f"node: {node}, i: {i}, r: {r}, new_r: {new_r}" + f"node: {node}, i: {i}, r: {var}, new_r: {new_var}" ) - self.scheduled[shpnode] = new_r + self.scheduled[shpnode] = new_var # In case 2, if r is a variable that we've scheduled for shape update, # then we should cancel it. - unscheduled = [k for k, v in self.scheduled.items() if v == r] + unscheduled = [k for k, v in self.scheduled.items() if v == var] for k in unscheduled: del self.scheduled[k] # In either case, r could be in shape_of.values(), that is, r itself # is the shape of something. In that case, we want to update # the value in shape_of, to keep it up-to-date. - for v in self.shape_of_reverse_index.get(r, []): + for v in self.shape_of_reverse_index.get(var, []): # The reverse index is only approximate. It is not updated on # deletion of variables, or on change_input so it might be the # case that there are a few extra `v`'s in it that no longer have @@ -652,9 +652,9 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason): # entirely. The important thing is that it permits to recall # all variables with r in their shape. for ii, svi in enumerate(self.shape_of.get(v, [])): - if svi == r: - self.set_shape_i(v, ii, new_r) - self.shape_of_reverse_index[r] = set() + if svi == var: + self.set_shape_i(v, ii, new_var) + self.shape_of_reverse_index[var] = set() def same_shape( self, diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index f825926f11..f212f2d0b8 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -1378,8 +1378,8 @@ def R_op(self, inputs, eval_points): class SubtensorPrinter(Printer): - def process(self, r, pstate): - return self._process(r.owner.op.idx_list, r.owner.inputs, pstate) + def process(self, var, pstate): + return self._process(var.owner.op.idx_list, var.owner.inputs, pstate) def _process(self, idxs, op_inputs, pstate): inputs = list(op_inputs) @@ -2027,15 +2027,15 @@ def grad(self, inputs, output_grads): class IncSubtensorPrinter(SubtensorPrinter): - def process(self, r, pstate): - x, _y, *idx_args = r.owner.inputs + def process(self, var, pstate): + x, _y, *idx_args = var.owner.inputs - res = self._process(r.owner.op.idx_list, [x, *idx_args], pstate) + res = self._process(var.owner.op.idx_list, [x, *idx_args], pstate) with set_precedence(pstate, 1000): - y_str = pstate.pprinter.process(r.owner.inputs[1], pstate) + y_str = pstate.pprinter.process(var.owner.inputs[1], pstate) - if r.owner.op.set_instead_of_inc: + if var.owner.op.set_instead_of_inc: res = f"set_subtensor({res}, {y_str})" else: res = f"inc_subtensor({res}, {y_str})" diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index aef4ad7a18..883259785c 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -50,7 +50,7 @@ class AssertNoChanges(Feature): """A `Feature` that raises an error when nodes are changed in a graph.""" - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input(self, fgraph, node, i, var, new_var, reason=None): raise AssertionError() diff --git a/tests/graph/test_types.py b/tests/graph/test_types.py index 9e9e8611ab..c244f0af2b 100644 --- a/tests/graph/test_types.py +++ b/tests/graph/test_types.py @@ -22,8 +22,8 @@ def __repr__(self): class MyType2(MyType): - def is_super(self, other): - if self.thingy <= other.thingy: + def is_super(self, otype): + if self.thingy <= otype.thingy: return True diff --git a/tests/tensor/conv/c_conv3d_corr3d_ref.py b/tests/tensor/conv/c_conv3d_corr3d_ref.py index 067be760bd..8e051141cd 100644 --- a/tests/tensor/conv/c_conv3d_corr3d_ref.py +++ b/tests/tensor/conv/c_conv3d_corr3d_ref.py @@ -208,7 +208,7 @@ def c_code_cache_version(self): # raise this whenever modifying any of the support_code_files return (8, self.openmp, blas_header_version()) - def c_support_code_apply(self, node, nodename): + def c_support_code_apply(self, node, name): # REMEMBER TO RAISE c_code_cache_version when changing any of # these files sub = {} diff --git a/tests/tensor/conv/c_conv_corr_ref.py b/tests/tensor/conv/c_conv_corr_ref.py index 4a96251095..ff97ca32b0 100644 --- a/tests/tensor/conv/c_conv_corr_ref.py +++ b/tests/tensor/conv/c_conv_corr_ref.py @@ -223,7 +223,7 @@ def c_code_cache_version(self): # raise this whenever modifying any of the support_code_files return (10, self.openmp, blas_header_version()) - def c_support_code_apply(self, node, nodename): + def c_support_code_apply(self, node, name): # REMEMBER TO RAISE c_code_cache_version when changing any of # these files sub = {} diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 197dd30f36..82cb427a45 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1485,10 +1485,10 @@ def __init__(self, n, *args, **kwargs): def impl(self, x): return x * self.n - def c_support_code_apply(self, node, nodename): + def c_support_code_apply(self, node, name): n = str(self.n) return f""" - float {nodename}_timesn(float x) {{ return x * {n}; }} + float {name}_timesn(float x) {{ return x * {n}; }} """ def c_code(self, node, name, inputs, outputs, sub): diff --git a/tests/test_ifelse.py b/tests/test_ifelse.py index 7beae3af07..8cd61b1781 100644 --- a/tests/test_ifelse.py +++ b/tests/test_ifelse.py @@ -576,7 +576,7 @@ def make_node(self, c1, t1, c2, t2, c3, t3, f3): assert t3.type == f3.type return Apply(self, [c1, t1, c2, t2, c3, t3, f3], [t1.type()]) - def make_thunk(self, node, storage_map, compute_map, no_recycling, impl): + def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): input_computed = [compute_map[v] for v in node.inputs] output_computed = [compute_map[v] for v in node.outputs] input_registers = [storage_map[v] for v in node.inputs] @@ -651,7 +651,7 @@ class NotImplementedOp(Op): def make_node(self, x): return Apply(self, [x], [x.type()]) - def make_thunk(self, node, storage_map, compute_map, no_recycling, impl): + def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): def thunk(): raise NotImplementedOpException()