From cdb9e67a8cee7f2ad8e4cd196a0231276ac45780 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 11 Dec 2025 16:42:06 -0800 Subject: [PATCH 01/13] Add an fma intrinsic This is equivalent to std::fma. The use-case is when you're in a strict_float context and you want an actual fma instruction. E.g. for bit-exact transcendentals. --- src/CodeGen_LLVM.cpp | 43 ++++++++++++------ src/IR.cpp | 1 + src/IR.h | 10 ++-- src/IROperator.cpp | 13 ++++++ src/IROperator.h | 7 +++ src/StrictifyFloat.cpp | 2 + test/correctness/CMakeLists.txt | 1 + test/correctness/simd_op_check_arm.cpp | 6 +-- test/correctness/simd_op_check_x86.cpp | 20 ++++---- test/correctness/strict_fma.cpp | 63 ++++++++++++++++++++++++++ 10 files changed, 134 insertions(+), 32 deletions(-) create mode 100644 test/correctness/strict_fma.cpp diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 32984f3f2e6f..9cc20ea45363 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -3320,10 +3320,19 @@ void CodeGen_LLVM::visit(const Call *op) { } } - Expr call = Call::make(op->type, op->name, new_args, op->call_type); { ScopedValue old_in_strict_float(in_strict_float, true); - value = codegen(unstrictify_float(call.as())); + if (op->is_intrinsic(Call::strict_fma)) { + // Redirect to an llvm fma intrinsic at some good width + const int native_lanes = target.natural_vector_size(op->type.element_of()); + Type t = op->type.with_lanes(native_lanes); + std::string name = "llvm.fma" + mangle_llvm_type(llvm_type_of(t)); + value = call_intrin(op->type, native_lanes, name, new_args); + } else { + // Lower to something other than a call node + Expr call = Call::make(op->type, op->name, new_args, op->call_type); + value = codegen(unstrictify_float(call.as())); + } } for (size_t i = 0; i < op->args.size(); i++) { @@ -4729,23 +4738,29 @@ Value *CodeGen_LLVM::call_intrin(const Type &result_type, int intrin_lanes, Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes, const string &name, vector arg_values, bool scalable_vector_result, bool is_reduction) { + auto fix_vector_lanes_of_type = [&](const llvm::Type *t) { + if (intrin_lanes == 1 || is_reduction) { + return t->getScalarType(); + } else { + if (scalable_vector_result && effective_vscale != 0) { + return get_vector_type(result_type->getScalarType(), + intrin_lanes / effective_vscale, VectorTypeConstraint::VScale); + } else { + return get_vector_type(result_type->getScalarType(), + intrin_lanes, VectorTypeConstraint::Fixed); + } + } + }; + llvm::Function *fn = module->getFunction(name); if (!fn) { vector arg_types(arg_values.size()); for (size_t i = 0; i < arg_values.size(); i++) { - arg_types[i] = arg_values[i]->getType(); + llvm::Type *t = arg_values[i]->getType(); + arg_types[i] = fix_vector_lanes_of_type(t); } - llvm::Type *intrinsic_result_type = result_type->getScalarType(); - if (intrin_lanes > 1 && !is_reduction) { - if (scalable_vector_result && effective_vscale != 0) { - intrinsic_result_type = get_vector_type(result_type->getScalarType(), - intrin_lanes / effective_vscale, VectorTypeConstraint::VScale); - } else { - intrinsic_result_type = get_vector_type(result_type->getScalarType(), - intrin_lanes, VectorTypeConstraint::Fixed); - } - } + llvm::Type *intrinsic_result_type = fix_vector_lanes_of_type(result_type); FunctionType *func_t = FunctionType::get(intrinsic_result_type, arg_types, false); fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get()); fn->setCallingConv(CallingConv::C); @@ -4780,7 +4795,7 @@ Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes if (arg_i_lanes >= arg_lanes) { // Horizontally reducing intrinsics may have // arguments that have more lanes than the - // result. Assume that the horizontally reduce + // result. Assume that they horizontally reduce // neighboring elements... int reduce = arg_i_lanes / arg_lanes; args.push_back(slice_vector(arg_values[i], start * reduce, intrin_lanes * reduce)); diff --git a/src/IR.cpp b/src/IR.cpp index c844c672656a..fdd53d8b6ae8 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -680,6 +680,7 @@ const char *const intrinsic_op_names[] = { "strict_add", "strict_div", "strict_eq", + "strict_fma", "strict_le", "strict_lt", "strict_max", diff --git a/src/IR.h b/src/IR.h index 6dc0204b89ec..5a54717df080 100644 --- a/src/IR.h +++ b/src/IR.h @@ -628,6 +628,7 @@ struct Call : public ExprNode { strict_add, strict_div, strict_eq, + strict_fma, strict_le, strict_lt, strict_max, @@ -793,13 +794,14 @@ struct Call : public ExprNode { return is_intrinsic( {Call::strict_add, Call::strict_div, + Call::strict_eq, + Call::strict_fma, + Call::strict_lt, + Call::strict_le, Call::strict_max, Call::strict_min, Call::strict_mul, - Call::strict_sub, - Call::strict_lt, - Call::strict_le, - Call::strict_eq}); + Call::strict_sub}); } static const IRNodeType _node_type = IRNodeType::Call; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index f1d2254abb27..285744ba6eef 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -2280,6 +2280,19 @@ Expr erf(const Expr &x) { return halide_erf(x); } +Expr fma(const Expr &a, const Expr &b, const Expr &c) { + user_assert(a.type().is_float()) << "fma requires floating-point arguments."; + user_assert(a.type() == b.type() && a.type() == c.type()) + << "All arguments to fma must have the same type."; + + // TODO: Once we use LLVM's native bfloat type instead of treating them as + // ints, we should be able to remove this assert. Currently, it tries to + // codegen an integer fma. + user_assert(!a.type().is_bfloat()) << "fma does not yet support bfloat types."; + + return Call::make(a.type(), Call::strict_fma, {a, b, c}, Call::PureIntrinsic); +} + Expr fast_pow(Expr x, Expr y) { if (auto i = as_const_int(y)) { return raise_to_integer_power(std::move(x), *i); diff --git a/src/IROperator.h b/src/IROperator.h index d6d33a1cf82e..58ca80eb6d15 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -978,6 +978,13 @@ Expr pow(Expr x, Expr y); * mantissa. Vectorizes cleanly. */ Expr erf(const Expr &x); +/** Fused multiply-add. fma(a, b, c) is equivalent to a * b + c, but only + * rounded once at the end. Halide will turn a * b + c into an fma + * automatically, except in strict_float contexts. This intrinsic only exists in + * order to request a true fma inside a strict_float context. A true fma will be + * emulated on targets without one. */ +Expr fma(const Expr &, const Expr &, const Expr &); + /** Fast vectorizable approximation to some trigonometric functions for * Float(32). Absolute approximation error is less than 1e-5. Slow on x86 if * you don't have at least sse 4.1. */ diff --git a/src/StrictifyFloat.cpp b/src/StrictifyFloat.cpp index 13dd0873bb12..aeaae9cff564 100644 --- a/src/StrictifyFloat.cpp +++ b/src/StrictifyFloat.cpp @@ -142,6 +142,8 @@ Expr unstrictify_float(const Call *op) { return op->args[0] <= op->args[1]; } else if (op->is_intrinsic(Call::strict_eq)) { return op->args[0] == op->args[1]; + } else if (op->is_intrinsic(Call::strict_fma)) { + return op->args[0] * op->args[1] + op->args[2]; } else { internal_error << "Missing lowering of strict float intrinsic: " << Expr(op) << "\n"; diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 4bce8789875e..144a7829f80f 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -317,6 +317,7 @@ tests(GROUPS correctness store_in.cpp strict_float.cpp strict_float_bounds.cpp + strict_fma.cpp strided_load.cpp target.cpp target_query.cpp diff --git a/test/correctness/simd_op_check_arm.cpp b/test/correctness/simd_op_check_arm.cpp index c3021578beb5..2952839ab490 100644 --- a/test/correctness/simd_op_check_arm.cpp +++ b/test/correctness/simd_op_check_arm.cpp @@ -390,10 +390,8 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vmla.i32" : "mla", 2 * w, u32_1 + u32_2 * u32_3); if (w == 1 || w == 2) { // Older llvms don't always fuse this at non-native widths - // TODO: Re-enable this after fixing https://github.com/halide/Halide/issues/3477 - // check(arm32 ? "vmla.f32" : "fmla", 2*w, f32_1 + f32_2*f32_3); - if (!arm32) - check(arm32 ? "vmla.f32" : "fmla", 2 * w, f32_1 + f32_2 * f32_3); + check(arm32 ? "vmla.f32" : "fmla", 2 * w, f32_1 + f32_2 * f32_3); + check(arm32 ? "vmla.f32" : "fmla", 2 * w, fma(f32_1, f32_2, f32_3)); } if (!arm32 && target.has_feature(Target::ARMFp16)) { check("fmlal", 4 * w, f32_1 + widening_mul(f16_2, f16_3)); diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index b5f4c0fa9f64..2d1918ff0a3f 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -411,23 +411,23 @@ class SimdOpCheckX86 : public SimdOpCheckTest { check(use_avx512 ? "vrsqrt*ps" : "vrsqrtps*ymm", 8, fast_inverse_sqrt(f32_1)); check(use_avx512 ? "vrcp*ps" : "vrcpps*ymm", 8, fast_inverse(f32_1)); -#if 0 - // Not implemented in the front end. - check("vandnps", 8, bool1 & (!bool2)); - check("vandps", 8, bool1 & bool2); - check("vorps", 8, bool1 | bool2); - check("vxorps", 8, bool1 ^ bool2); -#endif + check(use_avx512 ? "kandw" : "vandps", 8, bool_1 & bool_2); + check(use_avx512 ? "korw" : "vorps", 8, bool_1 | bool_2); + check(use_avx512 ? "kxorw" : "vxorps", 8, bool_1 ^ bool_2); check("vaddps*ymm", 8, f32_1 + f32_2); check("vaddpd*ymm", 4, f64_1 + f64_2); check("vmulps*ymm", 8, f32_1 * f32_2); check("vmulpd*ymm", 4, f64_1 * f64_2); + check("vfmadd*ps*ymm", 8, f32_1 * f32_2 + f32_3); + check("vfmadd*pd*ymm", 4, f64_1 * f64_2 + f64_3); + check("vfmadd*ps*ymm", 8, fma(f32_1, f32_2, f32_3)); + check("vfmadd*pd*ymm", 4, fma(f64_1, f64_2, f64_3)); check("vsubps*ymm", 8, f32_1 - f32_2); check("vsubpd*ymm", 4, f64_1 - f64_2); - // LLVM no longer generates division instruction when fast-math is on - // check("vdivps", 8, f32_1 / f32_2); - // check("vdivpd", 4, f64_1 / f64_2); + + check("vdivps", 8, strict_float(f32_1 / f32_2)); + check("vdivpd", 4, strict_float(f64_1 / f64_2)); check("vminps*ymm", 8, min(f32_1, f32_2)); check("vminpd*ymm", 4, min(f64_1, f64_2)); check("vmaxps*ymm", 8, max(f32_1, f32_2)); diff --git a/test/correctness/strict_fma.cpp b/test/correctness/strict_fma.cpp new file mode 100644 index 000000000000..002dd79bb7d8 --- /dev/null +++ b/test/correctness/strict_fma.cpp @@ -0,0 +1,63 @@ +#include "Halide.h" + +using namespace Halide; + +template +int test() { + std::cout << "Testing " << type_of() << "\n"; + Func f{"f"}, g{"g"}; + Param b{"b"}, c{"c"}; + Var x{"x"}; + + f(x) = fma(cast(x), b, c); + g(x) = strict_float(cast(x) * b + c); + + // Use a non-native vector width, to also test legalization + f.vectorize(x, 5); + g.vectorize(x, 5); + + // b.set((T)8769132.122433244233); + // c.set((T)2809.14123423413); + b.set((T)1.111111111); + c.set((T)1.101010101); + Buffer with_fma = f.realize({1024}); + Buffer without_fma = g.realize({1024}); + + bool saw_error = false; + for (int i = 0; i < with_fma.width(); i++) { + if (with_fma(i) == without_fma(i)) { + continue; + } + + saw_error = true; + // The rounding error, if any, ought to be 1 ULP + Bits fma_bits = Internal::reinterpret_bits(with_fma(i)); + Bits no_fma_bits = Internal::reinterpret_bits(without_fma(i)); + if (fma_bits + 1 != no_fma_bits && + fma_bits - 1 != no_fma_bits) { + printf("Difference greater than 1 ULP: %10.10g (0x%llx) vs %10.10g (0x%llx)!\n", + (double)with_fma(i), (long long unsigned)fma_bits, + (double)without_fma(i), (long long unsigned)no_fma_bits); + return -1; + } + } + + if (!saw_error) { + printf("There should have occasionally been a 1 ULP difference between fma and non-fma results\n"); + return -1; + } + + return 0; +} + +int main(int argc, char **argv) { + + if (test() || + test() || + test()) { + return -1; + } + + printf("Success!\n"); + return 0; +} From c622de9bcc6db57d91a06a9e982b7ad534dc8e8a Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 11 Dec 2025 16:43:52 -0800 Subject: [PATCH 02/13] Add fma to python bindings --- python_bindings/src/halide/halide_/PyIROperator.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/python_bindings/src/halide/halide_/PyIROperator.cpp b/python_bindings/src/halide/halide_/PyIROperator.cpp index 2adbe3bee35c..aa2a667e79bd 100644 --- a/python_bindings/src/halide/halide_/PyIROperator.cpp +++ b/python_bindings/src/halide/halide_/PyIROperator.cpp @@ -126,6 +126,7 @@ void define_operators(py::module &m) { m.def("log", &log); m.def("pow", &pow); m.def("erf", &erf); + m.def("fma", &fma); m.def("fast_sin", &fast_sin); m.def("fast_cos", &fast_cos); m.def("fast_log", &fast_log); From 9ea8c63043a60e2da5c45563ca031139e92a12b1 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 12 Dec 2025 13:34:08 -0800 Subject: [PATCH 03/13] Don't even try for fma on arm 32 --- test/correctness/simd_op_check_arm.cpp | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/test/correctness/simd_op_check_arm.cpp b/test/correctness/simd_op_check_arm.cpp index 2952839ab490..b3238d023dbe 100644 --- a/test/correctness/simd_op_check_arm.cpp +++ b/test/correctness/simd_op_check_arm.cpp @@ -388,14 +388,13 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vmla.i16" : "mla", 4 * w, u16_1 + u16_2 * u16_3); check(arm32 ? "vmla.i32" : "mla", 2 * w, i32_1 + i32_2 * i32_3); check(arm32 ? "vmla.i32" : "mla", 2 * w, u32_1 + u32_2 * u32_3); - if (w == 1 || w == 2) { - // Older llvms don't always fuse this at non-native widths - check(arm32 ? "vmla.f32" : "fmla", 2 * w, f32_1 + f32_2 * f32_3); - check(arm32 ? "vmla.f32" : "fmla", 2 * w, fma(f32_1, f32_2, f32_3)); - } - if (!arm32 && target.has_feature(Target::ARMFp16)) { - check("fmlal", 4 * w, f32_1 + widening_mul(f16_2, f16_3)); - check("fmlal2", 8 * w, widening_mul(f16_1, f16_2) + f32_3); + if (!arm32) { + check("fmla", 2 * w, f32_1 * f32_2 + f32_3); + check("fmla", 2 * w, fma(f32_1, f32_2, f32_3)); + if (target.has_feature(Target::ARMFp16)) { + check("fmlal", 4 * w, f32_1 + widening_mul(f16_2, f16_3)); + check("fmlal2", 8 * w, widening_mul(f16_1, f16_2) + f32_3); + } } // VMLS I, F F, D Multiply Subtract @@ -405,12 +404,8 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vmls.i16" : "mls", 4 * w, u16_1 - u16_2 * u16_3); check(arm32 ? "vmls.i32" : "mls", 2 * w, i32_1 - i32_2 * i32_3); check(arm32 ? "vmls.i32" : "mls", 2 * w, u32_1 - u32_2 * u32_3); - if (w == 1 || w == 2) { - // Older llvms don't always fuse this at non-native widths - // TODO: Re-enable this after fixing https://github.com/halide/Halide/issues/3477 - // check(arm32 ? "vmls.f32" : "fmls", 2*w, f32_1 - f32_2*f32_3); - if (!arm32) - check(arm32 ? "vmls.f32" : "fmls", 2 * w, f32_1 - f32_2 * f32_3); + if (!arm32) { + check("fmls", 2 * w, f32_1 - f32_2 * f32_3); } // VMLAL I - Multiply Accumulate Long From 22f67651dd7ab0fdcfe9dc8dc8f479d2bcd2db82 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 15 Dec 2025 14:14:06 -0800 Subject: [PATCH 04/13] Get fma working in C and GPU backends Also a drive-by fix for fmod --- apps/c_backend/pipeline_generator.cpp | 2 +- src/CodeGen_C.cpp | 25 ++++++++++++++-- src/CodeGen_C_prologue.template.cpp | 26 +++++++++++++++++ src/CodeGen_C_vectors.template.cpp | 32 ++++++++++++++++++++ src/CodeGen_D3D12Compute_Dev.cpp | 4 +++ src/CodeGen_GPU_Dev.cpp | 14 +++++++++ src/CodeGen_GPU_Dev.h | 10 +++++++ src/CodeGen_LLVM.cpp | 7 ++--- src/CodeGen_Metal_Dev.cpp | 1 + src/CodeGen_OpenCL_Dev.cpp | 2 +- src/CodeGen_PTX_Dev.cpp | 25 ++++++++++++---- src/CodeGen_Vulkan_Dev.cpp | 12 ++++++-- src/Lower.cpp | 2 +- src/OffloadGPULoops.cpp | 12 +++++--- src/OffloadGPULoops.h | 2 +- test/correctness/strict_fma.cpp | 42 +++++++++++++++++++++++---- 16 files changed, 190 insertions(+), 28 deletions(-) diff --git a/apps/c_backend/pipeline_generator.cpp b/apps/c_backend/pipeline_generator.cpp index c6a28bc477fa..f4602c9213a5 100644 --- a/apps/c_backend/pipeline_generator.cpp +++ b/apps/c_backend/pipeline_generator.cpp @@ -14,7 +14,7 @@ class Pipeline : public Halide::Generator { Var x, y; Func f, h; - f(x, y) = (input(clamp(x + 2, 0, input.dim(0).extent() - 1), clamp(y - 2, 0, input.dim(1).extent() - 1)) * 17) / 13; + f(x, y) = (input(clamp(x + 2, 0, input.dim(0).extent() - 1), clamp(y - 2, 0, input.dim(1).extent() - 1)) * 17) / 13 + cast(x % 3.4f + fma(cast(y), 0.5f, 1.2f)); h.define_extern("an_extern_stage", {f}, Int(16), 0, NameMangling::C); output(x, y) = cast(max(0, f(y, x) + f(x, y) + an_extern_func(x, y) + h())); diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index a5dc3298be63..65892bff2c2c 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -1351,7 +1351,12 @@ void CodeGen_C::visit(const Mod *op) { string arg0 = print_expr(op->a); string arg1 = print_expr(op->b); ostringstream rhs; - rhs << "fmod(" << arg0 << ", " << arg1 << ")"; + if (op->type.is_scalar()) { + rhs << "::halide_cpp_fmod("; + } else { + rhs << print_type(op->type) << "_ops::fmod("; + } + rhs << arg0 << ", " << arg1 << ")"; print_assignment(op->type, rhs.str()); } else { visit_binop(op->type, op->a, op->b, "%"); @@ -1845,8 +1850,24 @@ void CodeGen_C::visit(const Call *op) { << " + " << print_expr(base_offset) << "), /*rw*/0, /*locality*/0), 0)"; } else if (op->is_intrinsic(Call::size_of_halide_buffer_t)) { rhs << "(sizeof(halide_buffer_t))"; + } else if (op->is_intrinsic(Call::strict_fma)) { + internal_assert(op->args.size() == 3) + << "Wrong number of args for strict_fma: " << op->args.size(); + if (op->type.is_scalar()) { + rhs << "::halide_cpp_fma(" + << print_expr(op->args[0]) << ", " + << print_expr(op->args[1]) << ", " + << print_expr(op->args[2]) << ")"; + } else { + rhs << print_type(op->type) << "_ops::fma(" + << print_expr(op->args[0]) << ", " + << print_expr(op->args[1]) << ", " + << print_expr(op->args[2]) << ")"; + } } else if (op->is_strict_float_intrinsic()) { - // This depends on the generated C++ being compiled without -ffast-math + // This depends on the generated C++ being compiled without + // -ffast-math. Note that this would not be correct for strict_fma, so + // we handle it separately above. Expr equiv = unstrictify_float(op); rhs << print_expr(equiv); } else if (op->is_intrinsic()) { diff --git a/src/CodeGen_C_prologue.template.cpp b/src/CodeGen_C_prologue.template.cpp index 5d85d585716c..7340c50e3179 100644 --- a/src/CodeGen_C_prologue.template.cpp +++ b/src/CodeGen_C_prologue.template.cpp @@ -257,6 +257,32 @@ inline T halide_cpp_min(const T &a, const T &b) { return (a < b) ? a : b; } +template +inline T halide_cpp_fma(const T &a, const T &b, const T &c) { +#if __has_builtin(__builtin_fma) + return __builtin_fma(a, b, c); +#else + if (sizeof(T) == sizeof(float)) { + return fmaf(a, b, c); + } else { + return (T)fma((double)a, (double)b, (double)c); + } +#endif +} + +template +inline T halide_cpp_fmod(const T &a, const T &b) { +#if __has_builtin(__builtin_fmod) + return __builtin_fmod(a, b); +#else + if (sizeof(T) == sizeof(float)) { + return fmod(a, b); + } else { + return (T)fmod((double)a, (double)b); + } +#endif +} + template inline void halide_maybe_unused(const T &) { } diff --git a/src/CodeGen_C_vectors.template.cpp b/src/CodeGen_C_vectors.template.cpp index 003d2423414d..2ada1bd8afbd 100644 --- a/src/CodeGen_C_vectors.template.cpp +++ b/src/CodeGen_C_vectors.template.cpp @@ -150,6 +150,22 @@ class CppVectorOps { return r; } + static Vec fma(const Vec &a, const Vec &b, const Vec &c) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fma(a[i], b[i], c[i]); + } + return r; + } + + static Vec fmod(const Vec &a, const Vec &b) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fmod(a[i], b[i]); + } + return r; + } + static Mask logical_or(const Vec &a, const Vec &b) { CppVector r; for (size_t i = 0; i < Lanes; i++) { @@ -734,6 +750,22 @@ class NativeVectorOps { #endif } + static Vec fma(const Vec a, const Vec b, const Vec c) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fma(a[i], b[i], c[i]); + } + return r; + } + + static Vec fmod(const Vec a, const Vec b) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fmod(a[i], b[i]); + } + return r; + } + // The relational operators produce signed-int of same width as input; our codegen expects uint8. static Mask logical_or(const Vec a, const Vec b) { using T = typename NativeVectorComparisonType::type; diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index 4ce641e680ad..99c2ffb13a3b 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -1257,6 +1257,10 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s, void CodeGen_D3D12Compute_Dev::init_module() { debug(2) << "D3D12Compute device codegen init_module\n"; + // TODO: we could support strict float intrinsics with the precise qualifier + internal_assert(!any_strict_float) + << "strict float intrinsics not yet supported in d3d12compute backend"; + // wipe the internal kernel source src_stream.str(""); src_stream.clear(); diff --git a/src/CodeGen_GPU_Dev.cpp b/src/CodeGen_GPU_Dev.cpp index acd539664e96..595dbb82b9c5 100644 --- a/src/CodeGen_GPU_Dev.cpp +++ b/src/CodeGen_GPU_Dev.cpp @@ -245,6 +245,20 @@ void CodeGen_GPU_C::visit(const Call *op) { equiv.accept(this); } } + } else if (op->is_intrinsic(Call::strict_fma)) { + // All shader languages have fma + Expr equiv = Call::make(op->type, "fma", op->args, Call::PureExtern); + equiv.accept(this); + } else { + CodeGen_C::visit(op); + } +} + +void CodeGen_GPU_C::visit(const Mod *op) { + if (op->type.is_float()) { + // All shader languages have fmod + Expr equiv = Call::make(op->type, "fmod", {op->a, op->b}, Call::PureExtern); + equiv.accept(this); } else { CodeGen_C::visit(op); } diff --git a/src/CodeGen_GPU_Dev.h b/src/CodeGen_GPU_Dev.h index ee2950464526..be56625dac55 100644 --- a/src/CodeGen_GPU_Dev.h +++ b/src/CodeGen_GPU_Dev.h @@ -77,6 +77,15 @@ struct CodeGen_GPU_Dev { Device = 1, // Device/global memory fence Shared = 2 // Threadgroup/shared memory fence }; + + /** Some GPU APIs need to know what floating point mode we're in at kernel + * emission time, to emit appropriate pragmas. */ + bool any_strict_float = false; + +public: + void set_any_strict_float(bool any_strict_float) { + this->any_strict_float = any_strict_float; + } }; /** A base class for GPU backends that require C-like shader output. @@ -99,6 +108,7 @@ class CodeGen_GPU_C : public CodeGen_C { using CodeGen_C::visit; void visit(const Shuffle *op) override; void visit(const Call *op) override; + void visit(const Mod *op) override; std::string print_extern_call(const Call *op) override; diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 9cc20ea45363..4cf9ba492d93 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -3323,11 +3323,8 @@ void CodeGen_LLVM::visit(const Call *op) { { ScopedValue old_in_strict_float(in_strict_float, true); if (op->is_intrinsic(Call::strict_fma)) { - // Redirect to an llvm fma intrinsic at some good width - const int native_lanes = target.natural_vector_size(op->type.element_of()); - Type t = op->type.with_lanes(native_lanes); - std::string name = "llvm.fma" + mangle_llvm_type(llvm_type_of(t)); - value = call_intrin(op->type, native_lanes, name, new_args); + std::string name = "llvm.fma" + mangle_llvm_type(llvm_type_of(op->type)); + value = call_intrin(op->type, op->type.lanes(), name, new_args); } else { // Lower to something other than a call node Expr call = Call::make(op->type, op->name, new_args, op->call_type); diff --git a/src/CodeGen_Metal_Dev.cpp b/src/CodeGen_Metal_Dev.cpp index d31e97e30427..a60bd973cb29 100644 --- a/src/CodeGen_Metal_Dev.cpp +++ b/src/CodeGen_Metal_Dev.cpp @@ -834,6 +834,7 @@ void CodeGen_Metal_Dev::init_module() { // Write out the Halide math functions. src_stream << "#pragma clang diagnostic ignored \"-Wunused-function\"\n" + << "#pragma METAL fp math_mode(" << (any_strict_float ? "safe)\n" : "fast)\n") << "#include \n" << "using namespace metal;\n" // Seems like the right way to go. << "namespace {\n" diff --git a/src/CodeGen_OpenCL_Dev.cpp b/src/CodeGen_OpenCL_Dev.cpp index 1c945efb7cc1..807d75444ed4 100644 --- a/src/CodeGen_OpenCL_Dev.cpp +++ b/src/CodeGen_OpenCL_Dev.cpp @@ -1123,7 +1123,7 @@ void CodeGen_OpenCL_Dev::init_module() { // This identifies the program as OpenCL C (as opposed to SPIR). src_stream << "/*OpenCL C " << target.to_string() << "*/\n"; - src_stream << "#pragma OPENCL FP_CONTRACT ON\n"; + src_stream << "#pragma OPENCL FP_CONTRACT " << (any_strict_float ? "OFF\n" : "ON\n"); // Write out the Halide math functions. src_stream << "inline float float_from_bits(unsigned int x) {return as_float(x);}\n" diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 3204002652c9..28ee05137b38 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -220,6 +220,12 @@ void CodeGen_PTX_Dev::add_kernel(Stmt stmt, } void CodeGen_PTX_Dev::init_module() { + // This class uses multiple inheritance. It's a GPU device code generator, + // and also an llvm-based one. Both of these track strict_float presence, + // but OffloadGPULoops only sets the GPU device code generator flag, so here + // we set the CodeGen_LLVM flag to match. + CodeGen_LLVM::any_strict_float = CodeGen_GPU_Dev::any_strict_float; + init_context(); module = get_initial_module_for_ptx_device(target, context); @@ -249,6 +255,15 @@ void CodeGen_PTX_Dev::init_module() { function_does_not_access_memory(fn); fn->addFnAttr(llvm::Attribute::NoUnwind); } + + if (CodeGen_GPU_Dev::any_strict_float) { + debug(0) << "Setting strict fp math\n"; + set_strict_fp_math(); + in_strict_float = target.has_feature(Target::StrictFloat); + } else { + debug(0) << "Setting fast fp math\n"; + set_fast_fp_math(); + } } void CodeGen_PTX_Dev::visit(const Call *op) { @@ -611,13 +626,13 @@ vector CodeGen_PTX_Dev::compile_to_src() { internal_assert(llvm_target) << "Could not create LLVM target for " << triple.str() << "\n"; TargetOptions options; - options.AllowFPOpFusion = FPOpFusion::Fast; + options.AllowFPOpFusion = CodeGen_GPU_Dev::any_strict_float ? llvm::FPOpFusion::Strict : llvm::FPOpFusion::Fast; #if LLVM_VERSION < 210 - options.UnsafeFPMath = true; + options.UnsafeFPMath = !CodeGen_GPU_Dev::any_strict_float; #endif - options.NoInfsFPMath = true; - options.NoNaNsFPMath = true; - options.HonorSignDependentRoundingFPMathOption = false; + options.NoInfsFPMath = !CodeGen_GPU_Dev::any_strict_float; + options.NoNaNsFPMath = !CodeGen_GPU_Dev::any_strict_float; + options.HonorSignDependentRoundingFPMathOption = !CodeGen_GPU_Dev::any_strict_float; options.NoZerosInBSS = false; options.GuaranteedTailCallOpt = false; diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index 671f923ec183..e02a1a55a4ff 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -201,6 +201,7 @@ class CodeGen_Vulkan_Dev : public CodeGen_GPU_Dev { {"fast_pow_f32", GLSLstd450Pow}, {"floor_f16", GLSLstd450Floor}, {"floor_f32", GLSLstd450Floor}, + {"fma", GLSLstd450Fma}, {"log_f16", GLSLstd450Log}, {"log_f32", GLSLstd450Log}, {"sin_f16", GLSLstd450Sin}, @@ -1190,9 +1191,14 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Call *op) { e.accept(this); } } else if (op->is_strict_float_intrinsic()) { - // TODO: Enable/Disable RelaxedPrecision flags? - Expr e = unstrictify_float(op); - e.accept(this); + if (op->is_intrinsic(Call::strict_fma)) { + Expr builtin_call = Call::make(op->type, "fma", op->args, Call::PureExtern); + builtin_call.accept(this); + } else { + // TODO: Enable/Disable RelaxedPrecision flags? + Expr e = unstrictify_float(op); + e.accept(this); + } } else if (op->is_intrinsic(Call::IntrinsicOp::sorted_avg)) { internal_assert(op->args.size() == 2); // b > a, so the following works without widening: diff --git a/src/Lower.cpp b/src/Lower.cpp index fcbc66747242..32b64e83a2bd 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -493,7 +493,7 @@ void lower_impl(const vector &output_funcs, if (t.has_gpu_feature()) { debug(1) << "Offloading GPU loops...\n"; - s = inject_gpu_offload(s, t); + s = inject_gpu_offload(s, t, any_strict_float); debug(2) << "Lowering after splitting off GPU loops:\n" << s << "\n\n"; } else { diff --git a/src/OffloadGPULoops.cpp b/src/OffloadGPULoops.cpp index e93f67e8bfef..11b8c3ccecf3 100644 --- a/src/OffloadGPULoops.cpp +++ b/src/OffloadGPULoops.cpp @@ -245,7 +245,7 @@ class InjectGpuOffload : public IRMutator { } public: - InjectGpuOffload(const Target &target) + InjectGpuOffload(const Target &target, bool any_strict_float) : target(target) { Target device_target = target; // For the GPU target we just want to pass the flags, to avoid the @@ -266,12 +266,16 @@ class InjectGpuOffload : public IRMutator { cgdev[DeviceAPI::D3D12Compute] = new_CodeGen_D3D12Compute_Dev(device_target); } if (target.has_feature(Target::Vulkan)) { - cgdev[DeviceAPI::Vulkan] = new_CodeGen_Vulkan_Dev(target); + cgdev[DeviceAPI::Vulkan] = new_CodeGen_Vulkan_Dev(device_target); } if (target.has_feature(Target::WebGPU)) { cgdev[DeviceAPI::WebGPU] = new_CodeGen_WebGPU_Dev(device_target); } + for (auto &i : cgdev) { + i.second->set_any_strict_float(any_strict_float); + } + internal_assert(!cgdev.empty()) << "Requested unknown GPU target: " << target.to_string() << "\n"; } @@ -315,8 +319,8 @@ class InjectGpuOffload : public IRMutator { } // namespace -Stmt inject_gpu_offload(const Stmt &s, const Target &host_target) { - return InjectGpuOffload(host_target).inject(s); +Stmt inject_gpu_offload(const Stmt &s, const Target &host_target, bool any_strict_float) { + return InjectGpuOffload(host_target, any_strict_float).inject(s); } } // namespace Internal diff --git a/src/OffloadGPULoops.h b/src/OffloadGPULoops.h index d927f1a8b780..97cd7737271f 100644 --- a/src/OffloadGPULoops.h +++ b/src/OffloadGPULoops.h @@ -17,7 +17,7 @@ namespace Internal { /** Pull loops marked with GPU device APIs to a separate * module, and call them through the appropriate host runtime module. */ -Stmt inject_gpu_offload(const Stmt &s, const Target &host_target); +Stmt inject_gpu_offload(const Stmt &s, const Target &host_target, bool any_strict_float); } // namespace Internal } // namespace Halide diff --git a/test/correctness/strict_fma.cpp b/test/correctness/strict_fma.cpp index 002dd79bb7d8..bbd8eea32c87 100644 --- a/test/correctness/strict_fma.cpp +++ b/test/correctness/strict_fma.cpp @@ -12,9 +12,18 @@ int test() { f(x) = fma(cast(x), b, c); g(x) = strict_float(cast(x) * b + c); - // Use a non-native vector width, to also test legalization - f.vectorize(x, 5); - g.vectorize(x, 5); + Target t = get_jit_target_from_environment(); + if (std::is_same_v && + t.has_gpu_feature() && + !t.has_feature(Target::Vulkan)) { // TODO: Vulkan does not yet respect strict_float + Var xo{"xo"}, xi{"xi"}; + f.gpu_tile(x, xo, xi, 32); + g.gpu_tile(x, xo, xi, 32); + } else { + // Use a non-native vector width, to also test legalization + f.vectorize(x, 5); + g.vectorize(x, 5); + } // b.set((T)8769132.122433244233); // c.set((T)2809.14123423413); @@ -23,16 +32,39 @@ int test() { Buffer with_fma = f.realize({1024}); Buffer without_fma = g.realize({1024}); + with_fma.copy_to_host(); + without_fma.copy_to_host(); + bool saw_error = false; for (int i = 0; i < with_fma.width(); i++) { + + Bits fma_bits = Internal::reinterpret_bits(with_fma(i)); + Bits no_fma_bits = Internal::reinterpret_bits(without_fma(i)); + + if constexpr (sizeof(T) >= 4) { + T correct_fma = std::fma((T)i, b.get(), c.get()); + + if (with_fma(i) != correct_fma) { + printf("fma result does not match std::fma:\n" + " fma(%d, %10.10g, %10.10g) = %10.10g (0x%llx)\n" + " but std::fma gives %10.10g (0x%llx)\n", + i, + (double)b.get(), (double)c.get(), + (double)with_fma(i), + (long long unsigned)fma_bits, + (double)correct_fma, + (long long unsigned)Internal::reinterpret_bits(correct_fma)); + return -1; + } + } + if (with_fma(i) == without_fma(i)) { continue; } saw_error = true; // The rounding error, if any, ought to be 1 ULP - Bits fma_bits = Internal::reinterpret_bits(with_fma(i)); - Bits no_fma_bits = Internal::reinterpret_bits(without_fma(i)); + // printf("%llx %llx\n", (long long unsigned)fma_bits, (long long unsigned)no_fma_bits); if (fma_bits + 1 != no_fma_bits && fma_bits - 1 != no_fma_bits) { printf("Difference greater than 1 ULP: %10.10g (0x%llx) vs %10.10g (0x%llx)!\n", From 2f8558b1903436f455a06b62c50dea6fac012025 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 15 Dec 2025 16:04:44 -0800 Subject: [PATCH 05/13] move definition of has_builtin --- src/CodeGen_C_prologue.template.cpp | 5 ++++- src/CodeGen_C_vectors.template.cpp | 4 ---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/CodeGen_C_prologue.template.cpp b/src/CodeGen_C_prologue.template.cpp index 7340c50e3179..72d0b362febe 100644 --- a/src/CodeGen_C_prologue.template.cpp +++ b/src/CodeGen_C_prologue.template.cpp @@ -1,9 +1,12 @@ /* MACHINE GENERATED By Halide. */ - #if !(__cplusplus >= 201103L || _MSVC_LANG >= 201103L) #error "This code requires C++11 (or later); please upgrade your compiler." #endif +#if !defined(__has_builtin) +#define __has_builtin(x) 0 +#endif + #include #include #include diff --git a/src/CodeGen_C_vectors.template.cpp b/src/CodeGen_C_vectors.template.cpp index 2ada1bd8afbd..44a9b3c0eee5 100644 --- a/src/CodeGen_C_vectors.template.cpp +++ b/src/CodeGen_C_vectors.template.cpp @@ -2,10 +2,6 @@ #define __has_attribute(x) 0 #endif -#if !defined(__has_builtin) -#define __has_builtin(x) 0 -#endif - namespace { // We can't use std::array because that has its own overload of operator<, etc, From 5ae7b14eed167a2e498abd752b6ba49f022b91be Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 16 Dec 2025 11:52:04 -0800 Subject: [PATCH 06/13] Comment fixes --- src/IROperator.h | 8 ++++---- test/correctness/strict_fma.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/IROperator.h b/src/IROperator.h index 58ca80eb6d15..b9e40b898ec1 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -979,10 +979,10 @@ Expr pow(Expr x, Expr y); Expr erf(const Expr &x); /** Fused multiply-add. fma(a, b, c) is equivalent to a * b + c, but only - * rounded once at the end. Halide will turn a * b + c into an fma - * automatically, except in strict_float contexts. This intrinsic only exists in - * order to request a true fma inside a strict_float context. A true fma will be - * emulated on targets without one. */ + * rounded once at the end. For most targets, when not in a strict_float + * context, Halide will already generate fma instructions from a * b + c. This + * intrinsic's main purpose is to request a true fma inside a strict_float + * context. A true fma will be emulated on targets without one. */ Expr fma(const Expr &, const Expr &, const Expr &); /** Fast vectorizable approximation to some trigonometric functions for diff --git a/test/correctness/strict_fma.cpp b/test/correctness/strict_fma.cpp index bbd8eea32c87..a5ce8fcf17be 100644 --- a/test/correctness/strict_fma.cpp +++ b/test/correctness/strict_fma.cpp @@ -25,10 +25,9 @@ int test() { g.vectorize(x, 5); } - // b.set((T)8769132.122433244233); - // c.set((T)2809.14123423413); b.set((T)1.111111111); c.set((T)1.101010101); + Buffer with_fma = f.realize({1024}); Buffer without_fma = g.realize({1024}); @@ -63,8 +62,9 @@ int test() { } saw_error = true; - // The rounding error, if any, ought to be 1 ULP - // printf("%llx %llx\n", (long long unsigned)fma_bits, (long long unsigned)no_fma_bits); + // For the specific positive numbers picked above, the rounding error is + // at most 1 ULP. Note that it's possible to make much larger rounding + // errors if you introduce some catastrophic cancellation. if (fma_bits + 1 != no_fma_bits && fma_bits - 1 != no_fma_bits) { printf("Difference greater than 1 ULP: %10.10g (0x%llx) vs %10.10g (0x%llx)!\n", From 55e89567328785cb7f0bd8fccd3b851298af3e93 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 16 Dec 2025 12:25:37 -0800 Subject: [PATCH 07/13] Skip fma test on two legacy platforms --- test/correctness/strict_fma.cpp | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/test/correctness/strict_fma.cpp b/test/correctness/strict_fma.cpp index a5ce8fcf17be..b814c03b716a 100644 --- a/test/correctness/strict_fma.cpp +++ b/test/correctness/strict_fma.cpp @@ -13,9 +13,25 @@ int test() { g(x) = strict_float(cast(x) * b + c); Target t = get_jit_target_from_environment(); + if (std::is_same_v && + t.arch == Target::X86 && + t.os == Target::Windows && + t.bits == 32) { + // Don't try to resolve float16 math library functions on win-32. In + // theory LLVM is responsible for this, but at the time of writing + // (12/16/2025) it doesn't seem to work. + printf("Skipping float16 fma test on win-32\n"); + return 0; + } + if (std::is_same_v && t.has_gpu_feature() && - !t.has_feature(Target::Vulkan)) { // TODO: Vulkan does not yet respect strict_float + // Metal on x86 does not seem to respect strict float despite setting + // the appropriate pragma. + !(t.arch == Target::X86 && t.has_feature(Target::Metal)) && + // TODO: Vulkan does not respect strict_float yet: + // https://github.com/halide/Halide/issues/7239 + !t.has_feature(Target::Vulkan)) { Var xo{"xo"}, xi{"xi"}; f.gpu_tile(x, xo, xi, 32); g.gpu_tile(x, xo, xi, 32); From fcfa87106a71f1018c0f83dfc6c1dccfc433ff76 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 17 Dec 2025 16:47:01 -0800 Subject: [PATCH 08/13] Fix double-rounding bug in double -> (b)float16 casts --- src/CodeGen_X86.cpp | 3 +- src/EmulateFloat16Math.cpp | 98 ++++++++++++++++++++++++++++++---- src/Float16.cpp | 46 +++++++++++++--- src/Float16.h | 4 ++ src/IR.cpp | 1 + src/IR.h | 2 + src/StrictifyFloat.cpp | 12 +++++ test/correctness/float16_t.cpp | 76 ++++++++++++++++++++++++-- 8 files changed, 221 insertions(+), 21 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 0e63af410cce..3d2388fdf89c 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -524,7 +524,8 @@ void CodeGen_X86::visit(const Cast *op) { if (target.has_feature(Target::F16C) && dst.code() == Type::Float && src.code() == Type::Float && - (dst.bits() == 16 || src.bits() == 16)) { + (dst.bits() == 16 || src.bits() == 16) && + src.bits() <= 32) { // Don't use for narrowing casts from double - it results in a libm call // Node we use code() == Type::Float instead of is_float(), because we // don't want to catch bfloat casts. diff --git a/src/EmulateFloat16Math.cpp b/src/EmulateFloat16Math.cpp index 1fda58a838e9..b24ccffae096 100644 --- a/src/EmulateFloat16Math.cpp +++ b/src/EmulateFloat16Math.cpp @@ -9,27 +9,46 @@ namespace Halide { namespace Internal { Expr bfloat16_to_float32(Expr e) { + const int lanes = e.type().lanes(); if (e.type().is_bfloat()) { e = reinterpret(e.type().with_code(Type::UInt), e); } - e = cast(UInt(32, e.type().lanes()), e); + e = cast(UInt(32, lanes), e); e = e << 16; - e = reinterpret(Float(32, e.type().lanes()), e); + e = reinterpret(Float(32, lanes), e); e = strict_float(e); return e; } Expr float32_to_bfloat16(Expr e) { internal_assert(e.type().bits() == 32); + const int lanes = e.type().lanes(); e = strict_float(e); - e = reinterpret(UInt(32, e.type().lanes()), e); + e = reinterpret(UInt(32, lanes), e); // We want to round ties to even, so before truncating either // add 0x8000 (0.5) to odd numbers or 0x7fff (0.499999) to // even numbers. e += 0x7fff + ((e >> 16) & 1); e = (e >> 16); - e = cast(UInt(16, e.type().lanes()), e); - e = reinterpret(BFloat(16, e.type().lanes()), e); + e = cast(UInt(16, lanes), e); + e = reinterpret(BFloat(16, lanes), e); + return e; +} + +Expr float64_to_bfloat16(Expr e) { + internal_assert(e.type().bits() == 64); + const int lanes = e.type().lanes(); + e = strict_float(e); + + // First round to float and record any gain of loss of magnitude + Expr f = cast(Float(32, lanes), e); + Expr err = abs(e) - abs(f); + e = reinterpret(UInt(32, lanes), f); + // As above, but break ties using err, if non-zero + e += 0x7fff + (((err >= 0) & ((e >> 16) & 1)) | (err > 0)); + e = (e >> 16); + e = cast(UInt(16, lanes), e); + e = reinterpret(BFloat(16, lanes), e); return e; } @@ -96,10 +115,11 @@ Expr float32_to_float16(Expr value) { // 0.5 if the integer part is odd, or 0.4999999 if the // integer part is even, then truncate. bits += (bits >> 13) & 1; - bits += 0xfff; - bits = bits >> 13; + bits += make_const(UInt(32), ((uint32_t)1 << (13 - 1)) - 1); + bits = cast(u16_t, bits >> 13); + // Rebias the exponent - bits -= 0x1c000; + bits -= 0x4000; // Truncate the top bits of the exponent bits = bits & 0x7fff; bits = select(is_denorm, denorm_bits, @@ -111,6 +131,55 @@ Expr float32_to_float16(Expr value) { return common_subexpression_elimination(reinterpret(f16_t, bits)); } +Expr float64_to_float16(Expr value) { + value = strict_float(value); + + Type f64_t = Float(64, value.type().lanes()); + Type f16_t = Float(16, value.type().lanes()); + Type u64_t = UInt(64, value.type().lanes()); + Type u16_t = UInt(16, value.type().lanes()); + + Expr bits = reinterpret(u64_t, value); + + // Extract the sign bit + Expr sign = bits & make_const(u64_t, (uint64_t)(0x8000000000000000ULL)); + bits = bits ^ sign; + + // Test the endpoints + Expr is_denorm = (bits < make_const(u64_t, (uint64_t)(0x3f10000000000000ULL))); + Expr is_inf = (bits >= make_const(u64_t, (uint64_t)(0x40f0000000000000ULL))); + Expr is_nan = (bits > make_const(u64_t, (uint64_t)(0x7ff0000000000000ULL))); + + // Denorms are linearly spaced, so we can handle them by scaling up the + // input as a float or double by 2^24 and using the existing int-conversion + // rounding instructions. We can scale up by adding 24 to the exponent. + Expr denorm_bits = cast(u16_t, strict_float(round(strict_float(reinterpret(f64_t, bits + make_const(u64_t, (uint64_t)(0x0180000000000000ULL))))))); + Expr inf_bits = make_const(u16_t, 0x7c00); + Expr nan_bits = make_const(u16_t, 0x7fff); + + // We want to round to nearest even, so we add either 0.5 if after + // truncation the last bit would be 1, or 0.4999999 if after truncation the + // last bit would be zero, then truncate. + bits += (bits >> 42) & 1; + bits += make_const(UInt(64), ((uint64_t)1 << (42 - 1)) - 1); + bits = bits >> 42; + + // We no longer need the high bits + bits = cast(u16_t, bits); + + // Rebias the exponent + bits -= 0x4000; + // Truncate the top bits of the exponent + bits = bits & 0x7fff; + bits = select(is_denorm, denorm_bits, + is_inf, inf_bits, + is_nan, nan_bits, + cast(u16_t, bits)); + // Recover the sign bit + bits = bits | cast(u16_t, sign >> 48); + return common_subexpression_elimination(reinterpret(f16_t, bits)); +} + namespace { const std::map transcendental_remapping = @@ -171,6 +240,7 @@ Expr lower_float16_cast(const Cast *op) { Type src = op->value.type(); Type dst = op->type; Type f32 = Float(32, dst.lanes()); + Type f64 = Float(64, dst.lanes()); Expr val = op->value; if (src.is_bfloat()) { @@ -183,10 +253,18 @@ Expr lower_float16_cast(const Cast *op) { if (dst.is_bfloat()) { internal_assert(dst.bits() == 16); - val = float32_to_bfloat16(cast(f32, val)); + if (src.bits() > 32) { + val = float64_to_bfloat16(cast(f64, val)); + } else { + val = float32_to_bfloat16(cast(f32, val)); + } } else if (dst.is_float() && dst.bits() < 32) { internal_assert(dst.bits() == 16); - val = float32_to_float16(cast(f32, val)); + if (src.bits() > 32) { + val = float64_to_float16(cast(f64, val)); + } else { + val = float32_to_float16(cast(f32, val)); + } } return cast(dst, val); diff --git a/src/Float16.cpp b/src/Float16.cpp index 80c96a38e6f1..6e7dbbe4b7c0 100644 --- a/src/Float16.cpp +++ b/src/Float16.cpp @@ -9,7 +9,10 @@ namespace Internal { // Conversion routines to and from float cribbed from Christian Rau's // half library (half.sourceforge.net) -uint16_t float_to_float16(float value) { +template +uint16_t float_to_float16(T value) { + static_assert(std::is_same_v || std::is_same_v, + "float_to_float16 only supports float and double types"); // Start by copying over the sign bit uint16_t bits = std::signbit(value) << 15; @@ -40,14 +43,14 @@ uint16_t float_to_float16(float value) { // We've normalized value as much as possible. Put the integer // portion of it into the mantissa. - float ival; - float frac = std::modf(value, &ival); + T ival; + T frac = std::modf(value, &ival); bits += (uint16_t)(std::abs((int)ival)); // Now consider the fractional part. We round to nearest with ties // going to even. frac = std::abs(frac); - bits += (frac > 0.5f) | ((frac == 0.5f) & bits); + bits += (frac > T(0.5)) | ((frac == T(0.5)) & bits); return bits; } @@ -341,6 +344,19 @@ uint16_t float_to_bfloat16(float f) { return ret >> 16; } +uint16_t float_to_bfloat16(double f) { + // Coming from double is a little tricker. We first narrow to float and + // record if any magnitude was lost of gained in the process. If so we'll + // use that to break ties instead of testing whether or not truncation would + // return odd. + float f32 = (float)f; + const double err = std::abs(f) - (double)std::abs(f32); + uint32_t ret; + memcpy(&ret, &f32, sizeof(float)); + ret += 0x7fff + (((err >= 0) & ((ret >> 16) & 1)) | (err > 0)); + return ret >> 16; +} + float bfloat16_to_float(uint16_t b) { // Assume little-endian floats uint16_t bits[2] = {0, b}; @@ -362,7 +378,17 @@ float16_t::float16_t(double value) } float16_t::float16_t(int value) - : data(float_to_float16(value)) { + : data(float_to_float16((float)value)) { + // integers of any size that map to finite float16s are all representable as + // float, so we can go via the float conversion method. +} + +float16_t::float16_t(int64_t value) + : data(float_to_float16((float)value)) { +} + +float16_t::float16_t(uint64_t value) + : data(float_to_float16((float)value)) { } float16_t::operator float() const { @@ -464,7 +490,15 @@ bfloat16_t::bfloat16_t(double value) } bfloat16_t::bfloat16_t(int value) - : data(float_to_bfloat16(value)) { + : data(float_to_bfloat16((double)value)) { +} + +bfloat16_t::bfloat16_t(int64_t value) + : data(float_to_bfloat16((double)value)) { +} + +bfloat16_t::bfloat16_t(uint64_t value) + : data(float_to_bfloat16((double)value)) { } bfloat16_t::operator float() const { diff --git a/src/Float16.h b/src/Float16.h index d3c285d6c09f..376813cbd507 100644 --- a/src/Float16.h +++ b/src/Float16.h @@ -32,6 +32,8 @@ struct float16_t { explicit float16_t(float value); explicit float16_t(double value); explicit float16_t(int value); + explicit float16_t(int64_t value); + explicit float16_t(uint64_t value); // @} /** Construct a float16_t with the bits initialised to 0. This represents @@ -175,6 +177,8 @@ struct bfloat16_t { explicit bfloat16_t(float value); explicit bfloat16_t(double value); explicit bfloat16_t(int value); + explicit bfloat16_t(int64_t value); + explicit bfloat16_t(uint64_t value); // @} /** Construct a bfloat16_t with the bits initialised to 0. This represents diff --git a/src/IR.cpp b/src/IR.cpp index c844c672656a..c82ae4ebd252 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -678,6 +678,7 @@ const char *const intrinsic_op_names[] = { "sliding_window_marker", "sorted_avg", "strict_add", + "strict_cast", "strict_div", "strict_eq", "strict_le", diff --git a/src/IR.h b/src/IR.h index 6dc0204b89ec..da27019a93c7 100644 --- a/src/IR.h +++ b/src/IR.h @@ -626,6 +626,7 @@ struct Call : public ExprNode { // them as reals and ignoring the existence of nan and inf. Using these // intrinsics instead prevents any such optimizations. strict_add, + strict_cast, strict_div, strict_eq, strict_le, @@ -792,6 +793,7 @@ struct Call : public ExprNode { bool is_strict_float_intrinsic() const { return is_intrinsic( {Call::strict_add, + Call::strict_cast, Call::strict_div, Call::strict_max, Call::strict_min, diff --git a/src/StrictifyFloat.cpp b/src/StrictifyFloat.cpp index 13dd0873bb12..8953ba035888 100644 --- a/src/StrictifyFloat.cpp +++ b/src/StrictifyFloat.cpp @@ -83,6 +83,16 @@ class Strictify : public IRMutator { return IRMutator::visit(op); } } + + Expr visit(const Cast *op) override { + if (op->value.type().is_float() && + op->type.is_float()) { + return Call::make(op->type, Call::strict_cast, + {mutate(op->value)}, Call::PureIntrinsic); + } else { + return IRMutator::visit(op); + } + } }; const std::set strict_externs = { @@ -142,6 +152,8 @@ Expr unstrictify_float(const Call *op) { return op->args[0] <= op->args[1]; } else if (op->is_intrinsic(Call::strict_eq)) { return op->args[0] == op->args[1]; + } else if (op->is_intrinsic(Call::strict_cast)) { + return cast(op->type, op->args[0]); } else { internal_error << "Missing lowering of strict float intrinsic: " << Expr(op) << "\n"; diff --git a/test/correctness/float16_t.cpp b/test/correctness/float16_t.cpp index 9e917a6216e6..d206e43202b0 100644 --- a/test/correctness/float16_t.cpp +++ b/test/correctness/float16_t.cpp @@ -301,6 +301,74 @@ int run_test() { } } + { + for (double f : {1.0, -1.0, 0.235, -0.235, 1e-7, -1e-7}) { + { + // Test double -> float16 doesn't have double-rounding issues + float16_t k{f}; + float16_t k_plus_eps = float16_t::make_from_bits(k.to_bits() + 1); + const bool k_is_odd = k.to_bits() & 1; + float16_t to_even = k_is_odd ? k_plus_eps : k; + float16_t to_odd = k_is_odd ? k : k_plus_eps; + float halfway = (float(k) + float(k_plus_eps)) / 2.f; + + printf("float16 k_is_odd = %d\n", k_is_odd); + + // We expect ties to round to even + assert(float16_t(halfway) == to_even); + // Now let's construct a case where it *should* have rounded to + // odd if rounding directly from double, but rounding via + // float does the wrong thing. + double halfway_plus_eps = std::nextafter(halfway, (double)to_odd); + assert(std::abs(halfway_plus_eps - (double)to_odd) < + std::abs(halfway_plus_eps - (double)to_even)); + + assert(float(halfway_plus_eps) == halfway); +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 + assert(_Float16(halfway_plus_eps) == _Float16(float(to_odd))); +#endif + assert(float16_t(halfway_plus_eps) == to_odd); + + // Now test the same thing in generated code. We need strict float to + // prevent Halide from fusing multiple float casts into one. + Param p; + p.set(halfway_plus_eps); + // halfway plus epsilon rounds to exactly halfway as a float + assert(evaluate(strict_float(cast(Float(32), p))) == halfway); + // So if we go via float we get the even outcome, because + // exactly halfway rounds to even + assert(evaluate(strict_float(cast(Float(16), cast(Float(32), p)))) == to_even); + // But if we go direct, we should go to odd, because it's closer + assert(evaluate(strict_float(cast(Float(16), p))) == to_odd); + } + + { + // Test the same things for bfloat + bfloat16_t k{f}; + bfloat16_t k_plus_eps = bfloat16_t::make_from_bits(k.to_bits() + 1); + const bool k_is_odd = k.to_bits() & 1; + + bfloat16_t to_even = k_is_odd ? k_plus_eps : k; + bfloat16_t to_odd = k_is_odd ? k : k_plus_eps; + float halfway = (float(k) + float(k_plus_eps)) / 2.f; + + assert(bfloat16_t(halfway) == to_even); + double halfway_plus_eps = std::nextafter(halfway, (double)to_odd); + assert(std::abs(halfway_plus_eps - (double)to_odd) < + std::abs(halfway_plus_eps - (double)to_even)); + + assert(float(halfway_plus_eps) == halfway); + assert(bfloat16_t(halfway_plus_eps) == to_odd); + + Param p; + p.set(halfway_plus_eps); + assert(evaluate(strict_float(cast(Float(32), p))) == halfway); + assert(evaluate(strict_float(cast(BFloat(16), cast(Float(32), p)))) == to_even); + assert(evaluate(strict_float(cast(BFloat(16), p))) == to_odd); + } + } + } + // Enable to read assembly generated by the conversion routines if ((false)) { // Intentional dead code. Extra parens to pacify clang-tidy. Func src, to_f16, from_f16; @@ -309,11 +377,11 @@ int run_test() { to_f16(x) = cast(src(x)); from_f16(x) = cast(to_f16(x)); - src.compute_root().vectorize(x, 8, TailStrategy::RoundUp); - to_f16.compute_root().vectorize(x, 8, TailStrategy::RoundUp); - from_f16.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + src.compute_root().vectorize(x, 16, TailStrategy::RoundUp); + to_f16.compute_root().vectorize(x, 16, TailStrategy::RoundUp); + from_f16.compute_root().vectorize(x, 16, TailStrategy::RoundUp); - from_f16.compile_to_assembly("/dev/stdout", {}, Target("host-no_asserts-no_bounds_query-no_runtime-disable_llvm_loop_unroll-disable_llvm_loop_vectorize")); + from_f16.compile_to_assembly("/dev/stdout", {}, Target("host-no_asserts-no_bounds_query-no_runtime")); } // Check infinity handling for both float16_t and Halide codegen. From 9b23ae688f012ac00fae314fcf7a55028ffdf294 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 18 Dec 2025 12:49:40 -0800 Subject: [PATCH 09/13] Share more code between coming from 64 and 32 bits Also add and fix some comments --- src/EmulateFloat16Math.cpp | 152 ++++++++++++++------------------- src/EmulateFloat16Math.h | 4 +- src/Float16.cpp | 2 +- test/correctness/float16_t.cpp | 2 - 4 files changed, 66 insertions(+), 94 deletions(-) diff --git a/src/EmulateFloat16Math.cpp b/src/EmulateFloat16Math.cpp index b24ccffae096..9ffca1bb3b54 100644 --- a/src/EmulateFloat16Math.cpp +++ b/src/EmulateFloat16Math.cpp @@ -20,32 +20,30 @@ Expr bfloat16_to_float32(Expr e) { return e; } -Expr float32_to_bfloat16(Expr e) { - internal_assert(e.type().bits() == 32); - const int lanes = e.type().lanes(); - e = strict_float(e); - e = reinterpret(UInt(32, lanes), e); - // We want to round ties to even, so before truncating either - // add 0x8000 (0.5) to odd numbers or 0x7fff (0.499999) to - // even numbers. - e += 0x7fff + ((e >> 16) & 1); - e = (e >> 16); - e = cast(UInt(16, lanes), e); - e = reinterpret(BFloat(16, lanes), e); - return e; -} - -Expr float64_to_bfloat16(Expr e) { - internal_assert(e.type().bits() == 64); +Expr float_to_bfloat16(Expr e) { const int lanes = e.type().lanes(); e = strict_float(e); + Expr err; // First round to float and record any gain of loss of magnitude - Expr f = cast(Float(32, lanes), e); - Expr err = abs(e) - abs(f); - e = reinterpret(UInt(32, lanes), f); - // As above, but break ties using err, if non-zero - e += 0x7fff + (((err >= 0) & ((e >> 16) & 1)) | (err > 0)); + if (e.type().bits() == 64) { + Expr f = cast(Float(32, lanes), e); + err = abs(e) - abs(f); + e = f; + } else { + internal_assert(e.type().bits() == 32); + } + e = reinterpret(UInt(32, lanes), e); + + // We want to round ties to even, so if we have no error recorded above, + // before truncating either add 0x8000 (0.5) to odd numbers or 0x7fff + // (0.499999) to even numbers. If we have error, break ties using that + // instead. + Expr tie_breaker = (e >> 16) & 1; // 1 when rounding down would go to odd + if (err.defined()) { + tie_breaker = ((err == 0) & tie_breaker) | (err > 0); + } + e += tie_breaker + 0x7fff; e = (e >> 16); e = cast(UInt(16, lanes), e); e = reinterpret(BFloat(16, lanes), e); @@ -82,90 +80,64 @@ Expr float16_to_float32(Expr value) { return f32; } -Expr float32_to_float16(Expr value) { +Expr float_to_float16(Expr value) { // We're about the sniff the bits of a float, so we should // guard it with strict float to ensure we don't do things // like assume it can't be denormal. value = strict_float(value); - Type f32_t = Float(32, value.type().lanes()); + const int src_bits = value.type().bits(); + + Type float_t = Float(src_bits, value.type().lanes()); Type f16_t = Float(16, value.type().lanes()); - Type u32_t = UInt(32, value.type().lanes()); + Type bits_t = UInt(src_bits, value.type().lanes()); Type u16_t = UInt(16, value.type().lanes()); - Expr bits = reinterpret(u32_t, value); + Expr bits = reinterpret(bits_t, value); // Extract the sign bit - Expr sign = bits & make_const(u32_t, 0x80000000); + Expr sign = bits & make_const(bits_t, (uint64_t)1 << (src_bits - 1)); bits = bits ^ sign; // Test the endpoints - Expr is_denorm = (bits < make_const(u32_t, 0x38800000)); - Expr is_inf = (bits >= make_const(u32_t, 0x47800000)); - Expr is_nan = (bits > make_const(u32_t, 0x7f800000)); + + // Smallest input representable as normal float16 (2^-14) + Expr two_to_the_minus_14 = src_bits == 32 ? + make_const(bits_t, 0x38800000) : + make_const(bits_t, (uint64_t)0x3f10000000000000ULL); + Expr is_denorm = bits < two_to_the_minus_14; + + // Smallest input too big to represent as a float16 (2^16) + Expr two_to_the_16 = src_bits == 32 ? + make_const(bits_t, 0x47800000) : + make_const(bits_t, (uint64_t)0x40f0000000000000ULL); + Expr is_inf = bits >= two_to_the_16; + + // Check if the input is a nan, which is anything bigger than an infinity bit pattern + Expr input_inf_bits = src_bits == 32 ? + make_const(bits_t, 0x7f800000) : + make_const(bits_t, (uint64_t)0x7ff0000000000000ULL); + Expr is_nan = bits > input_inf_bits; // Denorms are linearly spaced, so we can handle them // by scaling up the input as a float and using the // existing int-conversion rounding instructions. - Expr denorm_bits = cast(u16_t, strict_float(round(strict_float(reinterpret(f32_t, bits + 0x0c000000))))); + Expr two_to_the_24 = src_bits == 32 ? + make_const(bits_t, 0x0c000000) : + make_const(bits_t, (uint64_t)0x0180000000000000ULL); + Expr denorm_bits = cast(u16_t, strict_float(round(reinterpret(float_t, bits + two_to_the_24)))); Expr inf_bits = make_const(u16_t, 0x7c00); Expr nan_bits = make_const(u16_t, 0x7fff); // We want to round to nearest even, so we add either // 0.5 if the integer part is odd, or 0.4999999 if the // integer part is even, then truncate. - bits += (bits >> 13) & 1; - bits += make_const(UInt(32), ((uint32_t)1 << (13 - 1)) - 1); - bits = cast(u16_t, bits >> 13); - - // Rebias the exponent - bits -= 0x4000; - // Truncate the top bits of the exponent - bits = bits & 0x7fff; - bits = select(is_denorm, denorm_bits, - is_inf, inf_bits, - is_nan, nan_bits, - cast(u16_t, bits)); - // Recover the sign bit - bits = bits | cast(u16_t, sign >> 16); - return common_subexpression_elimination(reinterpret(f16_t, bits)); -} - -Expr float64_to_float16(Expr value) { - value = strict_float(value); - - Type f64_t = Float(64, value.type().lanes()); - Type f16_t = Float(16, value.type().lanes()); - Type u64_t = UInt(64, value.type().lanes()); - Type u16_t = UInt(16, value.type().lanes()); - - Expr bits = reinterpret(u64_t, value); - - // Extract the sign bit - Expr sign = bits & make_const(u64_t, (uint64_t)(0x8000000000000000ULL)); - bits = bits ^ sign; - - // Test the endpoints - Expr is_denorm = (bits < make_const(u64_t, (uint64_t)(0x3f10000000000000ULL))); - Expr is_inf = (bits >= make_const(u64_t, (uint64_t)(0x40f0000000000000ULL))); - Expr is_nan = (bits > make_const(u64_t, (uint64_t)(0x7ff0000000000000ULL))); - - // Denorms are linearly spaced, so we can handle them by scaling up the - // input as a float or double by 2^24 and using the existing int-conversion - // rounding instructions. We can scale up by adding 24 to the exponent. - Expr denorm_bits = cast(u16_t, strict_float(round(strict_float(reinterpret(f64_t, bits + make_const(u64_t, (uint64_t)(0x0180000000000000ULL))))))); - Expr inf_bits = make_const(u16_t, 0x7c00); - Expr nan_bits = make_const(u16_t, 0x7fff); - - // We want to round to nearest even, so we add either 0.5 if after - // truncation the last bit would be 1, or 0.4999999 if after truncation the - // last bit would be zero, then truncate. - bits += (bits >> 42) & 1; - bits += make_const(UInt(64), ((uint64_t)1 << (42 - 1)) - 1); - bits = bits >> 42; - - // We no longer need the high bits - bits = cast(u16_t, bits); + const int float16_mantissa_bits = 10; + const int input_mantissa_bits = src_bits == 32 ? 23 : 52; + const int bits_lost = input_mantissa_bits - float16_mantissa_bits; + bits += (bits >> bits_lost) & 1; + bits += make_const(bits_t, ((uint64_t)1 << (bits_lost - 1)) - 1); + bits = cast(u16_t, bits >> bits_lost); // Rebias the exponent bits -= 0x4000; @@ -176,7 +148,7 @@ Expr float64_to_float16(Expr value) { is_nan, nan_bits, cast(u16_t, bits)); // Recover the sign bit - bits = bits | cast(u16_t, sign >> 48); + bits = bits | cast(u16_t, sign >> (src_bits - 16)); return common_subexpression_elimination(reinterpret(f16_t, bits)); } @@ -226,7 +198,7 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *op) { Expr e = Call::make(t, it->second, new_args, op->call_type, op->func, op->value_index, op->image, op->param); if (op->type.is_float()) { - e = float32_to_float16(e); + e = float_to_float16(e); } internal_assert(e.type() == op->type); return e; @@ -254,17 +226,19 @@ Expr lower_float16_cast(const Cast *op) { if (dst.is_bfloat()) { internal_assert(dst.bits() == 16); if (src.bits() > 32) { - val = float64_to_bfloat16(cast(f64, val)); + val = cast(f64, val); } else { - val = float32_to_bfloat16(cast(f32, val)); + val = cast(f32, val); } + val = float_to_bfloat16(val); } else if (dst.is_float() && dst.bits() < 32) { internal_assert(dst.bits() == 16); if (src.bits() > 32) { - val = float64_to_float16(cast(f64, val)); + val = cast(f64, val); } else { - val = float32_to_float16(cast(f32, val)); + val = cast(f32, val); } + val = float_to_float16(val); } return cast(dst, val); diff --git a/src/EmulateFloat16Math.h b/src/EmulateFloat16Math.h index de1a5e091588..f61de7456fbc 100644 --- a/src/EmulateFloat16Math.h +++ b/src/EmulateFloat16Math.h @@ -19,8 +19,8 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *); /** Cast to/from float and bfloat using bitwise math. */ //@{ -Expr float32_to_bfloat16(Expr e); -Expr float32_to_float16(Expr e); +Expr float_to_bfloat16(Expr e); +Expr float_to_float16(Expr e); Expr float16_to_float32(Expr e); Expr bfloat16_to_float32(Expr e); Expr lower_float16_cast(const Cast *op); diff --git a/src/Float16.cpp b/src/Float16.cpp index 6e7dbbe4b7c0..1e9e789f476b 100644 --- a/src/Float16.cpp +++ b/src/Float16.cpp @@ -346,7 +346,7 @@ uint16_t float_to_bfloat16(float f) { uint16_t float_to_bfloat16(double f) { // Coming from double is a little tricker. We first narrow to float and - // record if any magnitude was lost of gained in the process. If so we'll + // record if any magnitude was lost or gained in the process. If so we'll // use that to break ties instead of testing whether or not truncation would // return odd. float f32 = (float)f; diff --git a/test/correctness/float16_t.cpp b/test/correctness/float16_t.cpp index d206e43202b0..120682cd86d8 100644 --- a/test/correctness/float16_t.cpp +++ b/test/correctness/float16_t.cpp @@ -312,8 +312,6 @@ int run_test() { float16_t to_odd = k_is_odd ? k : k_plus_eps; float halfway = (float(k) + float(k_plus_eps)) / 2.f; - printf("float16 k_is_odd = %d\n", k_is_odd); - // We expect ties to round to even assert(float16_t(halfway) == to_even); // Now let's construct a case where it *should* have rounded to From 82f24c7078b8650d9199d42348f4975974e336c5 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 19 Dec 2025 09:38:33 -0800 Subject: [PATCH 10/13] handle float16 fmas Hopefully this means we can reenable win-32 testing, because we should no longer trigger the need for a lib call to convert double to float16 --- src/CodeGen_LLVM.cpp | 32 +++++++++++++++++++++++++------- src/Parameter.h | 2 +- test/correctness/strict_fma.cpp | 18 +++++------------- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 0fee03b74363..d64660ad9016 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -3306,6 +3306,7 @@ void CodeGen_LLVM::visit(const Call *op) { // Evaluate the args first outside the strict scope, as they may use // non-strict operations. std::vector new_args(op->args.size()); + std::vector to_pop; for (size_t i = 0; i < op->args.size(); i++) { const Expr &arg = op->args[i]; if (arg.as() || is_const(arg)) { @@ -3313,6 +3314,7 @@ void CodeGen_LLVM::visit(const Call *op) { } else { std::string name = unique_name('t'); sym_push(name, codegen(arg)); + to_pop.push_back(name); new_args[i] = Variable::make(arg.type(), name); } } @@ -3320,8 +3322,27 @@ void CodeGen_LLVM::visit(const Call *op) { { ScopedValue old_in_strict_float(in_strict_float, true); if (op->is_intrinsic(Call::strict_fma)) { - std::string name = "llvm.fma" + mangle_llvm_type(llvm_type_of(op->type)); - value = call_intrin(op->type, op->type.lanes(), name, new_args); + if (op->type.is_float() && op->type.bits() <= 16 && + upgrade_type_for_arithmetic(op->type) != op->type) { + // For (b)float16 and below, doing the fma as a + // double-precision fma is exact and is what llvm does. A + // double has enough bits of precision such that the add in + // the fma has no rounding error in the cases where the fma + // is going to return a finite float16. We do this + // legalization manually so that we can use our custom + // vectorizable float16 casts instead of letting llvm call + // library functions. + Type wide_t = Float(64, op->type.lanes()); + for (Expr &e : new_args) { + e = cast(wide_t, e); + } + Expr equiv = Call::make(wide_t, op->name, new_args, op->call_type); + equiv = cast(op->type, equiv); + value = codegen(equiv); + } else { + std::string name = "llvm.fma" + mangle_llvm_type(llvm_type_of(op->type)); + value = call_intrin(op->type, op->type.lanes(), name, new_args); + } } else { // Lower to something other than a call node Expr call = Call::make(op->type, op->name, new_args, op->call_type); @@ -3329,11 +3350,8 @@ void CodeGen_LLVM::visit(const Call *op) { } } - for (size_t i = 0; i < op->args.size(); i++) { - const Expr &arg = op->args[i]; - if (!arg.as() && !is_const(arg)) { - sym_pop(new_args[i].as()->name); - } + for (const auto &s : to_pop) { + sym_pop(s); } } else if (is_float16_transcendental(op) && !supports_call_as_float16(op)) { diff --git a/src/Parameter.h b/src/Parameter.h index 56a441f5ba35..ffe203241599 100644 --- a/src/Parameter.h +++ b/src/Parameter.h @@ -128,7 +128,7 @@ class Parameter { static_assert(sizeof(T) <= sizeof(halide_scalar_value_t)); const auto sv = scalar_data_checked(type_of()); T t; - memcpy(&t, &sv.u.u64, sizeof(t)); + memcpy((char *)(&t), &sv.u.u64, sizeof(t)); return t; } diff --git a/test/correctness/strict_fma.cpp b/test/correctness/strict_fma.cpp index b814c03b716a..1e0ad1ddd9fc 100644 --- a/test/correctness/strict_fma.cpp +++ b/test/correctness/strict_fma.cpp @@ -13,16 +13,6 @@ int test() { g(x) = strict_float(cast(x) * b + c); Target t = get_jit_target_from_environment(); - if (std::is_same_v && - t.arch == Target::X86 && - t.os == Target::Windows && - t.bits == 32) { - // Don't try to resolve float16 math library functions on win-32. In - // theory LLVM is responsible for this, but at the time of writing - // (12/16/2025) it doesn't seem to work. - printf("Skipping float16 fma test on win-32\n"); - return 0; - } if (std::is_same_v && t.has_gpu_feature() && @@ -42,7 +32,7 @@ int test() { } b.set((T)1.111111111); - c.set((T)1.101010101); + c.set((T)1.0101010101); Buffer with_fma = f.realize({1024}); Buffer without_fma = g.realize({1024}); @@ -58,7 +48,6 @@ int test() { if constexpr (sizeof(T) >= 4) { T correct_fma = std::fma((T)i, b.get(), c.get()); - if (with_fma(i) != correct_fma) { printf("fma result does not match std::fma:\n" " fma(%d, %10.10g, %10.10g) = %10.10g (0x%llx)\n" @@ -91,7 +80,10 @@ int test() { } if (!saw_error) { - printf("There should have occasionally been a 1 ULP difference between fma and non-fma results\n"); + printf("There should have occasionally been a 1 ULP difference between fma " + "and non-fma results. strict_float may not be respected on this target.\n"); + // Uncomment to inspect assembly + // g.compile_to_assembly("/dev/stdout", {b, c}, get_jit_target_from_environment()); return -1; } From 649ac3ecc3ece46156ad17ad0622648a5cb5f271 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 19 Dec 2025 10:19:59 -0800 Subject: [PATCH 11/13] wasm fix Not ideal. In fact performance is known to be terrible (https://gitlab.com/libeigen/eigen/-/issues/2959) but wasm only has a relaxed fma, not a strict one. --- src/WasmExecutor.cpp | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/WasmExecutor.cpp b/src/WasmExecutor.cpp index 55ffa7325e22..0aa50ff54ae4 100644 --- a/src/WasmExecutor.cpp +++ b/src/WasmExecutor.cpp @@ -914,6 +914,20 @@ wabt::Result wabt_posix_math_2(wabt::interp::Thread &thread, return wabt::Result::Ok; } +template +wabt::Result wabt_posix_math_3(wabt::interp::Thread &thread, + const wabt::interp::Values &args, + wabt::interp::Values &results, + wabt::interp::Trap::Ptr *trap) { + internal_assert(args.size() == 3); + const T in1 = args[0].Get(); + const T in2 = args[1].Get(); + const T in3 = args[2].Get(); + const T out = some_func(in1, in2, in3); + results[0] = wabt::interp::Value::Make(out); + return wabt::Result::Ok; +} + #define WABT_HOST_CALLBACK(x) \ wabt::Result wabt_jit_##x##_callback(wabt::interp::Thread &thread, \ const wabt::interp::Values &args, \ @@ -1998,6 +2012,20 @@ void wasm_jit_posix_math2_callback(const v8::FunctionCallbackInfo &ar args.GetReturnValue().Set(load_scalar(context, out)); } +template +void wasm_jit_posix_math3_callback(const v8::FunctionCallbackInfo &args) { + Isolate *isolate = args.GetIsolate(); + Local context = isolate->GetCurrentContext(); + HandleScope scope(isolate); + + const T in1 = args[0]->NumberValue(context).ToChecked(); + const T in2 = args[1]->NumberValue(context).ToChecked(); + const T in3 = args[2]->NumberValue(context).ToChecked(); + const T out = some_func(in1, in2, in3); + + args.GetReturnValue().Set(load_scalar(context, out)); +} + enum ExternWrapperFieldSlots { kTrampolineWrap, kArgTypesWrap @@ -2122,6 +2150,7 @@ using HostCallbackMap = std::unordered_map} #define DEFINE_POSIX_MATH_CALLBACK2(t, f) {#f, wabt_posix_math_2} +#define DEFINE_POSIX_MATH_CALLBACK3(t, f) {#f, wabt_posix_math_3} #endif @@ -2131,6 +2160,7 @@ using HostCallbackMap = std::unordered_map; #define DEFINE_CALLBACK(f) {#f, wasm_jit_##f##_callback} #define DEFINE_POSIX_MATH_CALLBACK(t, f) {#f, wasm_jit_posix_math_callback} #define DEFINE_POSIX_MATH_CALLBACK2(t, f) {#f, wasm_jit_posix_math2_callback} +#define DEFINE_POSIX_MATH_CALLBACK3(t, f) {#f, wasm_jit_posix_math3_callback} #endif const HostCallbackMap &get_host_callback_map() { @@ -2199,7 +2229,11 @@ const HostCallbackMap &get_host_callback_map() { DEFINE_POSIX_MATH_CALLBACK2(float, fmaxf), DEFINE_POSIX_MATH_CALLBACK2(double, fmax), DEFINE_POSIX_MATH_CALLBACK2(float, powf), - DEFINE_POSIX_MATH_CALLBACK2(double, pow)}; + DEFINE_POSIX_MATH_CALLBACK2(double, pow), + + DEFINE_POSIX_MATH_CALLBACK3(float, fmaf), + DEFINE_POSIX_MATH_CALLBACK3(double, fma), + }; return m; } From 7656742692f982a7a805bb5c6b31d19f0d06cb37 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 19 Dec 2025 12:47:30 -0800 Subject: [PATCH 12/13] Skip test on webgpu --- test/correctness/strict_fma.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/correctness/strict_fma.cpp b/test/correctness/strict_fma.cpp index 1e0ad1ddd9fc..f5f736cd5940 100644 --- a/test/correctness/strict_fma.cpp +++ b/test/correctness/strict_fma.cpp @@ -21,7 +21,10 @@ int test() { !(t.arch == Target::X86 && t.has_feature(Target::Metal)) && // TODO: Vulkan does not respect strict_float yet: // https://github.com/halide/Halide/issues/7239 - !t.has_feature(Target::Vulkan)) { + !t.has_feature(Target::Vulkan) && + // WebGPU does not and may never respect strict_float. There's no way to + // ask for it in the language. + !t.has_feature(Target::WebGPU)) { Var xo{"xo"}, xi{"xi"}; f.gpu_tile(x, xo, xi, 32); g.gpu_tile(x, xo, xi, 32); From 6acb53cf1730f4a052db42487dbdfcfed135c8a3 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 19 Dec 2025 12:49:52 -0800 Subject: [PATCH 13/13] Don't check for kandw --- test/correctness/simd_op_check_x86.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index 2d1918ff0a3f..0850d8f02bbc 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -411,7 +411,9 @@ class SimdOpCheckX86 : public SimdOpCheckTest { check(use_avx512 ? "vrsqrt*ps" : "vrsqrtps*ymm", 8, fast_inverse_sqrt(f32_1)); check(use_avx512 ? "vrcp*ps" : "vrcpps*ymm", 8, fast_inverse(f32_1)); - check(use_avx512 ? "kandw" : "vandps", 8, bool_1 & bool_2); + // Some llvm's don't use kandw, but instead predicate the computation of bool_2 + // using the result of bool_1 + // check(use_avx512 ? "kandw" : "vandps", 8, bool_1 & bool_2); check(use_avx512 ? "korw" : "vorps", 8, bool_1 | bool_2); check(use_avx512 ? "kxorw" : "vxorps", 8, bool_1 ^ bool_2);