diff --git a/apps/c_backend/pipeline_generator.cpp b/apps/c_backend/pipeline_generator.cpp index c6a28bc477fa..f4602c9213a5 100644 --- a/apps/c_backend/pipeline_generator.cpp +++ b/apps/c_backend/pipeline_generator.cpp @@ -14,7 +14,7 @@ class Pipeline : public Halide::Generator { Var x, y; Func f, h; - f(x, y) = (input(clamp(x + 2, 0, input.dim(0).extent() - 1), clamp(y - 2, 0, input.dim(1).extent() - 1)) * 17) / 13; + f(x, y) = (input(clamp(x + 2, 0, input.dim(0).extent() - 1), clamp(y - 2, 0, input.dim(1).extent() - 1)) * 17) / 13 + cast(x % 3.4f + fma(cast(y), 0.5f, 1.2f)); h.define_extern("an_extern_stage", {f}, Int(16), 0, NameMangling::C); output(x, y) = cast(max(0, f(y, x) + f(x, y) + an_extern_func(x, y) + h())); diff --git a/python_bindings/src/halide/halide_/PyIROperator.cpp b/python_bindings/src/halide/halide_/PyIROperator.cpp index 2adbe3bee35c..aa2a667e79bd 100644 --- a/python_bindings/src/halide/halide_/PyIROperator.cpp +++ b/python_bindings/src/halide/halide_/PyIROperator.cpp @@ -126,6 +126,7 @@ void define_operators(py::module &m) { m.def("log", &log); m.def("pow", &pow); m.def("erf", &erf); + m.def("fma", &fma); m.def("fast_sin", &fast_sin); m.def("fast_cos", &fast_cos); m.def("fast_log", &fast_log); diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index a5dc3298be63..65892bff2c2c 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -1351,7 +1351,12 @@ void CodeGen_C::visit(const Mod *op) { string arg0 = print_expr(op->a); string arg1 = print_expr(op->b); ostringstream rhs; - rhs << "fmod(" << arg0 << ", " << arg1 << ")"; + if (op->type.is_scalar()) { + rhs << "::halide_cpp_fmod("; + } else { + rhs << print_type(op->type) << "_ops::fmod("; + } + rhs << arg0 << ", " << arg1 << ")"; print_assignment(op->type, rhs.str()); } else { visit_binop(op->type, op->a, op->b, "%"); @@ -1845,8 +1850,24 @@ void CodeGen_C::visit(const Call *op) { << " + " << print_expr(base_offset) << "), /*rw*/0, /*locality*/0), 0)"; } else if (op->is_intrinsic(Call::size_of_halide_buffer_t)) { rhs << "(sizeof(halide_buffer_t))"; + } else if (op->is_intrinsic(Call::strict_fma)) { + internal_assert(op->args.size() == 3) + << "Wrong number of args for strict_fma: " << op->args.size(); + if (op->type.is_scalar()) { + rhs << "::halide_cpp_fma(" + << print_expr(op->args[0]) << ", " + << print_expr(op->args[1]) << ", " + << print_expr(op->args[2]) << ")"; + } else { + rhs << print_type(op->type) << "_ops::fma(" + << print_expr(op->args[0]) << ", " + << print_expr(op->args[1]) << ", " + << print_expr(op->args[2]) << ")"; + } } else if (op->is_strict_float_intrinsic()) { - // This depends on the generated C++ being compiled without -ffast-math + // This depends on the generated C++ being compiled without + // -ffast-math. Note that this would not be correct for strict_fma, so + // we handle it separately above. Expr equiv = unstrictify_float(op); rhs << print_expr(equiv); } else if (op->is_intrinsic()) { diff --git a/src/CodeGen_C_prologue.template.cpp b/src/CodeGen_C_prologue.template.cpp index 5d85d585716c..72d0b362febe 100644 --- a/src/CodeGen_C_prologue.template.cpp +++ b/src/CodeGen_C_prologue.template.cpp @@ -1,9 +1,12 @@ /* MACHINE GENERATED By Halide. */ - #if !(__cplusplus >= 201103L || _MSVC_LANG >= 201103L) #error "This code requires C++11 (or later); please upgrade your compiler." #endif +#if !defined(__has_builtin) +#define __has_builtin(x) 0 +#endif + #include #include #include @@ -257,6 +260,32 @@ inline T halide_cpp_min(const T &a, const T &b) { return (a < b) ? a : b; } +template +inline T halide_cpp_fma(const T &a, const T &b, const T &c) { +#if __has_builtin(__builtin_fma) + return __builtin_fma(a, b, c); +#else + if (sizeof(T) == sizeof(float)) { + return fmaf(a, b, c); + } else { + return (T)fma((double)a, (double)b, (double)c); + } +#endif +} + +template +inline T halide_cpp_fmod(const T &a, const T &b) { +#if __has_builtin(__builtin_fmod) + return __builtin_fmod(a, b); +#else + if (sizeof(T) == sizeof(float)) { + return fmod(a, b); + } else { + return (T)fmod((double)a, (double)b); + } +#endif +} + template inline void halide_maybe_unused(const T &) { } diff --git a/src/CodeGen_C_vectors.template.cpp b/src/CodeGen_C_vectors.template.cpp index 003d2423414d..44a9b3c0eee5 100644 --- a/src/CodeGen_C_vectors.template.cpp +++ b/src/CodeGen_C_vectors.template.cpp @@ -2,10 +2,6 @@ #define __has_attribute(x) 0 #endif -#if !defined(__has_builtin) -#define __has_builtin(x) 0 -#endif - namespace { // We can't use std::array because that has its own overload of operator<, etc, @@ -150,6 +146,22 @@ class CppVectorOps { return r; } + static Vec fma(const Vec &a, const Vec &b, const Vec &c) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fma(a[i], b[i], c[i]); + } + return r; + } + + static Vec fmod(const Vec &a, const Vec &b) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fmod(a[i], b[i]); + } + return r; + } + static Mask logical_or(const Vec &a, const Vec &b) { CppVector r; for (size_t i = 0; i < Lanes; i++) { @@ -734,6 +746,22 @@ class NativeVectorOps { #endif } + static Vec fma(const Vec a, const Vec b, const Vec c) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fma(a[i], b[i], c[i]); + } + return r; + } + + static Vec fmod(const Vec a, const Vec b) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fmod(a[i], b[i]); + } + return r; + } + // The relational operators produce signed-int of same width as input; our codegen expects uint8. static Mask logical_or(const Vec a, const Vec b) { using T = typename NativeVectorComparisonType::type; diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index 4ce641e680ad..99c2ffb13a3b 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -1257,6 +1257,10 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s, void CodeGen_D3D12Compute_Dev::init_module() { debug(2) << "D3D12Compute device codegen init_module\n"; + // TODO: we could support strict float intrinsics with the precise qualifier + internal_assert(!any_strict_float) + << "strict float intrinsics not yet supported in d3d12compute backend"; + // wipe the internal kernel source src_stream.str(""); src_stream.clear(); diff --git a/src/CodeGen_GPU_Dev.cpp b/src/CodeGen_GPU_Dev.cpp index acd539664e96..595dbb82b9c5 100644 --- a/src/CodeGen_GPU_Dev.cpp +++ b/src/CodeGen_GPU_Dev.cpp @@ -245,6 +245,20 @@ void CodeGen_GPU_C::visit(const Call *op) { equiv.accept(this); } } + } else if (op->is_intrinsic(Call::strict_fma)) { + // All shader languages have fma + Expr equiv = Call::make(op->type, "fma", op->args, Call::PureExtern); + equiv.accept(this); + } else { + CodeGen_C::visit(op); + } +} + +void CodeGen_GPU_C::visit(const Mod *op) { + if (op->type.is_float()) { + // All shader languages have fmod + Expr equiv = Call::make(op->type, "fmod", {op->a, op->b}, Call::PureExtern); + equiv.accept(this); } else { CodeGen_C::visit(op); } diff --git a/src/CodeGen_GPU_Dev.h b/src/CodeGen_GPU_Dev.h index ee2950464526..be56625dac55 100644 --- a/src/CodeGen_GPU_Dev.h +++ b/src/CodeGen_GPU_Dev.h @@ -77,6 +77,15 @@ struct CodeGen_GPU_Dev { Device = 1, // Device/global memory fence Shared = 2 // Threadgroup/shared memory fence }; + + /** Some GPU APIs need to know what floating point mode we're in at kernel + * emission time, to emit appropriate pragmas. */ + bool any_strict_float = false; + +public: + void set_any_strict_float(bool any_strict_float) { + this->any_strict_float = any_strict_float; + } }; /** A base class for GPU backends that require C-like shader output. @@ -99,6 +108,7 @@ class CodeGen_GPU_C : public CodeGen_C { using CodeGen_C::visit; void visit(const Shuffle *op) override; void visit(const Call *op) override; + void visit(const Mod *op) override; std::string print_extern_call(const Call *op) override; diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 4a5b45475533..d64660ad9016 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -3306,6 +3306,7 @@ void CodeGen_LLVM::visit(const Call *op) { // Evaluate the args first outside the strict scope, as they may use // non-strict operations. std::vector new_args(op->args.size()); + std::vector to_pop; for (size_t i = 0; i < op->args.size(); i++) { const Expr &arg = op->args[i]; if (arg.as() || is_const(arg)) { @@ -3313,21 +3314,44 @@ void CodeGen_LLVM::visit(const Call *op) { } else { std::string name = unique_name('t'); sym_push(name, codegen(arg)); + to_pop.push_back(name); new_args[i] = Variable::make(arg.type(), name); } } - Expr call = Call::make(op->type, op->name, new_args, op->call_type); { ScopedValue old_in_strict_float(in_strict_float, true); - value = codegen(unstrictify_float(call.as())); + if (op->is_intrinsic(Call::strict_fma)) { + if (op->type.is_float() && op->type.bits() <= 16 && + upgrade_type_for_arithmetic(op->type) != op->type) { + // For (b)float16 and below, doing the fma as a + // double-precision fma is exact and is what llvm does. A + // double has enough bits of precision such that the add in + // the fma has no rounding error in the cases where the fma + // is going to return a finite float16. We do this + // legalization manually so that we can use our custom + // vectorizable float16 casts instead of letting llvm call + // library functions. + Type wide_t = Float(64, op->type.lanes()); + for (Expr &e : new_args) { + e = cast(wide_t, e); + } + Expr equiv = Call::make(wide_t, op->name, new_args, op->call_type); + equiv = cast(op->type, equiv); + value = codegen(equiv); + } else { + std::string name = "llvm.fma" + mangle_llvm_type(llvm_type_of(op->type)); + value = call_intrin(op->type, op->type.lanes(), name, new_args); + } + } else { + // Lower to something other than a call node + Expr call = Call::make(op->type, op->name, new_args, op->call_type); + value = codegen(unstrictify_float(call.as())); + } } - for (size_t i = 0; i < op->args.size(); i++) { - const Expr &arg = op->args[i]; - if (!arg.as() && !is_const(arg)) { - sym_pop(new_args[i].as()->name); - } + for (const auto &s : to_pop) { + sym_pop(s); } } else if (is_float16_transcendental(op) && !supports_call_as_float16(op)) { @@ -4739,23 +4763,29 @@ Value *CodeGen_LLVM::call_intrin(const Type &result_type, int intrin_lanes, Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes, const string &name, vector arg_values, bool scalable_vector_result, bool is_reduction) { + auto fix_vector_lanes_of_type = [&](const llvm::Type *t) { + if (intrin_lanes == 1 || is_reduction) { + return t->getScalarType(); + } else { + if (scalable_vector_result && effective_vscale != 0) { + return get_vector_type(result_type->getScalarType(), + intrin_lanes / effective_vscale, VectorTypeConstraint::VScale); + } else { + return get_vector_type(result_type->getScalarType(), + intrin_lanes, VectorTypeConstraint::Fixed); + } + } + }; + llvm::Function *fn = module->getFunction(name); if (!fn) { vector arg_types(arg_values.size()); for (size_t i = 0; i < arg_values.size(); i++) { - arg_types[i] = arg_values[i]->getType(); + llvm::Type *t = arg_values[i]->getType(); + arg_types[i] = fix_vector_lanes_of_type(t); } - llvm::Type *intrinsic_result_type = result_type->getScalarType(); - if (intrin_lanes > 1 && !is_reduction) { - if (scalable_vector_result && effective_vscale != 0) { - intrinsic_result_type = get_vector_type(result_type->getScalarType(), - intrin_lanes / effective_vscale, VectorTypeConstraint::VScale); - } else { - intrinsic_result_type = get_vector_type(result_type->getScalarType(), - intrin_lanes, VectorTypeConstraint::Fixed); - } - } + llvm::Type *intrinsic_result_type = fix_vector_lanes_of_type(result_type); FunctionType *func_t = FunctionType::get(intrinsic_result_type, arg_types, false); fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get()); fn->setCallingConv(CallingConv::C); @@ -4790,7 +4820,7 @@ Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes if (arg_i_lanes >= arg_lanes) { // Horizontally reducing intrinsics may have // arguments that have more lanes than the - // result. Assume that the horizontally reduce + // result. Assume that they horizontally reduce // neighboring elements... int reduce = arg_i_lanes / arg_lanes; args.push_back(slice_vector(arg_values[i], start * reduce, intrin_lanes * reduce)); diff --git a/src/CodeGen_Metal_Dev.cpp b/src/CodeGen_Metal_Dev.cpp index d31e97e30427..a60bd973cb29 100644 --- a/src/CodeGen_Metal_Dev.cpp +++ b/src/CodeGen_Metal_Dev.cpp @@ -834,6 +834,7 @@ void CodeGen_Metal_Dev::init_module() { // Write out the Halide math functions. src_stream << "#pragma clang diagnostic ignored \"-Wunused-function\"\n" + << "#pragma METAL fp math_mode(" << (any_strict_float ? "safe)\n" : "fast)\n") << "#include \n" << "using namespace metal;\n" // Seems like the right way to go. << "namespace {\n" diff --git a/src/CodeGen_OpenCL_Dev.cpp b/src/CodeGen_OpenCL_Dev.cpp index 1c945efb7cc1..807d75444ed4 100644 --- a/src/CodeGen_OpenCL_Dev.cpp +++ b/src/CodeGen_OpenCL_Dev.cpp @@ -1123,7 +1123,7 @@ void CodeGen_OpenCL_Dev::init_module() { // This identifies the program as OpenCL C (as opposed to SPIR). src_stream << "/*OpenCL C " << target.to_string() << "*/\n"; - src_stream << "#pragma OPENCL FP_CONTRACT ON\n"; + src_stream << "#pragma OPENCL FP_CONTRACT " << (any_strict_float ? "OFF\n" : "ON\n"); // Write out the Halide math functions. src_stream << "inline float float_from_bits(unsigned int x) {return as_float(x);}\n" diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 3204002652c9..28ee05137b38 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -220,6 +220,12 @@ void CodeGen_PTX_Dev::add_kernel(Stmt stmt, } void CodeGen_PTX_Dev::init_module() { + // This class uses multiple inheritance. It's a GPU device code generator, + // and also an llvm-based one. Both of these track strict_float presence, + // but OffloadGPULoops only sets the GPU device code generator flag, so here + // we set the CodeGen_LLVM flag to match. + CodeGen_LLVM::any_strict_float = CodeGen_GPU_Dev::any_strict_float; + init_context(); module = get_initial_module_for_ptx_device(target, context); @@ -249,6 +255,15 @@ void CodeGen_PTX_Dev::init_module() { function_does_not_access_memory(fn); fn->addFnAttr(llvm::Attribute::NoUnwind); } + + if (CodeGen_GPU_Dev::any_strict_float) { + debug(0) << "Setting strict fp math\n"; + set_strict_fp_math(); + in_strict_float = target.has_feature(Target::StrictFloat); + } else { + debug(0) << "Setting fast fp math\n"; + set_fast_fp_math(); + } } void CodeGen_PTX_Dev::visit(const Call *op) { @@ -611,13 +626,13 @@ vector CodeGen_PTX_Dev::compile_to_src() { internal_assert(llvm_target) << "Could not create LLVM target for " << triple.str() << "\n"; TargetOptions options; - options.AllowFPOpFusion = FPOpFusion::Fast; + options.AllowFPOpFusion = CodeGen_GPU_Dev::any_strict_float ? llvm::FPOpFusion::Strict : llvm::FPOpFusion::Fast; #if LLVM_VERSION < 210 - options.UnsafeFPMath = true; + options.UnsafeFPMath = !CodeGen_GPU_Dev::any_strict_float; #endif - options.NoInfsFPMath = true; - options.NoNaNsFPMath = true; - options.HonorSignDependentRoundingFPMathOption = false; + options.NoInfsFPMath = !CodeGen_GPU_Dev::any_strict_float; + options.NoNaNsFPMath = !CodeGen_GPU_Dev::any_strict_float; + options.HonorSignDependentRoundingFPMathOption = !CodeGen_GPU_Dev::any_strict_float; options.NoZerosInBSS = false; options.GuaranteedTailCallOpt = false; diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index 671f923ec183..e02a1a55a4ff 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -201,6 +201,7 @@ class CodeGen_Vulkan_Dev : public CodeGen_GPU_Dev { {"fast_pow_f32", GLSLstd450Pow}, {"floor_f16", GLSLstd450Floor}, {"floor_f32", GLSLstd450Floor}, + {"fma", GLSLstd450Fma}, {"log_f16", GLSLstd450Log}, {"log_f32", GLSLstd450Log}, {"sin_f16", GLSLstd450Sin}, @@ -1190,9 +1191,14 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Call *op) { e.accept(this); } } else if (op->is_strict_float_intrinsic()) { - // TODO: Enable/Disable RelaxedPrecision flags? - Expr e = unstrictify_float(op); - e.accept(this); + if (op->is_intrinsic(Call::strict_fma)) { + Expr builtin_call = Call::make(op->type, "fma", op->args, Call::PureExtern); + builtin_call.accept(this); + } else { + // TODO: Enable/Disable RelaxedPrecision flags? + Expr e = unstrictify_float(op); + e.accept(this); + } } else if (op->is_intrinsic(Call::IntrinsicOp::sorted_avg)) { internal_assert(op->args.size() == 2); // b > a, so the following works without widening: diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 0e63af410cce..3d2388fdf89c 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -524,7 +524,8 @@ void CodeGen_X86::visit(const Cast *op) { if (target.has_feature(Target::F16C) && dst.code() == Type::Float && src.code() == Type::Float && - (dst.bits() == 16 || src.bits() == 16)) { + (dst.bits() == 16 || src.bits() == 16) && + src.bits() <= 32) { // Don't use for narrowing casts from double - it results in a libm call // Node we use code() == Type::Float instead of is_float(), because we // don't want to catch bfloat casts. diff --git a/src/EmulateFloat16Math.cpp b/src/EmulateFloat16Math.cpp index 1fda58a838e9..9ffca1bb3b54 100644 --- a/src/EmulateFloat16Math.cpp +++ b/src/EmulateFloat16Math.cpp @@ -9,27 +9,44 @@ namespace Halide { namespace Internal { Expr bfloat16_to_float32(Expr e) { + const int lanes = e.type().lanes(); if (e.type().is_bfloat()) { e = reinterpret(e.type().with_code(Type::UInt), e); } - e = cast(UInt(32, e.type().lanes()), e); + e = cast(UInt(32, lanes), e); e = e << 16; - e = reinterpret(Float(32, e.type().lanes()), e); + e = reinterpret(Float(32, lanes), e); e = strict_float(e); return e; } -Expr float32_to_bfloat16(Expr e) { - internal_assert(e.type().bits() == 32); +Expr float_to_bfloat16(Expr e) { + const int lanes = e.type().lanes(); e = strict_float(e); - e = reinterpret(UInt(32, e.type().lanes()), e); - // We want to round ties to even, so before truncating either - // add 0x8000 (0.5) to odd numbers or 0x7fff (0.499999) to - // even numbers. - e += 0x7fff + ((e >> 16) & 1); + + Expr err; + // First round to float and record any gain of loss of magnitude + if (e.type().bits() == 64) { + Expr f = cast(Float(32, lanes), e); + err = abs(e) - abs(f); + e = f; + } else { + internal_assert(e.type().bits() == 32); + } + e = reinterpret(UInt(32, lanes), e); + + // We want to round ties to even, so if we have no error recorded above, + // before truncating either add 0x8000 (0.5) to odd numbers or 0x7fff + // (0.499999) to even numbers. If we have error, break ties using that + // instead. + Expr tie_breaker = (e >> 16) & 1; // 1 when rounding down would go to odd + if (err.defined()) { + tie_breaker = ((err == 0) & tie_breaker) | (err > 0); + } + e += tie_breaker + 0x7fff; e = (e >> 16); - e = cast(UInt(16, e.type().lanes()), e); - e = reinterpret(BFloat(16, e.type().lanes()), e); + e = cast(UInt(16, lanes), e); + e = reinterpret(BFloat(16, lanes), e); return e; } @@ -63,43 +80,67 @@ Expr float16_to_float32(Expr value) { return f32; } -Expr float32_to_float16(Expr value) { +Expr float_to_float16(Expr value) { // We're about the sniff the bits of a float, so we should // guard it with strict float to ensure we don't do things // like assume it can't be denormal. value = strict_float(value); - Type f32_t = Float(32, value.type().lanes()); + const int src_bits = value.type().bits(); + + Type float_t = Float(src_bits, value.type().lanes()); Type f16_t = Float(16, value.type().lanes()); - Type u32_t = UInt(32, value.type().lanes()); + Type bits_t = UInt(src_bits, value.type().lanes()); Type u16_t = UInt(16, value.type().lanes()); - Expr bits = reinterpret(u32_t, value); + Expr bits = reinterpret(bits_t, value); // Extract the sign bit - Expr sign = bits & make_const(u32_t, 0x80000000); + Expr sign = bits & make_const(bits_t, (uint64_t)1 << (src_bits - 1)); bits = bits ^ sign; // Test the endpoints - Expr is_denorm = (bits < make_const(u32_t, 0x38800000)); - Expr is_inf = (bits >= make_const(u32_t, 0x47800000)); - Expr is_nan = (bits > make_const(u32_t, 0x7f800000)); + + // Smallest input representable as normal float16 (2^-14) + Expr two_to_the_minus_14 = src_bits == 32 ? + make_const(bits_t, 0x38800000) : + make_const(bits_t, (uint64_t)0x3f10000000000000ULL); + Expr is_denorm = bits < two_to_the_minus_14; + + // Smallest input too big to represent as a float16 (2^16) + Expr two_to_the_16 = src_bits == 32 ? + make_const(bits_t, 0x47800000) : + make_const(bits_t, (uint64_t)0x40f0000000000000ULL); + Expr is_inf = bits >= two_to_the_16; + + // Check if the input is a nan, which is anything bigger than an infinity bit pattern + Expr input_inf_bits = src_bits == 32 ? + make_const(bits_t, 0x7f800000) : + make_const(bits_t, (uint64_t)0x7ff0000000000000ULL); + Expr is_nan = bits > input_inf_bits; // Denorms are linearly spaced, so we can handle them // by scaling up the input as a float and using the // existing int-conversion rounding instructions. - Expr denorm_bits = cast(u16_t, strict_float(round(strict_float(reinterpret(f32_t, bits + 0x0c000000))))); + Expr two_to_the_24 = src_bits == 32 ? + make_const(bits_t, 0x0c000000) : + make_const(bits_t, (uint64_t)0x0180000000000000ULL); + Expr denorm_bits = cast(u16_t, strict_float(round(reinterpret(float_t, bits + two_to_the_24)))); Expr inf_bits = make_const(u16_t, 0x7c00); Expr nan_bits = make_const(u16_t, 0x7fff); // We want to round to nearest even, so we add either // 0.5 if the integer part is odd, or 0.4999999 if the // integer part is even, then truncate. - bits += (bits >> 13) & 1; - bits += 0xfff; - bits = bits >> 13; + const int float16_mantissa_bits = 10; + const int input_mantissa_bits = src_bits == 32 ? 23 : 52; + const int bits_lost = input_mantissa_bits - float16_mantissa_bits; + bits += (bits >> bits_lost) & 1; + bits += make_const(bits_t, ((uint64_t)1 << (bits_lost - 1)) - 1); + bits = cast(u16_t, bits >> bits_lost); + // Rebias the exponent - bits -= 0x1c000; + bits -= 0x4000; // Truncate the top bits of the exponent bits = bits & 0x7fff; bits = select(is_denorm, denorm_bits, @@ -107,7 +148,7 @@ Expr float32_to_float16(Expr value) { is_nan, nan_bits, cast(u16_t, bits)); // Recover the sign bit - bits = bits | cast(u16_t, sign >> 16); + bits = bits | cast(u16_t, sign >> (src_bits - 16)); return common_subexpression_elimination(reinterpret(f16_t, bits)); } @@ -157,7 +198,7 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *op) { Expr e = Call::make(t, it->second, new_args, op->call_type, op->func, op->value_index, op->image, op->param); if (op->type.is_float()) { - e = float32_to_float16(e); + e = float_to_float16(e); } internal_assert(e.type() == op->type); return e; @@ -171,6 +212,7 @@ Expr lower_float16_cast(const Cast *op) { Type src = op->value.type(); Type dst = op->type; Type f32 = Float(32, dst.lanes()); + Type f64 = Float(64, dst.lanes()); Expr val = op->value; if (src.is_bfloat()) { @@ -183,10 +225,20 @@ Expr lower_float16_cast(const Cast *op) { if (dst.is_bfloat()) { internal_assert(dst.bits() == 16); - val = float32_to_bfloat16(cast(f32, val)); + if (src.bits() > 32) { + val = cast(f64, val); + } else { + val = cast(f32, val); + } + val = float_to_bfloat16(val); } else if (dst.is_float() && dst.bits() < 32) { internal_assert(dst.bits() == 16); - val = float32_to_float16(cast(f32, val)); + if (src.bits() > 32) { + val = cast(f64, val); + } else { + val = cast(f32, val); + } + val = float_to_float16(val); } return cast(dst, val); diff --git a/src/EmulateFloat16Math.h b/src/EmulateFloat16Math.h index de1a5e091588..f61de7456fbc 100644 --- a/src/EmulateFloat16Math.h +++ b/src/EmulateFloat16Math.h @@ -19,8 +19,8 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *); /** Cast to/from float and bfloat using bitwise math. */ //@{ -Expr float32_to_bfloat16(Expr e); -Expr float32_to_float16(Expr e); +Expr float_to_bfloat16(Expr e); +Expr float_to_float16(Expr e); Expr float16_to_float32(Expr e); Expr bfloat16_to_float32(Expr e); Expr lower_float16_cast(const Cast *op); diff --git a/src/Float16.cpp b/src/Float16.cpp index 80c96a38e6f1..1e9e789f476b 100644 --- a/src/Float16.cpp +++ b/src/Float16.cpp @@ -9,7 +9,10 @@ namespace Internal { // Conversion routines to and from float cribbed from Christian Rau's // half library (half.sourceforge.net) -uint16_t float_to_float16(float value) { +template +uint16_t float_to_float16(T value) { + static_assert(std::is_same_v || std::is_same_v, + "float_to_float16 only supports float and double types"); // Start by copying over the sign bit uint16_t bits = std::signbit(value) << 15; @@ -40,14 +43,14 @@ uint16_t float_to_float16(float value) { // We've normalized value as much as possible. Put the integer // portion of it into the mantissa. - float ival; - float frac = std::modf(value, &ival); + T ival; + T frac = std::modf(value, &ival); bits += (uint16_t)(std::abs((int)ival)); // Now consider the fractional part. We round to nearest with ties // going to even. frac = std::abs(frac); - bits += (frac > 0.5f) | ((frac == 0.5f) & bits); + bits += (frac > T(0.5)) | ((frac == T(0.5)) & bits); return bits; } @@ -341,6 +344,19 @@ uint16_t float_to_bfloat16(float f) { return ret >> 16; } +uint16_t float_to_bfloat16(double f) { + // Coming from double is a little tricker. We first narrow to float and + // record if any magnitude was lost or gained in the process. If so we'll + // use that to break ties instead of testing whether or not truncation would + // return odd. + float f32 = (float)f; + const double err = std::abs(f) - (double)std::abs(f32); + uint32_t ret; + memcpy(&ret, &f32, sizeof(float)); + ret += 0x7fff + (((err >= 0) & ((ret >> 16) & 1)) | (err > 0)); + return ret >> 16; +} + float bfloat16_to_float(uint16_t b) { // Assume little-endian floats uint16_t bits[2] = {0, b}; @@ -362,7 +378,17 @@ float16_t::float16_t(double value) } float16_t::float16_t(int value) - : data(float_to_float16(value)) { + : data(float_to_float16((float)value)) { + // integers of any size that map to finite float16s are all representable as + // float, so we can go via the float conversion method. +} + +float16_t::float16_t(int64_t value) + : data(float_to_float16((float)value)) { +} + +float16_t::float16_t(uint64_t value) + : data(float_to_float16((float)value)) { } float16_t::operator float() const { @@ -464,7 +490,15 @@ bfloat16_t::bfloat16_t(double value) } bfloat16_t::bfloat16_t(int value) - : data(float_to_bfloat16(value)) { + : data(float_to_bfloat16((double)value)) { +} + +bfloat16_t::bfloat16_t(int64_t value) + : data(float_to_bfloat16((double)value)) { +} + +bfloat16_t::bfloat16_t(uint64_t value) + : data(float_to_bfloat16((double)value)) { } bfloat16_t::operator float() const { diff --git a/src/Float16.h b/src/Float16.h index d3c285d6c09f..376813cbd507 100644 --- a/src/Float16.h +++ b/src/Float16.h @@ -32,6 +32,8 @@ struct float16_t { explicit float16_t(float value); explicit float16_t(double value); explicit float16_t(int value); + explicit float16_t(int64_t value); + explicit float16_t(uint64_t value); // @} /** Construct a float16_t with the bits initialised to 0. This represents @@ -175,6 +177,8 @@ struct bfloat16_t { explicit bfloat16_t(float value); explicit bfloat16_t(double value); explicit bfloat16_t(int value); + explicit bfloat16_t(int64_t value); + explicit bfloat16_t(uint64_t value); // @} /** Construct a bfloat16_t with the bits initialised to 0. This represents diff --git a/src/IR.cpp b/src/IR.cpp index c844c672656a..5ea0193908bb 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -678,8 +678,10 @@ const char *const intrinsic_op_names[] = { "sliding_window_marker", "sorted_avg", "strict_add", + "strict_cast", "strict_div", "strict_eq", + "strict_fma", "strict_le", "strict_lt", "strict_max", diff --git a/src/IR.h b/src/IR.h index 6dc0204b89ec..43cbbf4fb7c1 100644 --- a/src/IR.h +++ b/src/IR.h @@ -626,8 +626,10 @@ struct Call : public ExprNode { // them as reals and ignoring the existence of nan and inf. Using these // intrinsics instead prevents any such optimizations. strict_add, + strict_cast, strict_div, strict_eq, + strict_fma, strict_le, strict_lt, strict_max, @@ -792,14 +794,16 @@ struct Call : public ExprNode { bool is_strict_float_intrinsic() const { return is_intrinsic( {Call::strict_add, + Call::strict_cast, Call::strict_div, + Call::strict_eq, + Call::strict_fma, + Call::strict_lt, + Call::strict_le, Call::strict_max, Call::strict_min, Call::strict_mul, - Call::strict_sub, - Call::strict_lt, - Call::strict_le, - Call::strict_eq}); + Call::strict_sub}); } static const IRNodeType _node_type = IRNodeType::Call; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index f1d2254abb27..285744ba6eef 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -2280,6 +2280,19 @@ Expr erf(const Expr &x) { return halide_erf(x); } +Expr fma(const Expr &a, const Expr &b, const Expr &c) { + user_assert(a.type().is_float()) << "fma requires floating-point arguments."; + user_assert(a.type() == b.type() && a.type() == c.type()) + << "All arguments to fma must have the same type."; + + // TODO: Once we use LLVM's native bfloat type instead of treating them as + // ints, we should be able to remove this assert. Currently, it tries to + // codegen an integer fma. + user_assert(!a.type().is_bfloat()) << "fma does not yet support bfloat types."; + + return Call::make(a.type(), Call::strict_fma, {a, b, c}, Call::PureIntrinsic); +} + Expr fast_pow(Expr x, Expr y) { if (auto i = as_const_int(y)) { return raise_to_integer_power(std::move(x), *i); diff --git a/src/IROperator.h b/src/IROperator.h index d6d33a1cf82e..b9e40b898ec1 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -978,6 +978,13 @@ Expr pow(Expr x, Expr y); * mantissa. Vectorizes cleanly. */ Expr erf(const Expr &x); +/** Fused multiply-add. fma(a, b, c) is equivalent to a * b + c, but only + * rounded once at the end. For most targets, when not in a strict_float + * context, Halide will already generate fma instructions from a * b + c. This + * intrinsic's main purpose is to request a true fma inside a strict_float + * context. A true fma will be emulated on targets without one. */ +Expr fma(const Expr &, const Expr &, const Expr &); + /** Fast vectorizable approximation to some trigonometric functions for * Float(32). Absolute approximation error is less than 1e-5. Slow on x86 if * you don't have at least sse 4.1. */ diff --git a/src/Lower.cpp b/src/Lower.cpp index fcbc66747242..32b64e83a2bd 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -493,7 +493,7 @@ void lower_impl(const vector &output_funcs, if (t.has_gpu_feature()) { debug(1) << "Offloading GPU loops...\n"; - s = inject_gpu_offload(s, t); + s = inject_gpu_offload(s, t, any_strict_float); debug(2) << "Lowering after splitting off GPU loops:\n" << s << "\n\n"; } else { diff --git a/src/OffloadGPULoops.cpp b/src/OffloadGPULoops.cpp index e93f67e8bfef..11b8c3ccecf3 100644 --- a/src/OffloadGPULoops.cpp +++ b/src/OffloadGPULoops.cpp @@ -245,7 +245,7 @@ class InjectGpuOffload : public IRMutator { } public: - InjectGpuOffload(const Target &target) + InjectGpuOffload(const Target &target, bool any_strict_float) : target(target) { Target device_target = target; // For the GPU target we just want to pass the flags, to avoid the @@ -266,12 +266,16 @@ class InjectGpuOffload : public IRMutator { cgdev[DeviceAPI::D3D12Compute] = new_CodeGen_D3D12Compute_Dev(device_target); } if (target.has_feature(Target::Vulkan)) { - cgdev[DeviceAPI::Vulkan] = new_CodeGen_Vulkan_Dev(target); + cgdev[DeviceAPI::Vulkan] = new_CodeGen_Vulkan_Dev(device_target); } if (target.has_feature(Target::WebGPU)) { cgdev[DeviceAPI::WebGPU] = new_CodeGen_WebGPU_Dev(device_target); } + for (auto &i : cgdev) { + i.second->set_any_strict_float(any_strict_float); + } + internal_assert(!cgdev.empty()) << "Requested unknown GPU target: " << target.to_string() << "\n"; } @@ -315,8 +319,8 @@ class InjectGpuOffload : public IRMutator { } // namespace -Stmt inject_gpu_offload(const Stmt &s, const Target &host_target) { - return InjectGpuOffload(host_target).inject(s); +Stmt inject_gpu_offload(const Stmt &s, const Target &host_target, bool any_strict_float) { + return InjectGpuOffload(host_target, any_strict_float).inject(s); } } // namespace Internal diff --git a/src/OffloadGPULoops.h b/src/OffloadGPULoops.h index d927f1a8b780..97cd7737271f 100644 --- a/src/OffloadGPULoops.h +++ b/src/OffloadGPULoops.h @@ -17,7 +17,7 @@ namespace Internal { /** Pull loops marked with GPU device APIs to a separate * module, and call them through the appropriate host runtime module. */ -Stmt inject_gpu_offload(const Stmt &s, const Target &host_target); +Stmt inject_gpu_offload(const Stmt &s, const Target &host_target, bool any_strict_float); } // namespace Internal } // namespace Halide diff --git a/src/Parameter.h b/src/Parameter.h index 56a441f5ba35..ffe203241599 100644 --- a/src/Parameter.h +++ b/src/Parameter.h @@ -128,7 +128,7 @@ class Parameter { static_assert(sizeof(T) <= sizeof(halide_scalar_value_t)); const auto sv = scalar_data_checked(type_of()); T t; - memcpy(&t, &sv.u.u64, sizeof(t)); + memcpy((char *)(&t), &sv.u.u64, sizeof(t)); return t; } diff --git a/src/StrictifyFloat.cpp b/src/StrictifyFloat.cpp index 13dd0873bb12..37263d00c89b 100644 --- a/src/StrictifyFloat.cpp +++ b/src/StrictifyFloat.cpp @@ -83,6 +83,16 @@ class Strictify : public IRMutator { return IRMutator::visit(op); } } + + Expr visit(const Cast *op) override { + if (op->value.type().is_float() && + op->type.is_float()) { + return Call::make(op->type, Call::strict_cast, + {mutate(op->value)}, Call::PureIntrinsic); + } else { + return IRMutator::visit(op); + } + } }; const std::set strict_externs = { @@ -142,6 +152,10 @@ Expr unstrictify_float(const Call *op) { return op->args[0] <= op->args[1]; } else if (op->is_intrinsic(Call::strict_eq)) { return op->args[0] == op->args[1]; + } else if (op->is_intrinsic(Call::strict_fma)) { + return op->args[0] * op->args[1] + op->args[2]; + } else if (op->is_intrinsic(Call::strict_cast)) { + return cast(op->type, op->args[0]); } else { internal_error << "Missing lowering of strict float intrinsic: " << Expr(op) << "\n"; diff --git a/src/WasmExecutor.cpp b/src/WasmExecutor.cpp index 55ffa7325e22..0aa50ff54ae4 100644 --- a/src/WasmExecutor.cpp +++ b/src/WasmExecutor.cpp @@ -914,6 +914,20 @@ wabt::Result wabt_posix_math_2(wabt::interp::Thread &thread, return wabt::Result::Ok; } +template +wabt::Result wabt_posix_math_3(wabt::interp::Thread &thread, + const wabt::interp::Values &args, + wabt::interp::Values &results, + wabt::interp::Trap::Ptr *trap) { + internal_assert(args.size() == 3); + const T in1 = args[0].Get(); + const T in2 = args[1].Get(); + const T in3 = args[2].Get(); + const T out = some_func(in1, in2, in3); + results[0] = wabt::interp::Value::Make(out); + return wabt::Result::Ok; +} + #define WABT_HOST_CALLBACK(x) \ wabt::Result wabt_jit_##x##_callback(wabt::interp::Thread &thread, \ const wabt::interp::Values &args, \ @@ -1998,6 +2012,20 @@ void wasm_jit_posix_math2_callback(const v8::FunctionCallbackInfo &ar args.GetReturnValue().Set(load_scalar(context, out)); } +template +void wasm_jit_posix_math3_callback(const v8::FunctionCallbackInfo &args) { + Isolate *isolate = args.GetIsolate(); + Local context = isolate->GetCurrentContext(); + HandleScope scope(isolate); + + const T in1 = args[0]->NumberValue(context).ToChecked(); + const T in2 = args[1]->NumberValue(context).ToChecked(); + const T in3 = args[2]->NumberValue(context).ToChecked(); + const T out = some_func(in1, in2, in3); + + args.GetReturnValue().Set(load_scalar(context, out)); +} + enum ExternWrapperFieldSlots { kTrampolineWrap, kArgTypesWrap @@ -2122,6 +2150,7 @@ using HostCallbackMap = std::unordered_map} #define DEFINE_POSIX_MATH_CALLBACK2(t, f) {#f, wabt_posix_math_2} +#define DEFINE_POSIX_MATH_CALLBACK3(t, f) {#f, wabt_posix_math_3} #endif @@ -2131,6 +2160,7 @@ using HostCallbackMap = std::unordered_map; #define DEFINE_CALLBACK(f) {#f, wasm_jit_##f##_callback} #define DEFINE_POSIX_MATH_CALLBACK(t, f) {#f, wasm_jit_posix_math_callback} #define DEFINE_POSIX_MATH_CALLBACK2(t, f) {#f, wasm_jit_posix_math2_callback} +#define DEFINE_POSIX_MATH_CALLBACK3(t, f) {#f, wasm_jit_posix_math3_callback} #endif const HostCallbackMap &get_host_callback_map() { @@ -2199,7 +2229,11 @@ const HostCallbackMap &get_host_callback_map() { DEFINE_POSIX_MATH_CALLBACK2(float, fmaxf), DEFINE_POSIX_MATH_CALLBACK2(double, fmax), DEFINE_POSIX_MATH_CALLBACK2(float, powf), - DEFINE_POSIX_MATH_CALLBACK2(double, pow)}; + DEFINE_POSIX_MATH_CALLBACK2(double, pow), + + DEFINE_POSIX_MATH_CALLBACK3(float, fmaf), + DEFINE_POSIX_MATH_CALLBACK3(double, fma), + }; return m; } diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 4bce8789875e..144a7829f80f 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -317,6 +317,7 @@ tests(GROUPS correctness store_in.cpp strict_float.cpp strict_float_bounds.cpp + strict_fma.cpp strided_load.cpp target.cpp target_query.cpp diff --git a/test/correctness/float16_t.cpp b/test/correctness/float16_t.cpp index 9e917a6216e6..120682cd86d8 100644 --- a/test/correctness/float16_t.cpp +++ b/test/correctness/float16_t.cpp @@ -301,6 +301,72 @@ int run_test() { } } + { + for (double f : {1.0, -1.0, 0.235, -0.235, 1e-7, -1e-7}) { + { + // Test double -> float16 doesn't have double-rounding issues + float16_t k{f}; + float16_t k_plus_eps = float16_t::make_from_bits(k.to_bits() + 1); + const bool k_is_odd = k.to_bits() & 1; + float16_t to_even = k_is_odd ? k_plus_eps : k; + float16_t to_odd = k_is_odd ? k : k_plus_eps; + float halfway = (float(k) + float(k_plus_eps)) / 2.f; + + // We expect ties to round to even + assert(float16_t(halfway) == to_even); + // Now let's construct a case where it *should* have rounded to + // odd if rounding directly from double, but rounding via + // float does the wrong thing. + double halfway_plus_eps = std::nextafter(halfway, (double)to_odd); + assert(std::abs(halfway_plus_eps - (double)to_odd) < + std::abs(halfway_plus_eps - (double)to_even)); + + assert(float(halfway_plus_eps) == halfway); +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 + assert(_Float16(halfway_plus_eps) == _Float16(float(to_odd))); +#endif + assert(float16_t(halfway_plus_eps) == to_odd); + + // Now test the same thing in generated code. We need strict float to + // prevent Halide from fusing multiple float casts into one. + Param p; + p.set(halfway_plus_eps); + // halfway plus epsilon rounds to exactly halfway as a float + assert(evaluate(strict_float(cast(Float(32), p))) == halfway); + // So if we go via float we get the even outcome, because + // exactly halfway rounds to even + assert(evaluate(strict_float(cast(Float(16), cast(Float(32), p)))) == to_even); + // But if we go direct, we should go to odd, because it's closer + assert(evaluate(strict_float(cast(Float(16), p))) == to_odd); + } + + { + // Test the same things for bfloat + bfloat16_t k{f}; + bfloat16_t k_plus_eps = bfloat16_t::make_from_bits(k.to_bits() + 1); + const bool k_is_odd = k.to_bits() & 1; + + bfloat16_t to_even = k_is_odd ? k_plus_eps : k; + bfloat16_t to_odd = k_is_odd ? k : k_plus_eps; + float halfway = (float(k) + float(k_plus_eps)) / 2.f; + + assert(bfloat16_t(halfway) == to_even); + double halfway_plus_eps = std::nextafter(halfway, (double)to_odd); + assert(std::abs(halfway_plus_eps - (double)to_odd) < + std::abs(halfway_plus_eps - (double)to_even)); + + assert(float(halfway_plus_eps) == halfway); + assert(bfloat16_t(halfway_plus_eps) == to_odd); + + Param p; + p.set(halfway_plus_eps); + assert(evaluate(strict_float(cast(Float(32), p))) == halfway); + assert(evaluate(strict_float(cast(BFloat(16), cast(Float(32), p)))) == to_even); + assert(evaluate(strict_float(cast(BFloat(16), p))) == to_odd); + } + } + } + // Enable to read assembly generated by the conversion routines if ((false)) { // Intentional dead code. Extra parens to pacify clang-tidy. Func src, to_f16, from_f16; @@ -309,11 +375,11 @@ int run_test() { to_f16(x) = cast(src(x)); from_f16(x) = cast(to_f16(x)); - src.compute_root().vectorize(x, 8, TailStrategy::RoundUp); - to_f16.compute_root().vectorize(x, 8, TailStrategy::RoundUp); - from_f16.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + src.compute_root().vectorize(x, 16, TailStrategy::RoundUp); + to_f16.compute_root().vectorize(x, 16, TailStrategy::RoundUp); + from_f16.compute_root().vectorize(x, 16, TailStrategy::RoundUp); - from_f16.compile_to_assembly("/dev/stdout", {}, Target("host-no_asserts-no_bounds_query-no_runtime-disable_llvm_loop_unroll-disable_llvm_loop_vectorize")); + from_f16.compile_to_assembly("/dev/stdout", {}, Target("host-no_asserts-no_bounds_query-no_runtime")); } // Check infinity handling for both float16_t and Halide codegen. diff --git a/test/correctness/simd_op_check_arm.cpp b/test/correctness/simd_op_check_arm.cpp index c3021578beb5..b3238d023dbe 100644 --- a/test/correctness/simd_op_check_arm.cpp +++ b/test/correctness/simd_op_check_arm.cpp @@ -388,16 +388,13 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vmla.i16" : "mla", 4 * w, u16_1 + u16_2 * u16_3); check(arm32 ? "vmla.i32" : "mla", 2 * w, i32_1 + i32_2 * i32_3); check(arm32 ? "vmla.i32" : "mla", 2 * w, u32_1 + u32_2 * u32_3); - if (w == 1 || w == 2) { - // Older llvms don't always fuse this at non-native widths - // TODO: Re-enable this after fixing https://github.com/halide/Halide/issues/3477 - // check(arm32 ? "vmla.f32" : "fmla", 2*w, f32_1 + f32_2*f32_3); - if (!arm32) - check(arm32 ? "vmla.f32" : "fmla", 2 * w, f32_1 + f32_2 * f32_3); - } - if (!arm32 && target.has_feature(Target::ARMFp16)) { - check("fmlal", 4 * w, f32_1 + widening_mul(f16_2, f16_3)); - check("fmlal2", 8 * w, widening_mul(f16_1, f16_2) + f32_3); + if (!arm32) { + check("fmla", 2 * w, f32_1 * f32_2 + f32_3); + check("fmla", 2 * w, fma(f32_1, f32_2, f32_3)); + if (target.has_feature(Target::ARMFp16)) { + check("fmlal", 4 * w, f32_1 + widening_mul(f16_2, f16_3)); + check("fmlal2", 8 * w, widening_mul(f16_1, f16_2) + f32_3); + } } // VMLS I, F F, D Multiply Subtract @@ -407,12 +404,8 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vmls.i16" : "mls", 4 * w, u16_1 - u16_2 * u16_3); check(arm32 ? "vmls.i32" : "mls", 2 * w, i32_1 - i32_2 * i32_3); check(arm32 ? "vmls.i32" : "mls", 2 * w, u32_1 - u32_2 * u32_3); - if (w == 1 || w == 2) { - // Older llvms don't always fuse this at non-native widths - // TODO: Re-enable this after fixing https://github.com/halide/Halide/issues/3477 - // check(arm32 ? "vmls.f32" : "fmls", 2*w, f32_1 - f32_2*f32_3); - if (!arm32) - check(arm32 ? "vmls.f32" : "fmls", 2 * w, f32_1 - f32_2 * f32_3); + if (!arm32) { + check("fmls", 2 * w, f32_1 - f32_2 * f32_3); } // VMLAL I - Multiply Accumulate Long diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index b5f4c0fa9f64..0850d8f02bbc 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -411,23 +411,25 @@ class SimdOpCheckX86 : public SimdOpCheckTest { check(use_avx512 ? "vrsqrt*ps" : "vrsqrtps*ymm", 8, fast_inverse_sqrt(f32_1)); check(use_avx512 ? "vrcp*ps" : "vrcpps*ymm", 8, fast_inverse(f32_1)); -#if 0 - // Not implemented in the front end. - check("vandnps", 8, bool1 & (!bool2)); - check("vandps", 8, bool1 & bool2); - check("vorps", 8, bool1 | bool2); - check("vxorps", 8, bool1 ^ bool2); -#endif + // Some llvm's don't use kandw, but instead predicate the computation of bool_2 + // using the result of bool_1 + // check(use_avx512 ? "kandw" : "vandps", 8, bool_1 & bool_2); + check(use_avx512 ? "korw" : "vorps", 8, bool_1 | bool_2); + check(use_avx512 ? "kxorw" : "vxorps", 8, bool_1 ^ bool_2); check("vaddps*ymm", 8, f32_1 + f32_2); check("vaddpd*ymm", 4, f64_1 + f64_2); check("vmulps*ymm", 8, f32_1 * f32_2); check("vmulpd*ymm", 4, f64_1 * f64_2); + check("vfmadd*ps*ymm", 8, f32_1 * f32_2 + f32_3); + check("vfmadd*pd*ymm", 4, f64_1 * f64_2 + f64_3); + check("vfmadd*ps*ymm", 8, fma(f32_1, f32_2, f32_3)); + check("vfmadd*pd*ymm", 4, fma(f64_1, f64_2, f64_3)); check("vsubps*ymm", 8, f32_1 - f32_2); check("vsubpd*ymm", 4, f64_1 - f64_2); - // LLVM no longer generates division instruction when fast-math is on - // check("vdivps", 8, f32_1 / f32_2); - // check("vdivpd", 4, f64_1 / f64_2); + + check("vdivps", 8, strict_float(f32_1 / f32_2)); + check("vdivpd", 4, strict_float(f64_1 / f64_2)); check("vminps*ymm", 8, min(f32_1, f32_2)); check("vminpd*ymm", 4, min(f64_1, f64_2)); check("vmaxps*ymm", 8, max(f32_1, f32_2)); diff --git a/test/correctness/strict_fma.cpp b/test/correctness/strict_fma.cpp new file mode 100644 index 000000000000..f5f736cd5940 --- /dev/null +++ b/test/correctness/strict_fma.cpp @@ -0,0 +1,106 @@ +#include "Halide.h" + +using namespace Halide; + +template +int test() { + std::cout << "Testing " << type_of() << "\n"; + Func f{"f"}, g{"g"}; + Param b{"b"}, c{"c"}; + Var x{"x"}; + + f(x) = fma(cast(x), b, c); + g(x) = strict_float(cast(x) * b + c); + + Target t = get_jit_target_from_environment(); + + if (std::is_same_v && + t.has_gpu_feature() && + // Metal on x86 does not seem to respect strict float despite setting + // the appropriate pragma. + !(t.arch == Target::X86 && t.has_feature(Target::Metal)) && + // TODO: Vulkan does not respect strict_float yet: + // https://github.com/halide/Halide/issues/7239 + !t.has_feature(Target::Vulkan) && + // WebGPU does not and may never respect strict_float. There's no way to + // ask for it in the language. + !t.has_feature(Target::WebGPU)) { + Var xo{"xo"}, xi{"xi"}; + f.gpu_tile(x, xo, xi, 32); + g.gpu_tile(x, xo, xi, 32); + } else { + // Use a non-native vector width, to also test legalization + f.vectorize(x, 5); + g.vectorize(x, 5); + } + + b.set((T)1.111111111); + c.set((T)1.0101010101); + + Buffer with_fma = f.realize({1024}); + Buffer without_fma = g.realize({1024}); + + with_fma.copy_to_host(); + without_fma.copy_to_host(); + + bool saw_error = false; + for (int i = 0; i < with_fma.width(); i++) { + + Bits fma_bits = Internal::reinterpret_bits(with_fma(i)); + Bits no_fma_bits = Internal::reinterpret_bits(without_fma(i)); + + if constexpr (sizeof(T) >= 4) { + T correct_fma = std::fma((T)i, b.get(), c.get()); + if (with_fma(i) != correct_fma) { + printf("fma result does not match std::fma:\n" + " fma(%d, %10.10g, %10.10g) = %10.10g (0x%llx)\n" + " but std::fma gives %10.10g (0x%llx)\n", + i, + (double)b.get(), (double)c.get(), + (double)with_fma(i), + (long long unsigned)fma_bits, + (double)correct_fma, + (long long unsigned)Internal::reinterpret_bits(correct_fma)); + return -1; + } + } + + if (with_fma(i) == without_fma(i)) { + continue; + } + + saw_error = true; + // For the specific positive numbers picked above, the rounding error is + // at most 1 ULP. Note that it's possible to make much larger rounding + // errors if you introduce some catastrophic cancellation. + if (fma_bits + 1 != no_fma_bits && + fma_bits - 1 != no_fma_bits) { + printf("Difference greater than 1 ULP: %10.10g (0x%llx) vs %10.10g (0x%llx)!\n", + (double)with_fma(i), (long long unsigned)fma_bits, + (double)without_fma(i), (long long unsigned)no_fma_bits); + return -1; + } + } + + if (!saw_error) { + printf("There should have occasionally been a 1 ULP difference between fma " + "and non-fma results. strict_float may not be respected on this target.\n"); + // Uncomment to inspect assembly + // g.compile_to_assembly("/dev/stdout", {b, c}, get_jit_target_from_environment()); + return -1; + } + + return 0; +} + +int main(int argc, char **argv) { + + if (test() || + test() || + test()) { + return -1; + } + + printf("Success!\n"); + return 0; +}