-
Notifications
You must be signed in to change notification settings - Fork 59
Closed
Description
#pragma unroll
for (int i_10 = 0; i_10 < 4; ++i_10) {
__syncthreads();
uint4 __1;
uint4 v_ = *(uint4*)(((half_t*)buf_dyn_shmem) + (((i_10 * 1024) + (((int)threadIdx.x) * 8)) + 3072));
uint4 v__1 = *(uint4*)(Bias + (((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) >> 5) * 16)) + ((((int)threadIdx.x) & 1) * 8)));
((half2*)(&(__1.x)))->x = (((half2*)(&(v_.x)))->x+((half2*)(&(v__1.x)))->x);
((half2*)(&(__1.x)))->y = (((half2*)(&(v_.x)))->y+((half2*)(&(v__1.x)))->y);
((half2*)(&(__1.y)))->x = (((half2*)(&(v_.y)))->x+((half2*)(&(v__1.y)))->x);
((half2*)(&(__1.y)))->y = (((half2*)(&(v_.y)))->y+((half2*)(&(v__1.y)))->y);
((half2*)(&(__1.z)))->x = (((half2*)(&(v_.z)))->x+((half2*)(&(v__1.z)))->x);
((half2*)(&(__1.z)))->y = (((half2*)(&(v_.z)))->y+((half2*)(&(v__1.z)))->y);
((half2*)(&(__1.w)))->x = (((half2*)(&(v_.w)))->x+((half2*)(&(v__1.w)))->x);
((half2*)(&(__1.w)))->y = (((half2*)(&(v_.w)))->y+((half2*)(&(v__1.w)))->y);
*(uint4*)(((half_t*)buf_dyn_shmem) + (((i_10 * 1024) + (((int)threadIdx.x) * 8)) + 3072)) = __1;
}
__syncthreads();
#pragma unroll
for (int i_11 = 0; i_11 < 16; ++i_11) {
atomicAddx2((&(C[(((((((int)blockIdx.y) * 65536) + (i_11 * 4096)) + ((((int)threadIdx.x) >> 5) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 31) * 2))])), (&(((half_t*)buf_dyn_shmem)[(((((((i_11 >> 2) * 1024) + (((((int)threadIdx.x) & 31) >> 3) * 256)) + ((i_11 & 3) * 64)) + ((((int)threadIdx.x) >> 5) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 3072)])));
}have correctness issues while without atomicAdd it's correct.
for (int i_14 = 0; i_14 < 4; ++i_14) {
__syncthreads();
uint4 __1;
uint4 v_ = *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_14 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072));
uint4 v__1 = *(uint4*)(Bias + ((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) & 7) * 8)));
((half2*)(&(__1.x)))->x = (((half2*)(&(v_.x)))->x+((half2*)(&(v__1.x)))->x);
((half2*)(&(__1.x)))->y = (((half2*)(&(v_.x)))->y+((half2*)(&(v__1.x)))->y);
((half2*)(&(__1.y)))->x = (((half2*)(&(v_.y)))->x+((half2*)(&(v__1.y)))->x);
((half2*)(&(__1.y)))->y = (((half2*)(&(v_.y)))->y+((half2*)(&(v__1.y)))->y);
((half2*)(&(__1.z)))->x = (((half2*)(&(v_.z)))->x+((half2*)(&(v__1.z)))->x);
((half2*)(&(__1.z)))->y = (((half2*)(&(v_.z)))->y+((half2*)(&(v__1.z)))->y);
((half2*)(&(__1.w)))->x = (((half2*)(&(v_.w)))->x+((half2*)(&(v__1.w)))->x);
((half2*)(&(__1.w)))->y = (((half2*)(&(v_.w)))->y+((half2*)(&(v__1.w)))->y);
*(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_14 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072)) = __1;
}
__syncthreads();
#pragma unroll
for (int i_15 = 0; i_15 < 4; ++i_15) {
*(uint4*)(C + (((((((int)blockIdx.y) * 65536) + (i_15 * 16384)) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8))) = *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_15 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072));
}currently we disable atomicAdd when we have bias to skip this situation.
Reproduce:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bitblas
import bitblas.testing
from bitblas import Linear as BitBLASLinear
import torch
import time
import numpy as np
import torch.nn as nn
torch.manual_seed(0)
bitblas.set_log_level("DEBUG")
def correctness_consistent(m, in_features, out_features, bias):
linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda())
linear_bitblas = BitBLASLinear(
in_features,
out_features,
bias=bias,
A_dtype="float16",
W_dtype="float16",
accum_dtype="float16",
out_dtype="float16",
opt_M=m,
).cuda()
with torch.no_grad():
linear_bitblas.load_and_transform_weight(linear_torch.weight.clone())
if bias:
linear_bitblas.bias = nn.Parameter(linear_torch.bias.clone())
with torch.no_grad():
if not isinstance(m, int):
# When m is a list, average m
m = sum(m) // len(m)
input_data = torch.randn(m, in_features, dtype=torch.float16).cuda()
output_torch = linear_torch(input_data)
output_bitblas = linear_bitblas(input_data)
print(output_torch)
print(output_bitblas)
bitblas.testing.torch_assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2)
def test_correctness_consistent():
correctness_consistent(1, 1024, 1024, False)
correctness_consistent(1, 1024, 1024, True)
correctness_consistent(1024, 1024, 1024, True)
correctness_consistent([1, 1024], 1024, 1024, True)
def correctness_weight_only_dequantize(
m,
in_features,
out_features,
bias,
W_dtype,
group_size,
with_scaling,
with_zeros,
zeros_mode,
):
import numpy as np
from bitblas.quantization.utils import general_compress
from bitblas.cache import global_operator_cache
global_operator_cache.clear()
linear_bitblas = BitBLASLinear(
in_features,
out_features,
bias=bias,
A_dtype="float16",
W_dtype=W_dtype,
accum_dtype="float16",
out_dtype="float16",
group_size=group_size,
with_scaling=with_scaling,
with_zeros=with_zeros,
opt_M=m,
).cuda()
if not isinstance(m, int):
# average m
m = sum(m) // len(m)
input_shape = (m, in_features)
weight_shape = (out_features, in_features)
output_shape = (m, out_features)
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
source_format, bit = (
linear_bitblas.bitblas_matmul.source_format,
linear_bitblas.bitblas_matmul.bit,
)
maxq = 2**(bit - 1)
zeros = maxq
if source_format == "uint":
inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
elif source_format == "int":
inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda())
else:
raise NotImplementedError
inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())
intweight = inputs[1]
intweight = intweight.cpu().to(torch.int8)
if source_format == "int":
intweight = intweight + maxq
if with_zeros:
inputs[1] = inputs[1] - zeros
bias_tensor = torch.rand((output_shape[-1],), dtype=torch.float16).cuda()
ref_result = torch.matmul(inputs[0], (inputs[1].t()).to(torch.float16))
if bias:
ref_result = ref_result + bias_tensor
with torch.no_grad():
permuted_inputs = []
permuted_inputs.append(inputs[0])
if linear_bitblas.bitblas_matmul.weight_transform is not None:
permuted_inputs.append(
linear_bitblas.bitblas_matmul.weight_transform(intweight.cpu()).cuda())
else:
permuted_inputs.append(inputs[1])
linear_bitblas.qweight.data = permuted_inputs[-1].clone()
if with_scaling:
if group_size == -1:
group_size = in_features
permuted_inputs.append(
torch.ones([out_features, in_features // group_size], dtype=torch.float16).cuda())
linear_bitblas.scales.data = permuted_inputs[-1].clone()
if with_zeros:
if zeros_mode == "original":
permuted_inputs.append(
torch.ones([out_features, in_features // group_size],
dtype=torch.float16).cuda() * zeros)
elif zeros_mode == "rescale":
original_zeros = (
torch.ones([out_features, in_features // group_size],
dtype=torch.float16).cuda() * zeros)
scaled_zeros = original_zeros * permuted_inputs[-1]
permuted_inputs.append(scaled_zeros)
elif zeros_mode == "quantized":
original_zeros = (
torch.ones([in_features // group_size, out_features], dtype=torch.int8).cuda() *
zeros)
qzeros = general_compress(
original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
permuted_inputs.append(torch.from_numpy(qzeros).cuda())
else:
raise NotImplementedError
linear_bitblas.zeros.data = permuted_inputs[-1].clone()
if bias:
permuted_inputs.append(bias_tensor)
linear_bitblas.bias.data = bias_tensor.clone()
with torch.no_grad():
output_bitblas = linear_bitblas(inputs[0])
rtol = 1e0
atol = 1e0
if zeros_mode == "original":
rtol = 1e2
atol = 1e2
print(output_bitblas)
print(ref_result)
torch.testing.assert_close(output_bitblas, ref_result, rtol=rtol, atol=atol)
def test_correctness_weight_only_dequantize():
correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original")
correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original")
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale")
def profile(model, input_data):
model = model.cuda()
model.eval()
def get_runtime(num_repeats=1):
tic = time.time()
for _ in range(num_repeats):
_ = model(input_data)
torch.cuda.synchronize()
return (time.time() - tic) * 1000 / num_repeats
with torch.no_grad():
# print("Warming up ...")
st = time.time()
while time.time() - st < 1.0:
get_runtime() # warmup
warmup_runtime = get_runtime()
num_repeats = max(1, int(1000 / warmup_runtime))
times = get_runtime(num_repeats)
return np.mean(times)
if __name__ == "__main__":
# bitblas.testing.main()
correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels