From e2639bfe9417a46cc2ea051905c3638b24edee59 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 20 Jan 2026 14:45:43 -0800 Subject: [PATCH] Add support to non-fp32 input dtype. --- KernelBench/level3/33_VanillaRNN.py | 2 +- KernelBench/level3/35_LSTM.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/KernelBench/level3/33_VanillaRNN.py b/KernelBench/level3/33_VanillaRNN.py index c725983f..3525f440 100644 --- a/KernelBench/level3/33_VanillaRNN.py +++ b/KernelBench/level3/33_VanillaRNN.py @@ -31,7 +31,7 @@ def forward(self, x: torch.Tensor, initial_hidden=None) -> torch.Tensor: """ if initial_hidden is not None: self.hidden.copy_(initial_hidden) - self.hidden = self.hidden.to(x.device) + self.hidden = self.hidden.to(x.device).to(x.dtype) combined = torch.cat((x, self.hidden), dim=1) # Concatenate input and hidden state self.hidden = self.tanh(self.i2h(combined)) # Update hidden state output = self.h2o(self.hidden) # Compute output diff --git a/KernelBench/level3/35_LSTM.py b/KernelBench/level3/35_LSTM.py index 7423ad18..1433ddb1 100644 --- a/KernelBench/level3/35_LSTM.py +++ b/KernelBench/level3/35_LSTM.py @@ -31,9 +31,9 @@ def forward(self, x, h0=None, c0=None): batch_size = x.size(0) if h0 is None: - h0 = torch.randn(self.num_layers, batch_size, self.hidden_size, device=x.device) + h0 = torch.randn(self.num_layers, batch_size, self.hidden_size, device=x.device, dtype=x.dtype) if c0 is None: - c0 = torch.randn(self.num_layers, batch_size, self.hidden_size, device=x.device) + c0 = torch.randn(self.num_layers, batch_size, self.hidden_size, device=x.device, dtype=x.dtype) out, _ = self.lstm(x, (h0, c0)) # out: (batch_size, seq_length, hidden_size) out = self.fc(out[:, -1, :]) # out: (batch_size, output_size)