Skip to content

Implements a new constraints system#802

Merged
pranavm-nvidia merged 32 commits intomainfrom
trt-rtx
Feb 10, 2026
Merged

Implements a new constraints system#802
pranavm-nvidia merged 32 commits intomainfrom
trt-rtx

Conversation

@pranavm-nvidia
Copy link
Collaborator

No description provided.

pranavm-nvidia and others added 29 commits February 6, 2026 18:51
…erties of outputs

See the docstring in `frontend/constraints/base.py` for more details.
Adds a `_doc_str()` function which will generate human-readable text
for a given constraint.
- Refactors `doc_str` such that the Constraints classes provide implementations.
     The free function is primarily used for handling non-Constraints types and
     performing the dispatch to `Constraints.doc_str()`.

- Adds an `info()` method to the `Constraints` class that allows setting
  additional human-readable information about a constraint.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com>
- Ports `ones` and `ones_like` to use the new constraint system.
- Updates `merge_function_arguments` to return default values for omitted arguments.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com>
- Migrated 11 reduction operators: all, any, sum, prod, mean, max, min, var, argmax, argmin, topk
- Type-preserving operators (sum, prod, mean, max, min): maintain input dtype for numeric types
- Bool-only operators (all, any): only accept bool dtype with explicit runtime validation
- Type-changing operators (argmax, argmin): multiple input types → int32
- Tuple return operator (topk): returns (values, indices) with coordinated types
- all/any now properly reject non-bool inputs to avoid incorrect casting behavior
- 216 tests passed for all reduction operators
- Migrated zeros(), zeros_like(), full(), full_like(), iota(), iota_like(), arange() (2 overloads)
- Applied input_requirements and output_guarantees patterns
- Pattern: Optional dtype uses If(GetInput('dtype') != None, ...)
- Pattern: Non-tensor params accessed with GetInput('dtype') not .value
- Pattern: Value dtype constraints exclude float8 and int8+bool combinations
- Fixed constraint system: Modified Constraints.find() to accept skip_within parameter
- Fixed auto-casting: _find_known_datatypes now skips If conditions to avoid treating
  conditional checks (e.g., If(GetInput('value').dtype == dt.int8, ...)) as type requirements
- Documented limitation: Auto-casting won't work for conditionally-dependent datatypes
Migrated 11 operators that preserve input dtype in outputs:
- matmul: Combined constraints for both inputs using & operator
- cumsum: Float types only
- copy: All dtypes supported
- pad: Specific dtype subset
- softmax: Float types only
- tril/triu: Excluded bool (where() backend limitation - legacy was incorrect)
- avgpool/maxpool: Pooling-specific dtypes
- resize: Both overloads with int8/float support

Key findings:
- Legacy dtype_constraints claimed tril/triu supported bool, but this was never enforced
- tril/triu use where() internally, which doesn't support bool
- New constraint system is enforced and must accurately reflect backend capabilities
- full_like constraints correctly reject dtype=float8 (broadcast doesn't support it)

Test results: 136 passed, 26 skipped for type-preserving operators
All full/full_like tests passing: 576 passed, 234 skipped
Migrated 7 operators with specialized constraint patterns:
- gather: input/output preserve dtype, separate indices constraint (int32/int64)
- where: condition=bool, input/other must match, output preserves input dtype
- masked_fill: mask=bool, input/output preserve dtype (excluding bool - where() limitation)
- quantize: float inputs → quantized outputs (int4/int8/float8), scale matches input
- dequantize: quantized inputs → float outputs, scale/dtype must match
- equal: inputs match, returns Python bool (no output_guarantees)
- allclose: float inputs match, returns Python bool (no output_guarantees)

Key changes:
- Fixed masked_fill to exclude bool (uses where() which doesn't support bool)
- Updated test framework to handle None output_guarantees for non-Tensor returns
- Quantize/dequantize have coordinated input/output dtype constraints

Test results: 1792 passed, 719 skipped
Migrated last 3 operators to new constraint system:
- outer: Both vectors must match dtype (float types only)
- shape: Accepts all dtypes, returns tuple of IntLike (no output_guarantees)
- split: Preserves input dtype, supports all dtypes

All 81 operators with dtype constraints have now been migrated:
- Steps 1-5: 60 operators (unary, shape, binary, reduction, initializers)
- Step 6: 11 type-preserving operators
- Step 7: 7 operators with complex dtype relationships
- Step 8: 3 final operators

Migration complete! All operators now use the new enforced constraint system
with input_requirements and output_guarantees alongside legacy dtype_constraints.

Test results: 88 passed, 20 skipped for final operators
…ystem

- Remove deprecated dtype_constraints/dtype_variables/dtype_exceptions system
- Remove enable_dtype_checking config option (replaced by enable_input_validation)
- Remove DATA_TYPE_CONSTRAINTS registry and DataTypeConstraints dataclass
- Remove legacy get_arg_dtype and type_var handling from wrappers.py

- Migrate remaining operators to input_requirements/output_guarantees:
  - Module ops: conv, conv_transpose, instancenorm, layernorm
  - Frontend ops: allclose, dequantize, equal, gather, masked_fill,
    outer, quantize, shape, split, tril, triu, where

- Enhance GetDataType fetcher to support Python scalar types (int, float, bool)
- Fix Equal/NotEqual constraint logic for None comparisons
- Improve dtype autocasting to prefer trusted sources (Tensors/dtypes)
  over inferred scalar literal types

- Update all tests to use new error message format
- Replace legacy test_datatype_constraints.py with integration tests
- Update documentation examples to use new constraint syntax
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Introduces a new operator constraints system to replace the previous datatype-constraints registry, enabling richer input validation, output guarantees, and improved operator-side documentation and testing.

Changes:

  • Replaces DATA_TYPE_CONSTRAINTS with OPERATOR_CONSTRAINTS and adds a composable constraints DSL (Fetcher, Logic, Constraints).
  • Updates many front-end ops/modules to declare input_requirements and output_guarantees, and renames config flag to enable_input_validation.
  • Adds new unit + integration tests for constraints and updates existing tests/error-message expectations accordingly.

Reviewed changes

Copilot reviewed 125 out of 125 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tripy/tests/utils/wrappers/test_datatype_constraints.py Removes old datatype-constraint exhaustive tests.
tripy/tests/utils/test_utils.py Adds tests for argument merging; switches to new constant_fields import.
tripy/tests/integration/test_operator_constraints.py Adds integration tests validating constraints vs backend acceptance.
tripy/tests/helper.py Clarifies documentation marker behavior.
tripy/tests/frontend/wrappers/test_wrappers.py Migrates wrapper tests to new constraints DSL and registry; adds _find_known_datatypes tests.
tripy/tests/frontend/ops/test_where.py Updates expected validation error message.
tripy/tests/frontend/ops/test_quantize.py Updates expected validation error message.
tripy/tests/frontend/ops/test_matmul.py Updates expected validation error message.
tripy/tests/frontend/ops/test_equal.py Updates expected validation error message.
tripy/tests/frontend/ops/test_dequantize.py Updates expected validation error message.
tripy/tests/frontend/ops/test_binary.py Updates expected validation error message.
tripy/tests/frontend/module/test_embedding.py Updates expected validation error message.
tripy/tests/frontend/module/test_conv.py Updates expected validation error message.
tripy/tests/frontend/constraints/test_logic.py Adds unit tests for logic constraint operators and formatting.
tripy/tests/frontend/constraints/test_fetcher.py Adds unit tests for fetchers (GetInput, GetReturn, GetDataType).
tripy/tests/frontend/constraints/test_doc_str.py Adds unit tests for documentation string generation.
tripy/tests/frontend/constraints/test_base.py Adds unit tests for constraint tree pattern matching (find).
tripy/tests/frontend/constraints/init.py Adds test package marker file.
tripy/tests/common/test_exception.py Updates wrappers module exclusion import to new location.
tripy/nvtripy/utils/utils.py Renames positional-arg helper and expands merge_function_arguments to return omitted defaults.
tripy/nvtripy/utils/stack_info.py Improves source-line retrieval via linecache and updates excluded modules.
tripy/nvtripy/frontend/wrappers.py Replaces datatype-constraints decorator with operator constraints registry + validation and doc injection.
tripy/nvtripy/frontend/ops/zeros.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/where.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unsqueeze.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/tanh.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/sqrt.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/sin.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/silu.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/sigmoid.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/rsqrt.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/relu.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/neg.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/log.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/invert.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/gelu.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/exp.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/cos.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/unary/abs.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/triu.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/tril.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/transpose.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/stack.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/squeeze.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/split.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/softmax.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/slice.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/shape.py Migrates to constraints-based input validation for property.
tripy/nvtripy/frontend/ops/resize.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/reshape.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/repeat.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/reduce/var.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/reduce/topk.py Migrates to input_requirements/output_guarantees (multi-return).
tripy/nvtripy/frontend/ops/reduce/sum.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/reduce/prod.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/reduce/min.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/reduce/mean.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/reduce/max.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/reduce/argmin.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/reduce/argmax.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/reduce/any.py Migrates to constraints and adds runtime bool enforcement.
tripy/nvtripy/frontend/ops/reduce/all.py Migrates to constraints and adds runtime bool enforcement.
tripy/nvtripy/frontend/ops/quantize.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/pooling/maxpool.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/pooling/avgpool.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/permute.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/pad.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/outer.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/ones.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/matmul.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/masked_fill.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/iota.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/gather.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/full.py Migrates to input_requirements/output_guarantees including conditional dtype logic.
tripy/nvtripy/frontend/ops/flip.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/flatten.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/expand.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/equal.py Migrates to constraints-based dtype matching.
tripy/nvtripy/frontend/ops/dequantize.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/cumsum.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/copy.py Uses output_guarantees to express dtype preservation.
tripy/nvtripy/frontend/ops/concatenate.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/cast.py Replaces exception list with logical constraints for invalid cast pairs.
tripy/nvtripy/frontend/ops/binary/sub.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/binary/pow.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/binary/not_equal.py Migrates to input_requirements/output_guarantees (bool return).
tripy/nvtripy/frontend/ops/binary/mul.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/binary/mod.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/binary/minimum.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/binary/maximum.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/binary/logical_or.py Migrates to input_requirements/output_guarantees (bool-only).
tripy/nvtripy/frontend/ops/binary/less_equal.py Migrates to input_requirements/output_guarantees (bool return).
tripy/nvtripy/frontend/ops/binary/less.py Migrates to input_requirements/output_guarantees (bool return).
tripy/nvtripy/frontend/ops/binary/greater_equal.py Migrates to input_requirements/output_guarantees (bool return).
tripy/nvtripy/frontend/ops/binary/greater.py Migrates to input_requirements/output_guarantees (bool return).
tripy/nvtripy/frontend/ops/binary/floor_div.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/binary/equal.py Migrates to input_requirements/output_guarantees (bool return).
tripy/nvtripy/frontend/ops/binary/div.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/binary/add.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/arange.py Migrates to input_requirements/output_guarantees.
tripy/nvtripy/frontend/ops/allclose.py Migrates to constraints-based dtype matching.
tripy/nvtripy/frontend/module/linear.py Switches to new constant_fields import.
tripy/nvtripy/frontend/module/layernorm.py Migrates layernorm op constraints; switches to new constant_fields.
tripy/nvtripy/frontend/module/instancenorm.py Migrates instancenorm op constraints; switches to new constant_fields.
tripy/nvtripy/frontend/module/groupnorm.py Switches to new constant_fields import.
tripy/nvtripy/frontend/module/embedding.py Switches to new constant_fields import.
tripy/nvtripy/frontend/module/conv/conv_transpose.py Migrates deconvolution constraints to new DSL.
tripy/nvtripy/frontend/module/conv/conv.py Migrates convolution constraints to new DSL.
tripy/nvtripy/frontend/module/conv/base.py Switches to new constant_fields import.
tripy/nvtripy/frontend/module/batchnorm.py Switches to new constant_fields import; minor import ordering.
tripy/nvtripy/frontend/constraints/logic.py Adds constraints logic core implementation.
tripy/nvtripy/frontend/constraints/fetcher.py Adds fetcher implementations used by constraints.
tripy/nvtripy/frontend/constraints/doc_str.py Adds doc formatting for constraints/dtypes.
tripy/nvtripy/frontend/constraints/base.py Adds constraints base class and tree matcher (find).
tripy/nvtripy/frontend/constraints/init.py Exposes constraints public API.
tripy/nvtripy/config.py Renames flag to enable_input_validation.
tripy/nvtripy/backend/api/compile.py Updates to renamed positional-args helper.
tripy/examples/nanogpt/README.md Updates ModelOpt documentation link.
tripy/docs/pre0_user_guides/02-quantization.md Updates ModelOpt documentation links.
tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md Updates examples to use new constraints DSL and wrappers path.
tripy/docs/post0_developer_guides/00-architecture.md Updates wrappers path reference.
tripy/docs/README.md Updates wrappers usage example and adds documentation philosophy section.
tripy/Dockerfile Makes container username configurable; installs zsh.
tripy/CONTRIBUTING.md Notes VS Code devcontainer alternative.
tripy/.devcontainer/devcontainer.json Adds devcontainer configuration for VS Code workflow.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

github-actions[bot]

This comment was marked as outdated.

@pranavm-nvidia pranavm-nvidia merged commit 8ca9598 into main Feb 10, 2026
5 checks passed
@pranavm-nvidia pranavm-nvidia deleted the trt-rtx branch February 10, 2026 19:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant