From a9a12a248dfc40951b45a74cd5da8fcaafb2d9a4 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 11 Nov 2025 13:28:09 -0800 Subject: [PATCH 01/17] Top-down type inference support --- src/Func.cpp | 7 ++++++ src/Function.cpp | 3 +++ src/IROperator.cpp | 45 +++++++++++++++++++++++++++++++++++++ src/IROperator.h | 4 ++++ src/IRPrinter.cpp | 2 ++ src/Type.h | 14 +++++++++++- src/runtime/HalideRuntime.h | 11 ++++----- 7 files changed, 80 insertions(+), 6 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index cf8904ec2d31..3380d4d34c1c 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -3304,8 +3304,15 @@ Stage FuncRef::operator/=(const FuncRef &e) { } FuncRef::operator Expr() const { + /* user_assert(func.has_pure_definition() || func.has_extern_definition()) << "Can't call Func \"" << func.name() << "\" because it has not yet been defined.\n"; + */ + + if (!(func.has_pure_definition() || func.has_extern_definition())) { + return Call::make(Type{}, func.name(), args, Call::Halide, + FunctionPtr(), 0, Buffer<>(), Parameter()); + } user_assert(func.outputs() == 1) << "Can't convert a reference Func \"" << func.name() diff --git a/src/Function.cpp b/src/Function.cpp index f66b886cdafd..35f2313a1d2f 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -222,6 +222,7 @@ struct CheckVars : public IRGraphVisitor { void visit(const Call *op) override { IRGraphVisitor::visit(op); + /* if (op->name == name && op->call_type == Call::Halide) { for (size_t i = 0; i < op->args.size(); i++) { const Variable *var = op->args[i].as(); @@ -234,6 +235,7 @@ struct CheckVars : public IRGraphVisitor { } } } + */ } void visit(const Variable *var) override { @@ -570,6 +572,7 @@ void Function::define(const vector &args, vector values) { // Freeze all called functions FreezeFunctions freezer(name()); + // TODO: Check for calls to undefined Funcs for (const auto &value : values) { value.accept(&freezer); } diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 3eae3ccbc788..e33495cd8f2b 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -670,6 +670,16 @@ void match_types(Expr &a, Expr &b) { return; } + if (a.type().is_unknown() && !b.type().is_unknown()) { + b = cast(Type{}, b); + return; + } + + if (b.type().is_unknown() && !a.type().is_unknown()) { + a = cast(Type{}, a); + return; + } + user_assert(!a.type().is_handle() && !b.type().is_handle()) << "Can't do arithmetic on opaque pointer types: " << a << ", " << b << "\n"; @@ -1480,12 +1490,47 @@ Expr saturating_cast(Type t, Expr e) { return Call::make(t, Call::saturating_cast, {std::move(e)}, Call::PureIntrinsic); } +Expr declare_type(Type t, const Expr &e) { + // TODO: This may be called on unsanitized exprs. May need CSE. + internal_assert(e.type().is_unknown()); + if (const Call *op = e.as()) { + internal_assert(op->call_type == Call::Halide); + return Call::make(t, op->name, op->args, op->call_type, + op->func, op->value_index, op->image, op->param); + } else if (const Add *op = e.as()) { + return Add::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Sub *op = e.as()) { + return Sub::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Mul *op = e.as()) { + return Mul::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Div *op = e.as
()) { + return Div::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Min *op = e.as()) { + return Min::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Max *op = e.as()) { + return Max::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Cast *op = e.as()) { + // Must be a cast to unknown + return Cast::make(t, op->value); + } else { + user_error << "Can't do top-down type inference on " << e; + } +} + Expr select(Expr condition, Expr true_value, Expr false_value) { if (as_const_int(condition)) { // Why are you doing this? We'll preserve the select node until constant folding for you. condition = cast(Bool(true_value.type().lanes()), std::move(condition)); } + if (true_value.type().is_unknown() && !false_value.type().is_unknown()) { + true_value = declare_type(false_value.type(), true_value); + } + + if (false_value.type().is_unknown() && !true_value.type().is_unknown()) { + false_value = declare_type(true_value.type(), false_value); + } + // Coerce int literals to the type of the other argument if (as_const_int(true_value)) { true_value = cast(false_value.type(), std::move(true_value)); diff --git a/src/IROperator.h b/src/IROperator.h index d6d33a1cf82e..0c94747fa040 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -389,6 +389,10 @@ inline Expr cast(Expr a) { /** Cast an expression to a new type. */ Expr cast(Type t, Expr a); +/** Declare an Expr with unknown type has a given type. Useful when defining + * Funcs inductively. */ +Expr declare_type(Type t, const Expr &e); + /** Return the sum of two expressions, doing any necessary type * coercion using \ref Internal::match_types */ Expr operator+(Expr a, Expr b); diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 68c61acaab2b..d59f2319efea 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -46,6 +46,8 @@ ostream &operator<<(ostream &out, const Type &type) { case Type::BFloat: out << "bfloat"; break; + case Type::Unknown: + out << "unknown"; } if (!type.is_handle()) { out << type.bits(); diff --git a/src/Type.h b/src/Type.h index d6143f38b6de..bda2ba4c7e21 100644 --- a/src/Type.h +++ b/src/Type.h @@ -293,6 +293,7 @@ struct Type { static constexpr halide_type_code_t Float = halide_type_float; static constexpr halide_type_code_t BFloat = halide_type_bfloat; static constexpr halide_type_code_t Handle = halide_type_handle; + static constexpr halide_type_code_t Unknown = halide_type_unknown; // @} /** The number of bytes required to store a single scalar value of this type. Ignores vector lanes. */ @@ -302,7 +303,7 @@ struct Type { // Default ctor initializes everything to predictable-but-unlikely values Type() - : type(Handle, 0, 0) { + : type(Unknown, 0, 1) { } /** Construct a runtime representation of a Halide type from: @@ -454,6 +455,12 @@ struct Type { return code() == Handle; } + /** Is this type a floating point type (float or double). */ + HALIDE_ALWAYS_INLINE + bool is_unknown() const { + return code() == Unknown; + } + // Returns true iff type is a signed integral type where overflow is defined. HALIDE_ALWAYS_INLINE bool can_overflow_int() const { @@ -567,6 +574,11 @@ inline Type Handle(int lanes = 1, const halide_handle_cplusplus_type *handle_typ return Type(Type::Handle, 64, lanes, handle_type); } +/** Construct an unknown type */ +inline Type Unknown() { + return Type{}; +} + /** Construct the halide equivalent of a C type */ template inline Type type_of() { diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 1558b9f3397f..508ae7582b8b 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -478,11 +478,12 @@ typedef enum halide_type_code_t : uint8_t #endif { - halide_type_int = 0, ///< signed integers - halide_type_uint = 1, ///< unsigned integers - halide_type_float = 2, ///< IEEE floating point numbers - halide_type_handle = 3, ///< opaque pointer type (void *) - halide_type_bfloat = 4, ///< floating point numbers in the bfloat format + halide_type_int = 0, ///< signed integers + halide_type_uint = 1, ///< unsigned integers + halide_type_float = 2, ///< IEEE floating point numbers + halide_type_handle = 3, ///< opaque pointer type (void *) + halide_type_bfloat = 4, ///< floating point numbers in the bfloat format + halide_type_unknown = 5, ///< an expression of unknown type, to be determined later } halide_type_code_t; // Note that while __attribute__ can go before or after the declaration, From 84f0dff135e48f11c2bbdda9511c1490640c8d5c Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Thu, 20 Nov 2025 19:57:17 -0500 Subject: [PATCH 02/17] Inductive functions --- src/Bounds.cpp | 2 +- src/BoundsInference.cpp | 16 ++ src/CMakeLists.txt | 2 + src/Func.cpp | 3 + src/Function.cpp | 89 ++++++++- src/Function.h | 6 + src/IRPrinter.cpp | 3 + src/Inductive.cpp | 126 +++++++++++++ src/Inductive.h | 24 +++ src/Schedule.h | 9 + src/ScheduleFunctions.cpp | 7 +- test/correctness/CMakeLists.txt | 1 + test/correctness/inductive.cpp | 243 +++++++++++++++++++++++++ test/error/CMakeLists.txt | 8 + test/error/inductive_loop.cpp | 18 ++ test/error/inductive_loop_2.cpp | 18 ++ test/error/inductive_loop_3.cpp | 18 ++ test/error/inductive_nested_select.cpp | 19 ++ test/error/inductive_no_select.cpp | 19 ++ test/error/inductive_reorder.cpp | 21 +++ test/error/inductive_var_swap.cpp | 19 ++ test/error/inductive_vectorize.cpp | 20 ++ 22 files changed, 686 insertions(+), 5 deletions(-) create mode 100644 src/Inductive.cpp create mode 100644 src/Inductive.h create mode 100644 test/correctness/inductive.cpp create mode 100644 test/error/inductive_loop.cpp create mode 100644 test/error/inductive_loop_2.cpp create mode 100644 test/error/inductive_loop_3.cpp create mode 100644 test/error/inductive_nested_select.cpp create mode 100644 test/error/inductive_no_select.cpp create mode 100644 test/error/inductive_reorder.cpp create mode 100644 test/error/inductive_var_swap.cpp create mode 100644 test/error/inductive_vectorize.cpp diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 670ccf11d177..59de794aa0f0 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -3299,7 +3299,7 @@ FuncValueBounds compute_function_value_bounds(const vector &order, Interval result; - if (f.is_pure()) { + if (f.is_pure() && !f.is_inductive()) { // Make a scope that says the args could be anything. Scope arg_scope; diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 9ba1f1af2019..0047f54fb8fe 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -7,6 +7,7 @@ #include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" +#include "Inductive.h" #include "Inline.h" #include "Qualify.h" #include "Scope.h" @@ -1021,6 +1022,21 @@ class BoundsInference : public IRMutator { } } + // For any inductively defined functions, make sure their + // bounds include the base case. + for (Stage &s : stages) { + if (!s.func.is_pure() || !s.func.is_inductive()) { + continue; + } + debug(4) << "Expanding bounds for inductively defined function " << s.func.name() << "\n"; + for (const auto &b1 : s.bounds) { + const Box &b = b1.second; + for (const auto &cval : s.exprs) { + s.bounds[b1.first] = expand_to_include_base_case(s.func.args(), cval.value, s.func.name(), b); + } + } + } + // The region required of the each output is expanded to include the size of the output buffer. for (const Function &output : outputs) { Box output_box; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index af419323b24e..376fceff3960 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -128,6 +128,7 @@ target_sources( HexagonOffload.h HexagonOptimize.h ImageParam.h + Inductive.h InferArguments.h InjectHostDevBufferCopies.h Inline.h @@ -305,6 +306,7 @@ target_sources( HexagonOffload.cpp HexagonOptimize.cpp ImageParam.cpp + Inductive.cpp InferArguments.cpp InjectHostDevBufferCopies.cpp Inline.cpp diff --git a/src/Func.cpp b/src/Func.cpp index 3380d4d34c1c..a67f6993c485 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -1342,6 +1342,9 @@ Stage &Stage::fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRV } else if (dims[i].dim_type == DimType::PureRVar || outer_type == DimType::PureRVar) { dims[i].dim_type = DimType::PureRVar; + } else if (dims[i].dim_type == DimType::InductiveVar || + outer_type == DimType::InductiveVar) { + dims[i].dim_type = DimType::InductiveVar; } else { dims[i].dim_type = DimType::PureVar; } diff --git a/src/Function.cpp b/src/Function.cpp index 35f2313a1d2f..8ef9f6279638 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -636,7 +636,11 @@ void Function::define(const vector &args, vector values) { contents->init_def = Definition(init_def_args, values, rdom, true); for (const auto &arg : args) { - Dim d = {arg, ForType::Serial, DeviceAPI::None, DimType::PureVar}; + DimType dtype = DimType::PureVar; + if (is_inductive(arg)) { + dtype = DimType::InductiveVar; + } + Dim d = {arg, ForType::Serial, DeviceAPI::None, dtype}; contents->init_def.schedule().dims().push_back(d); StorageDim sd = {arg}; contents->func_schedule.storage_dims().push_back(sd); @@ -1069,8 +1073,89 @@ bool Function::has_pure_definition() const { return contents->init_def.defined(); } +bool Function::is_inductive() const { + class RecursiveHelper : public IRVisitor { + using IRVisitor::visit; + const string &func; + void visit(const Call *op) override { + if (op->name == func) { + recursive = true; + } + IRVisitor::visit(op); + } + + public: + bool recursive = false; + RecursiveHelper(const string &func) + : func(func) { + } + }; + + if (!has_pure_definition()) { + return false; + } + + RecursiveHelper r(name()); + for (const Expr &e : definition().values()) { + e.accept(&r); + } + + return r.recursive; +} + +bool Function::is_inductive(const string &var) const { + class RecursiveHelper : public IRVisitor { + using IRVisitor::visit; + const string &func; + const string &var; + const int &pos; + void visit(const Call *op) override { + if (op->name == func) { + recursive = true; + if (const auto &v = op->args[pos].as()) { + if (v->name != var) { + inductive_in_var = true; + } + } else { + inductive_in_var = true; + } + } + IRVisitor::visit(op); + } + + public: + bool recursive = false; + bool inductive_in_var = false; + RecursiveHelper(const string &func, const string &var, const int &pos) + : func(func), var(var), pos(pos) { + } + }; + + if (!has_pure_definition()) { + return false; + } + + int pos = -1; + for (int i = 0; i < definition().args().size(); i++) { + if (const auto &v = definition().args()[i].as()) { + if (v->name == var) { + pos = i; + } + } + } + if (pos == -1) { + return false; + } + RecursiveHelper r(name(), var, pos); + for (const Expr &e : definition().values()) { + e.accept(&r); + } + + return r.inductive_in_var; +} + bool Function::can_be_inlined() const { - return is_pure() && definition().specializations().empty(); + return is_pure() && definition().specializations().empty() && !is_inductive(); } bool Function::has_update_definition() const { diff --git a/src/Function.h b/src/Function.h index 5305f4f058be..b413f7e4145d 100644 --- a/src/Function.h +++ b/src/Function.h @@ -187,6 +187,12 @@ class Function { !has_extern_definition()); } + /** Does this function have an inductive pure definition? */ + bool is_inductive() const; + + /** Is this function inductive in the given variable? */ + bool is_inductive(const std::string &var) const; + /** Is it legal to inline this function? */ bool can_be_inlined() const; diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 92d14dc24399..76d3b59881f0 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -498,6 +498,9 @@ std::ostream &operator<<(std::ostream &out, const DimType &t) { case DimType::PureRVar: out << "PureRVar"; break; + case DimType::InductiveVar: + out << "InductiveVar"; + break; case DimType::ImpureRVar: out << "ImpureRVar"; break; diff --git a/src/Inductive.cpp b/src/Inductive.cpp new file mode 100644 index 000000000000..15e413fd0e48 --- /dev/null +++ b/src/Inductive.cpp @@ -0,0 +1,126 @@ +#include "Inductive.h" + +#include "Bounds.h" +#include "ConciseCasts.h" +#include "Error.h" +#include "Function.h" +#include "IR.h" +#include "IREquality.h" +#include "IRVisitor.h" +#include "Simplify.h" +#include "Substitute.h" + +namespace Halide { +namespace Internal { + +using std::string; +using std::vector; + +class BaseCaseSolver : public IRVisitor { + using IRVisitor::visit; + const vector &vars; + const string &func; + + const vector &start_box; + + vector condition_intervals; + + Scope bounds; + + int nested_select = 0; + + void visit(const Call *op) override { + if (op->is_intrinsic(Call::if_then_else)) { + nested_select += 1; + vector old_intervals = condition_intervals; + for (int i = 0; i < vars.size(); i++) { + condition_intervals[i] = Interval::make_intersection(old_intervals[i], solve_for_outer_interval(simplify(op->args[0]), vars[i])); + bounds.push(vars[i], condition_intervals[i]); + } + + op->args[1].accept(this); + for (int i = 0; i < vars.size(); i++) { + condition_intervals[i] = Interval::make_intersection(old_intervals[i], solve_for_outer_interval(simplify(!op->args[0]), vars[i])); + bounds.pop(vars[i]); + bounds.push(vars[i], condition_intervals[i]); + } + op->args[2].accept(this); + condition_intervals = old_intervals; + for (int i = 0; i < vars.size(); i++) { + bounds.pop(vars[i]); + } + nested_select -= 1; + } else if (op->name == func) { + user_assert(nested_select > 0) << "Function " << func << " contains an inductive function reference outside of a select operation.\n"; + user_assert(nested_select == 1) << "Function " << func << " contains an inductive function reference inside a nested select operation.\n"; + bool found_inductive = false; + for (int position = 0; position < vars.size(); position++) { + const Expr inductive_expr = op->args[position]; + const Expr new_v = Variable::make(inductive_expr.type(), vars[position]); + const Expr gets_lower = simplify(new_v - inductive_expr > 0, true, bounds); + const Interval i_lower = solve_for_inner_interval(gets_lower, vars[position]); + + Interval new_interval; + if (equal(new_v, inductive_expr)) { + new_interval = start_box[position]; + } else if (i_lower.is_everything()) { + found_inductive = true; + new_interval = Interval(Interval::neg_inf(), start_box[position].max); + } else { + new_interval = Interval::everything(); + } + new_interval = Interval::make_intersection(new_interval, condition_intervals[position]); + Scope i_scope; + i_scope.push(vars[position], new_interval); + result_intervals[position] = Interval::make_union(result_intervals[position], Interval::make_union(new_interval, bounds_of_expr_in_scope(inductive_expr, i_scope))); + } + user_assert(found_inductive) << "Unable to prove in inductive function " << func << " that the inductive step is monotonically decreasing.\n"; + + IRVisitor::visit(op); + + } else { + IRVisitor::visit(op); + } + } + +public: + vector result_intervals; + + BaseCaseSolver(const vector &v, const string &func, const vector &con) + : vars(v), func(func), start_box(con) { + condition_intervals = vector(start_box.size()); + result_intervals = vector(start_box.size(), Interval::nothing()); + } +}; + +// anonymous namespace + +const Box expand_to_include_base_case(const vector &vars, const Expr &RHS, const string &func, const Box &box_required) { + Expr substed = substitute_in_all_lets(RHS); + Box box2 = box_required; + BaseCaseSolver b(vars, func, box_required.bounds); + substed.accept(&b); + for (uint i = 0; i < vars.size(); i++) { + user_assert(b.result_intervals[i].is_bounded()) << "Unable to prove that the inductive function " << func << " uses a bounded interval"; + Interval new_interval(min(b.result_intervals[i].min, box_required[i].min), box_required[i].max); + box2[i] = new_interval; + } + + return box2; +} + +const Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos) { + return expand_to_include_base_case(fn.args(), fn.values()[pos], fn.name(), box_required); +} + +const Box expand_to_include_base_case(const Function &fn, const Box &box_required) { + Box b = expand_to_include_base_case(fn.args(), fn.values()[0], fn.name(), box_required); + for (int pos = 1; pos < fn.values().size(); pos++) { + Box b2 = expand_to_include_base_case(fn.args(), fn.values()[pos], fn.name(), box_required); + merge_boxes(b, b2); + } + return b; +} + +} // namespace Internal +} // namespace Halide diff --git a/src/Inductive.h b/src/Inductive.h new file mode 100644 index 000000000000..12ee0143abfc --- /dev/null +++ b/src/Inductive.h @@ -0,0 +1,24 @@ +#ifndef INDUCTIVE_H +#define INDUCTIVE_H + +#include "Bounds.h" +#include "Expr.h" +#include "Interval.h" +#include "Scope.h" +#include "Solve.h" + +namespace Halide { +namespace Internal { + +/** Given an initial box for an inductively defined function, + returns an expanded box that includes the function's non-inductive base case. */ +const Box expand_to_include_base_case(const std::vector &vars, const Expr &RHS, const std::string &func, const Box &box_required); + +const Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos=0); + +const Box expand_to_include_base_case(const Function &fn, const Box &box_required); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/Schedule.h b/src/Schedule.h index 906dbe6c7b5e..0017b10c80e2 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -356,6 +356,11 @@ enum class DimType { * definitions you can even redundantly re-evaluate points. */ PureVar = 0, + /** The dim originated from a Var in an inductively defined pure + * definition. InductiveVars cannot be reordered, parallelized, + * or vectorized. */ + InductiveVar, + /** The dim originated from an RVar. You can evaluate a Func at * distinct values of this RVar in any order (including in * parallel) over exactly the interval specified in the @@ -466,6 +471,10 @@ struct Dim { return (dim_type == DimType::PureRVar) || (dim_type == DimType::ImpureRVar); } + bool is_inductive() const { + return dim_type == DimType::InductiveVar; + } + /** Could multiple iterations of this loop happen at the same * time, with reads and writes interleaved in arbitrary ways * according to the memory model of the underlying compiler and diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 19c3de055001..6c1bf9812ddb 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -1361,7 +1361,7 @@ class InjectFunctionRealization : public IRMutator { // none of the functions in a fused group can be inlined, so this will only // happen when we're lowering a single func. if (provide_name != funcs[0].name() && - !funcs[0].is_pure() && + (!funcs[0].is_pure() || funcs[0].is_inductive()) && funcs[0].schedule().compute_level().is_inlined() && function_is_used_in_stmt(funcs[0], provide_op)) { @@ -2348,7 +2348,7 @@ bool validate_schedule(Function f, const Stmt &s, const Target &target, bool is_ // will get lowered into compute_at innermost and thus can be treated // similarly as a non-inlined Func. if (store_at.is_inlined() && compute_at.is_inlined() && hoist_storage_at.is_inlined()) { - if (f.is_pure()) { + if (f.is_pure() && !f.is_inductive()) { validate_schedule_inlined_function(f); } return true; @@ -2589,6 +2589,9 @@ Stmt schedule_functions(const vector &outputs, const map &env, const Target &target, bool &any_memoized) { + for (const Function &o : outputs) { + user_assert(!o.is_inductive()) << "Function" << o.name() << " is an inductively defined output buffer, which is unsupported.\n"; + } string root_var = LoopLevel::root().lock().to_string(); Stmt s = For::make(root_var, 0, 1, ForType::Serial, Partition::Never, DeviceAPI::Host, Evaluate::make(0)); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 586db8d1db8e..3c3d57b8345c 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -182,6 +182,7 @@ tests(GROUPS correctness implicit_args_tests.cpp in_place.cpp indexing_access_undef.cpp + inductive.cpp infer_arguments.cpp inline_reduction.cpp inlined_generator.cpp diff --git a/test/correctness/inductive.cpp b/test/correctness/inductive.cpp new file mode 100644 index 000000000000..f95761e812ae --- /dev/null +++ b/test/correctness/inductive.cpp @@ -0,0 +1,243 @@ +#include "Halide.h" +#include "check_call_graphs.h" +#include "test_sharding.h" + +#include +#include + +namespace { + +using std::map; +using std::string; + +using namespace Halide; +using namespace Halide::Internal; + +int simple_inductive_test() { + Func g("g"), h("h"); + Var x("x"), y("y"); + + // g(x, y) = x + y; + // g(r.x, r.y) = g(r.x, r.y); + g(x, y) = select(x <= 0, 0, g(max(0, x - 1), y) + x + y); + + h(x, y) = g(x + 5, y) / 4; + + g.compute_at(h, x).store_at(h, y); + + Buffer im = h.realize({600, 5}); + auto func = [](int x, int y) { + return (y * (x + 5) + (x + 5) * (x + 6) / 2) / 4; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int reorder_test() { + Func g("g"), h("h"); + Var x("x"), y("y"); + + Var xi("xi"), xii("xii"), xo("xo"); + + // g(x, y) = x + y; + // g(r.x, r.y) = g(r.x, r.y); + g(x, y) = select(x <= 0, 0, g(max(0, x - 1), y) + x + y); + + h(x, y) = g(x + 5, y) / 4; + h.split(x, xo, xi, 24).reorder(xi, y, xo); + + g.compute_at(h, xo).store_root(); + + g.split(x, xi, xii, 5).reorder(xii, y, xi).vectorize(y, 8); + + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + return (y * (x + 5) + (x + 5) * (x + 6) / 2) / 4; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int summed_area_table() { + Func f("f"), g("g"), h("h"); + Var x("x"), y("y"); + f(x, y) = x + y; + g(x, y) = f(x, y) + select(x <= 0, 0, g(x - 1, y)) + select(y <= 0, 0, g(x, y - 1)) - select(x <= 0 || y <= 0, 0, g(x - 1, y - 1)); + h(x, y) = g(x, y) / 8; + g.compute_at(h, x).store_root(); + + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + return (x * (x + 1) / 2 * (y + 1) + y * (y + 1) / 2 * (x + 1)) / 8; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int large_baseline() { + Func g("g"), h("h"); + Var x("x"), y("y"); + + // g(x, y) = x + y; + // g(r.x, r.y) = g(r.x, r.y); + g(x, y) = select(x <= 8, (y * x + x * (x + 1) / 2) - 1, g(x - 1, y) + x + y); + + h(x, y) = g(x + 5, y) / 4; + + g.compute_at(h, x).store_at(h, y); + + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + return (y * (x + 5) + (x + 5) * (x + 6) / 2 - 1) / 4; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int fibonacci() { + Func g("g"), h("h"); + Var x("x"), y("y"); + + // g(x, y) = x + y; + // g(r.x, r.y) = g(r.x, r.y); + g(x, y) = select(x <= 1, 1, g(x - 1, y) + g(x - 2, y)); + h(x, y) = g(x, y); + + h.bound(x, 0, 80); + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + int a = 1; + int b = 1; + for (int i = 2; i <= x; i++) { + int c = a + b; + b = a; + a = c; + } + return a; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int sum_2d_test() { + Func f("f"), g("g"), h("h"); + Var x("x"), y("y"); + f(x, y) = select(x <= 0, 0, x + f(x - 1, y)); + g(x, y) = select(y <= 0, f(x, 0), f(x, y) + g(x, y - 1)); + h(x, y) = g(x, y); + h.bound(x, 0, 80).bound(y, 0, 80).vectorize(x, 8); + g.compute_at(h, x).store_root().vectorize(x, 8); + f.compute_at(h, x).store_root(); + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + int ans = 0; + for (int a = 0; a <= x; a++) { + for (int b = 0; b <= y; b++) { + ans += a; + } + } + return ans; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int sum_1d_test() { + Func f("f"), g("g"), h("h"); + Var x("x"), y("y"); + f(x, y) = x + y; + f(x, y) += x; // select(x<=0, 0, x+f(x-1,y)); + g(x, y) = select(y <= 0, f(x, 0), f(x, y) + g(x, y - 1)); + h(x, y) = g(x, y); + h.bound(x, 0, 80).bound(y, 0, 80); + // stress-testing bounds inference for dependent non-inlined funcs + f.compute_at(h, x); + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + int ans = 0; + for (int a = 0; a <= y; a++) { + ans += 2 * x + a; + } + return ans; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int multi_baseline_test() { + Func f("f"), g("g"), h("h"); + Var x("x"), y("y"); + f(x, y) = x + y; + f(x, y) += x; // select(x<=0, 0, x+f(x-1,y)); + g(x, y) = select(y <= 0, f(x, 0), f(x, y) + g(x, y - 1)) + select(y <= 3, f(x, 0), f(x, y) + g(x, y - 1)); + h(x, y) = g(x, y); + h.bound(x, 0, 80).bound(y, 0, 20); + f.compute_at(h, x); + Buffer im = h.realize({80, 20}); + auto func = [](int x, int y) { + std::vector ans; + + for (int a = 0; a <= y; a++) { + int b = 2 * x + a; + if (a <= 0) { + ans.emplace_back(4 * x); + } else if (a <= 3) { + ans.emplace_back(2 * x + (2 * x + a) + ans[a - 1]); + } else { + ans.emplace_back(2 * x + a + ans[a - 1] + (2 * x + a) + ans[a - 1]); + } + } + return ans[y]; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +} // namespace + +int main(int argc, char **argv) { + struct Task { + std::string desc; + std::function fn; + }; + + std::vector tasks = { + {"simple inductive test", simple_inductive_test}, + {"reordering test", reorder_test}, + {"summed area table test", summed_area_table}, + {"large baseline test", large_baseline}, + {"fibonacci test", fibonacci}, + {"2d sum test", sum_2d_test}, + {"1d sum test", sum_1d_test}, + {"multi-baseline tset", multi_baseline_test}}; + + using Sharder = Halide::Internal::Test::Sharder; + Sharder sharder; + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + std::cout << task.desc << "\n"; + if (task.fn() != 0) { + return 1; + } + } + + printf("Success!\n"); + return 0; +} diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 41816d5ba36b..a21c3f907af8 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -70,6 +70,14 @@ tests(GROUPS error implicit_args.cpp impossible_constraints.cpp incomplete_target.cpp + inductive_loop.cpp + inductive_loop_2.cpp + inductive_loop_3.cpp + inductive_nested_select.cpp + inductive_no_select.cpp + inductive_reorder.cpp + inductive_var_swap.cpp + inductive_vectorize.cpp init_def_should_be_all_vars.cpp inspect_loop_level.cpp lerp_float_weight_out_of_range.cpp diff --git a/test/error/inductive_loop.cpp b/test/error/inductive_loop.cpp new file mode 100644 index 000000000000..cb894256a355 --- /dev/null +++ b/test/error/inductive_loop.cpp @@ -0,0 +1,18 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = select(x < 1, 0, f(x)); + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_loop_2.cpp b/test/error/inductive_loop_2.cpp new file mode 100644 index 000000000000..f6086e239eef --- /dev/null +++ b/test/error/inductive_loop_2.cpp @@ -0,0 +1,18 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = select(x < 2, 0, f(x - 1) + f(x) + f(x - 2)); + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_loop_3.cpp b/test/error/inductive_loop_3.cpp new file mode 100644 index 000000000000..c281480d5fa7 --- /dev/null +++ b/test/error/inductive_loop_3.cpp @@ -0,0 +1,18 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = select(x < 2, 0, f(x - 1) + f(min(2, x)) + f(x - 2)); + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_nested_select.cpp b/test/error/inductive_nested_select.cpp new file mode 100644 index 000000000000..c9abf663008f --- /dev/null +++ b/test/error/inductive_nested_select.cpp @@ -0,0 +1,19 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + // Nested select operations are currently unsupported. + f(x) = select(x < 1, 0, select(x < 3, 1, f(x - 1))); + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_no_select.cpp b/test/error/inductive_no_select.cpp new file mode 100644 index 000000000000..d8d395be9e03 --- /dev/null +++ b/test/error/inductive_no_select.cpp @@ -0,0 +1,19 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = x + f(x - 1); + + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_reorder.cpp b/test/error/inductive_reorder.cpp new file mode 100644 index 000000000000..01dbe642645b --- /dev/null +++ b/test/error/inductive_reorder.cpp @@ -0,0 +1,21 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"), xi("xi"), xo("xo"); + + f(x) = select(x < 1, 0, x + f(x - 1)); + f.split(x, xo, xi, 8); + f.reorder(xo, xi); + + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_var_swap.cpp b/test/error/inductive_var_swap.cpp new file mode 100644 index 000000000000..4a8453eb0901 --- /dev/null +++ b/test/error/inductive_var_swap.cpp @@ -0,0 +1,19 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"), y("y"); + + f(x, y) = select(x < 1, 0, x + f(y - 1, x - 1)); + + g(x, y) = f(x, y) * 2; + + g.realize({10, 10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_vectorize.cpp b/test/error/inductive_vectorize.cpp new file mode 100644 index 000000000000..bc6dd1e4bc95 --- /dev/null +++ b/test/error/inductive_vectorize.cpp @@ -0,0 +1,20 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = select(x < 1, 0, x + f(x - 1)); + f.vectorize(x, 8); + + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} From 4cab180f553bb0a90825a5faab63e99a9a18c895 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Fri, 21 Nov 2025 00:02:44 -0500 Subject: [PATCH 03/17] fixes --- src/Func.cpp | 4 ++-- test/error/inductive_no_select.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index a67f6993c485..8b5dfaf8006d 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -491,7 +491,7 @@ void Stage::set_dim_type(const VarOrRVar &var, ForType t) { // If it's an rvar and the for type is parallel, we need to // validate that this doesn't introduce a race condition, // unless it is flagged explicitly or is a associative atomic operation. - if (!dim.is_pure() && var.is_rvar && is_parallel(t)) { + if (!dim.is_pure() && (var.is_rvar || dim.is_inductive()) && is_parallel(t)) { if (!definition.schedule().allow_race_conditions() && definition.schedule().atomic()) { if (!definition.schedule().override_atomic_associativity_test()) { @@ -3314,7 +3314,7 @@ FuncRef::operator Expr() const { if (!(func.has_pure_definition() || func.has_extern_definition())) { return Call::make(Type{}, func.name(), args, Call::Halide, - FunctionPtr(), 0, Buffer<>(), Parameter()); + func.get_contents(), 0, Buffer<>(), Parameter()); } user_assert(func.outputs() == 1) diff --git a/test/error/inductive_no_select.cpp b/test/error/inductive_no_select.cpp index d8d395be9e03..5fdc3f430232 100644 --- a/test/error/inductive_no_select.cpp +++ b/test/error/inductive_no_select.cpp @@ -8,7 +8,7 @@ int main(int argc, char **argv) { Var x("x"); - f(x) = x + f(x - 1); + f(x) = cast(x + f(x - 1)); g(x) = f(x) * 2; From cbda584e86315ac283f87b9bca6fc93b5c1e4c54 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Sat, 22 Nov 2025 01:44:51 -0500 Subject: [PATCH 04/17] additional fixes --- src/Function.cpp | 2 +- src/Inductive.cpp | 10 +++--- src/Type.h | 2 +- test/correctness/inductive.cpp | 1 - test/correctness/tracing.cpp | 56 +++++++++++++++++----------------- 5 files changed, 35 insertions(+), 36 deletions(-) diff --git a/src/Function.cpp b/src/Function.cpp index 8ef9f6279638..1594689d6bdf 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -1136,7 +1136,7 @@ bool Function::is_inductive(const string &var) const { } int pos = -1; - for (int i = 0; i < definition().args().size(); i++) { + for (size_t i = 0; i < definition().args().size(); i++) { if (const auto &v = definition().args()[i].as()) { if (v->name == var) { pos = i; diff --git a/src/Inductive.cpp b/src/Inductive.cpp index 15e413fd0e48..56ac4f9aa1ae 100644 --- a/src/Inductive.cpp +++ b/src/Inductive.cpp @@ -33,20 +33,20 @@ class BaseCaseSolver : public IRVisitor { if (op->is_intrinsic(Call::if_then_else)) { nested_select += 1; vector old_intervals = condition_intervals; - for (int i = 0; i < vars.size(); i++) { + for (size_t i = 0; i < vars.size(); i++) { condition_intervals[i] = Interval::make_intersection(old_intervals[i], solve_for_outer_interval(simplify(op->args[0]), vars[i])); bounds.push(vars[i], condition_intervals[i]); } op->args[1].accept(this); - for (int i = 0; i < vars.size(); i++) { + for (size_t i = 0; i < vars.size(); i++) { condition_intervals[i] = Interval::make_intersection(old_intervals[i], solve_for_outer_interval(simplify(!op->args[0]), vars[i])); bounds.pop(vars[i]); bounds.push(vars[i], condition_intervals[i]); } op->args[2].accept(this); condition_intervals = old_intervals; - for (int i = 0; i < vars.size(); i++) { + for (size_t i = 0; i < vars.size(); i++) { bounds.pop(vars[i]); } nested_select -= 1; @@ -54,7 +54,7 @@ class BaseCaseSolver : public IRVisitor { user_assert(nested_select > 0) << "Function " << func << " contains an inductive function reference outside of a select operation.\n"; user_assert(nested_select == 1) << "Function " << func << " contains an inductive function reference inside a nested select operation.\n"; bool found_inductive = false; - for (int position = 0; position < vars.size(); position++) { + for (size_t position = 0; position < vars.size(); position++) { const Expr inductive_expr = op->args[position]; const Expr new_v = Variable::make(inductive_expr.type(), vars[position]); const Expr gets_lower = simplify(new_v - inductive_expr > 0, true, bounds); @@ -100,7 +100,7 @@ const Box expand_to_include_base_case(const vector &vars, const Expr &RH Box box2 = box_required; BaseCaseSolver b(vars, func, box_required.bounds); substed.accept(&b); - for (uint i = 0; i < vars.size(); i++) { + for (size_t i = 0; i < vars.size(); i++) { user_assert(b.result_intervals[i].is_bounded()) << "Unable to prove that the inductive function " << func << " uses a bounded interval"; Interval new_interval(min(b.result_intervals[i].min, box_required[i].min), box_required[i].max); box2[i] = new_interval; diff --git a/src/Type.h b/src/Type.h index bda2ba4c7e21..d6f8e330dcd0 100644 --- a/src/Type.h +++ b/src/Type.h @@ -303,7 +303,7 @@ struct Type { // Default ctor initializes everything to predictable-but-unlikely values Type() - : type(Unknown, 0, 1) { + : type(Unknown, 0, 0) { } /** Construct a runtime representation of a Halide type from: diff --git a/test/correctness/inductive.cpp b/test/correctness/inductive.cpp index f95761e812ae..29b4719dbe8b 100644 --- a/test/correctness/inductive.cpp +++ b/test/correctness/inductive.cpp @@ -192,7 +192,6 @@ int multi_baseline_test() { std::vector ans; for (int a = 0; a <= y; a++) { - int b = 2 * x + a; if (a <= 0) { ans.emplace_back(4 * x); } else if (a <= 3) { diff --git a/test/correctness/tracing.cpp b/test/correctness/tracing.cpp index 27fa4a3f6e4f..d703c2856644 100644 --- a/test/correctness/tracing.cpp +++ b/test/correctness/tracing.cpp @@ -201,8 +201,8 @@ int main(int argc, char **argv) { // The golden trace, recorded when this test was written event correct_pipeline_trace[] = { - {102, 0, 8, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {102, 1, 9, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 0, 8, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 1, 9, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, }; if (!check_trace_correct(correct_pipeline_trace, 2)) { return 1; @@ -226,52 +226,52 @@ int main(int argc, char **argv) { // The golden trace, recorded when this test was written event correct_trace[] = { - {102, 0, 8, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {105, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 1 2 32 1 1 0 10"}, - {103, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 2 2 32 1 2 32 1 1 0 11"}, - {102, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 1 2 32 1 1 0 10"}, - {102, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "arbitrary data on f"}, - {102, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "more:arbitrary \xff data on f?"}, - {103, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "g whiz"}, - {102, 1, 2, 3, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 1, 2, 3, 0, 0, 0, 2, {0, 11, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {102, 8, 4, 3, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 4, 3, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 0, 8, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {105, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 1 2 32 1 1 0 10"}, + {103, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 2 2 32 1 2 32 1 1 0 11"}, + {102, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 1 2 32 1 1 0 10"}, + {102, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "arbitrary data on f"}, + {102, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "more:arbitrary \xff data on f?"}, + {103, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "g whiz"}, + {102, 1, 2, 5, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 1, 2, 5, 0, 0, 0, 2, {0, 11, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 8, 4, 5, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 4, 5, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 11, 1, 2, 32, 4, 0, 4, {0, 1, 2, 3}, {0.000000f, 0.099833f, 0.198669f, 0.295520f}, ""}, {103, 11, 1, 2, 32, 4, 1, 4, {0, 1, 2, 3}, {1.000000f, 0.995004f, 0.980067f, 0.955337f}, ""}, {103, 11, 1, 2, 32, 4, 0, 4, {1, 2, 3, 4}, {0.099833f, 0.198669f, 0.295520f, 0.389418f}, ""}, {103, 11, 1, 2, 32, 4, 1, 4, {1, 2, 3, 4}, {0.995004f, 0.980067f, 0.955337f, 0.921061f}, ""}, - {103, 11, 5, 3, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 6, 3, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 11, 5, 5, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 6, 5, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {105, 1, 0, 2, 32, 4, 0, 4, {0, 1, 2, 3}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 17, 0, 2, 32, 4, 0, 4, {0, 1, 2, 3}, {0.000000f, 0.099833f, 0.198669f, 0.295520f}, ""}, {103, 17, 0, 2, 32, 4, 1, 4, {1, 2, 3, 4}, {0.995004f, 0.980067f, 0.955337f, 0.921061f}, ""}, {102, 10, 1, 2, 32, 4, 0, 4, {0, 1, 2, 3}, {0.995004f, 1.079900f, 1.154006f, 1.216581f}, ""}, - {103, 17, 7, 3, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 4, 3, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 17, 7, 5, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 4, 5, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 23, 1, 2, 32, 4, 0, 4, {5, 6, 7, 8}, {0.479426f, 0.564642f, 0.644218f, 0.717356f}, ""}, {103, 23, 1, 2, 32, 4, 1, 4, {5, 6, 7, 8}, {0.877583f, 0.825336f, 0.764842f, 0.696707f}, ""}, - {103, 23, 5, 3, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 6, 3, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 23, 5, 5, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 6, 5, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {105, 1, 0, 2, 32, 4, 0, 4, {4, 5, 6, 7}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 27, 0, 2, 32, 4, 0, 4, {4, 5, 6, 7}, {0.389418f, 0.479426f, 0.564642f, 0.644218f}, ""}, {103, 27, 0, 2, 32, 4, 1, 4, {5, 6, 7, 8}, {0.877583f, 0.825336f, 0.764842f, 0.696707f}, ""}, {102, 10, 1, 2, 32, 4, 0, 4, {4, 5, 6, 7}, {1.267001f, 1.304761f, 1.329485f, 1.340924f}, ""}, - {103, 27, 7, 3, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 4, 3, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 27, 7, 5, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 4, 5, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 33, 1, 2, 32, 4, 0, 4, {7, 8, 9, 10}, {0.644218f, 0.717356f, 0.783327f, 0.841471f}, ""}, {103, 33, 1, 2, 32, 4, 1, 4, {7, 8, 9, 10}, {0.764842f, 0.696707f, 0.621610f, 0.540302f}, ""}, - {103, 33, 5, 3, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 6, 3, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 33, 5, 5, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 6, 5, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {105, 1, 0, 2, 32, 4, 0, 4, {6, 7, 8, 9}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 37, 0, 2, 32, 4, 0, 4, {6, 7, 8, 9}, {0.564642f, 0.644218f, 0.717356f, 0.783327f}, ""}, {103, 37, 0, 2, 32, 4, 1, 4, {7, 8, 9, 10}, {0.764842f, 0.696707f, 0.621610f, 0.540302f}, ""}, {102, 10, 1, 2, 32, 4, 0, 4, {6, 7, 8, 9}, {1.329485f, 1.340924f, 1.338966f, 1.323629f}, ""}, - {103, 37, 7, 3, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {102, 10, 5, 3, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 3, 3, 0, 0, 0, 2, {0, 11, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {102, 8, 3, 3, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {102, 1, 9, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 37, 7, 5, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 10, 5, 5, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 3, 5, 0, 0, 0, 2, {0, 11, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 8, 3, 5, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 1, 9, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, }; int correct_trace_length = sizeof(correct_trace) / sizeof(correct_trace[0]); From 2b83090f821c20a00efe4ea6c62e834a73d39491 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Sat, 22 Nov 2025 01:49:58 -0500 Subject: [PATCH 05/17] size_t fix --- src/Inductive.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Inductive.cpp b/src/Inductive.cpp index 56ac4f9aa1ae..332d38f341e6 100644 --- a/src/Inductive.cpp +++ b/src/Inductive.cpp @@ -115,7 +115,7 @@ const Box expand_to_include_base_case(const Function &fn, const Box &box_require const Box expand_to_include_base_case(const Function &fn, const Box &box_required) { Box b = expand_to_include_base_case(fn.args(), fn.values()[0], fn.name(), box_required); - for (int pos = 1; pos < fn.values().size(); pos++) { + for (size_t pos = 1; pos < fn.values().size(); pos++) { Box b2 = expand_to_include_base_case(fn.args(), fn.values()[pos], fn.name(), box_required); merge_boxes(b, b2); } From 83c88bcc71a0447cfbbfef25435f380aee7e7b51 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Sat, 22 Nov 2025 20:01:07 -0500 Subject: [PATCH 06/17] get default types back to normal; add new test --- python_bindings/test/correctness/basics.py | 9 ---- src/Func.cpp | 6 ++- src/IROperator.cpp | 4 +- src/Type.h | 4 +- test/correctness/inductive.cpp | 25 +++++++++- test/correctness/tracing.cpp | 56 +++++++++++----------- 6 files changed, 61 insertions(+), 43 deletions(-) diff --git a/python_bindings/test/correctness/basics.py b/python_bindings/test/correctness/basics.py index 03d9d86220bb..e95f99d9418c 100644 --- a/python_bindings/test/correctness/basics.py +++ b/python_bindings/test/correctness/basics.py @@ -489,15 +489,6 @@ def test_unevaluated_funcref(): f[x] += 1 assert f.realize([1])[0] == 1 - with assert_throws( - hl.HalideError, - r"Error: Can't call Func \"f(\$\d+)?\" because it has not yet been defined\.", - ): - # This is invalid because we only allow unevaluated func refs on the LHS of a - # binary operator. - f = hl.Func("f") - f[x] = 1 + f[x] - with assert_throws( hl.HalideError, r"Cannot use an unevaluated reference to 'f(\$\d+)?' to define an update at a different location\.", diff --git a/src/Func.cpp b/src/Func.cpp index 8b5dfaf8006d..128ad912cfc0 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -3313,7 +3313,11 @@ FuncRef::operator Expr() const { */ if (!(func.has_pure_definition() || func.has_extern_definition())) { - return Call::make(Type{}, func.name(), args, Call::Halide, + Type t = Type(Type::Unknown, 0, 1); + if (!func.required_types().empty()) { + t = func.required_types()[0]; + } + return Call::make(t, func.name(), args, Call::Halide, func.get_contents(), 0, Buffer<>(), Parameter()); } diff --git a/src/IROperator.cpp b/src/IROperator.cpp index a5b1930b42cd..8c0a0970b426 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -671,12 +671,12 @@ void match_types(Expr &a, Expr &b) { } if (a.type().is_unknown() && !b.type().is_unknown()) { - b = cast(Type{}, b); + b = cast(a.type(), b); return; } if (b.type().is_unknown() && !a.type().is_unknown()) { - a = cast(Type{}, a); + a = cast(b.type(), a); return; } diff --git a/src/Type.h b/src/Type.h index d6f8e330dcd0..e857de2354e7 100644 --- a/src/Type.h +++ b/src/Type.h @@ -303,7 +303,7 @@ struct Type { // Default ctor initializes everything to predictable-but-unlikely values Type() - : type(Unknown, 0, 0) { + : type(Handle, 0, 0) { } /** Construct a runtime representation of a Halide type from: @@ -576,7 +576,7 @@ inline Type Handle(int lanes = 1, const halide_handle_cplusplus_type *handle_typ /** Construct an unknown type */ inline Type Unknown() { - return Type{}; + return Type(Type::Unknown, 0, 1); } /** Construct the halide equivalent of a C type */ diff --git a/test/correctness/inductive.cpp b/test/correctness/inductive.cpp index 29b4719dbe8b..e10646bc91cb 100644 --- a/test/correctness/inductive.cpp +++ b/test/correctness/inductive.cpp @@ -208,6 +208,27 @@ int multi_baseline_test() { return 0; } +int type_declare_test() { + Func g = Func(Int(32), AnyDims, "g"); + Func h("h"); + Var x("x"), y("y"); + + g(x, y) = select(x <= 0, 0, 1 + g(max(0, x - 1), y) + x + 2); + + h(x, y) = g(x + 5, y) / 4; + + g.compute_at(h, x).store_at(h, y); + + Buffer im = h.realize({600, 5}); + auto func = [](int x, int y) { + return (3 * (x + 5) + (x + 5) * (x + 6) / 2) / 4; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + } // namespace int main(int argc, char **argv) { @@ -224,7 +245,9 @@ int main(int argc, char **argv) { {"fibonacci test", fibonacci}, {"2d sum test", sum_2d_test}, {"1d sum test", sum_1d_test}, - {"multi-baseline tset", multi_baseline_test}}; + {"multi-baseline test", multi_baseline_test}, + {"type declaration test", type_declare_test}, + }; using Sharder = Halide::Internal::Test::Sharder; Sharder sharder; diff --git a/test/correctness/tracing.cpp b/test/correctness/tracing.cpp index d703c2856644..27fa4a3f6e4f 100644 --- a/test/correctness/tracing.cpp +++ b/test/correctness/tracing.cpp @@ -201,8 +201,8 @@ int main(int argc, char **argv) { // The golden trace, recorded when this test was written event correct_pipeline_trace[] = { - {102, 0, 8, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {102, 1, 9, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 0, 8, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 1, 9, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, }; if (!check_trace_correct(correct_pipeline_trace, 2)) { return 1; @@ -226,52 +226,52 @@ int main(int argc, char **argv) { // The golden trace, recorded when this test was written event correct_trace[] = { - {102, 0, 8, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {105, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 1 2 32 1 1 0 10"}, - {103, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 2 2 32 1 2 32 1 1 0 11"}, - {102, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 1 2 32 1 1 0 10"}, - {102, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "arbitrary data on f"}, - {102, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "more:arbitrary \xff data on f?"}, - {103, 1, 10, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "g whiz"}, - {102, 1, 2, 5, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 1, 2, 5, 0, 0, 0, 2, {0, 11, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {102, 8, 4, 5, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 4, 5, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 0, 8, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {105, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 1 2 32 1 1 0 10"}, + {103, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 2 2 32 1 2 32 1 1 0 11"}, + {102, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "func_type_and_dim: 1 2 32 1 1 0 10"}, + {102, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "arbitrary data on f"}, + {102, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "more:arbitrary \xff data on f?"}, + {103, 1, 10, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, "g whiz"}, + {102, 1, 2, 3, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 1, 2, 3, 0, 0, 0, 2, {0, 11, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 8, 4, 3, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 4, 3, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 11, 1, 2, 32, 4, 0, 4, {0, 1, 2, 3}, {0.000000f, 0.099833f, 0.198669f, 0.295520f}, ""}, {103, 11, 1, 2, 32, 4, 1, 4, {0, 1, 2, 3}, {1.000000f, 0.995004f, 0.980067f, 0.955337f}, ""}, {103, 11, 1, 2, 32, 4, 0, 4, {1, 2, 3, 4}, {0.099833f, 0.198669f, 0.295520f, 0.389418f}, ""}, {103, 11, 1, 2, 32, 4, 1, 4, {1, 2, 3, 4}, {0.995004f, 0.980067f, 0.955337f, 0.921061f}, ""}, - {103, 11, 5, 5, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 6, 5, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 11, 5, 3, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 6, 3, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {105, 1, 0, 2, 32, 4, 0, 4, {0, 1, 2, 3}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 17, 0, 2, 32, 4, 0, 4, {0, 1, 2, 3}, {0.000000f, 0.099833f, 0.198669f, 0.295520f}, ""}, {103, 17, 0, 2, 32, 4, 1, 4, {1, 2, 3, 4}, {0.995004f, 0.980067f, 0.955337f, 0.921061f}, ""}, {102, 10, 1, 2, 32, 4, 0, 4, {0, 1, 2, 3}, {0.995004f, 1.079900f, 1.154006f, 1.216581f}, ""}, - {103, 17, 7, 5, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 4, 5, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 17, 7, 3, 0, 0, 0, 2, {0, 5, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 4, 3, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 23, 1, 2, 32, 4, 0, 4, {5, 6, 7, 8}, {0.479426f, 0.564642f, 0.644218f, 0.717356f}, ""}, {103, 23, 1, 2, 32, 4, 1, 4, {5, 6, 7, 8}, {0.877583f, 0.825336f, 0.764842f, 0.696707f}, ""}, - {103, 23, 5, 5, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 6, 5, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 23, 5, 3, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 6, 3, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {105, 1, 0, 2, 32, 4, 0, 4, {4, 5, 6, 7}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 27, 0, 2, 32, 4, 0, 4, {4, 5, 6, 7}, {0.389418f, 0.479426f, 0.564642f, 0.644218f}, ""}, {103, 27, 0, 2, 32, 4, 1, 4, {5, 6, 7, 8}, {0.877583f, 0.825336f, 0.764842f, 0.696707f}, ""}, {102, 10, 1, 2, 32, 4, 0, 4, {4, 5, 6, 7}, {1.267001f, 1.304761f, 1.329485f, 1.340924f}, ""}, - {103, 27, 7, 5, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 4, 5, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 27, 7, 3, 0, 0, 0, 2, {5, 4, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 4, 3, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 33, 1, 2, 32, 4, 0, 4, {7, 8, 9, 10}, {0.644218f, 0.717356f, 0.783327f, 0.841471f}, ""}, {103, 33, 1, 2, 32, 4, 1, 4, {7, 8, 9, 10}, {0.764842f, 0.696707f, 0.621610f, 0.540302f}, ""}, - {103, 33, 5, 5, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 6, 5, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 33, 5, 3, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 6, 3, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {105, 1, 0, 2, 32, 4, 0, 4, {6, 7, 8, 9}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, {103, 37, 0, 2, 32, 4, 0, 4, {6, 7, 8, 9}, {0.564642f, 0.644218f, 0.717356f, 0.783327f}, ""}, {103, 37, 0, 2, 32, 4, 1, 4, {7, 8, 9, 10}, {0.764842f, 0.696707f, 0.621610f, 0.540302f}, ""}, {102, 10, 1, 2, 32, 4, 0, 4, {6, 7, 8, 9}, {1.329485f, 1.340924f, 1.338966f, 1.323629f}, ""}, - {103, 37, 7, 5, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {102, 10, 5, 5, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {103, 9, 3, 5, 0, 0, 0, 2, {0, 11, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {102, 8, 3, 5, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, - {102, 1, 9, 5, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 37, 7, 3, 0, 0, 0, 2, {9, 2, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 10, 5, 3, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {103, 9, 3, 3, 0, 0, 0, 2, {0, 11, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 8, 3, 3, 0, 0, 0, 2, {0, 10, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, + {102, 1, 9, 3, 0, 0, 0, 0, {0, 0, 0, 0}, {0.000000f, 0.000000f, 0.000000f, 0.000000f}, ""}, }; int correct_trace_length = sizeof(correct_trace) / sizeof(correct_trace[0]); From 1dbd7f5d3ebb3d2a1a0e8fed7c9f69c575a77973 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Sat, 22 Nov 2025 20:12:53 -0500 Subject: [PATCH 07/17] clang-format --- src/Inductive.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Inductive.h b/src/Inductive.h index 12ee0143abfc..09ba5c491430 100644 --- a/src/Inductive.h +++ b/src/Inductive.h @@ -10,15 +10,15 @@ namespace Halide { namespace Internal { -/** Given an initial box for an inductively defined function, +/** Given an initial box for an inductively defined function, returns an expanded box that includes the function's non-inductive base case. */ const Box expand_to_include_base_case(const std::vector &vars, const Expr &RHS, const std::string &func, const Box &box_required); -const Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos=0); +const Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos = 0); const Box expand_to_include_base_case(const Function &fn, const Box &box_required); -} // namespace Internal -} // namespace Halide +} // namespace Internal +} // namespace Halide #endif From 0836105a50b72c9f160b5af6fa4d5d262d7c5151 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Sun, 23 Nov 2025 01:01:26 -0500 Subject: [PATCH 08/17] add test and documentation --- src/Function.cpp | 5 ++++- src/Inductive.h | 31 +++++++++++++++++++++++++++++-- test/error/CMakeLists.txt | 1 + test/error/inductive_update.cpp | 20 ++++++++++++++++++++ 4 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 test/error/inductive_update.cpp diff --git a/src/Function.cpp b/src/Function.cpp index 1594689d6bdf..c9312dd07fe6 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -696,7 +696,10 @@ void Function::define_update(const vector &_args, vector values, con user_assert(!frozen()) << "Func " << name() << " cannot be given a new update definition, " << "because it has already been realized or used in the definition of another Func.\n"; - + user_assert(!is_inductive()) + << "In update definition " << update_idx << " of Func \"" << name() << "\":\n" + << "Inductive functions cannot have update definitions.\n"; + for (auto &value : values) { user_assert(value.defined()) << "In update definition " << update_idx << " of Func \"" << name() << "\":\n" diff --git a/src/Inductive.h b/src/Inductive.h index 09ba5c491430..707a7fc8a2ac 100644 --- a/src/Inductive.h +++ b/src/Inductive.h @@ -1,6 +1,35 @@ #ifndef INDUCTIVE_H #define INDUCTIVE_H +/** \file + * + * Utilities for processing inductively defined functions. + * + * A simple example of an inductively defined function is + * f(x) = select(x <= 0, input(0), input(x) + f(x - 1)); + * The purpose of inductive functions is to allow execution patterns that are + * impossible with reduction domains. For example, in the following code: + * + * f(x) = select(x <= 0, input(0), input(x) + f(x - 1)); + * g(x) = f(x) / 4; + * f.compute_at(g, x).store_root(); + * + * The resulting program computes a single value of f(x) at each value of g(x), + * thanks to Halide's sliding window optimization. As a result of storage folding, + * only the two most recent values of f(x) are stored at any given time. This is + * impossible if f(x) is defined using a reduction domain, since every value of f(x) + * must be computed and stored before g(x) is computed. + * + * If Halide is unable to perform the sliding window optimization, computing the + * inductive function is generally inefficient. + * + * In inductive functions, any recursive references must be inside a select statement, + * and cannot be inside nested select statements. The inductive arguments in the + * recursive reference must be monotonically decreasing. Currently, only single-valued + * functions are supported. Inductive functions cannot be inlined, and cannot have + * update definitions. + */ + #include "Bounds.h" #include "Expr.h" #include "Interval.h" @@ -13,9 +42,7 @@ namespace Internal { /** Given an initial box for an inductively defined function, returns an expanded box that includes the function's non-inductive base case. */ const Box expand_to_include_base_case(const std::vector &vars, const Expr &RHS, const std::string &func, const Box &box_required); - const Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos = 0); - const Box expand_to_include_base_case(const Function &fn, const Box &box_required); } // namespace Internal diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index a21c3f907af8..7301aad941a5 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -76,6 +76,7 @@ tests(GROUPS error inductive_nested_select.cpp inductive_no_select.cpp inductive_reorder.cpp + inductive_update.cpp inductive_var_swap.cpp inductive_vectorize.cpp init_def_should_be_all_vars.cpp diff --git a/test/error/inductive_update.cpp b/test/error/inductive_update.cpp new file mode 100644 index 000000000000..acb16183f552 --- /dev/null +++ b/test/error/inductive_update.cpp @@ -0,0 +1,20 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = select(x < 1, 0, x + f(x - 1)); + f(x) += 1; + + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} \ No newline at end of file From c498b4b5ac298920493e463e15d3dd81671758fd Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Sun, 23 Nov 2025 19:30:14 -0500 Subject: [PATCH 09/17] add user error and additional support for function declarations --- src/Func.cpp | 4 ++++ src/Func.h | 4 ++++ src/IROperator.cpp | 5 +++++ src/Inductive.h | 8 ++++++++ test/correctness/inductive.cpp | 2 +- 5 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/Func.cpp b/src/Func.cpp index 128ad912cfc0..95bb20e52a69 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -66,6 +66,10 @@ Func::Func(const string &name) : func(unique_name(name)) { } +Func::Func(const Type &required_type, const string &name) + : func({required_type}, AnyDims, unique_name(name)) { +} + Func::Func(const Type &required_type, int required_dims, const string &name) : func({required_type}, required_dims, unique_name(name)) { } diff --git a/src/Func.h b/src/Func.h index 517b725e4bb6..fafdc3a5c07e 100644 --- a/src/Func.h +++ b/src/Func.h @@ -734,6 +734,10 @@ class Func { /** Declare a new undefined function with the given name */ explicit Func(const std::string &name); + /** Declare a new undefined function with the given name. + * The function will be constrained to represent Exprs of required_type. */ + explicit Func(const Type &required_type, const std::string &name); + /** Declare a new undefined function with the given name. * The function will be constrained to represent Exprs of required_type. * If required_dims is not AnyDims, the function will be constrained to exactly diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 8c0a0970b426..c0b5f8c157c9 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -370,6 +370,11 @@ Expr make_const_helper(Type t, T val) { return UIntImm::make(t, (uint64_t)val); } else if (t.is_float()) { return FloatImm::make(t, (double)val); + } else if(t.is_unknown()){ + user_error << "Can't make a constant of unknown type.\n" + << "This is likely caused by a failure of type inference inside an inductive function definition.\n" + << "If you are trying to create an inductive function, you must define its type explicitly.\n"; + return Expr(); } else { internal_error << "Can't make a constant of type " << t << "\n"; return Expr(); diff --git a/src/Inductive.h b/src/Inductive.h index 707a7fc8a2ac..33598de38b60 100644 --- a/src/Inductive.h +++ b/src/Inductive.h @@ -28,6 +28,14 @@ * recursive reference must be monotonically decreasing. Currently, only single-valued * functions are supported. Inductive functions cannot be inlined, and cannot have * update definitions. + * + * In some cases, the inductive function's type cannot be inferred and must be declared + * explicitly. This occurs when constants appear in operations with a recursive reference. + * For example, in the following code, Halide cannot infer the type of f: + * f(x) = select(x <= 0, 0, f(x - 1) + 1); + * + * To fix this, declare f with an explicit type: + * Func f = Func(Int(32), "f"); */ #include "Bounds.h" diff --git a/test/correctness/inductive.cpp b/test/correctness/inductive.cpp index e10646bc91cb..23c3ab568544 100644 --- a/test/correctness/inductive.cpp +++ b/test/correctness/inductive.cpp @@ -209,7 +209,7 @@ int multi_baseline_test() { } int type_declare_test() { - Func g = Func(Int(32), AnyDims, "g"); + Func g = Func(Int(32), "g"); Func h("h"); Var x("x"), y("y"); From b92e6c96beec88b5d375358c04619528c4698ed3 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Sun, 23 Nov 2025 19:31:50 -0500 Subject: [PATCH 10/17] clang-format --- src/Func.cpp | 2 +- src/Function.cpp | 2 +- src/IROperator.cpp | 2 +- src/Inductive.h | 26 +++++++++++++------------- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 95bb20e52a69..3985cbc08179 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -66,7 +66,7 @@ Func::Func(const string &name) : func(unique_name(name)) { } -Func::Func(const Type &required_type, const string &name) +Func::Func(const Type &required_type, const string &name) : func({required_type}, AnyDims, unique_name(name)) { } diff --git a/src/Function.cpp b/src/Function.cpp index c9312dd07fe6..a0770bbc7394 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -699,7 +699,7 @@ void Function::define_update(const vector &_args, vector values, con user_assert(!is_inductive()) << "In update definition " << update_idx << " of Func \"" << name() << "\":\n" << "Inductive functions cannot have update definitions.\n"; - + for (auto &value : values) { user_assert(value.defined()) << "In update definition " << update_idx << " of Func \"" << name() << "\":\n" diff --git a/src/IROperator.cpp b/src/IROperator.cpp index c0b5f8c157c9..8404280bf6a4 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -370,7 +370,7 @@ Expr make_const_helper(Type t, T val) { return UIntImm::make(t, (uint64_t)val); } else if (t.is_float()) { return FloatImm::make(t, (double)val); - } else if(t.is_unknown()){ + } else if (t.is_unknown()) { user_error << "Can't make a constant of unknown type.\n" << "This is likely caused by a failure of type inference inside an inductive function definition.\n" << "If you are trying to create an inductive function, you must define its type explicitly.\n"; diff --git a/src/Inductive.h b/src/Inductive.h index 33598de38b60..f0053c08b257 100644 --- a/src/Inductive.h +++ b/src/Inductive.h @@ -4,33 +4,33 @@ /** \file * * Utilities for processing inductively defined functions. - * - * A simple example of an inductively defined function is + * + * A simple example of an inductively defined function is * f(x) = select(x <= 0, input(0), input(x) + f(x - 1)); - * The purpose of inductive functions is to allow execution patterns that are + * The purpose of inductive functions is to allow execution patterns that are * impossible with reduction domains. For example, in the following code: * * f(x) = select(x <= 0, input(0), input(x) + f(x - 1)); * g(x) = f(x) / 4; * f.compute_at(g, x).store_root(); * - * The resulting program computes a single value of f(x) at each value of g(x), - * thanks to Halide's sliding window optimization. As a result of storage folding, - * only the two most recent values of f(x) are stored at any given time. This is - * impossible if f(x) is defined using a reduction domain, since every value of f(x) + * The resulting program computes a single value of f(x) at each value of g(x), + * thanks to Halide's sliding window optimization. As a result of storage folding, + * only the two most recent values of f(x) are stored at any given time. This is + * impossible if f(x) is defined using a reduction domain, since every value of f(x) * must be computed and stored before g(x) is computed. * - * If Halide is unable to perform the sliding window optimization, computing the + * If Halide is unable to perform the sliding window optimization, computing the * inductive function is generally inefficient. * - * In inductive functions, any recursive references must be inside a select statement, - * and cannot be inside nested select statements. The inductive arguments in the - * recursive reference must be monotonically decreasing. Currently, only single-valued - * functions are supported. Inductive functions cannot be inlined, and cannot have + * In inductive functions, any recursive references must be inside a select statement, + * and cannot be inside nested select statements. The inductive arguments in the + * recursive reference must be monotonically decreasing. Currently, only single-valued + * functions are supported. Inductive functions cannot be inlined, and cannot have * update definitions. * * In some cases, the inductive function's type cannot be inferred and must be declared - * explicitly. This occurs when constants appear in operations with a recursive reference. + * explicitly. This occurs when constants appear in operations with a recursive reference. * For example, in the following code, Halide cannot infer the type of f: * f(x) = select(x <= 0, 0, f(x - 1) + 1); * From dfbc8db486b8b51a3208051cd9e3aa674317a0d9 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Thu, 4 Dec 2025 16:06:25 -0500 Subject: [PATCH 11/17] clang-tidy and additional safety check. All tests pass --- src/Function.cpp | 48 ++++++++++++++++++++++++++++++++--------------- src/Inductive.cpp | 10 +++++----- src/Inductive.h | 6 +++--- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/src/Function.cpp b/src/Function.cpp index a0770bbc7394..612648438c9a 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -207,9 +207,10 @@ struct CheckVars : public IRGraphVisitor { Scope<> defined_internally; const std::string name; bool unbound_reduction_vars_ok = false; + bool pure; - CheckVars(const std::string &n) - : name(n) { + CheckVars(const std::string &n, bool pure) + : name(n), pure(pure) { } using IRVisitor::visit; @@ -222,20 +223,37 @@ struct CheckVars : public IRGraphVisitor { void visit(const Call *op) override { IRGraphVisitor::visit(op); - /* - if (op->name == name && op->call_type == Call::Halide) { - for (size_t i = 0; i < op->args.size(); i++) { - const Variable *var = op->args[i].as(); - if (!pure_args[i].empty()) { - user_assert(var && var->name == pure_args[i]) - << "In definition of Func \"" << name << "\":\n" - << "All of a function's recursive references to itself" - << " must contain the same pure variables in the same" - << " places as on the left-hand-side.\n"; + bool proper_func = (name == op->name || op->call_type != Call::Halide); + if (op->call_type == Call::Halide && + op->func.defined() && + op->name != name) { + Function func = Function(op->func); + proper_func = (name == op->name || func.has_pure_definition() || func.has_extern_definition()); + } + if (pure) { + user_assert(proper_func) + << "In pure definition of Func \"" << name << "\":\n" + << "Can't call Func \"" << op->name + << "\" because it has not yet been defined," + << " and it is not a recursive call.\n"; + } else { + user_assert(proper_func) + << "In update definition of Func \"" << name << "\":\n" + << "Can't call Func \"" << op->name + << "\" because it has not yet been defined.\n"; + if (op->name == name && op->call_type == Call::Halide) { + for (size_t i = 0; i < op->args.size(); i++) { + const Variable *var = op->args[i].as(); + if (!pure_args[i].empty()) { + user_assert(var && var->name == pure_args[i]) + << "In update definition of Func \"" << name << "\":\n" + << "All of a function's recursive references to itself" + << " in update definitions must contain the same pure" + << " variables in the same places as on the left-hand-side.\n"; + } } } } - */ } void visit(const Variable *var) override { @@ -564,7 +582,7 @@ void Function::define(const vector &args, vector values) { // Make sure all the vars in the value are either args or are // attached to some parameter - CheckVars check(name()); + CheckVars check(name(), true); check.pure_args = args; for (const auto &value : values) { value.accept(&check); @@ -769,7 +787,7 @@ void Function::define_update(const vector &_args, vector values, con // pure args, in the reduction domain, or a parameter. Also checks // that recursive references to the function contain all the pure // vars in the LHS in the correct places. - CheckVars check(name()); + CheckVars check(name(), false); check.pure_args = pure_args; for (const auto &arg : args) { arg.accept(&check); diff --git a/src/Inductive.cpp b/src/Inductive.cpp index 332d38f341e6..b45482fdbc50 100644 --- a/src/Inductive.cpp +++ b/src/Inductive.cpp @@ -46,8 +46,8 @@ class BaseCaseSolver : public IRVisitor { } op->args[2].accept(this); condition_intervals = old_intervals; - for (size_t i = 0; i < vars.size(); i++) { - bounds.pop(vars[i]); + for (const auto &var : vars) { + bounds.pop(var); } nested_select -= 1; } else if (op->name == func) { @@ -95,7 +95,7 @@ class BaseCaseSolver : public IRVisitor { // anonymous namespace -const Box expand_to_include_base_case(const vector &vars, const Expr &RHS, const string &func, const Box &box_required) { +Box expand_to_include_base_case(const vector &vars, const Expr &RHS, const string &func, const Box &box_required) { Expr substed = substitute_in_all_lets(RHS); Box box2 = box_required; BaseCaseSolver b(vars, func, box_required.bounds); @@ -109,11 +109,11 @@ const Box expand_to_include_base_case(const vector &vars, const Expr &RH return box2; } -const Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos) { +Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos) { return expand_to_include_base_case(fn.args(), fn.values()[pos], fn.name(), box_required); } -const Box expand_to_include_base_case(const Function &fn, const Box &box_required) { +Box expand_to_include_base_case(const Function &fn, const Box &box_required) { Box b = expand_to_include_base_case(fn.args(), fn.values()[0], fn.name(), box_required); for (size_t pos = 1; pos < fn.values().size(); pos++) { Box b2 = expand_to_include_base_case(fn.args(), fn.values()[pos], fn.name(), box_required); diff --git a/src/Inductive.h b/src/Inductive.h index f0053c08b257..90bd27e86b7e 100644 --- a/src/Inductive.h +++ b/src/Inductive.h @@ -49,9 +49,9 @@ namespace Internal { /** Given an initial box for an inductively defined function, returns an expanded box that includes the function's non-inductive base case. */ -const Box expand_to_include_base_case(const std::vector &vars, const Expr &RHS, const std::string &func, const Box &box_required); -const Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos = 0); -const Box expand_to_include_base_case(const Function &fn, const Box &box_required); +Box expand_to_include_base_case(const std::vector &vars, const Expr &RHS, const std::string &func, const Box &box_required); +Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos = 0); +Box expand_to_include_base_case(const Function &fn, const Box &box_required); } // namespace Internal } // namespace Halide From 472ad8ac6dba651620d453c3411d9ac53ad2673c Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Thu, 4 Dec 2025 20:02:34 -0500 Subject: [PATCH 12/17] patched makefile --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 54c61a622ae8..96ab3a29f80d 100644 --- a/Makefile +++ b/Makefile @@ -520,6 +520,7 @@ SOURCE_FILES = \ HexagonOffload.cpp \ HexagonOptimize.cpp \ ImageParam.cpp \ + Inductive.cpp \ InferArguments.cpp \ InjectHostDevBufferCopies.cpp \ Inline.cpp \ @@ -720,6 +721,7 @@ HEADER_FILES = \ HexagonOffload.h \ HexagonOptimize.h \ ImageParam.h \ + Inductive.h \ InferArguments.h \ InjectHostDevBufferCopies.h \ Inline.h \ From fa55c0aed2421251311d70608541a85e6f625193 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Fri, 5 Dec 2025 17:05:21 -0500 Subject: [PATCH 13/17] Serialization fix --- src/Serialization.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Serialization.cpp b/src/Serialization.cpp index d731d9c9d85c..2d8bc6407a28 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -352,6 +352,8 @@ Serialize::DimType Serializer::serialize_dim_type(const DimType &dim_type) { return Serialize::DimType::PureRVar; case DimType::ImpureRVar: return Serialize::DimType::ImpureRVar; + case DimType::InductiveVar: + return Serialize::DimType::InductiveVar; default: user_error << "Unsupported dim type\n"; return Serialize::DimType::PureVar; From 46d8d7fe81bdbe315802debf2f8868f736c1691e Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Fri, 5 Dec 2025 17:31:36 -0500 Subject: [PATCH 14/17] more dimtype fixes --- src/Deserialization.cpp | 2 ++ src/halide_ir.fbs | 1 + 2 files changed, 3 insertions(+) diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index b6f49cb1bf43..92d76c3ea4b1 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -382,6 +382,8 @@ DimType Deserializer::deserialize_dim_type(Serialize::DimType dim_type) { return DimType::PureRVar; case Serialize::DimType::ImpureRVar: return DimType::ImpureRVar; + case Serialize::DimType::InductiveVar: + return DimType::InductiveVar; default: user_error << "unknown dim type " << (int)dim_type << "\n"; return DimType::PureVar; diff --git a/src/halide_ir.fbs b/src/halide_ir.fbs index 499488ce8b95..1fdf0e5f1144 100644 --- a/src/halide_ir.fbs +++ b/src/halide_ir.fbs @@ -562,6 +562,7 @@ table Split { enum DimType: ubyte { PureVar, + InductiveVar, PureRVar, ImpureRVar, } From ba8725d9f03b9a585619983bf0fd9fbb579b7e72 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Thu, 11 Dec 2025 04:13:06 -0500 Subject: [PATCH 15/17] fix memory leaks --- src/Function.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/Function.cpp b/src/Function.cpp index 612648438c9a..9a426c861745 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -649,6 +649,22 @@ void Function::define(const vector &args, vector values) { for (size_t i = 0; i < args.size(); i++) { init_def_args[i] = Var(args[i]); } + + // If the function is inductive, + // the value and args might refer back to the + // function itself, introducing circular references and hence + // memory leaks. We need to break these cycles. + WeakenFunctionPtrs weakener(contents.get()); + for (auto &arg : init_def_args) { + arg = weakener.mutate(arg); + } + for (auto &value : values) { + value = weakener.mutate(value); + } + if (check.reduction_domain.defined()) { + check.reduction_domain.set_predicate( + weakener.mutate(check.reduction_domain.predicate())); + } ReductionDomain rdom; contents->init_def = Definition(init_def_args, values, rdom, true); From cc437c982fb4bfbeced6eb6f6f63c8b648a44591 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Fri, 12 Dec 2025 21:22:48 -0500 Subject: [PATCH 16/17] remove commented-out code --- test/correctness/inductive.cpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/correctness/inductive.cpp b/test/correctness/inductive.cpp index 23c3ab568544..67ab9c2fce7f 100644 --- a/test/correctness/inductive.cpp +++ b/test/correctness/inductive.cpp @@ -17,8 +17,6 @@ int simple_inductive_test() { Func g("g"), h("h"); Var x("x"), y("y"); - // g(x, y) = x + y; - // g(r.x, r.y) = g(r.x, r.y); g(x, y) = select(x <= 0, 0, g(max(0, x - 1), y) + x + y); h(x, y) = g(x + 5, y) / 4; @@ -41,8 +39,6 @@ int reorder_test() { Var xi("xi"), xii("xii"), xo("xo"); - // g(x, y) = x + y; - // g(r.x, r.y) = g(r.x, r.y); g(x, y) = select(x <= 0, 0, g(max(0, x - 1), y) + x + y); h(x, y) = g(x + 5, y) / 4; @@ -84,10 +80,7 @@ int large_baseline() { Func g("g"), h("h"); Var x("x"), y("y"); - // g(x, y) = x + y; - // g(r.x, r.y) = g(r.x, r.y); g(x, y) = select(x <= 8, (y * x + x * (x + 1) / 2) - 1, g(x - 1, y) + x + y); - h(x, y) = g(x + 5, y) / 4; g.compute_at(h, x).store_at(h, y); @@ -106,8 +99,6 @@ int fibonacci() { Func g("g"), h("h"); Var x("x"), y("y"); - // g(x, y) = x + y; - // g(r.x, r.y) = g(r.x, r.y); g(x, y) = select(x <= 1, 1, g(x - 1, y) + g(x - 2, y)); h(x, y) = g(x, y); From 44faa756d14c9dc7e971ba8a72e75d013433bb21 Mon Sep 17 00:00:00 2001 From: Steven Raphael Date: Fri, 12 Dec 2025 21:23:40 -0500 Subject: [PATCH 17/17] clang-format --- src/Function.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Function.cpp b/src/Function.cpp index 9a426c861745..3358c335b927 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -649,7 +649,7 @@ void Function::define(const vector &args, vector values) { for (size_t i = 0; i < args.size(); i++) { init_def_args[i] = Var(args[i]); } - + // If the function is inductive, // the value and args might refer back to the // function itself, introducing circular references and hence