diff --git a/python/tvm/relax/backend/cpu_generic/pipeline.py b/python/tvm/relax/backend/cpu_generic/pipeline.py index 74d951b817b1..527cda28d8cc 100644 --- a/python/tvm/relax/backend/cpu_generic/pipeline.py +++ b/python/tvm/relax/backend/cpu_generic/pipeline.py @@ -52,6 +52,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/backend/cuda/pipeline.py b/python/tvm/relax/backend/cuda/pipeline.py index d5c4c0856165..3861036c383b 100644 --- a/python/tvm/relax/backend/cuda/pipeline.py +++ b/python/tvm/relax/backend/cuda/pipeline.py @@ -64,6 +64,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/backend/gpu_generic/pipeline.py b/python/tvm/relax/backend/gpu_generic/pipeline.py index 86c60114c699..f3df2510ad51 100644 --- a/python/tvm/relax/backend/gpu_generic/pipeline.py +++ b/python/tvm/relax/backend/gpu_generic/pipeline.py @@ -63,6 +63,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/backend/rocm/pipeline.py b/python/tvm/relax/backend/rocm/pipeline.py index e74039ca8634..fa1da7cde689 100644 --- a/python/tvm/relax/backend/rocm/pipeline.py +++ b/python/tvm/relax/backend/rocm/pipeline.py @@ -63,6 +63,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index dacbc667be2b..72e23e089519 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -28,6 +28,7 @@ BundleModelParams, CallTIRRewrite, CanonicalizeBindings, + CanonicalizeShapeExpr, CombineParallelMatmul, ComputePrimValue, ConvertLayout, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index bfd7dbf87d70..2babf0c9ba90 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -735,6 +735,32 @@ def FoldConstant() -> tvm.ir.transform.Pass: return _ffi_api.FoldConstant() # type: ignore +def CanonicalizeShapeExpr() -> tvm.ir.transform.Pass: + """Canonicalize ShapeExpr by replacing compound PrimExpr with fresh symbolic variables. + + VMShapeLower can only handle ShapeExpr where each dimension is either: + - IntImm (concrete integer constant) + - tir::Var (symbolic variable from function parameters or match_cast) + + This pass transforms compound PrimExpr (e.g., n+1, 4*n*m) by: + 1. Creating a fresh tir::Var for each compound expression + 2. Emitting a MatchCast that binds the fresh var to a PrimValue computing the expression + 3. Replacing the compound expression in ShapeExpr with teh fresh var + + Example transformation: + Before: y = R.zeros(R.shape([n + 1]), dtype="float32") + After: _s0_pv: R.Prim(value=_s0) = R.match_cast(R.prim_value(n+1), R.Prim(value=_s0)) + y = R.zeros(R.shape([_s0]), dtype="float32") + + This pass should be applied before ComputePrimValue and before VMShapeLower. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CanonicalizeShapeExpr() # type: ignore + + def ExpandTupleArguments() -> tvm.ir.transform.Pass: """Expand tuple arguments to internal functions diff --git a/src/relax/transform/canonicalize_shape_expr.cc b/src/relax/transform/canonicalize_shape_expr.cc new file mode 100644 index 000000000000..846d27b03577 --- /dev/null +++ b/src/relax/transform/canonicalize_shape_expr.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/canonicalize_shape_expr.cc + * \brief Canonicalize ShapeExpr by replacing compound PrimExpr with fresh symbolic variables. + * + * VMShapeLower can only handle expressions where each PrimExpr dimension is either: + * - IntImm (concrete integer constant) + * - tir::Var (symbolic variable from function parameters or match_cast) + * + * This pass transforms compound PrimExpr (e.g., n+1, 4*n*m) in ShapeExpr and struct_info by: + * 1. Creating a fresh tir::Var for each compound expression + * 2. Emitting a MatchCast that binds the fresh var to a PrimValue computing the expression + * 3. Replacing the compound expression with the fresh var everywhere (ShapeExpr and struct_info) + * + * Example transformation: + * Before: y = R.Tensor((n + 1,)) = R.zeros(R.shape([n + 1]), dtype="float32") + * After: _s0_pv: R.Prim(value=_s0) = R.match_cast(R.prim_value(n + 1), R.Prim(value=_s0)) + * y = R.Tensor((_s0,)) = R.zeros(R.shape([_s0]), dtype="float32") + * + * This ensures VMShapeLower only sees simple tir::Var references, which it can resolve + * through the MatchCast bindings. + */ + +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +namespace { + +/*! + * \brief Check if a PrimExpr is trivial (already canonical for VMShapeLower) + * + * Trivial expressions are: + * - IntImm: concrete integer constants + * - tir::Var: symbolic variables + * + * Any other expression (arithmetic, casts, etc.) is compound and needs canonicalization. + */ +bool IsTrivialPrimExpr(const PrimExpr& expr) { + return expr->IsInstance() || expr->IsInstance(); +} + +/*! + * \brief Collector for compound PrimExpr in an expression tree. + * + * Scans ShapeExpr nodes and collects all compound (non-trivial) PrimExpr. + */ +class CompoundExprCollector : public ExprVisitor { + public: + void VisitExpr_(const ShapeExprNode* op) override { + for (const PrimExpr& dim : op->values) { + if (!IsTrivialPrimExpr(dim)) { + compound_exprs_.insert(dim); + } + } + ExprVisitor::VisitExpr_(op); + } + + std::unordered_set compound_exprs_; +}; + +/*! + * \brief StructInfo mutator that substitutes PrimExpr according to a mapping. + */ +class StructInfoPrimExprMutator : public StructInfoMutator { + public: + explicit StructInfoPrimExprMutator( + const std::unordered_map& expr_map) + : expr_map_(expr_map) {} + + PrimExpr VisitStructInfoExprField(const PrimExpr& expr) override { + auto it = expr_map_.find(expr); + if (it != expr_map_.end()) { + return it->second; + } + return expr; + } + + private: + const std::unordered_map& expr_map_; +}; + +/*! + * \brief Mutator to canonicalize ShapeExpr and struct_info by replacing compound PrimExpr + * with fresh symbolic variables bound via MatchCast. + */ +class ShapeExprCanonicalizer : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const FunctionNode* func) override { + // Reset state for each function + sym_var_counter_ = 0; + expr_to_var_.clear(); + + // First pass: collect all compound expressions in the function body + // so we can emit MatchCast bindings at the beginning + CollectCompoundExprsInFunction(func); + + // Visit params + ffi::Array params; + bool all_params_unchanged = true; + for (Var param : func->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + if (!param.same_as(new_param)) { + var_remap_[param->vid] = new_param; + all_params_unchanged = false; + } + } + + // Process the function body + Expr new_body = this->VisitWithNewScope(func->body, params); + + // Also substitute in the return struct_info + StructInfo new_ret_sinfo = SubstituteStructInfo(func->ret_struct_info); + + bool ret_sinfo_changed = !StructuralEqual()(new_ret_sinfo, func->ret_struct_info); + bool body_changed = !new_body.same_as(func->body); + + if (all_params_unchanged && !ret_sinfo_changed && !body_changed) { + return ffi::GetRef(func); + } + + return Function(params, new_body, new_ret_sinfo, func->is_pure, func->attrs, func->span); + } + + void VisitBinding_(const VarBindingNode* binding) override { + // First, emit MatchCast bindings for any compound PrimExpr in ShapeExpr + // This populates expr_to_var_ with mappings from compound expr to fresh vars + EmitMatchCastForCompoundExprs(binding->value); + + // Now visit the binding with substitution + Expr new_value = this->VisitExpr(binding->value); + + // Get the struct_info from the new value and substitute compound exprs + StructInfo new_sinfo = SubstituteStructInfo(GetStructInfo(new_value)); + + // Create a new relax::Var with the substituted struct_info + Var new_var(binding->var->name_hint(), new_sinfo, binding->var->span); + + // Remap the old var to the new var + var_remap_[binding->var->vid] = new_var; + + // Emit the new binding + builder_->EmitNormalized(VarBinding(new_var, new_value)); + } + + void VisitBinding_(const MatchCastNode* binding) override { + // Emit MatchCast bindings for compound PrimExpr in ShapeExpr first + EmitMatchCastForCompoundExprs(binding->value); + + // Visit the value + Expr new_value = this->VisitExpr(binding->value); + + // Substitute in the struct_info + StructInfo new_sinfo = SubstituteStructInfo(binding->struct_info); + + // Create a new relax::Var with the substituted struct_info + Var new_var(binding->var->name_hint(), new_sinfo, binding->var->span); + + var_remap_[binding->var->vid] = new_var; + + builder_->EmitNormalized(MatchCast(new_var, new_value, new_sinfo)); + } + + Expr VisitExpr_(const ShapeExprNode* op) override { + // Rewrite ShapeExpr to replace compound PrimExpr with fresh symbolic variables + ffi::Array new_values; + bool changed = false; + + for (const PrimExpr& dim : op->values) { + if (IsTrivialPrimExpr(dim)) { + new_values.push_back(dim); + } else { + auto it = expr_to_var_.find(dim); + if (it != expr_to_var_.end()) { + new_values.push_back(it->second); + changed = true; + } else { + new_values.push_back(dim); + } + } + } + + if (changed) { + return ShapeExpr(new_values, op->span); + } + return ffi::GetRef(op); + } + + private: + /*! + * \brief Collect all compound expressions in a function body. + */ + void CollectCompoundExprsInFunction(const FunctionNode* func) { + CompoundExprCollector collector; + collector.VisitExpr(func->body); + } + + /*! + * \brief Scan an expression for ShapeExpr nodes and emit MatchCast bindings + * for any compound PrimExpr dimensions. + */ + void EmitMatchCastForCompoundExprs(const Expr& expr) { + CompoundExprCollector collector; + collector.VisitExpr(expr); + + for (const PrimExpr& compound_expr : collector.compound_exprs_) { + EmitMatchCastIfNeeded(compound_expr); + } + } + + /*! + * \brief Substitute compound PrimExpr in a StructInfo with fresh variables. + */ + StructInfo SubstituteStructInfo(const StructInfo& sinfo) { + if (expr_to_var_.empty()) { + return sinfo; + } + StructInfoPrimExprMutator mutator(expr_to_var_); + return mutator.VisitStructInfo(sinfo); + } + + /*! + * \brief Emit a MatchCast binding for a compound PrimExpr if not already done. + */ + void EmitMatchCastIfNeeded(const PrimExpr& expr) { + if (IsTrivialPrimExpr(expr)) { + return; + } + + if (expr_to_var_.count(expr)) { + return; + } + + // Create a fresh tir::Var to hold the computed value + std::string var_name = "_s" + std::to_string(sym_var_counter_++); + tir::Var fresh_tir_var(var_name, expr->dtype); + + // Record the mapping for substitution + expr_to_var_[expr] = fresh_tir_var; + + // Create a PrimValue that computes the compound expression + PrimValue prim_value(expr); + + // Create a PrimStructInfo that declares the fresh variable as the value + PrimStructInfo target_sinfo(fresh_tir_var); + + // Create a Relax Var to hold the MatchCast result + std::string relax_var_name = var_name + "_pv"; + relax::Var match_var(relax_var_name, target_sinfo); + + // Emit the MatchCast binding + builder_->EmitNormalized(MatchCast(match_var, prim_value, target_sinfo)); + } + + int sym_var_counter_ = 0; + std::unordered_map expr_to_var_; +}; + +} // namespace + +Expr CanonicalizeShapeExpr(Expr expr) { return ShapeExprCanonicalizer()(std::move(expr)); } + +namespace transform { + +Pass CanonicalizeShapeExpr() { + auto pass_func = [=](Function f, IRModule m, PassContext pc) { + return Downcast(relax::CanonicalizeShapeExpr(f)); + }; + return CreateFunctionPass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"CanonicalizeShapeExpr", + /*required=*/{}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.CanonicalizeShapeExpr", CanonicalizeShapeExpr); +} + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_canonicalize_shape_expr.py b/tests/python/relax/test_transform_canonicalize_shape_expr.py new file mode 100644 index 000000000000..2511620b2524 --- /dev/null +++ b/tests/python/relax/test_transform_canonicalize_shape_expr.py @@ -0,0 +1,250 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for the CanonicalizeShapeExpr pass""" + +import pytest +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_simple_compound_shape(): + """Test canonicalization of simple compound shape expression""" + + @R.function + def before(x: R.Tensor(("n",), "float32")): + n = T.int64() + # Compound expression: n + 1 + y: R.Tensor((n + 1,), "float32") = R.zeros(R.shape([n + 1]), dtype="float32") + return y + + mod = tvm.IRModule.from_expr(before) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + mod = relax.transform.ComputePrimValue()(mod) + mod = relax.transform.VMShapeLower()(mod) + + assert "compute_symbolic_expr" in [str(gv) for gv in mod.get_global_vars()] + + +def test_compound_shape_in_constant(): + """Test canonicalization when compound shape appears in constant variable struct_info""" + + @R.function + def before(x: R.Tensor(("n", "m"), "float32")): + n = T.int64() + m = T.int64() + # This pattern can occur after FoldConstant inlines shapes + # The constant variable has compound expression in its struct_info + y: R.Tensor((n * m,), "float32") = R.zeros(R.shape([n * m]), dtype="float32") + return y + + mod = tvm.IRModule.from_expr(before) + print("=== Before CanonicalizeShapeExpr ===") + print(mod) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + print("=== After CanonicalizeShapeExpr ===") + print(mod) + + # Check well-formed immediately after CanonicalizeShapeExpr + is_wf = relax.analysis.well_formed(mod) + print(f"=== Well-formed after CanonicalizeShapeExpr: {is_wf} ===") + if not is_wf: + raise RuntimeError("IR is not well-formed after CanonicalizeShapeExpr") + + mod = relax.transform.Normalize()(mod) + print("=== After Normalize ===") + print(mod) + + # Check well-formed immediately after Normalize + is_wf = relax.analysis.well_formed(mod) + print(f"=== Well-formed after Normalize: {is_wf} ===") + if not is_wf: + raise RuntimeError("IR is not well-formed after Normalize") + + mod = relax.transform.ComputePrimValue()(mod) + print("=== After ComputePrimValue ===") + print(mod) + + # Check well-formed immediately after ComputePrimValue + is_wf = relax.analysis.well_formed(mod) + print(f"=== Well-formed after ComputePrimValue: {is_wf} ===") + if not is_wf: + raise RuntimeError("IR is not well-formed after ComputePrimValue") + + mod = relax.transform.VMShapeLower()(mod) + print("=== After VMShapeLower ===") + print(mod) + + # Verify a compute function was generated for the compound expression + assert "compute_symbolic_expr" in [str(gv) for gv in mod.get_global_vars()] + + +def test_multiply_compound_shape(): + """Test the original issue case: 4 * x_0 * x_1 * x_2 * x_3""" + + @R.function + def before(x: R.Tensor(("n", "m", "p", "q"), "float32")): + n = T.int64() + m = T.int64() + p = T.int64() + q = T.int64() + # Compound expression: 4 * n * m * p * q + y: R.Tensor((4 * n * m * p * q,), "float32") = R.zeros( + R.shape([4 * n * m * p * q]), dtype="float32" + ) + return y + + mod = tvm.IRModule.from_expr(before) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + mod = relax.transform.ComputePrimValue()(mod) + mod = relax.transform.VMShapeLower()(mod) + + # Verify a compute function was generated for the compound expression + assert "compute_symbolic_expr" in [str(gv) for gv in mod.get_global_vars()] + + +def test_no_change_for_canonical_shape(): + """Test that already canonical shapes are not modified""" + + @R.function + def before(x: R.Tensor(("n",), "float32")): + n = T.int64() + # Already canonical shape + y: R.Tensor((n,), "float32") = R.zeros(R.shape([n]), dtype="float32") + return y + + mod_before = tvm.IRModule.from_expr(before) + mod_after = relax.transform.CanonicalizeShapeExpr()(mod_before) + + # The mod should be unchanged (or minimally changed) + # Both should work with VMShapeLower + mod_before_lower = relax.transform.ComputePrimValue()(mod_before) + mod_before_lower = relax.transform.VMShapeLower()(mod_before_lower) + mod_after_lower = relax.transform.ComputePrimValue()(mod_after) + mod_after_lower = relax.transform.VMShapeLower()(mod_after_lower) + + # For canonical shapes, no compute_symbolic_expr should be generated + # (only compound expression need computation) + global_var_names = [str(gv) for gv in mod_after_lower.get_global_vars()] + assert "compute_symbolic_expr" not in global_var_names + + +def test_no_change_for_concrete_shape(): + """Test that concrete integer shapes are not modified""" + + @R.function + def before(x: R.Tensor((10,), "float32")): + # Concrete shape + y: R.Tensor((10,), "float32") = R.zeros(R.shape([10]), dtype="float32") + return y + + mod = tvm.IRModule.from_expr(before) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + mod = relax.transform.ComputePrimValue()(mod) + mod = relax.transform.VMShapeLower()(mod) + + # For concrete shapes, no compute_symbolic_expr should be generated + global_var_names = [str(gv) for gv in mod.get_global_vars()] + assert "compute_symbolic_expr" not in global_var_names + + +def test_tuple_struct_info(): + """Test canonicalization with tuple struct info containing compound shapes""" + + @R.function + def before(x: R.Tensor(("n",), "float32")): + n = T.int64() + # Tuple with compound shapes + y: R.Tuple(R.Tensor((n + 1,), "float32"), R.Tensor((n * 2,), "float32")) = ( + R.zeros(R.shape([n + 1]), dtype="float32"), + R.zeros(R.shape([n * 2]), dtype="float32"), + ) + return y + + mod = tvm.IRModule.from_expr(before) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + mod = relax.transform.ComputePrimValue()(mod) + mod = relax.transform.VMShapeLower()(mod) + + # Verify compute functions were generated for the compound expressions + assert "compute_symbolic_expr" in [str(gv) for gv in mod.get_global_vars()] + + +def test_full_pipeline_with_opt_level_1(): + """Test the full pipeline with opt_level=1""" + + @R.function + def before(x: R.Tensor(("n", "m"), "float32")): + n = T.int64() + m = T.int64() + y: R.Tensor((n * m,), "float32") = R.reshape(x, R.shape([n * m])) + return y + + mod = tvm.IRModule.from_expr(before) + + with tvm.transform.PassContext(opt_level=1): + # Apply the passes in order + mod = relax.transform.LegalizeOps()(mod) + mod = relax.transform.AnnotateTIROpPattern()(mod) + mod = relax.transform.FoldConstant()(mod) + mod = relax.transform.ComputePrimValue()(mod) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + mod = relax.transform.VMShapeLower()(mod) + + # Verify a compute function was generated for the compound expression + assert "compute_symbolic_expr" in [str(gv) for gv in mod.get_global_vars()] + + +if __name__ == "__main__": + import sys + + print("Running CanonicalizeShapeExpr unit tests...") + print("=" * 80) + + tests = [ + ("Simple compound shape", test_simple_compound_shape), + ("Compound shape in constant", test_compound_shape_in_constant), + ("Multiply compound shape", test_multiply_compound_shape), + ("No change for canonical shape", test_no_change_for_canonical_shape), + ("No change for concrete shape", test_no_change_for_concrete_shape), + ("Tuple struct info", test_tuple_struct_info), + ("Full pipeline with opt_level=1", test_full_pipeline_with_opt_level_1), + ] + + passed = 0 + failed = 0 + + for name, test_func in tests: + try: + print(f"\nTest: {name}") + test_func() + print("Result: PASSED") + passed += 1 + except Exception as e: + print(f"Result: FAILED: {e}") + import traceback + + traceback.print_exc() + failed += 1 + + print("\n" + "=" * 80) + print(f"Total tests run: {passed + failed}, Passed: {passed}, Failed: {failed}") + + sys.exit(0 if failed == 0 else 1)