@@ -70,9 +70,9 @@ def L_op(self, inputs, outputs, output_grads):
7070 )
7171 return (gz * cst * exp (- x * x ),)
7272
73- def c_code (self , node , name , inp , out , sub ):
74- (x ,) = inp
75- (z ,) = out
73+ def c_code (self , node , name , inputs , outputs , sub ):
74+ (x ,) = inputs
75+ (z ,) = outputs
7676 if node .inputs [0 ].type in complex_types :
7777 raise NotImplementedError ("type not supported" , type )
7878 cast = node .outputs [0 ].type .dtype_specs ()[1 ]
@@ -104,9 +104,9 @@ def L_op(self, inputs, outputs, output_grads):
104104 )
105105 return (- gz * cst * exp (- x * x ),)
106106
107- def c_code (self , node , name , inp , out , sub ):
108- (x ,) = inp
109- (z ,) = out
107+ def c_code (self , node , name , inputs , outputs , sub ):
108+ (x ,) = inputs
109+ (z ,) = outputs
110110 if node .inputs [0 ].type in complex_types :
111111 raise NotImplementedError ("type not supported" , type )
112112 cast = node .outputs [0 ].type .dtype_specs ()[1 ]
@@ -162,9 +162,9 @@ def c_support_code(self, **kwargs):
162162 # Using Faddeeva.cc source file from: http://ab-initio.mit.edu/wiki/index.php/Faddeeva_Package
163163 return (C_CODE_PATH / "Faddeeva.cc" ).read_text (encoding = "utf-8" )
164164
165- def c_code (self , node , name , inp , out , sub ):
166- (x ,) = inp
167- (z ,) = out
165+ def c_code (self , node , name , inputs , outputs , sub ):
166+ (x ,) = inputs
167+ (z ,) = outputs
168168
169169 if node .inputs [0 ].type in float_types :
170170 dtype = "npy_" + node .outputs [0 ].dtype
@@ -209,7 +209,7 @@ def L_op(self, inputs, outputs, output_grads):
209209 )
210210 return (gz * cst * exp (erfinv (x ) ** 2 ),)
211211
212- def c_code (self , node , name , inp , out , sub ):
212+ def c_code (self , node , name , inputs , outputs , sub ):
213213 # TODO: erfinv() is not provided by the C standard library
214214 # x, = inp
215215 # z, = out
@@ -244,7 +244,7 @@ def L_op(self, inputs, outputs, output_grads):
244244 )
245245 return (- gz * cst * exp (erfcinv (x ) ** 2 ),)
246246
247- def c_code (self , node , name , inp , out , sub ):
247+ def c_code (self , node , name , inputs , outputs , sub ):
248248 # TODO: erfcinv() is not provided by the C standard library
249249 # x, = inp
250250 # z, = out
@@ -336,9 +336,9 @@ def L_op(self, inputs, outputs, output_grads):
336336
337337 return [gz * psi (x )]
338338
339- def c_code (self , node , name , inp , out , sub ):
340- (x ,) = inp
341- (z ,) = out
339+ def c_code (self , node , name , inputs , outputs , sub ):
340+ (x ,) = inputs
341+ (z ,) = outputs
342342 # no c code for complex
343343 # [u]int* will be casted to float64 before computation
344344 if node .inputs [0 ].type in complex_types :
@@ -439,9 +439,9 @@ def c_support_code(self, **kwargs):
439439 #endif
440440 """
441441
442- def c_code (self , node , name , inp , out , sub ):
443- (x ,) = inp
444- (z ,) = out
442+ def c_code (self , node , name , inputs , outputs , sub ):
443+ (x ,) = inputs
444+ (z ,) = outputs
445445 if node .inputs [0 ].type in float_types :
446446 dtype = "npy_" + node .outputs [0 ].dtype
447447 return f"{ z } = ({ dtype } ) _psi({ x } );"
@@ -523,9 +523,9 @@ def c_support_code(self, **kwargs):
523523 #endif
524524 """
525525
526- def c_code (self , node , name , inp , out , sub ):
527- (x ,) = inp
528- (z ,) = out
526+ def c_code (self , node , name , inputs , outputs , sub ):
527+ (x ,) = inputs
528+ (z ,) = outputs
529529 if node .inputs [0 ].type in float_types :
530530 return f"""{ z } =
531531 _tri_gamma({ x } );"""
@@ -597,9 +597,9 @@ def grad(self, inputs, output_grads):
597597 def c_support_code (self , ** kwargs ):
598598 return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
599599
600- def c_code (self , node , name , inp , out , sub ):
601- k , x = inp
602- (z ,) = out
600+ def c_code (self , node , name , inputs , outputs , sub ):
601+ k , x = inputs
602+ (z ,) = outputs
603603 if node .inputs [0 ].type in float_types :
604604 dtype = "npy_" + node .outputs [0 ].dtype
605605 return f"""{ z } =
@@ -644,9 +644,9 @@ def grad(self, inputs, output_grads):
644644 def c_support_code (self , ** kwargs ):
645645 return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
646646
647- def c_code (self , node , name , inp , out , sub ):
648- k , x = inp
649- (z ,) = out
647+ def c_code (self , node , name , inputs , outputs , sub ):
648+ k , x = inputs
649+ (z ,) = outputs
650650 if node .inputs [0 ].type in float_types :
651651 dtype = "npy_" + node .outputs [0 ].dtype
652652 return f"""{ z } =
@@ -943,9 +943,9 @@ def impl(self, k, x):
943943 def c_support_code (self , ** kwargs ):
944944 return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
945945
946- def c_code (self , node , name , inp , out , sub ):
947- k , x = inp
948- (z ,) = out
946+ def c_code (self , node , name , inputs , outputs , sub ):
947+ k , x = inputs
948+ (z ,) = outputs
949949 if node .inputs [0 ].type in float_types :
950950 dtype = "npy_" + node .outputs [0 ].dtype
951951 return f"""{ z } =
@@ -975,9 +975,9 @@ def impl(self, k, x):
975975 def c_support_code (self , ** kwargs ):
976976 return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
977977
978- def c_code (self , node , name , inp , out , sub ):
979- k , x = inp
980- (z ,) = out
978+ def c_code (self , node , name , inputs , outputs , sub ):
979+ k , x = inputs
980+ (z ,) = outputs
981981 if node .inputs [0 ].type in float_types :
982982 dtype = "npy_" + node .outputs [0 ].dtype
983983 return f"""{ z } =
@@ -1034,9 +1034,9 @@ def grad(self, inputs, output_grads):
10341034 (gz ,) = output_grads
10351035 return [gz * (j0 (x ) - jv (2 , x )) / 2.0 ]
10361036
1037- def c_code (self , node , name , inp , out , sub ):
1038- (x ,) = inp
1039- (z ,) = out
1037+ def c_code (self , node , name , inputs , outputs , sub ):
1038+ (x ,) = inputs
1039+ (z ,) = outputs
10401040 if node .inputs [0 ].type in float_types :
10411041 return f"""{ z } =
10421042 j1({ x } );"""
@@ -1061,9 +1061,9 @@ def grad(self, inputs, output_grads):
10611061 (gz ,) = output_grads
10621062 return [gz * - 1 * j1 (x )]
10631063
1064- def c_code (self , node , name , inp , out , sub ):
1065- (x ,) = inp
1066- (z ,) = out
1064+ def c_code (self , node , name , inputs , outputs , sub ):
1065+ (x ,) = inputs
1066+ (z ,) = outputs
10671067 if node .inputs [0 ].type in float_types :
10681068 return f"""{ z } =
10691069 j0({ x } );"""
@@ -1217,9 +1217,9 @@ def grad(self, inputs, output_grads):
12171217
12181218 return [rval ]
12191219
1220- def c_code (self , node , name , inp , out , sub ):
1221- (x ,) = inp
1222- (z ,) = out
1220+ def c_code (self , node , name , inputs , outputs , sub ):
1221+ (x ,) = inputs
1222+ (z ,) = outputs
12231223
12241224 if node .inputs [0 ].type in float_types :
12251225 if node .inputs [0 ].type == float64 :
@@ -1280,9 +1280,9 @@ def grad(self, inputs, output_grads):
12801280 (gz ,) = output_grads
12811281 return [gz * sigmoid (x )]
12821282
1283- def c_code (self , node , name , inp , out , sub ):
1284- (x ,) = inp
1285- (z ,) = out
1283+ def c_code (self , node , name , inputs , outputs , sub ):
1284+ (x ,) = inputs
1285+ (z ,) = outputs
12861286 # We use the same limits for all precisions, which may be suboptimal. The reference
12871287 # paper only looked at double precision
12881288 if node .inputs [0 ].type in float_types :
@@ -1351,9 +1351,9 @@ def grad(self, inputs, output_grads):
13511351 res = switch (isinf (res ), - np .inf , res )
13521352 return [gz * res ]
13531353
1354- def c_code (self , node , name , inp , out , sub ):
1355- (x ,) = inp
1356- (z ,) = out
1354+ def c_code (self , node , name , inputs , outputs , sub ):
1355+ (x ,) = inputs
1356+ (z ,) = outputs
13571357
13581358 if node .inputs [0 ].type in float_types :
13591359 if node .inputs [0 ].type == float64 :
@@ -1396,9 +1396,9 @@ def grad(self, inputs, output_grads):
13961396 def c_support_code (self , ** kwargs ):
13971397 return (C_CODE_PATH / "incbet.c" ).read_text (encoding = "utf-8" )
13981398
1399- def c_code (self , node , name , inp , out , sub ):
1400- (a , b , x ) = inp
1401- (z ,) = out
1399+ def c_code (self , node , name , inputs , outputs , sub ):
1400+ (a , b , x ) = inputs
1401+ (z ,) = outputs
14021402 if (
14031403 node .inputs [0 ].type in float_types
14041404 and node .inputs [1 ].type in float_types
0 commit comments