From c122e54445cfb96ef63478077d6eb874a6c7d892 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 20 Oct 2025 06:57:25 +0800 Subject: [PATCH 1/5] Refactor tests to use `device` parameter instead of hardcoded `"cuda"` device --- tests/test_abs.py | 4 +--- tests/test_add.py | 4 +--- tests/test_addmm.py | 4 +--- tests/test_bitwise_and.py | 4 +--- tests/test_bitwise_not.py | 4 +--- tests/test_bitwise_or.py | 4 +--- tests/test_bmm.py | 4 +--- tests/test_clamp.py | 4 +--- tests/test_cos.py | 5 +---- tests/test_div.py | 4 +--- tests/test_dropout.py | 4 +--- tests/test_eq.py | 4 +--- tests/test_exp.py | 5 +---- tests/test_ge.py | 4 +--- tests/test_gelu.py | 4 +--- tests/test_gt.py | 4 +--- tests/test_isinf.py | 4 +--- tests/test_isnan.py | 4 +--- tests/test_layer_norm.py | 6 +++--- tests/test_le.py | 4 +--- tests/test_lt.py | 4 +--- tests/test_mm.py | 10 +++++----- tests/test_mul.py | 4 +--- tests/test_ne.py | 4 +--- tests/test_neg.py | 4 +--- tests/test_pow.py | 4 +--- tests/test_relu.py | 4 +--- tests/test_rms_norm.py | 4 +--- tests/test_rotary_position_embedding.py | 16 ++++++++++++---- tests/test_rsqrt.py | 4 +--- tests/test_scaled_dot_product_attention.py | 11 +++++++---- tests/test_sigmoid.py | 4 +--- tests/test_silu.py | 4 +--- tests/test_sin.py | 5 +---- tests/test_softmax.py | 4 +--- tests/test_sub.py | 4 +--- tests/test_tanh.py | 4 +--- tests/utils.py | 6 ++++-- 38 files changed, 64 insertions(+), 120 deletions(-) diff --git a/tests/test_abs.py b/tests/test_abs.py index 3ebe40c..0abd3f5 100644 --- a/tests/test_abs.py +++ b/tests/test_abs.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_abs(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.abs(input) diff --git a/tests/test_add.py b/tests/test_add.py index 7e67ed3..69109dd 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_add(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) alpha = gauss() diff --git a/tests/test_addmm.py b/tests/test_addmm.py index 8c105c2..9b8f095 100644 --- a/tests/test_addmm.py +++ b/tests/test_addmm.py @@ -9,9 +9,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(m, n, k, dtype, atol, rtol): - device = "cuda" - +def test_addmm(m, n, k, dtype, device, atol, rtol): input = torch.randn((m, n), dtype=dtype, device=device) x = torch.randn((m, k), dtype=dtype, device=device) y = torch.randn((k, n), dtype=dtype, device=device) diff --git a/tests/test_bitwise_and.py b/tests/test_bitwise_and.py index b286af2..0566eca 100644 --- a/tests/test_bitwise_and.py +++ b/tests/test_bitwise_and.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments(False)) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_bitwise_and(shape, dtype, device, atol, rtol): if dtype == torch.bool: prob = 0.5 input = torch.rand(shape, dtype=torch.float32, device=device) > prob diff --git a/tests/test_bitwise_not.py b/tests/test_bitwise_not.py index 4a70a96..4a3b4d9 100644 --- a/tests/test_bitwise_not.py +++ b/tests/test_bitwise_not.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments(False)) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_bitwise_not(shape, dtype, device, atol, rtol): if dtype == torch.bool: prob = 0.5 input = torch.rand(shape, dtype=torch.float32, device=device) > prob diff --git a/tests/test_bitwise_or.py b/tests/test_bitwise_or.py index 7f5a4d6..f2de543 100644 --- a/tests/test_bitwise_or.py +++ b/tests/test_bitwise_or.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments(False)) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_bitwise_or(shape, dtype, device, atol, rtol): if dtype == torch.bool: prob = 0.5 input = torch.rand(shape, dtype=torch.float32, device=device) > prob diff --git a/tests/test_bmm.py b/tests/test_bmm.py index 96337a8..3ad3b14 100644 --- a/tests/test_bmm.py +++ b/tests/test_bmm.py @@ -10,9 +10,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(m, n, k, dtype, atol, rtol): - device = "cuda" - +def test_bmm(m, n, k, dtype, device, atol, rtol): b = random.randint(4, 16) input = torch.randn((b, m, k), dtype=dtype, device=device) other = torch.randn((b, k, n), dtype=dtype, device=device) diff --git a/tests/test_clamp.py b/tests/test_clamp.py index 608fd3d..9a3f1bb 100644 --- a/tests/test_clamp.py +++ b/tests/test_clamp.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_clamp(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) min = torch.randn(shape, dtype=dtype, device=device) max = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_cos.py b/tests/test_cos.py index 005588c..1604232 100644 --- a/tests/test_cos.py +++ b/tests/test_cos.py @@ -8,13 +8,10 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_cos(shape, dtype, device, atol, rtol): # TODO: Test for `float16` later. if dtype is torch.float16: return - - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.cos(input) diff --git a/tests/test_div.py b/tests/test_div.py index bf3644f..e227f94 100644 --- a/tests/test_div.py +++ b/tests/test_div.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_div(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_dropout.py b/tests/test_dropout.py index 1181e79..eeb3577 100644 --- a/tests/test_dropout.py +++ b/tests/test_dropout.py @@ -11,9 +11,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_dropout(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) p = random.uniform(0, 1) diff --git a/tests/test_eq.py b/tests/test_eq.py index 6c44191..1c1d3f2 100644 --- a/tests/test_eq.py +++ b/tests/test_eq.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_eq(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_exp.py b/tests/test_exp.py index f0d66d3..2cc153c 100644 --- a/tests/test_exp.py +++ b/tests/test_exp.py @@ -8,13 +8,10 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_exp(shape, dtype, device, atol, rtol): # TODO: Test for `float16` later. if dtype is torch.float16: return - - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.exp(input) diff --git a/tests/test_ge.py b/tests/test_ge.py index 0bf2a49..86ae6c4 100644 --- a/tests/test_ge.py +++ b/tests/test_ge.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_ge(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_gelu.py b/tests/test_gelu.py index 6395a2e..51f85a4 100644 --- a/tests/test_gelu.py +++ b/tests/test_gelu.py @@ -9,9 +9,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_gelu(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) for approximate in ("none", "tanh"): diff --git a/tests/test_gt.py b/tests/test_gt.py index 509b6e3..8c9d310 100644 --- a/tests/test_gt.py +++ b/tests/test_gt.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_gt(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_isinf.py b/tests/test_isinf.py index b4f8d80..e1af79e 100644 --- a/tests/test_isinf.py +++ b/tests/test_isinf.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_isinf(shape, dtype, device, atol, rtol): def generate_inf_tensor(shape, dtype, device): x = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_isnan.py b/tests/test_isnan.py index 7f87bcc..73f8d45 100644 --- a/tests/test_isnan.py +++ b/tests/test_isnan.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_isnan(shape, dtype, device, atol, rtol): def generate_nan_tensor(shape, dtype, device): nan_prob = 0.4 prob_tensor = torch.rand(shape, device=device) diff --git a/tests/test_layer_norm.py b/tests/test_layer_norm.py index fbe2c73..0ae7597 100644 --- a/tests/test_layer_norm.py +++ b/tests/test_layer_norm.py @@ -13,9 +13,9 @@ @pytest.mark.parametrize("bias_is_none", (False, True)) @pytest.mark.parametrize("weight_is_none", (False, True)) @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol, weight_is_none, bias_is_none, eps): - device = "cuda" - +def test_layer_norm( + shape, dtype, device, atol, rtol, weight_is_none, bias_is_none, eps +): input = torch.randn(shape, dtype=dtype, device=device) normalized_shape = shape[-random.randint(1, len(shape)) :] if weight_is_none: diff --git a/tests/test_le.py b/tests/test_le.py index b6d3b28..8e35e3a 100644 --- a/tests/test_le.py +++ b/tests/test_le.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_le(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_lt.py b/tests/test_lt.py index 1cda1f5..a2ba58f 100644 --- a/tests/test_lt.py +++ b/tests/test_lt.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_lt(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_mm.py b/tests/test_mm.py index 8874f32..f61a178 100644 --- a/tests/test_mm.py +++ b/tests/test_mm.py @@ -11,6 +11,8 @@ def generate_arguments(): arguments = [] for dtype in (torch.float32, torch.float16): + device = "cuda" + if dtype is torch.float32: atol = 0.001 rtol = 0.001 @@ -25,16 +27,14 @@ def generate_random_size(): n = generate_random_size() k = generate_random_size() - arguments.append((m, n, k, dtype, atol, rtol)) + arguments.append((m, n, k, dtype, device, atol, rtol)) - return "m, n, k, dtype, atol, rtol", arguments + return "m, n, k, dtype, device, atol, rtol", arguments @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(m, n, k, dtype, atol, rtol): - device = "cuda" - +def test_mm(m, n, k, dtype, device, atol, rtol): input = torch.randn((m, k), dtype=dtype, device=device) other = torch.randn((k, n), dtype=dtype, device=device) diff --git a/tests/test_mul.py b/tests/test_mul.py index 78d7951..78b7161 100644 --- a/tests/test_mul.py +++ b/tests/test_mul.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_mul(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_ne.py b/tests/test_ne.py index d1ca7ec..ae2be2b 100644 --- a/tests/test_ne.py +++ b/tests/test_ne.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_ne(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_neg.py b/tests/test_neg.py index 243a6b1..2ea8cd1 100644 --- a/tests/test_neg.py +++ b/tests/test_neg.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_neg(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.neg(input) diff --git a/tests/test_pow.py b/tests/test_pow.py index 5ffc1a0..8fd2ccc 100644 --- a/tests/test_pow.py +++ b/tests/test_pow.py @@ -8,13 +8,11 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_pow(shape, dtype, device, atol, rtol): # TODO: Test for `float16` later. if dtype is torch.float16: return - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) exponent = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_relu.py b/tests/test_relu.py index 9fdade5..629cd20 100644 --- a/tests/test_relu.py +++ b/tests/test_relu.py @@ -9,9 +9,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_relu(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) for inplace in (False, True): diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index 646906a..fcafb2e 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -12,9 +12,7 @@ @pytest.mark.parametrize("eps", (None, 0, 1e-5, 1e-3)) @pytest.mark.parametrize("weight_is_none", (False, True)) @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol, weight_is_none, eps): - device = "cuda" - +def test_rms_norm(shape, dtype, device, atol, rtol, weight_is_none, eps): input = torch.randn(shape, dtype=dtype, device=device) normalized_shape = shape[-random.randint(1, len(shape)) :] if weight_is_none: diff --git a/tests/test_rotary_position_embedding.py b/tests/test_rotary_position_embedding.py index a017c9d..c8aea87 100644 --- a/tests/test_rotary_position_embedding.py +++ b/tests/test_rotary_position_embedding.py @@ -46,6 +46,7 @@ def _generate_sin_and_cos_tables( @skip_if_cuda_not_available +@pytest.mark.parametrize("device", ("cuda",)) @pytest.mark.parametrize( "dtype, atol, rtol", ((torch.float32, 0.001, 0), (torch.float16, 0.001, 0.001)) ) @@ -55,11 +56,18 @@ def _generate_sin_and_cos_tables( @pytest.mark.parametrize("num_heads", (1, 8)) @pytest.mark.parametrize("seq_len", (1, 128)) @pytest.mark.parametrize("batch_size", (1, 4)) -def test_cuda( - batch_size, seq_len, num_heads, emb_dim, interleaved, inplace, dtype, atol, rtol +def test_rotary_position_embedding( + batch_size, + seq_len, + num_heads, + emb_dim, + interleaved, + inplace, + dtype, + device, + atol, + rtol, ): - device = "cuda" - input = torch.randn( batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device ) diff --git a/tests/test_rsqrt.py b/tests/test_rsqrt.py index d87fafb..bf391e3 100644 --- a/tests/test_rsqrt.py +++ b/tests/test_rsqrt.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_rsqrt(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.rsqrt(input) diff --git a/tests/test_scaled_dot_product_attention.py b/tests/test_scaled_dot_product_attention.py index 433ffa9..0a6790e 100644 --- a/tests/test_scaled_dot_product_attention.py +++ b/tests/test_scaled_dot_product_attention.py @@ -22,6 +22,7 @@ def _generate_random_size(): is_causal_values = (False, True) scales = (None, random.uniform(0.05, 0.5)) dtypes = (torch.float32, torch.float16) + devices = ("cuda",) with_kv_cache_values = (False, True) causal_variants = (None, CausalVariant.LOWER_RIGHT, CausalVariant.UPPER_LEFT) @@ -30,6 +31,7 @@ def _generate_random_size(): is_causal, scale, dtype, + device, with_kv_cache, causal_variant, ) in itertools.product( @@ -37,6 +39,7 @@ def _generate_random_size(): is_causal_values, scales, dtypes, + devices, with_kv_cache_values, causal_variants, ): @@ -77,20 +80,21 @@ def _generate_random_size(): causal_variant, with_kv_cache, dtype, + device, atol, rtol, ) ) return ( - "batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, causal_variant, with_kv_cache, dtype, atol, rtol", + "batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, causal_variant, with_kv_cache, dtype, device, atol, rtol", arguments, ) @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda( +def test_scaled_dot_product_attention( batch_size, num_heads_q, seq_len_q, @@ -104,11 +108,10 @@ def test_cuda( causal_variant, with_kv_cache, dtype, + device, atol, rtol, ): - device = "cuda" - shape_q = (batch_size, num_heads_q, seq_len_q, head_dim) shape_kv = (batch_size, num_heads_kv, seq_len_kv, head_dim) diff --git a/tests/test_sigmoid.py b/tests/test_sigmoid.py index d4fc499..84f03e7 100644 --- a/tests/test_sigmoid.py +++ b/tests/test_sigmoid.py @@ -8,13 +8,11 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_sigmoid(shape, dtype, device, atol, rtol): # TODO: Test for `float16` later. if dtype is torch.float16: return - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.sigmoid(input) diff --git a/tests/test_silu.py b/tests/test_silu.py index cba8ae2..17cbea6 100644 --- a/tests/test_silu.py +++ b/tests/test_silu.py @@ -9,9 +9,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_silu(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) # TODO: Add `inplace` tests later. diff --git a/tests/test_sin.py b/tests/test_sin.py index d395b97..eddcfda 100644 --- a/tests/test_sin.py +++ b/tests/test_sin.py @@ -8,13 +8,10 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_sin(shape, dtype, device, atol, rtol): # TODO: Test for `float16` later. if dtype is torch.float16: return - - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.sin(input) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index 48eaad0..fce1f1f 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -10,9 +10,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_softmax(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) dim = random.randint(0, input.ndim - 1) dtype = random.choice([torch.float16, torch.float32, torch.float64]) diff --git a/tests/test_sub.py b/tests/test_sub.py index 726f78f..d6b39e2 100644 --- a/tests/test_sub.py +++ b/tests/test_sub.py @@ -8,9 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): - device = "cuda" - +def test_sub(shape, dtype, device, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) alpha = gauss() diff --git a/tests/test_tanh.py b/tests/test_tanh.py index e2bc49e..f7b9e0b 100644 --- a/tests/test_tanh.py +++ b/tests/test_tanh.py @@ -8,13 +8,11 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cuda(shape, dtype, atol, rtol): +def test_tanh(shape, dtype, device, atol, rtol): # TODO: Test for `float16` later. if dtype is torch.float16: return - device = "cuda" - input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.tanh(input) diff --git a/tests/utils.py b/tests/utils.py index 23af61d..3de38eb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,6 +13,8 @@ def generate_arguments(use_float=True): for ndim in range(1, 5): for dtype in dtype_arr: + device = "cuda" + if dtype is torch.float32: atol = 0.001 rtol = 0.001 @@ -20,9 +22,9 @@ def generate_arguments(use_float=True): atol = 0.01 rtol = 0.01 - arguments.append((_random_shape(ndim), dtype, atol, rtol)) + arguments.append((_random_shape(ndim), dtype, device, atol, rtol)) - return "shape, dtype, atol, rtol", arguments + return "shape, dtype, device, atol, rtol", arguments def gauss(mu=0.0, sigma=1.0): From 490d23a18508993bc9b288a43d5516aa2a8651ad Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 20 Oct 2025 07:32:19 +0800 Subject: [PATCH 2/5] Add configurable defaults for `_cached_make` kernel compilation parameters --- src/ntops/torch/utils.py | 45 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/ntops/torch/utils.py b/src/ntops/torch/utils.py index 1e0d5ba..e9b2dde 100644 --- a/src/ntops/torch/utils.py +++ b/src/ntops/torch/utils.py @@ -6,10 +6,55 @@ import ntops +class _CachedMakeDefaultConfig: + def __init__(self, num_warps=None, num_stages=None, max_num_configs=None): + self.num_warps = num_warps + + self.num_stages = num_stages + + self.max_num_configs = max_num_configs + + +_cached_make_default_config = _CachedMakeDefaultConfig() + + +def get_default_num_warps(): + return _cached_make_default_config.num_warps + + +def set_default_num_warps(num_warps): + _cached_make_default_config.num_warps = num_warps + + +def get_default_num_stages(): + return _cached_make_default_config.num_stages + + +def set_default_num_stages(num_stages): + _cached_make_default_config.num_stages = num_stages + + +def get_default_max_num_configs(): + return _cached_make_default_config.max_num_configs + + +def set_default_max_num_configs(max_num_configs): + _cached_make_default_config.max_num_configs = max_num_configs + + @functools.cache def _cached_make( premake, *args, num_warps=None, num_stages=None, max_num_configs=None, **keywords ): + if num_warps is None: + num_warps = _cached_make_default_config.num_warps + + if num_stages is None: + num_stages = _cached_make_default_config.num_stages + + if max_num_configs is None: + max_num_configs = _cached_make_default_config.max_num_configs + return ninetoothed.make( *premake(*args, **keywords), num_warps=num_warps, From a5776c65992bb66865e223772a85967f8017bee9 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 20 Oct 2025 07:35:39 +0800 Subject: [PATCH 3/5] Configure default `max_num_configs` in `pytest_configure` --- tests/conftest.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 3f29fd4..0f89217 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,11 +4,15 @@ import pytest import torch +import ntops.torch.utils + def pytest_configure(): torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + ntops.torch.utils.set_default_max_num_configs(_DEFAULT_MAX_NUM_CONFIGS) + def pytest_collectstart(collector): if isinstance(collector, pytest.Module): @@ -25,6 +29,9 @@ def set_seed_per_test(request): _set_random_seed(_hash(_test_case_path_from_request(request))) +_DEFAULT_MAX_NUM_CONFIGS = 3 + + def _set_random_seed(seed): random.seed(seed) torch.manual_seed(seed) From 53746de9ec3aa3ce744570d8f4448563615c28df Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 20 Oct 2025 21:01:05 +0800 Subject: [PATCH 4/5] Refactor tests to use `rtol, atol` parameter order instead of `atol, rtol` --- tests/test_abs.py | 4 ++-- tests/test_add.py | 4 ++-- tests/test_addmm.py | 4 ++-- tests/test_bitwise_and.py | 2 +- tests/test_bitwise_not.py | 2 +- tests/test_bitwise_or.py | 2 +- tests/test_bmm.py | 4 ++-- tests/test_clamp.py | 4 ++-- tests/test_cos.py | 4 ++-- tests/test_div.py | 4 ++-- tests/test_dropout.py | 2 +- tests/test_eq.py | 2 +- tests/test_exp.py | 4 ++-- tests/test_ge.py | 2 +- tests/test_gelu.py | 4 ++-- tests/test_gt.py | 2 +- tests/test_isinf.py | 2 +- tests/test_isnan.py | 2 +- tests/test_layer_norm.py | 4 ++-- tests/test_le.py | 2 +- tests/test_lt.py | 2 +- tests/test_mm.py | 8 ++++---- tests/test_mul.py | 4 ++-- tests/test_ne.py | 2 +- tests/test_neg.py | 4 ++-- tests/test_pow.py | 4 ++-- tests/test_relu.py | 4 ++-- tests/test_rms_norm.py | 4 ++-- tests/test_rotary_position_embedding.py | 6 +++--- tests/test_rsqrt.py | 4 ++-- tests/test_scaled_dot_product_attention.py | 8 ++++---- tests/test_sigmoid.py | 4 ++-- tests/test_silu.py | 4 ++-- tests/test_sin.py | 4 ++-- tests/test_softmax.py | 4 ++-- tests/test_sub.py | 4 ++-- tests/test_tanh.py | 4 ++-- tests/utils.py | 4 ++-- 38 files changed, 69 insertions(+), 69 deletions(-) diff --git a/tests/test_abs.py b/tests/test_abs.py index 0abd3f5..d643cf8 100644 --- a/tests/test_abs.py +++ b/tests/test_abs.py @@ -8,10 +8,10 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_abs(shape, dtype, device, atol, rtol): +def test_abs(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.abs(input) reference_output = torch.abs(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_add.py b/tests/test_add.py index 69109dd..1a34e97 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_add(shape, dtype, device, atol, rtol): +def test_add(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) alpha = gauss() @@ -16,4 +16,4 @@ def test_add(shape, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.add(input, other, alpha=alpha) reference_output = torch.add(input, other, alpha=alpha) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_addmm.py b/tests/test_addmm.py index 9b8f095..56361d6 100644 --- a/tests/test_addmm.py +++ b/tests/test_addmm.py @@ -9,7 +9,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_addmm(m, n, k, dtype, device, atol, rtol): +def test_addmm(m, n, k, dtype, device, rtol, atol): input = torch.randn((m, n), dtype=dtype, device=device) x = torch.randn((m, k), dtype=dtype, device=device) y = torch.randn((k, n), dtype=dtype, device=device) @@ -19,4 +19,4 @@ def test_addmm(m, n, k, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.addmm(input, x, y, beta=beta, alpha=alpha) reference_output = torch.addmm(input, x, y, beta=beta, alpha=alpha) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_bitwise_and.py b/tests/test_bitwise_and.py index 0566eca..77ddf95 100644 --- a/tests/test_bitwise_and.py +++ b/tests/test_bitwise_and.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments(False)) -def test_bitwise_and(shape, dtype, device, atol, rtol): +def test_bitwise_and(shape, dtype, device, rtol, atol): if dtype == torch.bool: prob = 0.5 input = torch.rand(shape, dtype=torch.float32, device=device) > prob diff --git a/tests/test_bitwise_not.py b/tests/test_bitwise_not.py index 4a3b4d9..93e48a5 100644 --- a/tests/test_bitwise_not.py +++ b/tests/test_bitwise_not.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments(False)) -def test_bitwise_not(shape, dtype, device, atol, rtol): +def test_bitwise_not(shape, dtype, device, rtol, atol): if dtype == torch.bool: prob = 0.5 input = torch.rand(shape, dtype=torch.float32, device=device) > prob diff --git a/tests/test_bitwise_or.py b/tests/test_bitwise_or.py index f2de543..2d8096e 100644 --- a/tests/test_bitwise_or.py +++ b/tests/test_bitwise_or.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments(False)) -def test_bitwise_or(shape, dtype, device, atol, rtol): +def test_bitwise_or(shape, dtype, device, rtol, atol): if dtype == torch.bool: prob = 0.5 input = torch.rand(shape, dtype=torch.float32, device=device) > prob diff --git a/tests/test_bmm.py b/tests/test_bmm.py index 3ad3b14..2b66d9f 100644 --- a/tests/test_bmm.py +++ b/tests/test_bmm.py @@ -10,7 +10,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_bmm(m, n, k, dtype, device, atol, rtol): +def test_bmm(m, n, k, dtype, device, rtol, atol): b = random.randint(4, 16) input = torch.randn((b, m, k), dtype=dtype, device=device) other = torch.randn((b, k, n), dtype=dtype, device=device) @@ -18,4 +18,4 @@ def test_bmm(m, n, k, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.bmm(input, other) reference_output = torch.bmm(input, other) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_clamp.py b/tests/test_clamp.py index 9a3f1bb..4ba0de9 100644 --- a/tests/test_clamp.py +++ b/tests/test_clamp.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_clamp(shape, dtype, device, atol, rtol): +def test_clamp(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) min = torch.randn(shape, dtype=dtype, device=device) max = torch.randn(shape, dtype=dtype, device=device) @@ -16,4 +16,4 @@ def test_clamp(shape, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.clamp(input, min, max) reference_output = torch.clamp(input, min, max) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_cos.py b/tests/test_cos.py index 1604232..d663310 100644 --- a/tests/test_cos.py +++ b/tests/test_cos.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_cos(shape, dtype, device, atol, rtol): +def test_cos(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return @@ -17,4 +17,4 @@ def test_cos(shape, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.cos(input) reference_output = torch.cos(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_div.py b/tests/test_div.py index e227f94..64478a2 100644 --- a/tests/test_div.py +++ b/tests/test_div.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_div(shape, dtype, device, atol, rtol): +def test_div(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) @@ -21,5 +21,5 @@ def test_div(shape, dtype, device, atol, rtol): reference_output = torch.div(input, other, rounding_mode=rounding_mode) assert torch.allclose( - ninetoothed_output, reference_output, atol=atol, rtol=rtol + ninetoothed_output, reference_output, rtol=rtol, atol=atol ) diff --git a/tests/test_dropout.py b/tests/test_dropout.py index eeb3577..89c08fd 100644 --- a/tests/test_dropout.py +++ b/tests/test_dropout.py @@ -11,7 +11,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_dropout(shape, dtype, device, atol, rtol): +def test_dropout(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) p = random.uniform(0, 1) diff --git a/tests/test_eq.py b/tests/test_eq.py index 1c1d3f2..04ebe42 100644 --- a/tests/test_eq.py +++ b/tests/test_eq.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_eq(shape, dtype, device, atol, rtol): +def test_eq(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_exp.py b/tests/test_exp.py index 2cc153c..33f7c59 100644 --- a/tests/test_exp.py +++ b/tests/test_exp.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_exp(shape, dtype, device, atol, rtol): +def test_exp(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return @@ -17,4 +17,4 @@ def test_exp(shape, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.exp(input) reference_output = torch.exp(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_ge.py b/tests/test_ge.py index 86ae6c4..bc4a83b 100644 --- a/tests/test_ge.py +++ b/tests/test_ge.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_ge(shape, dtype, device, atol, rtol): +def test_ge(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_gelu.py b/tests/test_gelu.py index 51f85a4..a04b304 100644 --- a/tests/test_gelu.py +++ b/tests/test_gelu.py @@ -9,7 +9,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_gelu(shape, dtype, device, atol, rtol): +def test_gelu(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) for approximate in ("none", "tanh"): @@ -17,5 +17,5 @@ def test_gelu(shape, dtype, device, atol, rtol): reference_output = F.gelu(input) assert torch.allclose( - ninetoothed_output, reference_output, atol=atol, rtol=rtol + ninetoothed_output, reference_output, rtol=rtol, atol=atol ) diff --git a/tests/test_gt.py b/tests/test_gt.py index 8c9d310..cfc03cc 100644 --- a/tests/test_gt.py +++ b/tests/test_gt.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_gt(shape, dtype, device, atol, rtol): +def test_gt(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_isinf.py b/tests/test_isinf.py index e1af79e..e651836 100644 --- a/tests/test_isinf.py +++ b/tests/test_isinf.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_isinf(shape, dtype, device, atol, rtol): +def test_isinf(shape, dtype, device, rtol, atol): def generate_inf_tensor(shape, dtype, device): x = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_isnan.py b/tests/test_isnan.py index 73f8d45..7e29b01 100644 --- a/tests/test_isnan.py +++ b/tests/test_isnan.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_isnan(shape, dtype, device, atol, rtol): +def test_isnan(shape, dtype, device, rtol, atol): def generate_nan_tensor(shape, dtype, device): nan_prob = 0.4 prob_tensor = torch.rand(shape, device=device) diff --git a/tests/test_layer_norm.py b/tests/test_layer_norm.py index 0ae7597..3d1bd83 100644 --- a/tests/test_layer_norm.py +++ b/tests/test_layer_norm.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize("weight_is_none", (False, True)) @pytest.mark.parametrize(*generate_arguments()) def test_layer_norm( - shape, dtype, device, atol, rtol, weight_is_none, bias_is_none, eps + shape, dtype, device, rtol, atol, weight_is_none, bias_is_none, eps ): input = torch.randn(shape, dtype=dtype, device=device) normalized_shape = shape[-random.randint(1, len(shape)) :] @@ -34,4 +34,4 @@ def test_layer_norm( input, normalized_shape, weight=weight, bias=bias, eps=eps ) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_le.py b/tests/test_le.py index 8e35e3a..5a6f961 100644 --- a/tests/test_le.py +++ b/tests/test_le.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_le(shape, dtype, device, atol, rtol): +def test_le(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_lt.py b/tests/test_lt.py index a2ba58f..7d2b376 100644 --- a/tests/test_lt.py +++ b/tests/test_lt.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_lt(shape, dtype, device, atol, rtol): +def test_lt(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_mm.py b/tests/test_mm.py index f61a178..57bd387 100644 --- a/tests/test_mm.py +++ b/tests/test_mm.py @@ -27,18 +27,18 @@ def generate_random_size(): n = generate_random_size() k = generate_random_size() - arguments.append((m, n, k, dtype, device, atol, rtol)) + arguments.append((m, n, k, dtype, device, rtol, atol)) - return "m, n, k, dtype, device, atol, rtol", arguments + return "m, n, k, dtype, device, rtol, atol", arguments @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_mm(m, n, k, dtype, device, atol, rtol): +def test_mm(m, n, k, dtype, device, rtol, atol): input = torch.randn((m, k), dtype=dtype, device=device) other = torch.randn((k, n), dtype=dtype, device=device) ninetoothed_output = ntops.torch.mm(input, other) reference_output = torch.mm(input, other) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_mul.py b/tests/test_mul.py index 78b7161..d707cf9 100644 --- a/tests/test_mul.py +++ b/tests/test_mul.py @@ -8,11 +8,11 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_mul(shape, dtype, device, atol, rtol): +def test_mul(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.mul(input, other) reference_output = torch.mul(input, other) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_ne.py b/tests/test_ne.py index ae2be2b..e6568c5 100644 --- a/tests/test_ne.py +++ b/tests/test_ne.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_ne(shape, dtype, device, atol, rtol): +def test_ne(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) diff --git a/tests/test_neg.py b/tests/test_neg.py index 2ea8cd1..4f77f20 100644 --- a/tests/test_neg.py +++ b/tests/test_neg.py @@ -8,10 +8,10 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_neg(shape, dtype, device, atol, rtol): +def test_neg(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.neg(input) reference_output = torch.neg(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_pow.py b/tests/test_pow.py index 8fd2ccc..a1df1b0 100644 --- a/tests/test_pow.py +++ b/tests/test_pow.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_pow(shape, dtype, device, atol, rtol): +def test_pow(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return @@ -20,5 +20,5 @@ def test_pow(shape, dtype, device, atol, rtol): reference_output = torch.pow(input, exponent) assert torch.allclose( - ninetoothed_output, reference_output, atol=atol, rtol=rtol, equal_nan=True + ninetoothed_output, reference_output, rtol=rtol, atol=atol, equal_nan=True ) diff --git a/tests/test_relu.py b/tests/test_relu.py index 629cd20..19b6ed7 100644 --- a/tests/test_relu.py +++ b/tests/test_relu.py @@ -9,7 +9,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_relu(shape, dtype, device, atol, rtol): +def test_relu(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) for inplace in (False, True): @@ -17,5 +17,5 @@ def test_relu(shape, dtype, device, atol, rtol): reference_output = F.relu(input, inplace) assert torch.allclose( - ninetoothed_output, reference_output, atol=atol, rtol=rtol + ninetoothed_output, reference_output, rtol=rtol, atol=atol ) diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index fcafb2e..211589d 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("eps", (None, 0, 1e-5, 1e-3)) @pytest.mark.parametrize("weight_is_none", (False, True)) @pytest.mark.parametrize(*generate_arguments()) -def test_rms_norm(shape, dtype, device, atol, rtol, weight_is_none, eps): +def test_rms_norm(shape, dtype, device, rtol, atol, weight_is_none, eps): input = torch.randn(shape, dtype=dtype, device=device) normalized_shape = shape[-random.randint(1, len(shape)) :] if weight_is_none: @@ -27,4 +27,4 @@ def test_rms_norm(shape, dtype, device, atol, rtol, weight_is_none, eps): input, normalized_shape, weight=weight, eps=eps ) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_rotary_position_embedding.py b/tests/test_rotary_position_embedding.py index c8aea87..bb2749c 100644 --- a/tests/test_rotary_position_embedding.py +++ b/tests/test_rotary_position_embedding.py @@ -48,7 +48,7 @@ def _generate_sin_and_cos_tables( @skip_if_cuda_not_available @pytest.mark.parametrize("device", ("cuda",)) @pytest.mark.parametrize( - "dtype, atol, rtol", ((torch.float32, 0.001, 0), (torch.float16, 0.001, 0.001)) + "dtype, rtol, atol", ((torch.float32, 0, 0.001), (torch.float16, 0.001, 0.001)) ) @pytest.mark.parametrize("inplace", (False, True)) @pytest.mark.parametrize("interleaved", (False, True)) @@ -65,8 +65,8 @@ def test_rotary_position_embedding( inplace, dtype, device, - atol, rtol, + atol, ): input = torch.randn( batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device @@ -86,4 +86,4 @@ def test_rotary_position_embedding( input, sin_table, cos_table, interleaved=interleaved ) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_rsqrt.py b/tests/test_rsqrt.py index bf391e3..35252dd 100644 --- a/tests/test_rsqrt.py +++ b/tests/test_rsqrt.py @@ -8,12 +8,12 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_rsqrt(shape, dtype, device, atol, rtol): +def test_rsqrt(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) ninetoothed_output = ntops.torch.rsqrt(input) reference_output = torch.rsqrt(input) assert torch.allclose( - ninetoothed_output, reference_output, atol=atol, rtol=rtol, equal_nan=True + ninetoothed_output, reference_output, rtol=rtol, atol=atol, equal_nan=True ) diff --git a/tests/test_scaled_dot_product_attention.py b/tests/test_scaled_dot_product_attention.py index 0a6790e..b1f1046 100644 --- a/tests/test_scaled_dot_product_attention.py +++ b/tests/test_scaled_dot_product_attention.py @@ -81,13 +81,13 @@ def _generate_random_size(): with_kv_cache, dtype, device, - atol, rtol, + atol, ) ) return ( - "batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, causal_variant, with_kv_cache, dtype, device, atol, rtol", + "batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, causal_variant, with_kv_cache, dtype, device, rtol, atol", arguments, ) @@ -109,8 +109,8 @@ def test_scaled_dot_product_attention( with_kv_cache, dtype, device, - atol, rtol, + atol, ): shape_q = (batch_size, num_heads_q, seq_len_q, head_dim) shape_kv = (batch_size, num_heads_kv, seq_len_kv, head_dim) @@ -183,4 +183,4 @@ def _generate_present_and_slot(tensor): enable_gqa=enable_gqa, ) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_sigmoid.py b/tests/test_sigmoid.py index 84f03e7..c369fa8 100644 --- a/tests/test_sigmoid.py +++ b/tests/test_sigmoid.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_sigmoid(shape, dtype, device, atol, rtol): +def test_sigmoid(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return @@ -18,4 +18,4 @@ def test_sigmoid(shape, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.sigmoid(input) reference_output = torch.sigmoid(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_silu.py b/tests/test_silu.py index 17cbea6..8aae00b 100644 --- a/tests/test_silu.py +++ b/tests/test_silu.py @@ -9,11 +9,11 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_silu(shape, dtype, device, atol, rtol): +def test_silu(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) # TODO: Add `inplace` tests later. ninetoothed_output = ntops.torch.silu(input) reference_output = F.silu(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_sin.py b/tests/test_sin.py index eddcfda..5e6cd35 100644 --- a/tests/test_sin.py +++ b/tests/test_sin.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_sin(shape, dtype, device, atol, rtol): +def test_sin(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return @@ -17,4 +17,4 @@ def test_sin(shape, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.sin(input) reference_output = torch.sin(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index fce1f1f..ad86181 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -10,7 +10,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_softmax(shape, dtype, device, atol, rtol): +def test_softmax(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) dim = random.randint(0, input.ndim - 1) dtype = random.choice([torch.float16, torch.float32, torch.float64]) @@ -18,4 +18,4 @@ def test_softmax(shape, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.softmax(input, dim, dtype) reference_output = torch.nn.functional.softmax(input, dim=dim, dtype=dtype) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_sub.py b/tests/test_sub.py index d6b39e2..1909cf0 100644 --- a/tests/test_sub.py +++ b/tests/test_sub.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_sub(shape, dtype, device, atol, rtol): +def test_sub(shape, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) alpha = gauss() @@ -16,4 +16,4 @@ def test_sub(shape, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.sub(input, other, alpha=alpha) reference_output = torch.sub(input, other, alpha=alpha) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_tanh.py b/tests/test_tanh.py index f7b9e0b..7b7d490 100644 --- a/tests/test_tanh.py +++ b/tests/test_tanh.py @@ -8,7 +8,7 @@ @skip_if_cuda_not_available @pytest.mark.parametrize(*generate_arguments()) -def test_tanh(shape, dtype, device, atol, rtol): +def test_tanh(shape, dtype, device, rtol, atol): # TODO: Test for `float16` later. if dtype is torch.float16: return @@ -18,4 +18,4 @@ def test_tanh(shape, dtype, device, atol, rtol): ninetoothed_output = ntops.torch.tanh(input) reference_output = torch.tanh(input) - assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/utils.py b/tests/utils.py index 3de38eb..ac1949f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,9 +22,9 @@ def generate_arguments(use_float=True): atol = 0.01 rtol = 0.01 - arguments.append((_random_shape(ndim), dtype, device, atol, rtol)) + arguments.append((_random_shape(ndim), dtype, device, rtol, atol)) - return "shape, dtype, device, atol, rtol", arguments + return "shape, dtype, device, rtol, atol", arguments def gauss(mu=0.0, sigma=1.0): From b0a1a7705541fca45d50546fe38278698cca11b8 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 20 Oct 2025 21:34:20 +0800 Subject: [PATCH 5/5] Refactor tests to use `pytest.mark.parametrize` instead of `for` loops --- tests/test_div.py | 25 ++++++++++++++----------- tests/test_gelu.py | 20 +++++++++++++------- tests/test_relu.py | 12 +++++------- 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/tests/test_div.py b/tests/test_div.py index 64478a2..8a3878d 100644 --- a/tests/test_div.py +++ b/tests/test_div.py @@ -7,19 +7,22 @@ @skip_if_cuda_not_available +@pytest.mark.parametrize( + "rounding_mode", + [ + None, + pytest.param( + "trunc", marks=pytest.mark.skip(reason="TODO: Test for `trunc` mode later.") + ), + "floor", + ], +) @pytest.mark.parametrize(*generate_arguments()) -def test_div(shape, dtype, device, rtol, atol): +def test_div(shape, rounding_mode, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) other = torch.randn(shape, dtype=dtype, device=device) - for rounding_mode in (None, "trunc", "floor"): - # TODO: Test for `trunc` mode later. - if rounding_mode == "trunc": - continue + ninetoothed_output = ntops.torch.div(input, other, rounding_mode=rounding_mode) + reference_output = torch.div(input, other, rounding_mode=rounding_mode) - ninetoothed_output = ntops.torch.div(input, other, rounding_mode=rounding_mode) - reference_output = torch.div(input, other, rounding_mode=rounding_mode) - - assert torch.allclose( - ninetoothed_output, reference_output, rtol=rtol, atol=atol - ) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_gelu.py b/tests/test_gelu.py index a04b304..af14670 100644 --- a/tests/test_gelu.py +++ b/tests/test_gelu.py @@ -8,14 +8,20 @@ @skip_if_cuda_not_available +@pytest.mark.parametrize( + "approximate", + ( + "none", + pytest.param( + "tanh", marks=pytest.mark.skip(reason="TODO: Test for `tanh` mode later.") + ), + ), +) @pytest.mark.parametrize(*generate_arguments()) -def test_gelu(shape, dtype, device, rtol, atol): +def test_gelu(shape, approximate, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) - for approximate in ("none", "tanh"): - ninetoothed_output = ntops.torch.gelu(input) - reference_output = F.gelu(input) + ninetoothed_output = ntops.torch.gelu(input, approximate=approximate) + reference_output = F.gelu(input, approximate=approximate) - assert torch.allclose( - ninetoothed_output, reference_output, rtol=rtol, atol=atol - ) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_relu.py b/tests/test_relu.py index 19b6ed7..28cc07c 100644 --- a/tests/test_relu.py +++ b/tests/test_relu.py @@ -8,14 +8,12 @@ @skip_if_cuda_not_available +@pytest.mark.parametrize("inplace", (False, True)) @pytest.mark.parametrize(*generate_arguments()) -def test_relu(shape, dtype, device, rtol, atol): +def test_relu(shape, inplace, dtype, device, rtol, atol): input = torch.randn(shape, dtype=dtype, device=device) - for inplace in (False, True): - ninetoothed_output = ntops.torch.relu(input, inplace) - reference_output = F.relu(input, inplace) + ninetoothed_output = ntops.torch.relu(input, inplace) + reference_output = F.relu(input, inplace) - assert torch.allclose( - ninetoothed_output, reference_output, rtol=rtol, atol=atol - ) + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)