Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,20 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {

// If the numerator's lower bound is known, express the floordiv
// in terms of truncdiv using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value < 0 &&
const_int_bound->min_value >
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {

// The optimization below rewrites expressions involving `-a_min + (b - 1)`.
// Without proper bounds checking, this expression may overflow the dtype
// maximum, leading to non-equivalent transformations.
// To ensure safety, we require:
// b_max - a_min <= max_value_of_dtype + 1
// This provides a conservative upper bound that prevents overflow and
// preserves the original semantics.
arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a);
arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(op->b);
const int64_t max_value_of_dtype =
Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
if (const_int_bound_a->min_value < 0 &&
const_int_bound_b->max_value - const_int_bound_a->min_value <= max_value_of_dtype + 1) {
// The goal is to write floordiv(a,b) in terms of truncdiv, without using
// negative operands.
//
Expand Down Expand Up @@ -152,7 +162,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// floordiv(a,b)
// == floordiv(a + b*c, b) - c
// == truncdiv(a + b*c, b) - c
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncdiv(offset_numerator, op->b) - ceildiv;
Expand Down Expand Up @@ -214,10 +224,20 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {

// If the numerator's lower bound is known, express the floormod
// in terms of truncmod using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value < 0 &&
const_int_bound->min_value >
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {

// The optimization below rewrites expressions involving `-a_min + (b - 1)`.
// Without proper bounds checking, this expression may overflow the dtype
// maximum, leading to non-equivalent transformations.
// To ensure safety, we require:
// b_max - a_min <= max_value_of_dtype + 1
// This provides a conservative upper bound that prevents overflow and
// preserves the original semantics.
arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a);
arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(op->b);
const int64_t max_value_of_dtype =
Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
if (const_int_bound_a->min_value < 0 &&
const_int_bound_b->max_value - const_int_bound_a->min_value <= max_value_of_dtype + 1) {
// The goal is to write floormod(a,b) in terms of truncdiv and truncmod,
// without using negative operands.
//
Expand Down Expand Up @@ -247,7 +267,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// floormod(a,b)
// == floormod(a + b*c, b)
// == truncmod(a + b*c, b)
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncmod(offset_numerator, op->b);
Expand Down
12 changes: 12 additions & 0 deletions tests/python/tir-transform/test_tir_transform_lower_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def test_lower_floordiv():
# const power of two
res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b)
# floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1.
res = lower_intrin(
[x, y],
tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)),
)
check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) // b)


@tvm.testing.requires_llvm
Expand All @@ -115,6 +121,12 @@ def test_lower_floormod():
# const power of two
res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b)
# floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1.
res = lower_intrin(
[x, y],
tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)),
)
check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) % b)


if __name__ == "__main__":
Expand Down