Skip to content

Commit b59c24e

Browse files
committed
Standardize Op.grad arguments
1 parent 666da8d commit b59c24e

32 files changed

+327
-325
lines changed

pytensor/breakpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ def perform(self, node, inputs, output_storage):
143143
for i in range(len(output_storage)):
144144
output_storage[i][0] = inputs[i + 1]
145145

146-
def grad(self, inputs, output_gradients):
147-
return [DisconnectedType()(), *output_gradients]
146+
def grad(self, inputs, output_grads):
147+
return [DisconnectedType()(), *output_grads]
148148

149149
def infer_shape(self, fgraph, inputs, input_shapes):
150150
# Return the shape of every input but the condition (first input)

pytensor/compile/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def make_node(self, x):
9292
def infer_shape(self, fgraph, node, input_shapes):
9393
return input_shapes
9494

95-
def grad(self, args, g_outs):
96-
return g_outs
95+
def grad(self, inputs, output_grads):
96+
return output_grads
9797

9898

9999
view_op = ViewOp()

pytensor/gradient.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2316,8 +2316,8 @@ def _is_zero(x):
23162316

23172317

23182318
class ZeroGrad(ViewOp):
2319-
def grad(self, args, g_outs):
2320-
return [g_out.zeros_like() for g_out in g_outs]
2319+
def grad(self, inputs, output_grads):
2320+
return [g_out.zeros_like() for g_out in output_grads]
23212321

23222322
def R_op(self, inputs, eval_points):
23232323
if eval_points[0] is None:
@@ -2354,8 +2354,8 @@ def zero_grad(x):
23542354

23552355

23562356
class UndefinedGrad(ViewOp):
2357-
def grad(self, args, g_outs):
2358-
return [grad_undefined(self, i, arg) for i, arg in enumerate(args)]
2357+
def grad(self, inputs, output_grads):
2358+
return [grad_undefined(self, i, arg) for i, arg in enumerate(inputs)]
23592359

23602360
def R_op(self, inputs, eval_points):
23612361
return [None]
@@ -2392,8 +2392,8 @@ def undefined_grad(x):
23922392

23932393

23942394
class DisconnectedGrad(ViewOp):
2395-
def grad(self, args, g_outs):
2396-
return [disconnected_type() for g_out in g_outs]
2395+
def grad(self, inputs, output_grads):
2396+
return [disconnected_type() for g_out in output_grads]
23972397

23982398
def R_op(self, inputs, eval_points):
23992399
return [None]
@@ -2447,10 +2447,10 @@ def __init__(self, clip_lower_bound, clip_upper_bound):
24472447
if not self.clip_upper_bound >= self.clip_lower_bound:
24482448
raise ValueError("`clip_upper_bound` should be >= `clip_lower_bound`")
24492449

2450-
def grad(self, args, g_outs):
2450+
def grad(self, inputs, output_grads):
24512451
return [
24522452
pytensor.tensor.clip(g_out, self.clip_lower_bound, self.clip_upper_bound)
2453-
for g_out in g_outs
2453+
for g_out in output_grads
24542454
]
24552455

24562456

@@ -2490,8 +2490,8 @@ class GradScale(ViewOp):
24902490
def __init__(self, multiplier):
24912491
self.multiplier = multiplier
24922492

2493-
def grad(self, args, g_outs):
2494-
return [self.multiplier * g_out for g_out in g_outs]
2493+
def grad(self, inputs, output_grads):
2494+
return [self.multiplier * g_out for g_out in output_grads]
24952495

24962496

24972497
def grad_scale(x, multiplier):

pytensor/ifelse.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,10 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any):
239239
def R_op(self, inputs, eval_points):
240240
return self(inputs[0], *eval_points[1:], return_list=True)
241241

242-
def grad(self, ins, grads):
243-
condition = ins[0]
244-
inputs_true_branch = ins[1:][: self.n_outs]
245-
inputs_false_branch = ins[1:][self.n_outs :]
242+
def grad(self, inputs, output_grads):
243+
condition = inputs[0]
244+
inputs_true_branch = inputs[1:][: self.n_outs]
245+
inputs_false_branch = inputs[1:][self.n_outs :]
246246

247247
if self.name is not None:
248248
nw_name_t = self.name + "_grad_t"
@@ -260,19 +260,19 @@ def grad(self, ins, grads):
260260
# dtypes.
261261
inputs_true_grad = (
262262
[condition]
263-
+ grads
263+
+ output_grads
264264
+ [
265-
pt.basic.zeros_like(t, dtype=grads[i].dtype)
265+
pt.basic.zeros_like(t, dtype=output_grads[i].dtype)
266266
for i, t in enumerate(inputs_true_branch)
267267
]
268268
)
269269
inputs_false_grad = (
270270
[condition]
271271
+ [
272-
pt.basic.zeros_like(f, dtype=grads[i].dtype)
272+
pt.basic.zeros_like(f, dtype=output_grads[i].dtype)
273273
for i, f in enumerate(inputs_false_branch)
274274
]
275-
+ grads
275+
+ output_grads
276276
)
277277

278278
# `condition` does affect the elements of the output so it is connected.

pytensor/link/jax/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,14 @@ def perform_jax(self, *inputs):
135135
return outputs[0]
136136
return outputs
137137

138-
def grad(self, inputs, output_gradients):
138+
def grad(self, inputs, output_grads):
139139
"""Compute gradients using JAX's vector-Jacobian product (VJP)."""
140140
import jax
141141

142142
# Find indices of outputs that need gradients
143143
connected_output_indices = [
144144
i
145-
for i, output_grad in enumerate(output_gradients)
145+
for i, output_grad in enumerate(output_grads)
146146
if not isinstance(output_grad.type, DisconnectedType)
147147
]
148148

@@ -190,7 +190,7 @@ def restricted_function(*input_values):
190190
)
191191

192192
return vjp_op(
193-
*[*inputs, *[output_gradients[i] for i in connected_output_indices]],
193+
*[*inputs, *[output_grads[i] for i in connected_output_indices]],
194194
return_list=True,
195195
)
196196

pytensor/printing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -804,8 +804,8 @@ def perform(self, node, inputs, output_storage):
804804
xout[0] = xin
805805
self.global_fn(self, xin)
806806

807-
def grad(self, input, output_gradients):
808-
return output_gradients
807+
def grad(self, inputs, output_grads):
808+
return output_grads
809809

810810
def R_op(self, inputs, eval_points):
811811
return list(eval_points)

pytensor/raise_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def perform(self, node, inputs, output_storage):
8888
if not all(conds):
8989
raise self.exc_type(self.msg)
9090

91-
def grad(self, input, output_gradients):
92-
return output_gradients + [DisconnectedType()()] * (len(input) - 1)
91+
def grad(self, inputs, output_grads):
92+
return output_grads + [DisconnectedType()()] * (len(inputs) - 1)
9393

9494
def connection_pattern(self, node):
9595
return [[1]] + [[0]] * (len(node.inputs) - 1)

pytensor/scalar/basic.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,7 +1272,7 @@ def perform(self, node, inputs, output_storage):
12721272
def impl(self, *inputs):
12731273
raise MethodNotDefined("impl", type(self), self.__class__.__name__)
12741274

1275-
def grad(self, inputs, output_gradients):
1275+
def grad(self, inputs, output_grads):
12761276
raise MethodNotDefined("grad", type(self), self.__class__.__name__)
12771277

12781278
def L_op(self, inputs, outputs, output_grads):
@@ -1683,7 +1683,7 @@ def output_types(self, *input_types):
16831683
)
16841684
return upcast_out(*input_types[0])
16851685

1686-
def grad(self, inputs, output_gradients):
1686+
def grad(self, inputs, output_grads):
16871687
return [inputs[0].zeros_like(dtype=config.floatX)]
16881688

16891689

@@ -1701,7 +1701,7 @@ def output_types(self, *input_types):
17011701
)
17021702
return upcast_out(*input_types[0])
17031703

1704-
def grad(self, inputs, output_gradients):
1704+
def grad(self, inputs, output_grads):
17051705
a, b = inputs
17061706
return [
17071707
a.zeros_like(dtype=config.floatX),
@@ -1942,8 +1942,8 @@ def c_code(self, node, name, inputs, outputs, sub):
19421942
else:
19431943
return z + " = " + op.join(inputs) + ";"
19441944

1945-
def grad(self, inputs, gout):
1946-
(gz,) = gout
1945+
def grad(self, inputs, output_grads):
1946+
(gz,) = output_grads
19471947
retval = []
19481948

19491949
# The following 3 lines verify that gz is complex when the
@@ -2045,9 +2045,9 @@ def c_code(self, node, name, inputs, outputs, sub):
20452045
return f"{z} = ((double){x}) / {y};"
20462046
return f"{z} = {x} / {y};"
20472047

2048-
def grad(self, inputs, gout):
2048+
def grad(self, inputs, output_grads):
20492049
(x, y) = inputs
2050-
(gz,) = gout
2050+
(gz,) = output_grads
20512051
if x.type in complex_types:
20522052
raise NotImplementedError()
20532053

@@ -2166,7 +2166,7 @@ def c_code(self, node, name, inputs, outputs, sub):
21662166
def c_code_cache_version(self):
21672167
return (6,)
21682168

2169-
def grad(self, inputs, g_output):
2169+
def grad(self, inputs, output_grads):
21702170
return [inp.zeros_like(dtype=config.floatX) for inp in inputs]
21712171

21722172

@@ -2440,9 +2440,9 @@ def connection_pattern(self, node):
24402440

24412441
return [[False], [True]]
24422442

2443-
def grad(self, inputs, gout):
2443+
def grad(self, inputs, output_grads):
24442444
(_x, y) = inputs
2445-
(gz,) = gout
2445+
(gz,) = output_grads
24462446
if y.type in continuous_types:
24472447
# x is disconnected because the elements of x are not used
24482448
return DisconnectedType()(), gz
@@ -2466,9 +2466,9 @@ def c_code(self, node, name, inputs, outputs, sub):
24662466
(z,) = outputs
24672467
return f"{z} = {x};"
24682468

2469-
def grad(self, inputs, gout):
2469+
def grad(self, inputs, output_grads):
24702470
(x,) = inputs
2471-
(gz,) = gout
2471+
(gz,) = output_grads
24722472
if x.type in continuous_types:
24732473
return (gz,)
24742474
else:
@@ -2505,9 +2505,9 @@ def c_code(self, node, name, inputs, outputs, sub):
25052505
return f"{z} = ({x}) ? 1 : 0;"
25062506
return f"{z} = ({node.outputs[0].type.dtype_specs()[1]}){x};"
25072507

2508-
def grad(self, inputs, gout):
2508+
def grad(self, inputs, output_grads):
25092509
(x,) = inputs
2510-
(gz,) = gout
2510+
(gz,) = output_grads
25112511
if self.o_type in continuous_types:
25122512
return [gz]
25132513
else:
@@ -2636,9 +2636,9 @@ def impl(self, x):
26362636
# casting to output type is handled by filter
26372637
return np.sign(x)
26382638

2639-
def grad(self, inputs, gout):
2639+
def grad(self, inputs, output_grads):
26402640
(x,) = inputs
2641-
(_gz,) = gout
2641+
(_gz,) = output_grads
26422642
rval = x.zeros_like()
26432643

26442644
if rval.type in discrete_types:
@@ -2677,9 +2677,9 @@ class Ceil(UnaryScalarOp):
26772677
def impl(self, x):
26782678
return np.ceil(x)
26792679

2680-
def grad(self, inputs, gout):
2680+
def grad(self, inputs, output_grads):
26812681
(x,) = inputs
2682-
(_gz,) = gout
2682+
(_gz,) = output_grads
26832683
rval = x.zeros_like()
26842684

26852685
if rval.type in discrete_types:
@@ -2703,9 +2703,9 @@ class Floor(UnaryScalarOp):
27032703
def impl(self, x):
27042704
return np.floor(x)
27052705

2706-
def grad(self, inputs, gout):
2706+
def grad(self, inputs, output_grads):
27072707
(x,) = inputs
2708-
(_gz,) = gout
2708+
(_gz,) = output_grads
27092709
rval = x.zeros_like()
27102710

27112711
if rval.type in discrete_types:
@@ -2729,9 +2729,9 @@ class Trunc(UnaryScalarOp):
27292729
def impl(self, x):
27302730
return np.trunc(x)
27312731

2732-
def grad(self, inputs, gout):
2732+
def grad(self, inputs, output_grads):
27332733
(x,) = inputs
2734-
(_gz,) = gout
2734+
(_gz,) = output_grads
27352735
return [x.zeros_like(dtype=config.floatX)]
27362736

27372737
def c_code(self, node, name, inputs, outputs, sub):
@@ -2757,9 +2757,9 @@ class RoundHalfToEven(UnaryScalarOp):
27572757
def impl(self, x):
27582758
return np.round(x)
27592759

2760-
def grad(self, inputs, gout):
2760+
def grad(self, inputs, output_grads):
27612761
(x,) = inputs
2762-
(_gz,) = gout
2762+
(_gz,) = output_grads
27632763
rval = x.zeros_like()
27642764

27652765
if rval.type in discrete_types:
@@ -2843,9 +2843,9 @@ class RoundHalfAwayFromZero(UnaryScalarOp):
28432843
def impl(self, x):
28442844
return round_half_away_from_zero_vec(x)
28452845

2846-
def grad(self, inputs, gout):
2846+
def grad(self, inputs, output_grads):
28472847
(x,) = inputs
2848-
(_gz,) = gout
2848+
(_gz,) = output_grads
28492849
rval = x.zeros_like()
28502850

28512851
if rval.type in discrete_types:
@@ -3865,9 +3865,9 @@ class Real(UnaryScalarOp):
38653865
def impl(self, x):
38663866
return np.real(x)
38673867

3868-
def grad(self, inputs, gout):
3868+
def grad(self, inputs, output_grads):
38693869
(_x,) = inputs
3870-
(gz,) = gout
3870+
(gz,) = output_grads
38713871
return [complex(gz, 0)]
38723872

38733873
def c_code(self, *args, **kwargs):
@@ -3883,9 +3883,9 @@ class Imag(UnaryScalarOp):
38833883
def impl(self, x):
38843884
return np.imag(x)
38853885

3886-
def grad(self, inputs, gout):
3886+
def grad(self, inputs, output_grads):
38873887
(x,) = inputs
3888-
(gz,) = gout
3888+
(gz,) = output_grads
38893889
if x.type in complex_types:
38903890
return [complex(0, gz)]
38913891
elif x.type in float_types:
@@ -3906,7 +3906,7 @@ class Angle(UnaryScalarOp):
39063906
def impl(self, x):
39073907
return np.angle(x)
39083908

3909-
def grad(self, inputs, gout):
3909+
def grad(self, inputs, output_grads):
39103910
# y = x.imag
39113911
# r = sqrt(y**2 + x.real**2)
39123912
# g = y/r
@@ -3918,7 +3918,7 @@ def grad(self, inputs, gout):
39183918
# theta = -numpy.arcsin(g)+numpy.pi
39193919

39203920
(c,) = inputs
3921-
(gtheta,) = gout
3921+
(gtheta,) = output_grads
39223922
x = real(c)
39233923
y = imag(c)
39243924
r = _abs(c)
@@ -3957,9 +3957,9 @@ def output_types_preference(x, y):
39573957
def impl(self, x, y):
39583958
return builtins.complex(x, y)
39593959

3960-
def grad(self, inputs, gout):
3960+
def grad(self, inputs, output_grads):
39613961
(x, y) = inputs
3962-
(gz,) = gout
3962+
(gz,) = output_grads
39633963
return [cast(real(gz), x.type.dtype), cast(imag(gz), y.type.dtype)]
39643964

39653965
def c_code(self, *args, **kwargs):
@@ -4002,9 +4002,9 @@ def impl(self, r, theta):
40024002
else:
40034003
return np.complex128(builtins.complex(x, y))
40044004

4005-
def grad(self, inputs, gout):
4005+
def grad(self, inputs, output_grads):
40064006
(r, theta) = inputs
4007-
(gz,) = gout
4007+
(gz,) = output_grads
40084008
gr = gz * complex_from_polar(1, theta)
40094009
gtheta = gz * complex_from_polar(r, -theta)
40104010
return [gr, gtheta]

0 commit comments

Comments
 (0)