Skip to content

Commit c17f657

Browse files
committed
Standardize Op.c_code arguments
1 parent b59c24e commit c17f657

File tree

23 files changed

+242
-242
lines changed

23 files changed

+242
-242
lines changed

pytensor/compile/ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def perform(self, node, inputs, output_storage):
5050
def __str__(self):
5151
return f"{self.__class__.__name__}"
5252

53-
def c_code(self, node, nodename, inp, out, sub):
54-
(iname,) = inp
55-
(oname,) = out
53+
def c_code(self, node, name, inputs, outputs, sub):
54+
(iname,) = inputs
55+
(oname,) = outputs
5656
fail = sub["fail"]
5757

5858
itype = node.inputs[0].type.__class__
@@ -192,9 +192,9 @@ def c_code_cache_version(self):
192192
version.append(1)
193193
return tuple(version)
194194

195-
def c_code(self, node, name, inames, onames, sub):
196-
(iname,) = inames
197-
(oname,) = onames
195+
def c_code(self, node, name, inputs, outputs, sub):
196+
(iname,) = inputs
197+
(oname,) = outputs
198198
fail = sub["fail"]
199199

200200
itype = node.inputs[0].type.__class__

pytensor/link/c/op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def c_init_code_struct(self, node, name, sub):
571571
else:
572572
return super().c_init_code_struct(node, name, sub)
573573

574-
def c_code(self, node, name, inp, out, sub):
574+
def c_code(self, node, name, inputs, outputs, sub):
575575
if self.func_name is not None:
576576
assert "code" not in self.code_sections
577577

@@ -587,7 +587,7 @@ def c_code(self, node, name, inp, out, sub):
587587
return f"""
588588
{define_macros}
589589
{{
590-
if ({self.func_name}({self.format_c_function_args(inp, out)}{params}) != 0) {{
590+
if ({self.func_name}({self.format_c_function_args(inputs, outputs)}{params}) != 0) {{
591591
{sub["fail"]}
592592
}}
593593
}}
@@ -599,7 +599,7 @@ def c_code(self, node, name, inp, out, sub):
599599

600600
def_macros, undef_macros = self.get_c_macros(node, name)
601601
def_sub, undef_sub = get_sub_macros(sub)
602-
def_io, undef_io = get_io_macros(inp, out)
602+
def_io, undef_io = get_io_macros(inputs, outputs)
603603

604604
return (
605605
f"{def_macros}\n{def_sub}\n{def_io}\n{op_code}"

pytensor/raise_op.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,15 @@ def grad(self, inputs, output_grads):
9494
def connection_pattern(self, node):
9595
return [[1]] + [[0]] * (len(node.inputs) - 1)
9696

97-
def c_code(self, node, name, inames, onames, props):
97+
def c_code(self, node, name, inputs, outputs, sub):
9898
if not isinstance(node.inputs[0].type, DenseTensorType | ScalarType):
9999
raise NotImplementedError(
100100
f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}"
101101
)
102-
value_name, *cond_names = inames
103-
out_name = onames[0]
104-
fail_code = props["fail"]
105-
param_struct_name = props["params"]
102+
value_name, *cond_names = inputs
103+
out_name = outputs[0]
104+
fail_code = sub["fail"]
105+
param_struct_name = sub["params"]
106106
msg = self.msg.replace('"', '\\"').replace("\n", "\\n")
107107

108108
all_conds = " && ".join(cond_names)

pytensor/scalar/basic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,7 +1294,7 @@ def __str__(self):
12941294
def c_code_cache_version(self):
12951295
return (4,)
12961296

1297-
def c_code_contiguous(self, node, name, inp, out, sub):
1297+
def c_code_contiguous(self, node, name, inputs, outputs, sub):
12981298
"""
12991299
This function is called by Elemwise when all inputs and outputs are
13001300
c_contiguous. This allows to use the SIMD version of this op.
@@ -4406,15 +4406,15 @@ def c_code_template(self):
44064406

44074407
return self._c_code
44084408

4409-
def c_code(self, node, nodename, inames, onames, sub):
4409+
def c_code(self, node, name, inputs, outputs, sub):
44104410
d = dict(
44114411
chain(
4412-
zip((f"i{i}" for i in range(len(inames))), inames, strict=True),
4413-
zip((f"o{i}" for i in range(len(onames))), onames, strict=True),
4412+
zip((f"i{i}" for i in range(len(inputs))), inputs, strict=True),
4413+
zip((f"o{i}" for i in range(len(outputs))), outputs, strict=True),
44144414
),
44154415
**sub,
44164416
)
4417-
d["nodename"] = nodename
4417+
d["nodename"] = name
44184418
if "id" not in sub:
44194419
# The use of a dummy id is safe as the code is in a separate block.
44204420
# It won't generate conflicting variable name.

pytensor/scalar/loop.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,23 +312,23 @@ def c_code_template(self):
312312

313313
return self._c_code
314314

315-
def c_code(self, node, nodename, inames, onames, sub):
315+
def c_code(self, node, name, inputs, outputs, sub):
316316
d = dict(
317317
chain(
318-
zip((f"i{i}" for i in range(len(inames))), inames, strict=True),
319-
zip((f"o{i}" for i in range(len(onames))), onames, strict=True),
318+
zip((f"i{i}" for i in range(len(inputs))), inputs, strict=True),
319+
zip((f"o{i}" for i in range(len(outputs))), outputs, strict=True),
320320
),
321321
**sub,
322322
)
323-
d["nodename"] = nodename
323+
d["nodename"] = name
324324
if "id" not in sub:
325325
# The use of a dummy id is safe as the code is in a separate block.
326326
# It won't generate conflicting variable name.
327327
d["id"] = "_DUMMY_ID_"
328328

329329
# When called inside Elemwise we don't have access to the dtype
330330
# via the usual `f"dtype_{inames[i]}"` variable
331-
d["n_steps"] = inames[0]
331+
d["n_steps"] = inputs[0]
332332
d["n_steps_dtype"] = "npy_" + node.inputs[0].dtype
333333

334334
res = self.c_code_template % d

pytensor/scalar/math.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def L_op(self, inputs, outputs, output_grads):
7070
)
7171
return (gz * cst * exp(-x * x),)
7272

73-
def c_code(self, node, name, inp, out, sub):
74-
(x,) = inp
75-
(z,) = out
73+
def c_code(self, node, name, inputs, outputs, sub):
74+
(x,) = inputs
75+
(z,) = outputs
7676
if node.inputs[0].type in complex_types:
7777
raise NotImplementedError("type not supported", type)
7878
cast = node.outputs[0].type.dtype_specs()[1]
@@ -104,9 +104,9 @@ def L_op(self, inputs, outputs, output_grads):
104104
)
105105
return (-gz * cst * exp(-x * x),)
106106

107-
def c_code(self, node, name, inp, out, sub):
108-
(x,) = inp
109-
(z,) = out
107+
def c_code(self, node, name, inputs, outputs, sub):
108+
(x,) = inputs
109+
(z,) = outputs
110110
if node.inputs[0].type in complex_types:
111111
raise NotImplementedError("type not supported", type)
112112
cast = node.outputs[0].type.dtype_specs()[1]
@@ -162,9 +162,9 @@ def c_support_code(self, **kwargs):
162162
# Using Faddeeva.cc source file from: http://ab-initio.mit.edu/wiki/index.php/Faddeeva_Package
163163
return (C_CODE_PATH / "Faddeeva.cc").read_text(encoding="utf-8")
164164

165-
def c_code(self, node, name, inp, out, sub):
166-
(x,) = inp
167-
(z,) = out
165+
def c_code(self, node, name, inputs, outputs, sub):
166+
(x,) = inputs
167+
(z,) = outputs
168168

169169
if node.inputs[0].type in float_types:
170170
dtype = "npy_" + node.outputs[0].dtype
@@ -209,7 +209,7 @@ def L_op(self, inputs, outputs, output_grads):
209209
)
210210
return (gz * cst * exp(erfinv(x) ** 2),)
211211

212-
def c_code(self, node, name, inp, out, sub):
212+
def c_code(self, node, name, inputs, outputs, sub):
213213
# TODO: erfinv() is not provided by the C standard library
214214
# x, = inp
215215
# z, = out
@@ -244,7 +244,7 @@ def L_op(self, inputs, outputs, output_grads):
244244
)
245245
return (-gz * cst * exp(erfcinv(x) ** 2),)
246246

247-
def c_code(self, node, name, inp, out, sub):
247+
def c_code(self, node, name, inputs, outputs, sub):
248248
# TODO: erfcinv() is not provided by the C standard library
249249
# x, = inp
250250
# z, = out
@@ -336,9 +336,9 @@ def L_op(self, inputs, outputs, output_grads):
336336

337337
return [gz * psi(x)]
338338

339-
def c_code(self, node, name, inp, out, sub):
340-
(x,) = inp
341-
(z,) = out
339+
def c_code(self, node, name, inputs, outputs, sub):
340+
(x,) = inputs
341+
(z,) = outputs
342342
# no c code for complex
343343
# [u]int* will be casted to float64 before computation
344344
if node.inputs[0].type in complex_types:
@@ -439,9 +439,9 @@ def c_support_code(self, **kwargs):
439439
#endif
440440
"""
441441

442-
def c_code(self, node, name, inp, out, sub):
443-
(x,) = inp
444-
(z,) = out
442+
def c_code(self, node, name, inputs, outputs, sub):
443+
(x,) = inputs
444+
(z,) = outputs
445445
if node.inputs[0].type in float_types:
446446
dtype = "npy_" + node.outputs[0].dtype
447447
return f"{z} = ({dtype}) _psi({x});"
@@ -523,9 +523,9 @@ def c_support_code(self, **kwargs):
523523
#endif
524524
"""
525525

526-
def c_code(self, node, name, inp, out, sub):
527-
(x,) = inp
528-
(z,) = out
526+
def c_code(self, node, name, inputs, outputs, sub):
527+
(x,) = inputs
528+
(z,) = outputs
529529
if node.inputs[0].type in float_types:
530530
return f"""{z} =
531531
_tri_gamma({x});"""
@@ -597,9 +597,9 @@ def grad(self, inputs, output_grads):
597597
def c_support_code(self, **kwargs):
598598
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
599599

600-
def c_code(self, node, name, inp, out, sub):
601-
k, x = inp
602-
(z,) = out
600+
def c_code(self, node, name, inputs, outputs, sub):
601+
k, x = inputs
602+
(z,) = outputs
603603
if node.inputs[0].type in float_types:
604604
dtype = "npy_" + node.outputs[0].dtype
605605
return f"""{z} =
@@ -644,9 +644,9 @@ def grad(self, inputs, output_grads):
644644
def c_support_code(self, **kwargs):
645645
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
646646

647-
def c_code(self, node, name, inp, out, sub):
648-
k, x = inp
649-
(z,) = out
647+
def c_code(self, node, name, inputs, outputs, sub):
648+
k, x = inputs
649+
(z,) = outputs
650650
if node.inputs[0].type in float_types:
651651
dtype = "npy_" + node.outputs[0].dtype
652652
return f"""{z} =
@@ -943,9 +943,9 @@ def impl(self, k, x):
943943
def c_support_code(self, **kwargs):
944944
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
945945

946-
def c_code(self, node, name, inp, out, sub):
947-
k, x = inp
948-
(z,) = out
946+
def c_code(self, node, name, inputs, outputs, sub):
947+
k, x = inputs
948+
(z,) = outputs
949949
if node.inputs[0].type in float_types:
950950
dtype = "npy_" + node.outputs[0].dtype
951951
return f"""{z} =
@@ -975,9 +975,9 @@ def impl(self, k, x):
975975
def c_support_code(self, **kwargs):
976976
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
977977

978-
def c_code(self, node, name, inp, out, sub):
979-
k, x = inp
980-
(z,) = out
978+
def c_code(self, node, name, inputs, outputs, sub):
979+
k, x = inputs
980+
(z,) = outputs
981981
if node.inputs[0].type in float_types:
982982
dtype = "npy_" + node.outputs[0].dtype
983983
return f"""{z} =
@@ -1034,9 +1034,9 @@ def grad(self, inputs, output_grads):
10341034
(gz,) = output_grads
10351035
return [gz * (j0(x) - jv(2, x)) / 2.0]
10361036

1037-
def c_code(self, node, name, inp, out, sub):
1038-
(x,) = inp
1039-
(z,) = out
1037+
def c_code(self, node, name, inputs, outputs, sub):
1038+
(x,) = inputs
1039+
(z,) = outputs
10401040
if node.inputs[0].type in float_types:
10411041
return f"""{z} =
10421042
j1({x});"""
@@ -1061,9 +1061,9 @@ def grad(self, inputs, output_grads):
10611061
(gz,) = output_grads
10621062
return [gz * -1 * j1(x)]
10631063

1064-
def c_code(self, node, name, inp, out, sub):
1065-
(x,) = inp
1066-
(z,) = out
1064+
def c_code(self, node, name, inputs, outputs, sub):
1065+
(x,) = inputs
1066+
(z,) = outputs
10671067
if node.inputs[0].type in float_types:
10681068
return f"""{z} =
10691069
j0({x});"""
@@ -1217,9 +1217,9 @@ def grad(self, inputs, output_grads):
12171217

12181218
return [rval]
12191219

1220-
def c_code(self, node, name, inp, out, sub):
1221-
(x,) = inp
1222-
(z,) = out
1220+
def c_code(self, node, name, inputs, outputs, sub):
1221+
(x,) = inputs
1222+
(z,) = outputs
12231223

12241224
if node.inputs[0].type in float_types:
12251225
if node.inputs[0].type == float64:
@@ -1280,9 +1280,9 @@ def grad(self, inputs, output_grads):
12801280
(gz,) = output_grads
12811281
return [gz * sigmoid(x)]
12821282

1283-
def c_code(self, node, name, inp, out, sub):
1284-
(x,) = inp
1285-
(z,) = out
1283+
def c_code(self, node, name, inputs, outputs, sub):
1284+
(x,) = inputs
1285+
(z,) = outputs
12861286
# We use the same limits for all precisions, which may be suboptimal. The reference
12871287
# paper only looked at double precision
12881288
if node.inputs[0].type in float_types:
@@ -1351,9 +1351,9 @@ def grad(self, inputs, output_grads):
13511351
res = switch(isinf(res), -np.inf, res)
13521352
return [gz * res]
13531353

1354-
def c_code(self, node, name, inp, out, sub):
1355-
(x,) = inp
1356-
(z,) = out
1354+
def c_code(self, node, name, inputs, outputs, sub):
1355+
(x,) = inputs
1356+
(z,) = outputs
13571357

13581358
if node.inputs[0].type in float_types:
13591359
if node.inputs[0].type == float64:
@@ -1396,9 +1396,9 @@ def grad(self, inputs, output_grads):
13961396
def c_support_code(self, **kwargs):
13971397
return (C_CODE_PATH / "incbet.c").read_text(encoding="utf-8")
13981398

1399-
def c_code(self, node, name, inp, out, sub):
1400-
(a, b, x) = inp
1401-
(z,) = out
1399+
def c_code(self, node, name, inputs, outputs, sub):
1400+
(a, b, x) = inputs
1401+
(z,) = outputs
14021402
if (
14031403
node.inputs[0].type in float_types
14041404
and node.inputs[1].type in float_types

0 commit comments

Comments
 (0)