diff --git a/Makefile b/Makefile index 54c61a622ae8..96ab3a29f80d 100644 --- a/Makefile +++ b/Makefile @@ -520,6 +520,7 @@ SOURCE_FILES = \ HexagonOffload.cpp \ HexagonOptimize.cpp \ ImageParam.cpp \ + Inductive.cpp \ InferArguments.cpp \ InjectHostDevBufferCopies.cpp \ Inline.cpp \ @@ -720,6 +721,7 @@ HEADER_FILES = \ HexagonOffload.h \ HexagonOptimize.h \ ImageParam.h \ + Inductive.h \ InferArguments.h \ InjectHostDevBufferCopies.h \ Inline.h \ diff --git a/python_bindings/test/correctness/basics.py b/python_bindings/test/correctness/basics.py index 03d9d86220bb..e95f99d9418c 100644 --- a/python_bindings/test/correctness/basics.py +++ b/python_bindings/test/correctness/basics.py @@ -489,15 +489,6 @@ def test_unevaluated_funcref(): f[x] += 1 assert f.realize([1])[0] == 1 - with assert_throws( - hl.HalideError, - r"Error: Can't call Func \"f(\$\d+)?\" because it has not yet been defined\.", - ): - # This is invalid because we only allow unevaluated func refs on the LHS of a - # binary operator. - f = hl.Func("f") - f[x] = 1 + f[x] - with assert_throws( hl.HalideError, r"Cannot use an unevaluated reference to 'f(\$\d+)?' to define an update at a different location\.", diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 670ccf11d177..59de794aa0f0 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -3299,7 +3299,7 @@ FuncValueBounds compute_function_value_bounds(const vector &order, Interval result; - if (f.is_pure()) { + if (f.is_pure() && !f.is_inductive()) { // Make a scope that says the args could be anything. Scope arg_scope; diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 9ba1f1af2019..0047f54fb8fe 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -7,6 +7,7 @@ #include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" +#include "Inductive.h" #include "Inline.h" #include "Qualify.h" #include "Scope.h" @@ -1021,6 +1022,21 @@ class BoundsInference : public IRMutator { } } + // For any inductively defined functions, make sure their + // bounds include the base case. + for (Stage &s : stages) { + if (!s.func.is_pure() || !s.func.is_inductive()) { + continue; + } + debug(4) << "Expanding bounds for inductively defined function " << s.func.name() << "\n"; + for (const auto &b1 : s.bounds) { + const Box &b = b1.second; + for (const auto &cval : s.exprs) { + s.bounds[b1.first] = expand_to_include_base_case(s.func.args(), cval.value, s.func.name(), b); + } + } + } + // The region required of the each output is expanded to include the size of the output buffer. for (const Function &output : outputs) { Box output_box; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index af419323b24e..376fceff3960 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -128,6 +128,7 @@ target_sources( HexagonOffload.h HexagonOptimize.h ImageParam.h + Inductive.h InferArguments.h InjectHostDevBufferCopies.h Inline.h @@ -305,6 +306,7 @@ target_sources( HexagonOffload.cpp HexagonOptimize.cpp ImageParam.cpp + Inductive.cpp InferArguments.cpp InjectHostDevBufferCopies.cpp Inline.cpp diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index b6f49cb1bf43..92d76c3ea4b1 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -382,6 +382,8 @@ DimType Deserializer::deserialize_dim_type(Serialize::DimType dim_type) { return DimType::PureRVar; case Serialize::DimType::ImpureRVar: return DimType::ImpureRVar; + case Serialize::DimType::InductiveVar: + return DimType::InductiveVar; default: user_error << "unknown dim type " << (int)dim_type << "\n"; return DimType::PureVar; diff --git a/src/Func.cpp b/src/Func.cpp index cf8904ec2d31..3985cbc08179 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -66,6 +66,10 @@ Func::Func(const string &name) : func(unique_name(name)) { } +Func::Func(const Type &required_type, const string &name) + : func({required_type}, AnyDims, unique_name(name)) { +} + Func::Func(const Type &required_type, int required_dims, const string &name) : func({required_type}, required_dims, unique_name(name)) { } @@ -491,7 +495,7 @@ void Stage::set_dim_type(const VarOrRVar &var, ForType t) { // If it's an rvar and the for type is parallel, we need to // validate that this doesn't introduce a race condition, // unless it is flagged explicitly or is a associative atomic operation. - if (!dim.is_pure() && var.is_rvar && is_parallel(t)) { + if (!dim.is_pure() && (var.is_rvar || dim.is_inductive()) && is_parallel(t)) { if (!definition.schedule().allow_race_conditions() && definition.schedule().atomic()) { if (!definition.schedule().override_atomic_associativity_test()) { @@ -1342,6 +1346,9 @@ Stage &Stage::fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRV } else if (dims[i].dim_type == DimType::PureRVar || outer_type == DimType::PureRVar) { dims[i].dim_type = DimType::PureRVar; + } else if (dims[i].dim_type == DimType::InductiveVar || + outer_type == DimType::InductiveVar) { + dims[i].dim_type = DimType::InductiveVar; } else { dims[i].dim_type = DimType::PureVar; } @@ -3304,8 +3311,19 @@ Stage FuncRef::operator/=(const FuncRef &e) { } FuncRef::operator Expr() const { + /* user_assert(func.has_pure_definition() || func.has_extern_definition()) << "Can't call Func \"" << func.name() << "\" because it has not yet been defined.\n"; + */ + + if (!(func.has_pure_definition() || func.has_extern_definition())) { + Type t = Type(Type::Unknown, 0, 1); + if (!func.required_types().empty()) { + t = func.required_types()[0]; + } + return Call::make(t, func.name(), args, Call::Halide, + func.get_contents(), 0, Buffer<>(), Parameter()); + } user_assert(func.outputs() == 1) << "Can't convert a reference Func \"" << func.name() diff --git a/src/Func.h b/src/Func.h index 517b725e4bb6..fafdc3a5c07e 100644 --- a/src/Func.h +++ b/src/Func.h @@ -734,6 +734,10 @@ class Func { /** Declare a new undefined function with the given name */ explicit Func(const std::string &name); + /** Declare a new undefined function with the given name. + * The function will be constrained to represent Exprs of required_type. */ + explicit Func(const Type &required_type, const std::string &name); + /** Declare a new undefined function with the given name. * The function will be constrained to represent Exprs of required_type. * If required_dims is not AnyDims, the function will be constrained to exactly diff --git a/src/Function.cpp b/src/Function.cpp index f66b886cdafd..3358c335b927 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -207,9 +207,10 @@ struct CheckVars : public IRGraphVisitor { Scope<> defined_internally; const std::string name; bool unbound_reduction_vars_ok = false; + bool pure; - CheckVars(const std::string &n) - : name(n) { + CheckVars(const std::string &n, bool pure) + : name(n), pure(pure) { } using IRVisitor::visit; @@ -222,15 +223,34 @@ struct CheckVars : public IRGraphVisitor { void visit(const Call *op) override { IRGraphVisitor::visit(op); - if (op->name == name && op->call_type == Call::Halide) { - for (size_t i = 0; i < op->args.size(); i++) { - const Variable *var = op->args[i].as(); - if (!pure_args[i].empty()) { - user_assert(var && var->name == pure_args[i]) - << "In definition of Func \"" << name << "\":\n" - << "All of a function's recursive references to itself" - << " must contain the same pure variables in the same" - << " places as on the left-hand-side.\n"; + bool proper_func = (name == op->name || op->call_type != Call::Halide); + if (op->call_type == Call::Halide && + op->func.defined() && + op->name != name) { + Function func = Function(op->func); + proper_func = (name == op->name || func.has_pure_definition() || func.has_extern_definition()); + } + if (pure) { + user_assert(proper_func) + << "In pure definition of Func \"" << name << "\":\n" + << "Can't call Func \"" << op->name + << "\" because it has not yet been defined," + << " and it is not a recursive call.\n"; + } else { + user_assert(proper_func) + << "In update definition of Func \"" << name << "\":\n" + << "Can't call Func \"" << op->name + << "\" because it has not yet been defined.\n"; + if (op->name == name && op->call_type == Call::Halide) { + for (size_t i = 0; i < op->args.size(); i++) { + const Variable *var = op->args[i].as(); + if (!pure_args[i].empty()) { + user_assert(var && var->name == pure_args[i]) + << "In update definition of Func \"" << name << "\":\n" + << "All of a function's recursive references to itself" + << " in update definitions must contain the same pure" + << " variables in the same places as on the left-hand-side.\n"; + } } } } @@ -562,7 +582,7 @@ void Function::define(const vector &args, vector values) { // Make sure all the vars in the value are either args or are // attached to some parameter - CheckVars check(name()); + CheckVars check(name(), true); check.pure_args = args; for (const auto &value : values) { value.accept(&check); @@ -570,6 +590,7 @@ void Function::define(const vector &args, vector values) { // Freeze all called functions FreezeFunctions freezer(name()); + // TODO: Check for calls to undefined Funcs for (const auto &value : values) { value.accept(&freezer); } @@ -629,11 +650,31 @@ void Function::define(const vector &args, vector values) { init_def_args[i] = Var(args[i]); } + // If the function is inductive, + // the value and args might refer back to the + // function itself, introducing circular references and hence + // memory leaks. We need to break these cycles. + WeakenFunctionPtrs weakener(contents.get()); + for (auto &arg : init_def_args) { + arg = weakener.mutate(arg); + } + for (auto &value : values) { + value = weakener.mutate(value); + } + if (check.reduction_domain.defined()) { + check.reduction_domain.set_predicate( + weakener.mutate(check.reduction_domain.predicate())); + } + ReductionDomain rdom; contents->init_def = Definition(init_def_args, values, rdom, true); for (const auto &arg : args) { - Dim d = {arg, ForType::Serial, DeviceAPI::None, DimType::PureVar}; + DimType dtype = DimType::PureVar; + if (is_inductive(arg)) { + dtype = DimType::InductiveVar; + } + Dim d = {arg, ForType::Serial, DeviceAPI::None, dtype}; contents->init_def.schedule().dims().push_back(d); StorageDim sd = {arg}; contents->func_schedule.storage_dims().push_back(sd); @@ -689,6 +730,9 @@ void Function::define_update(const vector &_args, vector values, con user_assert(!frozen()) << "Func " << name() << " cannot be given a new update definition, " << "because it has already been realized or used in the definition of another Func.\n"; + user_assert(!is_inductive()) + << "In update definition " << update_idx << " of Func \"" << name() << "\":\n" + << "Inductive functions cannot have update definitions.\n"; for (auto &value : values) { user_assert(value.defined()) @@ -759,7 +803,7 @@ void Function::define_update(const vector &_args, vector values, con // pure args, in the reduction domain, or a parameter. Also checks // that recursive references to the function contain all the pure // vars in the LHS in the correct places. - CheckVars check(name()); + CheckVars check(name(), false); check.pure_args = pure_args; for (const auto &arg : args) { arg.accept(&check); @@ -1066,8 +1110,89 @@ bool Function::has_pure_definition() const { return contents->init_def.defined(); } +bool Function::is_inductive() const { + class RecursiveHelper : public IRVisitor { + using IRVisitor::visit; + const string &func; + void visit(const Call *op) override { + if (op->name == func) { + recursive = true; + } + IRVisitor::visit(op); + } + + public: + bool recursive = false; + RecursiveHelper(const string &func) + : func(func) { + } + }; + + if (!has_pure_definition()) { + return false; + } + + RecursiveHelper r(name()); + for (const Expr &e : definition().values()) { + e.accept(&r); + } + + return r.recursive; +} + +bool Function::is_inductive(const string &var) const { + class RecursiveHelper : public IRVisitor { + using IRVisitor::visit; + const string &func; + const string &var; + const int &pos; + void visit(const Call *op) override { + if (op->name == func) { + recursive = true; + if (const auto &v = op->args[pos].as()) { + if (v->name != var) { + inductive_in_var = true; + } + } else { + inductive_in_var = true; + } + } + IRVisitor::visit(op); + } + + public: + bool recursive = false; + bool inductive_in_var = false; + RecursiveHelper(const string &func, const string &var, const int &pos) + : func(func), var(var), pos(pos) { + } + }; + + if (!has_pure_definition()) { + return false; + } + + int pos = -1; + for (size_t i = 0; i < definition().args().size(); i++) { + if (const auto &v = definition().args()[i].as()) { + if (v->name == var) { + pos = i; + } + } + } + if (pos == -1) { + return false; + } + RecursiveHelper r(name(), var, pos); + for (const Expr &e : definition().values()) { + e.accept(&r); + } + + return r.inductive_in_var; +} + bool Function::can_be_inlined() const { - return is_pure() && definition().specializations().empty(); + return is_pure() && definition().specializations().empty() && !is_inductive(); } bool Function::has_update_definition() const { diff --git a/src/Function.h b/src/Function.h index 5305f4f058be..b413f7e4145d 100644 --- a/src/Function.h +++ b/src/Function.h @@ -187,6 +187,12 @@ class Function { !has_extern_definition()); } + /** Does this function have an inductive pure definition? */ + bool is_inductive() const; + + /** Is this function inductive in the given variable? */ + bool is_inductive(const std::string &var) const; + /** Is it legal to inline this function? */ bool can_be_inlined() const; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index ec832404e0eb..8404280bf6a4 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -370,6 +370,11 @@ Expr make_const_helper(Type t, T val) { return UIntImm::make(t, (uint64_t)val); } else if (t.is_float()) { return FloatImm::make(t, (double)val); + } else if (t.is_unknown()) { + user_error << "Can't make a constant of unknown type.\n" + << "This is likely caused by a failure of type inference inside an inductive function definition.\n" + << "If you are trying to create an inductive function, you must define its type explicitly.\n"; + return Expr(); } else { internal_error << "Can't make a constant of type " << t << "\n"; return Expr(); @@ -670,6 +675,16 @@ void match_types(Expr &a, Expr &b) { return; } + if (a.type().is_unknown() && !b.type().is_unknown()) { + b = cast(a.type(), b); + return; + } + + if (b.type().is_unknown() && !a.type().is_unknown()) { + a = cast(b.type(), a); + return; + } + user_assert(!a.type().is_handle() && !b.type().is_handle()) << "Can't do arithmetic on opaque pointer types: " << a << ", " << b << "\n"; @@ -1494,12 +1509,47 @@ Expr saturating_cast(Type t, Expr e) { return Call::make(t, Call::saturating_cast, {std::move(e)}, Call::PureIntrinsic); } +Expr declare_type(Type t, const Expr &e) { + // TODO: This may be called on unsanitized exprs. May need CSE. + internal_assert(e.type().is_unknown()); + if (const Call *op = e.as()) { + internal_assert(op->call_type == Call::Halide); + return Call::make(t, op->name, op->args, op->call_type, + op->func, op->value_index, op->image, op->param); + } else if (const Add *op = e.as()) { + return Add::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Sub *op = e.as()) { + return Sub::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Mul *op = e.as()) { + return Mul::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Div *op = e.as
()) { + return Div::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Min *op = e.as()) { + return Min::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Max *op = e.as()) { + return Max::make(declare_type(t, op->a), declare_type(t, op->b)); + } else if (const Cast *op = e.as()) { + // Must be a cast to unknown + return Cast::make(t, op->value); + } else { + user_error << "Can't do top-down type inference on " << e; + } +} + Expr select(Expr condition, Expr true_value, Expr false_value) { if (as_const_int(condition)) { // Why are you doing this? We'll preserve the select node until constant folding for you. condition = cast(Bool(true_value.type().lanes()), std::move(condition)); } + if (true_value.type().is_unknown() && !false_value.type().is_unknown()) { + true_value = declare_type(false_value.type(), true_value); + } + + if (false_value.type().is_unknown() && !true_value.type().is_unknown()) { + false_value = declare_type(true_value.type(), false_value); + } + // Coerce int literals to the type of the other argument if (as_const_int(true_value)) { true_value = cast(false_value.type(), std::move(true_value)); diff --git a/src/IROperator.h b/src/IROperator.h index d6d33a1cf82e..0c94747fa040 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -389,6 +389,10 @@ inline Expr cast(Expr a) { /** Cast an expression to a new type. */ Expr cast(Type t, Expr a); +/** Declare an Expr with unknown type has a given type. Useful when defining + * Funcs inductively. */ +Expr declare_type(Type t, const Expr &e); + /** Return the sum of two expressions, doing any necessary type * coercion using \ref Internal::match_types */ Expr operator+(Expr a, Expr b); diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 97f6a409d6c9..76d3b59881f0 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -46,6 +46,8 @@ ostream &operator<<(ostream &out, const Type &type) { case Type::BFloat: out << "bfloat"; break; + case Type::Unknown: + out << "unknown"; } if (!type.is_handle()) { out << type.bits(); @@ -496,6 +498,9 @@ std::ostream &operator<<(std::ostream &out, const DimType &t) { case DimType::PureRVar: out << "PureRVar"; break; + case DimType::InductiveVar: + out << "InductiveVar"; + break; case DimType::ImpureRVar: out << "ImpureRVar"; break; diff --git a/src/Inductive.cpp b/src/Inductive.cpp new file mode 100644 index 000000000000..b45482fdbc50 --- /dev/null +++ b/src/Inductive.cpp @@ -0,0 +1,126 @@ +#include "Inductive.h" + +#include "Bounds.h" +#include "ConciseCasts.h" +#include "Error.h" +#include "Function.h" +#include "IR.h" +#include "IREquality.h" +#include "IRVisitor.h" +#include "Simplify.h" +#include "Substitute.h" + +namespace Halide { +namespace Internal { + +using std::string; +using std::vector; + +class BaseCaseSolver : public IRVisitor { + using IRVisitor::visit; + const vector &vars; + const string &func; + + const vector &start_box; + + vector condition_intervals; + + Scope bounds; + + int nested_select = 0; + + void visit(const Call *op) override { + if (op->is_intrinsic(Call::if_then_else)) { + nested_select += 1; + vector old_intervals = condition_intervals; + for (size_t i = 0; i < vars.size(); i++) { + condition_intervals[i] = Interval::make_intersection(old_intervals[i], solve_for_outer_interval(simplify(op->args[0]), vars[i])); + bounds.push(vars[i], condition_intervals[i]); + } + + op->args[1].accept(this); + for (size_t i = 0; i < vars.size(); i++) { + condition_intervals[i] = Interval::make_intersection(old_intervals[i], solve_for_outer_interval(simplify(!op->args[0]), vars[i])); + bounds.pop(vars[i]); + bounds.push(vars[i], condition_intervals[i]); + } + op->args[2].accept(this); + condition_intervals = old_intervals; + for (const auto &var : vars) { + bounds.pop(var); + } + nested_select -= 1; + } else if (op->name == func) { + user_assert(nested_select > 0) << "Function " << func << " contains an inductive function reference outside of a select operation.\n"; + user_assert(nested_select == 1) << "Function " << func << " contains an inductive function reference inside a nested select operation.\n"; + bool found_inductive = false; + for (size_t position = 0; position < vars.size(); position++) { + const Expr inductive_expr = op->args[position]; + const Expr new_v = Variable::make(inductive_expr.type(), vars[position]); + const Expr gets_lower = simplify(new_v - inductive_expr > 0, true, bounds); + const Interval i_lower = solve_for_inner_interval(gets_lower, vars[position]); + + Interval new_interval; + if (equal(new_v, inductive_expr)) { + new_interval = start_box[position]; + } else if (i_lower.is_everything()) { + found_inductive = true; + new_interval = Interval(Interval::neg_inf(), start_box[position].max); + } else { + new_interval = Interval::everything(); + } + new_interval = Interval::make_intersection(new_interval, condition_intervals[position]); + Scope i_scope; + i_scope.push(vars[position], new_interval); + result_intervals[position] = Interval::make_union(result_intervals[position], Interval::make_union(new_interval, bounds_of_expr_in_scope(inductive_expr, i_scope))); + } + user_assert(found_inductive) << "Unable to prove in inductive function " << func << " that the inductive step is monotonically decreasing.\n"; + + IRVisitor::visit(op); + + } else { + IRVisitor::visit(op); + } + } + +public: + vector result_intervals; + + BaseCaseSolver(const vector &v, const string &func, const vector &con) + : vars(v), func(func), start_box(con) { + condition_intervals = vector(start_box.size()); + result_intervals = vector(start_box.size(), Interval::nothing()); + } +}; + +// anonymous namespace + +Box expand_to_include_base_case(const vector &vars, const Expr &RHS, const string &func, const Box &box_required) { + Expr substed = substitute_in_all_lets(RHS); + Box box2 = box_required; + BaseCaseSolver b(vars, func, box_required.bounds); + substed.accept(&b); + for (size_t i = 0; i < vars.size(); i++) { + user_assert(b.result_intervals[i].is_bounded()) << "Unable to prove that the inductive function " << func << " uses a bounded interval"; + Interval new_interval(min(b.result_intervals[i].min, box_required[i].min), box_required[i].max); + box2[i] = new_interval; + } + + return box2; +} + +Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos) { + return expand_to_include_base_case(fn.args(), fn.values()[pos], fn.name(), box_required); +} + +Box expand_to_include_base_case(const Function &fn, const Box &box_required) { + Box b = expand_to_include_base_case(fn.args(), fn.values()[0], fn.name(), box_required); + for (size_t pos = 1; pos < fn.values().size(); pos++) { + Box b2 = expand_to_include_base_case(fn.args(), fn.values()[pos], fn.name(), box_required); + merge_boxes(b, b2); + } + return b; +} + +} // namespace Internal +} // namespace Halide diff --git a/src/Inductive.h b/src/Inductive.h new file mode 100644 index 000000000000..90bd27e86b7e --- /dev/null +++ b/src/Inductive.h @@ -0,0 +1,59 @@ +#ifndef INDUCTIVE_H +#define INDUCTIVE_H + +/** \file + * + * Utilities for processing inductively defined functions. + * + * A simple example of an inductively defined function is + * f(x) = select(x <= 0, input(0), input(x) + f(x - 1)); + * The purpose of inductive functions is to allow execution patterns that are + * impossible with reduction domains. For example, in the following code: + * + * f(x) = select(x <= 0, input(0), input(x) + f(x - 1)); + * g(x) = f(x) / 4; + * f.compute_at(g, x).store_root(); + * + * The resulting program computes a single value of f(x) at each value of g(x), + * thanks to Halide's sliding window optimization. As a result of storage folding, + * only the two most recent values of f(x) are stored at any given time. This is + * impossible if f(x) is defined using a reduction domain, since every value of f(x) + * must be computed and stored before g(x) is computed. + * + * If Halide is unable to perform the sliding window optimization, computing the + * inductive function is generally inefficient. + * + * In inductive functions, any recursive references must be inside a select statement, + * and cannot be inside nested select statements. The inductive arguments in the + * recursive reference must be monotonically decreasing. Currently, only single-valued + * functions are supported. Inductive functions cannot be inlined, and cannot have + * update definitions. + * + * In some cases, the inductive function's type cannot be inferred and must be declared + * explicitly. This occurs when constants appear in operations with a recursive reference. + * For example, in the following code, Halide cannot infer the type of f: + * f(x) = select(x <= 0, 0, f(x - 1) + 1); + * + * To fix this, declare f with an explicit type: + * Func f = Func(Int(32), "f"); + */ + +#include "Bounds.h" +#include "Expr.h" +#include "Interval.h" +#include "Scope.h" +#include "Solve.h" + +namespace Halide { +namespace Internal { + +/** Given an initial box for an inductively defined function, + returns an expanded box that includes the function's non-inductive base case. */ +Box expand_to_include_base_case(const std::vector &vars, const Expr &RHS, const std::string &func, const Box &box_required); +Box expand_to_include_base_case(const Function &fn, const Box &box_required, const int &pos = 0); +Box expand_to_include_base_case(const Function &fn, const Box &box_required); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/Schedule.h b/src/Schedule.h index 906dbe6c7b5e..0017b10c80e2 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -356,6 +356,11 @@ enum class DimType { * definitions you can even redundantly re-evaluate points. */ PureVar = 0, + /** The dim originated from a Var in an inductively defined pure + * definition. InductiveVars cannot be reordered, parallelized, + * or vectorized. */ + InductiveVar, + /** The dim originated from an RVar. You can evaluate a Func at * distinct values of this RVar in any order (including in * parallel) over exactly the interval specified in the @@ -466,6 +471,10 @@ struct Dim { return (dim_type == DimType::PureRVar) || (dim_type == DimType::ImpureRVar); } + bool is_inductive() const { + return dim_type == DimType::InductiveVar; + } + /** Could multiple iterations of this loop happen at the same * time, with reads and writes interleaved in arbitrary ways * according to the memory model of the underlying compiler and diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 19c3de055001..6c1bf9812ddb 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -1361,7 +1361,7 @@ class InjectFunctionRealization : public IRMutator { // none of the functions in a fused group can be inlined, so this will only // happen when we're lowering a single func. if (provide_name != funcs[0].name() && - !funcs[0].is_pure() && + (!funcs[0].is_pure() || funcs[0].is_inductive()) && funcs[0].schedule().compute_level().is_inlined() && function_is_used_in_stmt(funcs[0], provide_op)) { @@ -2348,7 +2348,7 @@ bool validate_schedule(Function f, const Stmt &s, const Target &target, bool is_ // will get lowered into compute_at innermost and thus can be treated // similarly as a non-inlined Func. if (store_at.is_inlined() && compute_at.is_inlined() && hoist_storage_at.is_inlined()) { - if (f.is_pure()) { + if (f.is_pure() && !f.is_inductive()) { validate_schedule_inlined_function(f); } return true; @@ -2589,6 +2589,9 @@ Stmt schedule_functions(const vector &outputs, const map &env, const Target &target, bool &any_memoized) { + for (const Function &o : outputs) { + user_assert(!o.is_inductive()) << "Function" << o.name() << " is an inductively defined output buffer, which is unsupported.\n"; + } string root_var = LoopLevel::root().lock().to_string(); Stmt s = For::make(root_var, 0, 1, ForType::Serial, Partition::Never, DeviceAPI::Host, Evaluate::make(0)); diff --git a/src/Serialization.cpp b/src/Serialization.cpp index d731d9c9d85c..2d8bc6407a28 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -352,6 +352,8 @@ Serialize::DimType Serializer::serialize_dim_type(const DimType &dim_type) { return Serialize::DimType::PureRVar; case DimType::ImpureRVar: return Serialize::DimType::ImpureRVar; + case DimType::InductiveVar: + return Serialize::DimType::InductiveVar; default: user_error << "Unsupported dim type\n"; return Serialize::DimType::PureVar; diff --git a/src/Type.h b/src/Type.h index d6143f38b6de..e857de2354e7 100644 --- a/src/Type.h +++ b/src/Type.h @@ -293,6 +293,7 @@ struct Type { static constexpr halide_type_code_t Float = halide_type_float; static constexpr halide_type_code_t BFloat = halide_type_bfloat; static constexpr halide_type_code_t Handle = halide_type_handle; + static constexpr halide_type_code_t Unknown = halide_type_unknown; // @} /** The number of bytes required to store a single scalar value of this type. Ignores vector lanes. */ @@ -454,6 +455,12 @@ struct Type { return code() == Handle; } + /** Is this type a floating point type (float or double). */ + HALIDE_ALWAYS_INLINE + bool is_unknown() const { + return code() == Unknown; + } + // Returns true iff type is a signed integral type where overflow is defined. HALIDE_ALWAYS_INLINE bool can_overflow_int() const { @@ -567,6 +574,11 @@ inline Type Handle(int lanes = 1, const halide_handle_cplusplus_type *handle_typ return Type(Type::Handle, 64, lanes, handle_type); } +/** Construct an unknown type */ +inline Type Unknown() { + return Type(Type::Unknown, 0, 1); +} + /** Construct the halide equivalent of a C type */ template inline Type type_of() { diff --git a/src/halide_ir.fbs b/src/halide_ir.fbs index 499488ce8b95..1fdf0e5f1144 100644 --- a/src/halide_ir.fbs +++ b/src/halide_ir.fbs @@ -562,6 +562,7 @@ table Split { enum DimType: ubyte { PureVar, + InductiveVar, PureRVar, ImpureRVar, } diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 1558b9f3397f..508ae7582b8b 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -478,11 +478,12 @@ typedef enum halide_type_code_t : uint8_t #endif { - halide_type_int = 0, ///< signed integers - halide_type_uint = 1, ///< unsigned integers - halide_type_float = 2, ///< IEEE floating point numbers - halide_type_handle = 3, ///< opaque pointer type (void *) - halide_type_bfloat = 4, ///< floating point numbers in the bfloat format + halide_type_int = 0, ///< signed integers + halide_type_uint = 1, ///< unsigned integers + halide_type_float = 2, ///< IEEE floating point numbers + halide_type_handle = 3, ///< opaque pointer type (void *) + halide_type_bfloat = 4, ///< floating point numbers in the bfloat format + halide_type_unknown = 5, ///< an expression of unknown type, to be determined later } halide_type_code_t; // Note that while __attribute__ can go before or after the declaration, diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 586db8d1db8e..3c3d57b8345c 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -182,6 +182,7 @@ tests(GROUPS correctness implicit_args_tests.cpp in_place.cpp indexing_access_undef.cpp + inductive.cpp infer_arguments.cpp inline_reduction.cpp inlined_generator.cpp diff --git a/test/correctness/inductive.cpp b/test/correctness/inductive.cpp new file mode 100644 index 000000000000..67ab9c2fce7f --- /dev/null +++ b/test/correctness/inductive.cpp @@ -0,0 +1,256 @@ +#include "Halide.h" +#include "check_call_graphs.h" +#include "test_sharding.h" + +#include +#include + +namespace { + +using std::map; +using std::string; + +using namespace Halide; +using namespace Halide::Internal; + +int simple_inductive_test() { + Func g("g"), h("h"); + Var x("x"), y("y"); + + g(x, y) = select(x <= 0, 0, g(max(0, x - 1), y) + x + y); + + h(x, y) = g(x + 5, y) / 4; + + g.compute_at(h, x).store_at(h, y); + + Buffer im = h.realize({600, 5}); + auto func = [](int x, int y) { + return (y * (x + 5) + (x + 5) * (x + 6) / 2) / 4; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int reorder_test() { + Func g("g"), h("h"); + Var x("x"), y("y"); + + Var xi("xi"), xii("xii"), xo("xo"); + + g(x, y) = select(x <= 0, 0, g(max(0, x - 1), y) + x + y); + + h(x, y) = g(x + 5, y) / 4; + h.split(x, xo, xi, 24).reorder(xi, y, xo); + + g.compute_at(h, xo).store_root(); + + g.split(x, xi, xii, 5).reorder(xii, y, xi).vectorize(y, 8); + + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + return (y * (x + 5) + (x + 5) * (x + 6) / 2) / 4; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int summed_area_table() { + Func f("f"), g("g"), h("h"); + Var x("x"), y("y"); + f(x, y) = x + y; + g(x, y) = f(x, y) + select(x <= 0, 0, g(x - 1, y)) + select(y <= 0, 0, g(x, y - 1)) - select(x <= 0 || y <= 0, 0, g(x - 1, y - 1)); + h(x, y) = g(x, y) / 8; + g.compute_at(h, x).store_root(); + + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + return (x * (x + 1) / 2 * (y + 1) + y * (y + 1) / 2 * (x + 1)) / 8; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int large_baseline() { + Func g("g"), h("h"); + Var x("x"), y("y"); + + g(x, y) = select(x <= 8, (y * x + x * (x + 1) / 2) - 1, g(x - 1, y) + x + y); + h(x, y) = g(x + 5, y) / 4; + + g.compute_at(h, x).store_at(h, y); + + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + return (y * (x + 5) + (x + 5) * (x + 6) / 2 - 1) / 4; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int fibonacci() { + Func g("g"), h("h"); + Var x("x"), y("y"); + + g(x, y) = select(x <= 1, 1, g(x - 1, y) + g(x - 2, y)); + h(x, y) = g(x, y); + + h.bound(x, 0, 80); + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + int a = 1; + int b = 1; + for (int i = 2; i <= x; i++) { + int c = a + b; + b = a; + a = c; + } + return a; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int sum_2d_test() { + Func f("f"), g("g"), h("h"); + Var x("x"), y("y"); + f(x, y) = select(x <= 0, 0, x + f(x - 1, y)); + g(x, y) = select(y <= 0, f(x, 0), f(x, y) + g(x, y - 1)); + h(x, y) = g(x, y); + h.bound(x, 0, 80).bound(y, 0, 80).vectorize(x, 8); + g.compute_at(h, x).store_root().vectorize(x, 8); + f.compute_at(h, x).store_root(); + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + int ans = 0; + for (int a = 0; a <= x; a++) { + for (int b = 0; b <= y; b++) { + ans += a; + } + } + return ans; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int sum_1d_test() { + Func f("f"), g("g"), h("h"); + Var x("x"), y("y"); + f(x, y) = x + y; + f(x, y) += x; // select(x<=0, 0, x+f(x-1,y)); + g(x, y) = select(y <= 0, f(x, 0), f(x, y) + g(x, y - 1)); + h(x, y) = g(x, y); + h.bound(x, 0, 80).bound(y, 0, 80); + // stress-testing bounds inference for dependent non-inlined funcs + f.compute_at(h, x); + Buffer im = h.realize({80, 80}); + auto func = [](int x, int y) { + int ans = 0; + for (int a = 0; a <= y; a++) { + ans += 2 * x + a; + } + return ans; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int multi_baseline_test() { + Func f("f"), g("g"), h("h"); + Var x("x"), y("y"); + f(x, y) = x + y; + f(x, y) += x; // select(x<=0, 0, x+f(x-1,y)); + g(x, y) = select(y <= 0, f(x, 0), f(x, y) + g(x, y - 1)) + select(y <= 3, f(x, 0), f(x, y) + g(x, y - 1)); + h(x, y) = g(x, y); + h.bound(x, 0, 80).bound(y, 0, 20); + f.compute_at(h, x); + Buffer im = h.realize({80, 20}); + auto func = [](int x, int y) { + std::vector ans; + + for (int a = 0; a <= y; a++) { + if (a <= 0) { + ans.emplace_back(4 * x); + } else if (a <= 3) { + ans.emplace_back(2 * x + (2 * x + a) + ans[a - 1]); + } else { + ans.emplace_back(2 * x + a + ans[a - 1] + (2 * x + a) + ans[a - 1]); + } + } + return ans[y]; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +int type_declare_test() { + Func g = Func(Int(32), "g"); + Func h("h"); + Var x("x"), y("y"); + + g(x, y) = select(x <= 0, 0, 1 + g(max(0, x - 1), y) + x + 2); + + h(x, y) = g(x + 5, y) / 4; + + g.compute_at(h, x).store_at(h, y); + + Buffer im = h.realize({600, 5}); + auto func = [](int x, int y) { + return (3 * (x + 5) + (x + 5) * (x + 6) / 2) / 4; + }; + if (check_image(im, func)) { + return 1; + } + return 0; +} + +} // namespace + +int main(int argc, char **argv) { + struct Task { + std::string desc; + std::function fn; + }; + + std::vector tasks = { + {"simple inductive test", simple_inductive_test}, + {"reordering test", reorder_test}, + {"summed area table test", summed_area_table}, + {"large baseline test", large_baseline}, + {"fibonacci test", fibonacci}, + {"2d sum test", sum_2d_test}, + {"1d sum test", sum_1d_test}, + {"multi-baseline test", multi_baseline_test}, + {"type declaration test", type_declare_test}, + }; + + using Sharder = Halide::Internal::Test::Sharder; + Sharder sharder; + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + std::cout << task.desc << "\n"; + if (task.fn() != 0) { + return 1; + } + } + + printf("Success!\n"); + return 0; +} diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 41816d5ba36b..7301aad941a5 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -70,6 +70,15 @@ tests(GROUPS error implicit_args.cpp impossible_constraints.cpp incomplete_target.cpp + inductive_loop.cpp + inductive_loop_2.cpp + inductive_loop_3.cpp + inductive_nested_select.cpp + inductive_no_select.cpp + inductive_reorder.cpp + inductive_update.cpp + inductive_var_swap.cpp + inductive_vectorize.cpp init_def_should_be_all_vars.cpp inspect_loop_level.cpp lerp_float_weight_out_of_range.cpp diff --git a/test/error/inductive_loop.cpp b/test/error/inductive_loop.cpp new file mode 100644 index 000000000000..cb894256a355 --- /dev/null +++ b/test/error/inductive_loop.cpp @@ -0,0 +1,18 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = select(x < 1, 0, f(x)); + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_loop_2.cpp b/test/error/inductive_loop_2.cpp new file mode 100644 index 000000000000..f6086e239eef --- /dev/null +++ b/test/error/inductive_loop_2.cpp @@ -0,0 +1,18 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = select(x < 2, 0, f(x - 1) + f(x) + f(x - 2)); + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_loop_3.cpp b/test/error/inductive_loop_3.cpp new file mode 100644 index 000000000000..c281480d5fa7 --- /dev/null +++ b/test/error/inductive_loop_3.cpp @@ -0,0 +1,18 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = select(x < 2, 0, f(x - 1) + f(min(2, x)) + f(x - 2)); + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_nested_select.cpp b/test/error/inductive_nested_select.cpp new file mode 100644 index 000000000000..c9abf663008f --- /dev/null +++ b/test/error/inductive_nested_select.cpp @@ -0,0 +1,19 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + // Nested select operations are currently unsupported. + f(x) = select(x < 1, 0, select(x < 3, 1, f(x - 1))); + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_no_select.cpp b/test/error/inductive_no_select.cpp new file mode 100644 index 000000000000..5fdc3f430232 --- /dev/null +++ b/test/error/inductive_no_select.cpp @@ -0,0 +1,19 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = cast(x + f(x - 1)); + + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_reorder.cpp b/test/error/inductive_reorder.cpp new file mode 100644 index 000000000000..01dbe642645b --- /dev/null +++ b/test/error/inductive_reorder.cpp @@ -0,0 +1,21 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"), xi("xi"), xo("xo"); + + f(x) = select(x < 1, 0, x + f(x - 1)); + f.split(x, xo, xi, 8); + f.reorder(xo, xi); + + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_update.cpp b/test/error/inductive_update.cpp new file mode 100644 index 000000000000..acb16183f552 --- /dev/null +++ b/test/error/inductive_update.cpp @@ -0,0 +1,20 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = select(x < 1, 0, x + f(x - 1)); + f(x) += 1; + + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +} \ No newline at end of file diff --git a/test/error/inductive_var_swap.cpp b/test/error/inductive_var_swap.cpp new file mode 100644 index 000000000000..4a8453eb0901 --- /dev/null +++ b/test/error/inductive_var_swap.cpp @@ -0,0 +1,19 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"), y("y"); + + f(x, y) = select(x < 1, 0, x + f(y - 1, x - 1)); + + g(x, y) = f(x, y) * 2; + + g.realize({10, 10}); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/inductive_vectorize.cpp b/test/error/inductive_vectorize.cpp new file mode 100644 index 000000000000..bc6dd1e4bc95 --- /dev/null +++ b/test/error/inductive_vectorize.cpp @@ -0,0 +1,20 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Func f("f"), g("g"); + + Var x("x"); + + f(x) = select(x < 1, 0, x + f(x - 1)); + f.vectorize(x, 8); + + g(x) = f(x) * 2; + + g.realize({10}); + + printf("Success!\n"); + return 0; +}