From b6cd281071b745cd3fb2aacda4a151521d994ee7 Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 22 Oct 2025 16:39:46 -0700 Subject: [PATCH 01/32] Adds a constraints system for validating inputs and guaranteeing properties of outputs See the docstring in `frontend/constraints/base.py` for more details. --- .../nvtripy/frontend/constraints/__init__.py | 19 +++ tripy/nvtripy/frontend/constraints/base.py | 130 ++++++++++++++++ tripy/nvtripy/frontend/constraints/fetcher.py | 111 ++++++++++++++ tripy/nvtripy/frontend/constraints/logic.py | 125 +++++++++++++++ tripy/tests/frontend/constraints/__init__.py | 16 ++ tripy/tests/frontend/constraints/test_base.py | 116 ++++++++++++++ .../frontend/constraints/test_fetcher.py | 104 +++++++++++++ .../tests/frontend/constraints/test_logic.py | 145 ++++++++++++++++++ 8 files changed, 766 insertions(+) create mode 100644 tripy/nvtripy/frontend/constraints/__init__.py create mode 100644 tripy/nvtripy/frontend/constraints/base.py create mode 100644 tripy/nvtripy/frontend/constraints/fetcher.py create mode 100644 tripy/nvtripy/frontend/constraints/logic.py create mode 100644 tripy/tests/frontend/constraints/__init__.py create mode 100644 tripy/tests/frontend/constraints/test_base.py create mode 100644 tripy/tests/frontend/constraints/test_fetcher.py create mode 100644 tripy/tests/frontend/constraints/test_logic.py diff --git a/tripy/nvtripy/frontend/constraints/__init__.py b/tripy/nvtripy/frontend/constraints/__init__.py new file mode 100644 index 000000000..f1f037e98 --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/__init__.py @@ -0,0 +1,19 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from nvtripy.frontend.constraints.base import Constraints +from nvtripy.frontend.constraints.fetcher import Fetcher, GetDataType, GetInput, GetReturn, ValueFetcher +from nvtripy.frontend.constraints.logic import And, Equal, Logic, Not, NotEqual, OneOf diff --git a/tripy/nvtripy/frontend/constraints/base.py b/tripy/nvtripy/frontend/constraints/base.py new file mode 100644 index 000000000..ffecad43b --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/base.py @@ -0,0 +1,130 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +The constraints system has two purposes: + +1. Imposing input requirements. +2. Describing output guarantees. + +Constraints are specified by composing one or more `Constraints` subclasses: + +```py +constraint = And( + Equal(GetDataType(GetInput("input0")), GetInput("dtype")), + Equal(GetDataType(GetInput("input1")), GetInput("dtype")), + ) +``` + +We also override several bitwise operators and properties to provide a convenient shorthand. +For example, the above can be written as: + +```py +constraint = (GetInput("input0").dtype == GetInput("dtype")) & (GetInput("input1").dtype == GetInput("dtype")) +``` + +The constraints class also provides a pattern matcher. +For example, we may want to find all constraints that check the data type of an input (`None` is a wildcard). + +```py +matches = constraint.find(Equal(GetDataType(GetInput), None)) +``` +""" + +from abc import ABC +from typing import List + + +class Constraints(ABC): + """ + Base class for the entire constraints system. + """ + + def get_children(self) -> List["Constraints"]: + children = [] + for attr_value in vars(self).values(): + if isinstance(attr_value, Constraints): + children.append(attr_value) + elif isinstance(attr_value, (list, tuple)): + children.extend(v for v in attr_value if isinstance(v, Constraints)) + return children + + def find(self, pattern: "Constraints") -> List["Constraints"]: + """ + Find all constraints in the tree that match the given pattern. + + Performs a depth-first search through the constraint tree to find all + constraints that structurally match the given pattern, using the current + constraint as the root node. + + Args: + pattern: The pattern to search for (e.g., Equal(GetDataType, GetDataType)). + Use None as a wildcard to match anything. + + Returns: + A list of all matching constraints found in the tree. + + Example: + pattern = Equal(GetDataType(TensorFetcher), None) # None matches any second argument + matches = constraint_tree.find(pattern) + """ + + def matches_pattern(pattern: Constraints, constraint: Constraints) -> bool: + # None is a wildcard that matches anything + if pattern is None: + return True + + if isinstance(pattern, type): + return isinstance(constraint, pattern) + + if type(pattern) != type(constraint): + return False + + # Need to index into sequences rather than comparing directly since there may be patterns in the sequence. + if isinstance(pattern, (list, tuple)) and isinstance(constraint, (list, tuple)): + if len(pattern) != len(constraint): + return False + return all(matches_pattern(p_val, c_val) for p_val, c_val in zip(pattern, constraint)) + + if not isinstance(pattern, Constraints): + return pattern == constraint + + # Compare attributes + pattern_attrs = vars(pattern) + constraint_attrs = vars(constraint) + + for key, pattern_value in pattern_attrs.items(): + if key not in constraint_attrs: + return False + + constraint_value = constraint_attrs[key] + + if not matches_pattern(pattern_value, constraint_value): + return False + + return True + + matches = [] + + if matches_pattern(pattern, self): + matches.append(self) + + # Recursively search children + for child in self.get_children(): + matches.extend(child.find(pattern)) + + return matches diff --git a/tripy/nvtripy/frontend/constraints/fetcher.py b/tripy/nvtripy/frontend/constraints/fetcher.py new file mode 100644 index 000000000..fa1638fd7 --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/fetcher.py @@ -0,0 +1,111 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import abstractmethod +from typing import Any, List, Sequence, Tuple + +from nvtripy.common.datatype import dtype as tp_dtype +from nvtripy.common.exception import raise_error +from nvtripy.frontend.constraints.base import Constraints + + +class Fetcher(Constraints): + """ + Fetches a value based on the function parameters or return value. + """ + + @abstractmethod + def __call__(self, args: List[Tuple[str, Any]]) -> Any: ... + + def __eq__(self, other: "Fetcher") -> "Equal": + from nvtripy.frontend.constraints.logic import Equal + + return Equal(self, other) + + def __ne__(self, other: "Fetcher") -> "NotEqual": + from nvtripy.frontend.constraints.logic import NotEqual + + return NotEqual(self, other) + + +class ValueFetcher(Fetcher): + @property + def dtype(self) -> "GetDataType": + return GetDataType(self) + + +class GetInput(ValueFetcher): + def __init__(self, name: str): + self.name = name + + def __call__(self, args: List[Tuple[str, Any]]) -> Any: + for name, value in args: + if name == self.name: + return value + assert False, f"Input '{self.name}' not found in arguments." + + def __str__(self): + return self.name + + +class GetReturn(ValueFetcher): + def __init__(self, index: int): + self.index = index + + def __call__(self, args: List[Tuple[str, Any]]) -> Any: + raise NotImplementedError( + "GetReturn is only used to describe output guarantees and must not be called for input validation purposes." + ) + + def __str__(self): + return f"return[{self.index}]" + + +class GetDataType(Fetcher): + def __init__(self, value_fetcher: ValueFetcher): + self.value_fetcher = value_fetcher + + def __call__(self, args: List[Tuple[str, Any]]) -> Any: + from nvtripy.frontend.tensor import Tensor + + def get_arg_dtype(arg: Any) -> tp_dtype: + if isinstance(arg, Sequence): + arg_dtypes = [get_arg_dtype(elem) for elem in arg] + + if len(set(arg_dtypes)) != 1: + raise_error( + f"Could not determine data type of {self.value_fetcher}", + [ + f"Mismatched data types in sequence argument.\n", + f"For parameter: '{self.value_fetcher}', all arguments must have the same data type, but got: " + f"{arg_dtypes}", + ], + ) + arg_dtype = arg_dtypes[0] + elif isinstance(arg, Tensor): + arg_dtype = arg.dtype + else: + raise_error( + f"Could not determine data type of {self.value_fetcher}", + [f"Expected a tensor or data type argument for {self.value_fetcher}, but got: {arg}"], + ) + return arg_dtype + + tensor = self.value_fetcher(args) + return get_arg_dtype(tensor) + + def __str__(self): + return f"{self.value_fetcher}.dtype" diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py new file mode 100644 index 000000000..0b53599e6 --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -0,0 +1,125 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import abstractmethod +from typing import Any, List, Sequence, Tuple + +from nvtripy.frontend.constraints.base import Constraints +from nvtripy.frontend.constraints.fetcher import Fetcher +from nvtripy.utils.result import Result + + +class Logic(Constraints): + """ + Represents logical operations on constraints. + """ + + @abstractmethod + def __call__(self, args: List[Tuple[str, Any]]) -> Result: ... + + def __and__(self, other: "Logic") -> "Logic": + if isinstance(self, And): + return And(*self.constraints, other) + elif isinstance(other, And): + return And(self, *other.constraints) + return And(self, other) + + def __invert__(self) -> "Logic": + if isinstance(self, Equal): + return NotEqual(self.fetcher1, self.fetcher2) + return Not(self) + + +class OneOf(Logic): + def __init__(self, fetcher: Fetcher, options: Sequence[Any]): + self.fetcher = fetcher + self.options = options + + def __call__(self, args: List[Tuple[str, Any]]) -> Result: + value = self.fetcher(args) + if value in self.options: + return Result.ok() + + return Result.err([f"Expected {self.fetcher} to be one of {self.options}, but got {value}."]) + + def __str__(self): + return f"{self.fetcher} is one of {self.options}" + + +class Equal(Logic): + def __init__(self, fetcher1: Fetcher, fetcher2: Fetcher): + self.fetcher1 = fetcher1 + self.fetcher2 = fetcher2 + + def __call__(self, args: List[Tuple[str, Any]]) -> Result: + value1 = self.fetcher1(args) + value2 = self.fetcher2(args) + if value1 == value2: + return Result.ok() + + return Result.err([f"Expected {self.fetcher1} to be equal to {self.fetcher2}, but got {value1} and {value2}."]) + + def __str__(self): + return f"{self.fetcher1} == {self.fetcher2}" + + +class NotEqual(Logic): + def __init__(self, fetcher1: Fetcher, fetcher2: Fetcher): + self.fetcher1 = fetcher1 + self.fetcher2 = fetcher2 + + def __call__(self, args: List[Tuple[str, Any]]) -> Result: + value1 = self.fetcher1(args) + value2 = self.fetcher2(args) + if value1 != value2: + return Result.ok() + + return Result.err([f"Expected {self.fetcher1} to be not equal to {self.fetcher2}, but both were {value1}."]) + + def __str__(self): + return f"{self.fetcher1} != {self.fetcher2}" + + +class And(Logic): + def __init__(self, *constraints: Logic): + self.constraints = constraints + + def __call__(self, args: List[Tuple[str, Any]]) -> Result: + errors = [] + for constraint in self.constraints: + result = constraint(args) + if not result: + errors.extend(result.error_details) + if errors: + return Result.err(errors) + return Result.ok() + + def __str__(self): + return " and ".join(str(constraint) for constraint in self.constraints) + + +class Not(Logic): + def __init__(self, constraint: Logic): + self.constraint = constraint + + def __call__(self, args: List[Tuple[str, Any]]) -> Result: + result = self.constraint(args) + if result: + return Result.err([f"Expected NOT {self.constraint}, but it was satisfied."]) + return Result.ok() + + def __str__(self): + return f"NOT ({self.constraint})" diff --git a/tripy/tests/frontend/constraints/__init__.py b/tripy/tests/frontend/constraints/__init__.py new file mode 100644 index 000000000..8bb95d5cb --- /dev/null +++ b/tripy/tests/frontend/constraints/__init__.py @@ -0,0 +1,16 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tripy/tests/frontend/constraints/test_base.py b/tripy/tests/frontend/constraints/test_base.py new file mode 100644 index 000000000..ac3b03cb1 --- /dev/null +++ b/tripy/tests/frontend/constraints/test_base.py @@ -0,0 +1,116 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from nvtripy.frontend.constraints import And, Equal, GetDataType, GetInput, OneOf + + +class TestConstraints: + def test_find_exact_match(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = Equal(GetInput("a"), GetInput("b")) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_no_match(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = OneOf(GetInput("a"), [1, 2, 3]) + assert len(constraint.find(pattern)) == 0 + + def test_find_in_nested_and(self): + inner_constraint = Equal(GetInput("a"), GetInput("b")) + constraint = And(inner_constraint, OneOf(GetInput("c"), [1, 2, 3])) + pattern = Equal(GetInput, GetInput) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is inner_constraint + + def test_find_multiple_matches(self): + equal1 = Equal(GetInput("a"), GetInput("b")) + equal2 = Equal(GetInput("c"), GetInput("d")) + constraint = And(equal1, equal2, OneOf(GetInput("e"), [1, 2, 3])) + matches = constraint.find(Equal(GetInput, GetInput)) + assert len(matches) == 2 and matches[0] is equal1 and matches[1] is equal2 + + def test_find_with_dtype_pattern(self): + constraint = Equal(GetDataType(GetInput("tensor1")), GetDataType(GetInput("tensor2"))) + pattern = Equal(GetDataType(GetInput), GetDataType(GetInput)) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_deeply_nested_matches(self): + equal1 = Equal(GetInput("a"), GetInput("b")) + equal2 = Equal(GetInput("d"), GetInput("e")) + constraint = And( + And(equal1, OneOf(GetInput("c"), [1, 2, 3])), + And(equal2, OneOf(GetInput("f"), [4, 5, 6])), + ) + matches = constraint.find(Equal(GetInput, GetInput)) + assert len(matches) == 2 and matches[0] is equal1 and matches[1] is equal2 + + def test_find_with_specific_names(self): + match_constraint = Equal(GetInput("a"), GetInput("b")) + constraint = And(match_constraint, Equal(GetInput("c"), GetInput("d"))) + matches = constraint.find(Equal(GetInput("a"), GetInput("b"))) + assert len(matches) == 1 and matches[0] is match_constraint + + def test_find_with_multiple_children(self): + equal1 = Equal(GetInput("a"), GetInput("b")) + equal2 = Equal(GetInput("c"), GetInput("d")) + oneof1 = OneOf(GetInput("e"), [1, 2, 3]) + equal3 = Equal(GetInput("f"), GetInput("g")) + constraint = And(equal1, equal2, oneof1, equal3) + matches = constraint.find(Equal(GetInput, GetInput)) + assert len(matches) == 3 + assert equal1 in matches + assert equal2 in matches + assert equal3 in matches + + def test_find_and_constraint(self): + and1 = And(Equal(GetInput("a"), GetInput("b")), OneOf(GetInput("c"), [1, 2, 3])) + and2 = And(Equal(GetInput("d"), GetInput("e")), OneOf(GetInput("f"), [4, 5, 6])) + constraint = And(and1, and2) + matches = constraint.find(And(Equal, OneOf)) + assert len(matches) == 2 + assert and1 in matches + assert and2 in matches + + def test_find_with_none_wildcard_second_arg(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = Equal(GetInput("a"), None) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_with_none_wildcard_first_arg(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = Equal(None, GetInput("b")) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_with_none_wildcard_in_nested(self): + equal1 = Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b"))) + equal2 = Equal(GetInput("c"), GetInput("d")) + constraint = And(equal1, equal2) + pattern = Equal(GetDataType(GetInput), None) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is equal1 + + def test_find_with_none_wildcard_matches_different_types(self): + equal = Equal(GetInput("a"), GetInput("b")) + oneof = OneOf(GetInput("c"), [1, 2, 3]) + constraint = And(equal, oneof) + pattern = None + matches = constraint.find(pattern) + assert len(matches) == 6 + assert constraint in matches diff --git a/tripy/tests/frontend/constraints/test_fetcher.py b/tripy/tests/frontend/constraints/test_fetcher.py new file mode 100644 index 000000000..d0f1c111a --- /dev/null +++ b/tripy/tests/frontend/constraints/test_fetcher.py @@ -0,0 +1,104 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import nvtripy as tp +from nvtripy.common.exception import TripyException +from nvtripy.frontend.constraints import Equal, GetDataType, GetInput, GetReturn, NotEqual +from tests import helper + + +class TestFetcher: + def test_eq_operator_returns_equal(self): + fetcher1 = GetInput("param1") + fetcher2 = GetInput("param2") + constraint = fetcher1 == fetcher2 + assert isinstance(constraint, Equal) + assert constraint.fetcher1 == fetcher1 + assert constraint.fetcher2 == fetcher2 + + def test_ne_operator_returns_not_equal(self): + fetcher1 = GetInput("param1") + fetcher2 = GetInput("param2") + constraint = fetcher1 != fetcher2 + assert isinstance(constraint, NotEqual) + assert constraint.fetcher1 == fetcher1 + assert constraint.fetcher2 == fetcher2 + + +class TestValueFetcher: + def test_dtype_property(self): + fetcher = GetInput("tensor") + dtype_fetcher = fetcher.dtype + assert isinstance(dtype_fetcher, GetDataType) + assert dtype_fetcher.value_fetcher == fetcher + + +class TestGetInput: + def test_call(self): + fetcher = GetInput("data") + args = [("data", 42), ("other", "hello")] + assert fetcher(args) == 42 + + def test_str(self): + fetcher = GetInput("data") + assert str(fetcher) == "data" + + +class TestGetReturn: + def test_init(self): + fetcher = GetReturn(0) + assert fetcher.index == 0 + + def test_call_raises_not_implemented(self): + fetcher = GetReturn(0) + args = [("input", 42)] + with helper.raises(NotImplementedError, match="GetReturn is only used to describe output guarantees"): + fetcher(args) + + def test_str(self): + fetcher = GetReturn(0) + assert str(fetcher) == "return[0]" + + fetcher2 = GetReturn(2) + assert str(fetcher2) == "return[2]" + + +class TestGetDataType: + def test_call(self): + tensor = tp.ones((2, 3), dtype=tp.float32) + fetcher = GetDataType(GetInput("input_tensor")) + assert fetcher([("input_tensor", tensor)]) == tp.float32 + + def test_call_with_sequence(self): + tensors = [tp.ones((2, 3), dtype=tp.float32)] * 2 + fetcher = GetDataType(GetInput("input_tensors")) + assert fetcher([("input_tensors", tensors)]) == tp.float32 + + def test_call_with_mismatched_dtypes_in_sequence(self): + tensors = [tp.ones((2, 3), dtype=tp.float32), tp.ones((2, 3), dtype=tp.int32)] + fetcher = GetDataType(GetInput("input_tensors")) + with helper.raises(TripyException, match="Could not determine data type"): + fetcher([("input_tensors", tensors)]) + + def test_call_with_non_tensor_argument(self): + fetcher = GetDataType(GetInput("input_data")) + with helper.raises(TripyException, match="Expected a tensor or data type argument"): + fetcher([("input_data", 42)]) + + def test_call_with_nested_sequence_error(self): + fetcher = GetDataType(GetInput("input_data")) + with helper.raises(TripyException, match="Could not determine data type"): + fetcher([("input_data", [tp.ones((2, 3), dtype=tp.float32), [42]])]) diff --git a/tripy/tests/frontend/constraints/test_logic.py b/tripy/tests/frontend/constraints/test_logic.py new file mode 100644 index 000000000..f898048a6 --- /dev/null +++ b/tripy/tests/frontend/constraints/test_logic.py @@ -0,0 +1,145 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from nvtripy.frontend.constraints import And, Equal, GetInput, Not, NotEqual, OneOf + + +class TestLogic: + def test_operator_and_basic(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + combined = constraint1 & constraint2 + assert isinstance(combined, And) + assert combined([("param1", 2), ("param2", "b")]) + + def test_operator_and_chaining(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + constraint3 = OneOf(GetInput("param3"), [True, False]) + combined = constraint1 & constraint2 & constraint3 + assert isinstance(combined, And) + assert len(combined.constraints) == 3 + assert combined([("param1", 2), ("param2", "b"), ("param3", True)]) + + def test_operator_not_basic(self): + constraint = OneOf(GetInput("param"), [1, 2, 3]) + negated = ~constraint + assert isinstance(negated, Not) + assert negated([("param", 5)]) + assert not negated([("param", 2)]) + + +class TestOneOf: + def test_call_success(self): + constraint = OneOf(GetInput("param"), [1, 2, 3]) + result = constraint([("param", 2)]) + assert result + + def test_call_failure(self): + constraint = OneOf(GetInput("param"), [1, 2, 3]) + result = constraint([("param", 5)]) + assert not result + assert "Expected param to be one of [1, 2, 3], but got 5." in result.error_details + + def test_str(self): + constraint = OneOf(GetInput("param"), [1, 2, 3]) + assert str(constraint) == "param is one of [1, 2, 3]" + + +class TestAnd: + def test_call_all_pass(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert and_constraint([("param1", 2), ("param2", "b")]) + + def test_call_one_fails(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + result = and_constraint([("param1", 5), ("param2", "b")]) + assert not result + assert "Expected param1 to be one of [1, 2, 3], but got 5." in result.error_details + + def test_call_all_fail(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + result = and_constraint([("param1", 5), ("param2", "z")]) + assert not result + errors = result.error_details + assert len(errors) == 2 + assert any("Expected param1 to be one of [1, 2, 3], but got 5" in err for err in errors) + assert any("Expected param2 to be one of ['a', 'b', 'c'], but got z" in err for err in errors) + + def test_str(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + assert str(and_constraint) == "param1 is one of [1, 2, 3] and param2 is one of ['a', 'b']" + + +class TestEqual: + def test_call_success(self): + constraint = Equal(GetInput("param1"), GetInput("param2")) + assert constraint([("param1", 5), ("param2", 5)]) + + def test_call_failure(self): + constraint = Equal(GetInput("param1"), GetInput("param2")) + result = constraint([("param1", 5), ("param2", 10)]) + assert not result + assert "Expected param1 to be equal to param2, but got 5 and 10." in result.error_details + + def test_str(self): + constraint = Equal(GetInput("param1"), GetInput("param2")) + assert str(constraint) == "param1 == param2" + + def test_operator_on_fetcher(self): + constraint = GetInput("param1") == GetInput("param2") + assert isinstance(constraint, Equal) + + +class TestNotEqual: + def test_call_success(self): + constraint = GetInput("param1") != GetInput("param2") + assert constraint([("param1", 5), ("param2", 10)]) + + def test_call_failure(self): + constraint = GetInput("param1") != GetInput("param2") + result = constraint([("param1", 5), ("param2", 5)]) + assert not result + assert "Expected param1 to be not equal to param2, but both were 5." in result.error_details + + def test_str(self): + constraint = GetInput("param1") != GetInput("param2") + assert str(constraint) == "param1 != param2" + + def test_operator_on_fetcher(self): + constraint = GetInput("param1") != GetInput("param2") + assert isinstance(constraint, NotEqual) + + +class TestNot: + def test_call_success(self): + constraint = Not(OneOf(GetInput("param"), [1, 2, 3])) + assert constraint([("param", 5)]) + + def test_call_failure(self): + constraint = Not(OneOf(GetInput("param"), [1, 2, 3])) + result = constraint([("param", 2)]) + assert not result + assert "Expected NOT param is one of [1, 2, 3], but it was satisfied." in result.error_details + + def test_str(self): + constraint = Not(OneOf(GetInput("param"), [1, 2, 3])) + assert str(constraint) == "NOT (param is one of [1, 2, 3])" + + def test_double_negation(self): + constraint = Not(Not(OneOf(GetInput("param"), [1, 2, 3]))) + assert constraint([("param", 2)]) + assert not constraint([("param", 5)]) From d91bce7f63059504ab2d5b6db4db3638ae06b247 Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 24 Oct 2025 16:05:36 -0700 Subject: [PATCH 02/32] Moves `wrappers.py` into the frontend, which is the only place it's used --- tripy/docs/README.md | 2 +- .../post0_developer_guides/00-architecture.md | 2 +- .../01-how-to-add-new-ops.md | 2 +- tripy/nvtripy/config.py | 9 ++++++++- tripy/nvtripy/frontend/module/batchnorm.py | 5 +++-- tripy/nvtripy/frontend/module/conv/base.py | 3 ++- tripy/nvtripy/frontend/module/conv/conv.py | 2 +- .../frontend/module/conv/conv_transpose.py | 2 +- tripy/nvtripy/frontend/module/embedding.py | 3 ++- tripy/nvtripy/frontend/module/groupnorm.py | 3 ++- tripy/nvtripy/frontend/module/instancenorm.py | 10 +++++----- tripy/nvtripy/frontend/module/layernorm.py | 8 ++++---- tripy/nvtripy/frontend/module/linear.py | 9 +++++---- tripy/nvtripy/frontend/ops/allclose.py | 2 +- tripy/nvtripy/frontend/ops/arange.py | 2 +- tripy/nvtripy/frontend/ops/binary/add.py | 2 +- tripy/nvtripy/frontend/ops/binary/div.py | 2 +- tripy/nvtripy/frontend/ops/binary/equal.py | 2 +- tripy/nvtripy/frontend/ops/binary/floor_div.py | 2 +- tripy/nvtripy/frontend/ops/binary/greater.py | 2 +- .../frontend/ops/binary/greater_equal.py | 2 +- tripy/nvtripy/frontend/ops/binary/less.py | 2 +- .../nvtripy/frontend/ops/binary/less_equal.py | 2 +- .../nvtripy/frontend/ops/binary/logical_or.py | 2 +- tripy/nvtripy/frontend/ops/binary/maximum.py | 2 +- tripy/nvtripy/frontend/ops/binary/minimum.py | 2 +- tripy/nvtripy/frontend/ops/binary/mod.py | 2 +- tripy/nvtripy/frontend/ops/binary/mul.py | 2 +- tripy/nvtripy/frontend/ops/binary/not_equal.py | 2 +- tripy/nvtripy/frontend/ops/binary/pow.py | 2 +- tripy/nvtripy/frontend/ops/binary/sub.py | 2 +- tripy/nvtripy/frontend/ops/cast.py | 18 +++++++++++++++++- tripy/nvtripy/frontend/ops/concatenate.py | 6 +++++- tripy/nvtripy/frontend/ops/copy.py | 2 +- tripy/nvtripy/frontend/ops/cumsum.py | 2 +- tripy/nvtripy/frontend/ops/dequantize.py | 2 +- tripy/nvtripy/frontend/ops/equal.py | 2 +- tripy/nvtripy/frontend/ops/expand.py | 2 +- tripy/nvtripy/frontend/ops/flatten.py | 2 +- tripy/nvtripy/frontend/ops/flip.py | 2 +- tripy/nvtripy/frontend/ops/full.py | 2 +- tripy/nvtripy/frontend/ops/gather.py | 2 +- tripy/nvtripy/frontend/ops/iota.py | 2 +- tripy/nvtripy/frontend/ops/masked_fill.py | 2 +- tripy/nvtripy/frontend/ops/matmul.py | 2 +- tripy/nvtripy/frontend/ops/ones.py | 2 +- tripy/nvtripy/frontend/ops/outer.py | 2 +- tripy/nvtripy/frontend/ops/pad.py | 2 +- tripy/nvtripy/frontend/ops/permute.py | 2 +- tripy/nvtripy/frontend/ops/pooling/avgpool.py | 2 +- tripy/nvtripy/frontend/ops/pooling/maxpool.py | 2 +- tripy/nvtripy/frontend/ops/quantize.py | 2 +- tripy/nvtripy/frontend/ops/reduce/all.py | 2 +- tripy/nvtripy/frontend/ops/reduce/any.py | 2 +- tripy/nvtripy/frontend/ops/reduce/argmax.py | 2 +- tripy/nvtripy/frontend/ops/reduce/argmin.py | 2 +- tripy/nvtripy/frontend/ops/reduce/max.py | 2 +- tripy/nvtripy/frontend/ops/reduce/mean.py | 2 +- tripy/nvtripy/frontend/ops/reduce/min.py | 2 +- tripy/nvtripy/frontend/ops/reduce/prod.py | 2 +- tripy/nvtripy/frontend/ops/reduce/sum.py | 2 +- tripy/nvtripy/frontend/ops/reduce/topk.py | 8 ++++++-- tripy/nvtripy/frontend/ops/reduce/var.py | 2 +- tripy/nvtripy/frontend/ops/repeat.py | 2 +- tripy/nvtripy/frontend/ops/reshape.py | 2 +- tripy/nvtripy/frontend/ops/resize.py | 2 +- tripy/nvtripy/frontend/ops/shape.py | 2 +- tripy/nvtripy/frontend/ops/slice.py | 2 +- tripy/nvtripy/frontend/ops/softmax.py | 2 +- tripy/nvtripy/frontend/ops/split.py | 2 +- tripy/nvtripy/frontend/ops/squeeze.py | 2 +- tripy/nvtripy/frontend/ops/stack.py | 2 +- tripy/nvtripy/frontend/ops/transpose.py | 2 +- tripy/nvtripy/frontend/ops/tril.py | 4 ++-- tripy/nvtripy/frontend/ops/triu.py | 4 ++-- tripy/nvtripy/frontend/ops/unary/abs.py | 2 +- tripy/nvtripy/frontend/ops/unary/cos.py | 4 ++-- tripy/nvtripy/frontend/ops/unary/exp.py | 4 ++-- tripy/nvtripy/frontend/ops/unary/gelu.py | 2 +- tripy/nvtripy/frontend/ops/unary/invert.py | 2 +- tripy/nvtripy/frontend/ops/unary/log.py | 4 ++-- tripy/nvtripy/frontend/ops/unary/neg.py | 2 +- tripy/nvtripy/frontend/ops/unary/relu.py | 2 +- tripy/nvtripy/frontend/ops/unary/rsqrt.py | 2 +- tripy/nvtripy/frontend/ops/unary/sigmoid.py | 2 +- tripy/nvtripy/frontend/ops/unary/silu.py | 2 +- tripy/nvtripy/frontend/ops/unary/sin.py | 4 ++-- tripy/nvtripy/frontend/ops/unary/sqrt.py | 2 +- tripy/nvtripy/frontend/ops/unary/tanh.py | 2 +- tripy/nvtripy/frontend/ops/unsqueeze.py | 2 +- tripy/nvtripy/frontend/ops/where.py | 2 +- tripy/nvtripy/frontend/ops/zeros.py | 2 +- tripy/nvtripy/{utils => frontend}/wrappers.py | 18 ++++++++++++++++-- tripy/nvtripy/utils/stack_info.py | 2 +- tripy/tests/common/test_exception.py | 2 +- .../wrappers/test_datatype_constraints.py | 7 +++++-- .../wrappers/test_interface.py | 4 ++-- tripy/tests/utils/test_utils.py | 3 ++- 98 files changed, 173 insertions(+), 119 deletions(-) rename tripy/nvtripy/{utils => frontend}/wrappers.py (96%) rename tripy/tests/{utils => frontend}/wrappers/test_datatype_constraints.py (98%) rename tripy/tests/{utils => frontend}/wrappers/test_interface.py (98%) diff --git a/tripy/docs/README.md b/tripy/docs/README.md index a107706ed..612f0f326 100644 --- a/tripy/docs/README.md +++ b/tripy/docs/README.md @@ -45,7 +45,7 @@ which specifies doc metadata for each API (e.g. location). - Docstring must include *at least* **one [code example](#code-examples)**. - If the function accepts `tp.Tensor`s, must indicate **data type constraints** - with the [`wrappers.interface`](../nvtripy/utils/wrappers.py) decorator. + with the [`wrappers.interface`](../nvtripy/frontend/wrappers.py) decorator. **Example:** diff --git a/tripy/docs/post0_developer_guides/00-architecture.md b/tripy/docs/post0_developer_guides/00-architecture.md index 055dd0b3b..880cf5441 100644 --- a/tripy/docs/post0_developer_guides/00-architecture.md +++ b/tripy/docs/post0_developer_guides/00-architecture.md @@ -76,7 +76,7 @@ and various operations, e.g. {class}`nvtripy.resize`. :::{admonition} Info Most operations are decorated with: 1. [`@export.public_api`](source:/nvtripy/export.py): Enables documentation, type checking, and overloading. -2. [`@wrappers.interface`](source:/nvtripy/utils/wrappers.py): Enforces (and generates tests for) data type constraints. +2. [`@wrappers.interface`](source:/nvtripy/frontend/wrappers.py): Enforces (and generates tests for) data type constraints. ::: Operations are **lazily evaluated**. diff --git a/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md b/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md index bd354d582..8a16af3a6 100644 --- a/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md +++ b/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md @@ -129,7 +129,7 @@ from typing import Tuple from nvtripy import export from nvtripy.trace.ops.topn import TopN -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.frontend.ops import utils as op_utils @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/config.py b/tripy/nvtripy/config.py index ba9d8d10c..f274aa124 100644 --- a/tripy/nvtripy/config.py +++ b/tripy/nvtripy/config.py @@ -50,7 +50,14 @@ module=sys.modules[__name__], symbol="enable_dtype_checking", )(True) -"""Whether to enable data type checking in API functions.""" +"""[DEPRECATED - use enable_input_validation] Whether to enable data type checking in API functions.""" + +enable_input_validation: bool = export.public_api( + document_under="config.rst", + module=sys.modules[__name__], + symbol="enable_input_validation", +)(True) +"""Whether to enable input validation in API functions.""" extra_error_information: List[str] = export.public_api( document_under="config.rst", diff --git a/tripy/nvtripy/frontend/module/batchnorm.py b/tripy/nvtripy/frontend/module/batchnorm.py index f7c1380a2..d08a2ff78 100644 --- a/tripy/nvtripy/frontend/module/batchnorm.py +++ b/tripy/nvtripy/frontend/module/batchnorm.py @@ -22,11 +22,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["num_features"]) +@constant_fields(["num_features"]) class BatchNorm(Module): r""" Applies batch normalization over an N-dimensional input tensor using precomputed statistics: @@ -105,8 +106,8 @@ def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor": Returns: A tensor of the same shape as the input. """ - from nvtripy.frontend.ops.unary.rsqrt import rsqrt from nvtripy.frontend.ops.reshape import reshape + from nvtripy.frontend.ops.unary.rsqrt import rsqrt x_shape = (1, self.num_features, *([1] * (x.rank - 2))) diff --git a/tripy/nvtripy/frontend/module/conv/base.py b/tripy/nvtripy/frontend/module/conv/base.py index 812e08fef..fcebf79f0 100644 --- a/tripy/nvtripy/frontend/module/conv/base.py +++ b/tripy/nvtripy/frontend/module/conv/base.py @@ -23,10 +23,11 @@ from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @dataclass -@utils.wrappers.constant_fields(["dtype", "padding", "stride", "groups", "dilation"]) +@constant_fields(["dtype", "padding", "stride", "groups", "dilation"]) class ConvBase(Module): r"""Base class for sharing common functionality between Conv and ConvTranspose.""" diff --git a/tripy/nvtripy/frontend/module/conv/conv.py b/tripy/nvtripy/frontend/module/conv/conv.py index 2c322fdf8..978cc8394 100644 --- a/tripy/nvtripy/frontend/module/conv/conv.py +++ b/tripy/nvtripy/frontend/module/conv/conv.py @@ -26,7 +26,7 @@ from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor from nvtripy.trace.ops.convolution import Convolution -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers # This function is added so that we can do dtype checking. diff --git a/tripy/nvtripy/frontend/module/conv/conv_transpose.py b/tripy/nvtripy/frontend/module/conv/conv_transpose.py index 3c5b5666e..4becc1103 100644 --- a/tripy/nvtripy/frontend/module/conv/conv_transpose.py +++ b/tripy/nvtripy/frontend/module/conv/conv_transpose.py @@ -26,7 +26,7 @@ from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor from nvtripy.trace.ops.deconvolution import Deconvolution -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers # This function is added so that we can do dtype checking. diff --git a/tripy/nvtripy/frontend/module/embedding.py b/tripy/nvtripy/frontend/module/embedding.py index b8e5b72f0..d72e4134c 100644 --- a/tripy/nvtripy/frontend/module/embedding.py +++ b/tripy/nvtripy/frontend/module/embedding.py @@ -22,11 +22,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["dtype"]) +@constant_fields(["dtype"]) class Embedding(Module): """ A lookup table for embedding vectors of a fixed size. diff --git a/tripy/nvtripy/frontend/module/groupnorm.py b/tripy/nvtripy/frontend/module/groupnorm.py index 5f98807ca..390f07fa8 100644 --- a/tripy/nvtripy/frontend/module/groupnorm.py +++ b/tripy/nvtripy/frontend/module/groupnorm.py @@ -25,11 +25,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["num_groups", "num_channels", "dtype"]) +@constant_fields(["num_groups", "num_channels", "dtype"]) class GroupNorm(Module): r""" Applies group normalization over the input tensor: diff --git a/tripy/nvtripy/frontend/module/instancenorm.py b/tripy/nvtripy/frontend/module/instancenorm.py index 54b29a250..5a9a41ee4 100644 --- a/tripy/nvtripy/frontend/module/instancenorm.py +++ b/tripy/nvtripy/frontend/module/instancenorm.py @@ -17,15 +17,15 @@ from dataclasses import dataclass -from nvtripy import constants, export, utils +from nvtripy import constants, export from nvtripy.common import datatype from nvtripy.common.exception import raise_error +from nvtripy.frontend import wrappers from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter -from nvtripy.frontend.tensor import Tensor - from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields from nvtripy.trace.ops.instancenorm import InstanceNorm as InstanceNormOp @@ -81,7 +81,7 @@ def instancenorm( @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["num_channels", "dtype", "eps"]) +@constant_fields(["num_channels", "dtype", "eps"]) class InstanceNorm(Module): r""" Applies Instance Normalization over a mini-batch of inputs: diff --git a/tripy/nvtripy/frontend/module/layernorm.py b/tripy/nvtripy/frontend/module/layernorm.py index be2627ed1..1128d41cf 100644 --- a/tripy/nvtripy/frontend/module/layernorm.py +++ b/tripy/nvtripy/frontend/module/layernorm.py @@ -21,12 +21,12 @@ from nvtripy import export, utils from nvtripy.common import datatype from nvtripy.common.exception import raise_error +from nvtripy.frontend import wrappers from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter -from nvtripy.frontend.tensor import Tensor - from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields from nvtripy.trace.ops.layernorm import LayerNorm as LayerNormOp @@ -70,7 +70,7 @@ def layernorm( @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["dtype", "normalized_shape"]) +@constant_fields(["dtype", "normalized_shape"]) class LayerNorm(Module): r""" Applies layer normalization over the input tensor: diff --git a/tripy/nvtripy/frontend/module/linear.py b/tripy/nvtripy/frontend/module/linear.py index 6c3c06eba..919c9e1f1 100644 --- a/tripy/nvtripy/frontend/module/linear.py +++ b/tripy/nvtripy/frontend/module/linear.py @@ -23,11 +23,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter, OptionalParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["dtype", "quant_dtype"]) +@constant_fields(["dtype", "quant_dtype"]) class Linear(Module): r""" Applies a linear transformation to the input: @@ -117,11 +118,11 @@ def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor": A tensor of shape :math:`[*, \text{out_features}]`. """ from nvtripy.common.exception import raise_error - from nvtripy.frontend.tensor import Tensor - from nvtripy.frontend.ops.transpose import transpose - from nvtripy.frontend.ops.unsqueeze import unsqueeze from nvtripy.frontend.ops.dequantize import dequantize from nvtripy.frontend.ops.quantize import quantize + from nvtripy.frontend.ops.transpose import transpose + from nvtripy.frontend.ops.unsqueeze import unsqueeze + from nvtripy.frontend.tensor import Tensor if self.quant_dtype is not None: if isinstance(self.input_scale, Tensor): diff --git a/tripy/nvtripy/frontend/ops/allclose.py b/tripy/nvtripy/frontend/ops/allclose.py index 28ac163ed..0b4d756e2 100644 --- a/tripy/nvtripy/frontend/ops/allclose.py +++ b/tripy/nvtripy/frontend/ops/allclose.py @@ -16,7 +16,7 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/arange.py b/tripy/nvtripy/frontend/ops/arange.py index e20be7917..6aa07ca8f 100644 --- a/tripy/nvtripy/frontend/ops/arange.py +++ b/tripy/nvtripy/frontend/ops/arange.py @@ -22,7 +22,7 @@ from nvtripy.frontend.ops.cast import cast from nvtripy.frontend.ops.reshape import reshape from nvtripy.trace.ops.linspace import Linspace -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/binary/add.py b/tripy/nvtripy/frontend/ops/binary/add.py index f673c452c..e936129cd 100644 --- a/tripy/nvtripy/frontend/ops/binary/add.py +++ b/tripy/nvtripy/frontend/ops/binary/add.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Add from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__add__") diff --git a/tripy/nvtripy/frontend/ops/binary/div.py b/tripy/nvtripy/frontend/ops/binary/div.py index 125b42b46..482a06a59 100644 --- a/tripy/nvtripy/frontend/ops/binary/div.py +++ b/tripy/nvtripy/frontend/ops/binary/div.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Div from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__truediv__") diff --git a/tripy/nvtripy/frontend/ops/binary/equal.py b/tripy/nvtripy/frontend/ops/binary/equal.py index 55c899b36..699882a74 100644 --- a/tripy/nvtripy/frontend/ops/binary/equal.py +++ b/tripy/nvtripy/frontend/ops/binary/equal.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Equal from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__eq__") diff --git a/tripy/nvtripy/frontend/ops/binary/floor_div.py b/tripy/nvtripy/frontend/ops/binary/floor_div.py index 4ddbe0a5b..e221e0fb8 100644 --- a/tripy/nvtripy/frontend/ops/binary/floor_div.py +++ b/tripy/nvtripy/frontend/ops/binary/floor_div.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import FloorDiv from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__floordiv__") diff --git a/tripy/nvtripy/frontend/ops/binary/greater.py b/tripy/nvtripy/frontend/ops/binary/greater.py index c82ded594..cf5504d7a 100644 --- a/tripy/nvtripy/frontend/ops/binary/greater.py +++ b/tripy/nvtripy/frontend/ops/binary/greater.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Greater from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__gt__") diff --git a/tripy/nvtripy/frontend/ops/binary/greater_equal.py b/tripy/nvtripy/frontend/ops/binary/greater_equal.py index 6be66c9db..c75c3cb12 100644 --- a/tripy/nvtripy/frontend/ops/binary/greater_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/greater_equal.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__ge__") diff --git a/tripy/nvtripy/frontend/ops/binary/less.py b/tripy/nvtripy/frontend/ops/binary/less.py index 6495319fa..6b8431f39 100644 --- a/tripy/nvtripy/frontend/ops/binary/less.py +++ b/tripy/nvtripy/frontend/ops/binary/less.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Less from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__lt__") diff --git a/tripy/nvtripy/frontend/ops/binary/less_equal.py b/tripy/nvtripy/frontend/ops/binary/less_equal.py index 14bf35c9b..9694fa4e7 100644 --- a/tripy/nvtripy/frontend/ops/binary/less_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/less_equal.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__le__") diff --git a/tripy/nvtripy/frontend/ops/binary/logical_or.py b/tripy/nvtripy/frontend/ops/binary/logical_or.py index 2302415ce..268f42b73 100644 --- a/tripy/nvtripy/frontend/ops/binary/logical_or.py +++ b/tripy/nvtripy/frontend/ops/binary/logical_or.py @@ -15,7 +15,7 @@ from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import LogicalOr -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__or__") diff --git a/tripy/nvtripy/frontend/ops/binary/maximum.py b/tripy/nvtripy/frontend/ops/binary/maximum.py index 05023378d..c25d8152a 100644 --- a/tripy/nvtripy/frontend/ops/binary/maximum.py +++ b/tripy/nvtripy/frontend/ops/binary/maximum.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Max from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/binary/minimum.py b/tripy/nvtripy/frontend/ops/binary/minimum.py index 0a1954b1b..19a7e5a74 100644 --- a/tripy/nvtripy/frontend/ops/binary/minimum.py +++ b/tripy/nvtripy/frontend/ops/binary/minimum.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Min from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/binary/mod.py b/tripy/nvtripy/frontend/ops/binary/mod.py index f38203f43..f22ee1295 100644 --- a/tripy/nvtripy/frontend/ops/binary/mod.py +++ b/tripy/nvtripy/frontend/ops/binary/mod.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers def mod_impl(lhs, rhs): diff --git a/tripy/nvtripy/frontend/ops/binary/mul.py b/tripy/nvtripy/frontend/ops/binary/mul.py index e6bcd9940..0fc2c9811 100644 --- a/tripy/nvtripy/frontend/ops/binary/mul.py +++ b/tripy/nvtripy/frontend/ops/binary/mul.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Mul from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__mul__") diff --git a/tripy/nvtripy/frontend/ops/binary/not_equal.py b/tripy/nvtripy/frontend/ops/binary/not_equal.py index 6e175dc84..d4a8b4fbb 100644 --- a/tripy/nvtripy/frontend/ops/binary/not_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/not_equal.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__ne__") diff --git a/tripy/nvtripy/frontend/ops/binary/pow.py b/tripy/nvtripy/frontend/ops/binary/pow.py index 285e73141..17b55e1d3 100644 --- a/tripy/nvtripy/frontend/ops/binary/pow.py +++ b/tripy/nvtripy/frontend/ops/binary/pow.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Pow from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__pow__") diff --git a/tripy/nvtripy/frontend/ops/binary/sub.py b/tripy/nvtripy/frontend/ops/binary/sub.py index 94b51ec9e..7b4bc4d6e 100644 --- a/tripy/nvtripy/frontend/ops/binary/sub.py +++ b/tripy/nvtripy/frontend/ops/binary/sub.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Sub from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__sub__") diff --git a/tripy/nvtripy/frontend/ops/cast.py b/tripy/nvtripy/frontend/ops/cast.py index 9c197af57..b2a50fbf4 100644 --- a/tripy/nvtripy/frontend/ops/cast.py +++ b/tripy/nvtripy/frontend/ops/cast.py @@ -24,7 +24,23 @@ from nvtripy.frontend.ops.dequantize import dequantize from nvtripy.frontend.ops.quantize import quantize from nvtripy.trace.ops.cast import Cast -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +# constraints = ( +# OneOf( +# GetInput("input").dtype, +# [tp.float32, tp.float16, tp.bfloat16, tp.float8, tp.int4, tp.int8, tp.int32, tp.int64, tp.bool], +# ) +# & OneOf( +# GetInput("dtype"), +# [tp.float32, tp.float16, tp.bfloat16, tp.float8, tp.int4, tp.int8, tp.int32, tp.int64, tp.bool], +# ) +# & ~(GetInput("input").dtype == tp.float8 & OneOf(GetInput("dtype"), [tp.int4, tp.int8])) +# & ~(GetInput("input").dtype == tp.int8 & GetInput("dtype") == tp.float8) +# & ~(GetInput("input").dtype == tp.int4 & OneOf(GetInput("dtype"), [tp.float8, tp.int8, tp.int64])) +# ) +# output_guarantees = GetReturn(0).dtype == GetInput("dtype") @register_tensor_method("cast") diff --git a/tripy/nvtripy/frontend/ops/concatenate.py b/tripy/nvtripy/frontend/ops/concatenate.py index de9d95de7..4f1826db6 100644 --- a/tripy/nvtripy/frontend/ops/concatenate.py +++ b/tripy/nvtripy/frontend/ops/concatenate.py @@ -21,7 +21,11 @@ from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.concatenate import Concatenate -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +# constraints = OneOf(GetInput("tensors").dtype, [tp.float32, tp.float16, tp.bfloat16, tp.float8, tp.int4, tp.int8, tp.int32, tp.int64, tp.bool]) +# output_guarantees = GetReturn(0).dtype == GetInput("tensors").dtype @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/copy.py b/tripy/nvtripy/frontend/ops/copy.py index 4e5c10d6a..488cde55a 100644 --- a/tripy/nvtripy/frontend/ops/copy.py +++ b/tripy/nvtripy/frontend/ops/copy.py @@ -22,7 +22,7 @@ from nvtripy.common.datatype import DATA_TYPES from nvtripy.common.exception import raise_error from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("copy") diff --git a/tripy/nvtripy/frontend/ops/cumsum.py b/tripy/nvtripy/frontend/ops/cumsum.py index c0dd76902..711d717ae 100644 --- a/tripy/nvtripy/frontend/ops/cumsum.py +++ b/tripy/nvtripy/frontend/ops/cumsum.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy import export from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/dequantize.py b/tripy/nvtripy/frontend/ops/dequantize.py index 816efa28c..cc0d6a39b 100644 --- a/tripy/nvtripy/frontend/ops/dequantize.py +++ b/tripy/nvtripy/frontend/ops/dequantize.py @@ -22,7 +22,7 @@ from nvtripy.common import datatype from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.dequantize import Dequantize -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/quantization") diff --git a/tripy/nvtripy/frontend/ops/equal.py b/tripy/nvtripy/frontend/ops/equal.py index 7d5fc3a7b..93ac7256f 100644 --- a/tripy/nvtripy/frontend/ops/equal.py +++ b/tripy/nvtripy/frontend/ops/equal.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy import export from nvtripy.common.datatype import DATA_TYPES -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/expand.py b/tripy/nvtripy/frontend/ops/expand.py index 6cced90bc..41b975bd3 100644 --- a/tripy/nvtripy/frontend/ops/expand.py +++ b/tripy/nvtripy/frontend/ops/expand.py @@ -21,7 +21,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.broadcast import Broadcast from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers def process_sizes(input: "nvtripy.Tensor", sizes: ShapeLike): diff --git a/tripy/nvtripy/frontend/ops/flatten.py b/tripy/nvtripy/frontend/ops/flatten.py index a356bf123..85b3b175c 100644 --- a/tripy/nvtripy/frontend/ops/flatten.py +++ b/tripy/nvtripy/frontend/ops/flatten.py @@ -18,7 +18,7 @@ from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("flatten") diff --git a/tripy/nvtripy/frontend/ops/flip.py b/tripy/nvtripy/frontend/ops/flip.py index 16912f97b..c2e3d71d1 100644 --- a/tripy/nvtripy/frontend/ops/flip.py +++ b/tripy/nvtripy/frontend/ops/flip.py @@ -19,7 +19,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/full.py b/tripy/nvtripy/frontend/ops/full.py index 044cfe1e5..e8882250a 100644 --- a/tripy/nvtripy/frontend/ops/full.py +++ b/tripy/nvtripy/frontend/ops/full.py @@ -22,7 +22,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.broadcast import Broadcast from nvtripy.types import ShapeLike, TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/gather.py b/tripy/nvtripy/frontend/ops/gather.py index 7af68d070..bced20763 100644 --- a/tripy/nvtripy/frontend/ops/gather.py +++ b/tripy/nvtripy/frontend/ops/gather.py @@ -19,7 +19,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.gather import Gather -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/iota.py b/tripy/nvtripy/frontend/ops/iota.py index fa5b247f1..6df4d2668 100644 --- a/tripy/nvtripy/frontend/ops/iota.py +++ b/tripy/nvtripy/frontend/ops/iota.py @@ -22,7 +22,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.linspace import Linspace from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers def iota_impl(shape: "nvtripy.Tensor", dim: int, dtype: datatype.dtype) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/masked_fill.py b/tripy/nvtripy/frontend/ops/masked_fill.py index 3bfb54226..540198fc4 100644 --- a/tripy/nvtripy/frontend/ops/masked_fill.py +++ b/tripy/nvtripy/frontend/ops/masked_fill.py @@ -15,7 +15,7 @@ import numbers from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/matmul.py b/tripy/nvtripy/frontend/ops/matmul.py index 1f323987f..43210a79a 100644 --- a/tripy/nvtripy/frontend/ops/matmul.py +++ b/tripy/nvtripy/frontend/ops/matmul.py @@ -19,7 +19,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.matmul import MatrixMultiply -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__matmul__") diff --git a/tripy/nvtripy/frontend/ops/ones.py b/tripy/nvtripy/frontend/ops/ones.py index 2abcd81cc..154d5863c 100644 --- a/tripy/nvtripy/frontend/ops/ones.py +++ b/tripy/nvtripy/frontend/ops/ones.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.common import datatype from nvtripy.frontend.ops.full import full, full_like -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/outer.py b/tripy/nvtripy/frontend/ops/outer.py index fae200134..c25652023 100644 --- a/tripy/nvtripy/frontend/ops/outer.py +++ b/tripy/nvtripy/frontend/ops/outer.py @@ -16,7 +16,7 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/pad.py b/tripy/nvtripy/frontend/ops/pad.py index f58afdfd7..42512cf49 100644 --- a/tripy/nvtripy/frontend/ops/pad.py +++ b/tripy/nvtripy/frontend/ops/pad.py @@ -23,7 +23,7 @@ from nvtripy.trace.ops.shape import Shape from nvtripy.trace.ops.slice import SliceFill, SliceReflect from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/permute.py b/tripy/nvtripy/frontend/ops/permute.py index a367ce939..8f798ec15 100644 --- a/tripy/nvtripy/frontend/ops/permute.py +++ b/tripy/nvtripy/frontend/ops/permute.py @@ -22,7 +22,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.permute import Permute -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("permute") diff --git a/tripy/nvtripy/frontend/ops/pooling/avgpool.py b/tripy/nvtripy/frontend/ops/pooling/avgpool.py index 50eb3a18d..68038cbd5 100644 --- a/tripy/nvtripy/frontend/ops/pooling/avgpool.py +++ b/tripy/nvtripy/frontend/ops/pooling/avgpool.py @@ -23,7 +23,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops.pooling import utils as pooling_utils from nvtripy.trace.ops.pooling import AvgPooling -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/pooling/maxpool.py b/tripy/nvtripy/frontend/ops/pooling/maxpool.py index c38ab13f6..dd177e780 100644 --- a/tripy/nvtripy/frontend/ops/pooling/maxpool.py +++ b/tripy/nvtripy/frontend/ops/pooling/maxpool.py @@ -21,7 +21,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops.pooling import utils as pooling_utils from nvtripy.trace.ops.pooling import MaxPooling -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/quantize.py b/tripy/nvtripy/frontend/ops/quantize.py index 69c062a0b..3d229789d 100644 --- a/tripy/nvtripy/frontend/ops/quantize.py +++ b/tripy/nvtripy/frontend/ops/quantize.py @@ -22,7 +22,7 @@ from nvtripy.common import datatype from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.quantize import Quantize -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/quantization") diff --git a/tripy/nvtripy/frontend/ops/reduce/all.py b/tripy/nvtripy/frontend/ops/reduce/all.py index fc3116287..d7dc507e4 100644 --- a/tripy/nvtripy/frontend/ops/reduce/all.py +++ b/tripy/nvtripy/frontend/ops/reduce/all.py @@ -15,7 +15,7 @@ from typing import Optional, Sequence, Union from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.common import datatype diff --git a/tripy/nvtripy/frontend/ops/reduce/any.py b/tripy/nvtripy/frontend/ops/reduce/any.py index 01630dd60..c9f4af47b 100644 --- a/tripy/nvtripy/frontend/ops/reduce/any.py +++ b/tripy/nvtripy/frontend/ops/reduce/any.py @@ -15,7 +15,7 @@ from typing import Optional, Sequence, Union from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.common import datatype diff --git a/tripy/nvtripy/frontend/ops/reduce/argmax.py b/tripy/nvtripy/frontend/ops/reduce/argmax.py index e3abdb4c0..7d5ee84ab 100644 --- a/tripy/nvtripy/frontend/ops/reduce/argmax.py +++ b/tripy/nvtripy/frontend/ops/reduce/argmax.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import arg_min_max_impl from nvtripy.trace.ops.topk import TopKMax -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/argmin.py b/tripy/nvtripy/frontend/ops/reduce/argmin.py index 69d616162..bc09187a5 100644 --- a/tripy/nvtripy/frontend/ops/reduce/argmin.py +++ b/tripy/nvtripy/frontend/ops/reduce/argmin.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import arg_min_max_impl from nvtripy.trace.ops.topk import TopKMin -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/max.py b/tripy/nvtripy/frontend/ops/reduce/max.py index 5b28f99a4..b6b569b2d 100644 --- a/tripy/nvtripy/frontend/ops/reduce/max.py +++ b/tripy/nvtripy/frontend/ops/reduce/max.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Max -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/mean.py b/tripy/nvtripy/frontend/ops/reduce/mean.py index 5a85e6357..e9dd4d5a1 100644 --- a/tripy/nvtripy/frontend/ops/reduce/mean.py +++ b/tripy/nvtripy/frontend/ops/reduce/mean.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Avg -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/min.py b/tripy/nvtripy/frontend/ops/reduce/min.py index 1bd15785b..b08c5a494 100644 --- a/tripy/nvtripy/frontend/ops/reduce/min.py +++ b/tripy/nvtripy/frontend/ops/reduce/min.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Min -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/prod.py b/tripy/nvtripy/frontend/ops/reduce/prod.py index 75ab67f21..016bf20e8 100644 --- a/tripy/nvtripy/frontend/ops/reduce/prod.py +++ b/tripy/nvtripy/frontend/ops/reduce/prod.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Prod -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/sum.py b/tripy/nvtripy/frontend/ops/reduce/sum.py index e118eb397..9e85e4197 100644 --- a/tripy/nvtripy/frontend/ops/reduce/sum.py +++ b/tripy/nvtripy/frontend/ops/reduce/sum.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Sum -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/topk.py b/tripy/nvtripy/frontend/ops/reduce/topk.py index 57e1e2dfe..e5713c82a 100644 --- a/tripy/nvtripy/frontend/ops/reduce/topk.py +++ b/tripy/nvtripy/frontend/ops/reduce/topk.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +17,11 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import topk_impl from nvtripy.trace.ops.topk import TopKMax -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +# constraints = OneOf(GetInput("input").dtype, [tp.float32, tp.float16, tp.bfloat16, tp.int32, tp.int64]) +# output_guarantees = (GetReturn(0).dtype == GetInput("input").dtype) & (GetReturn(1).dtype == tp.int32)) @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/var.py b/tripy/nvtripy/frontend/ops/reduce/var.py index ad8cd4f8d..585600034 100644 --- a/tripy/nvtripy/frontend/ops/reduce/var.py +++ b/tripy/nvtripy/frontend/ops/reduce/var.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/repeat.py b/tripy/nvtripy/frontend/ops/repeat.py index d7e531b41..9dc498e4c 100644 --- a/tripy/nvtripy/frontend/ops/repeat.py +++ b/tripy/nvtripy/frontend/ops/repeat.py @@ -18,7 +18,7 @@ from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reshape.py b/tripy/nvtripy/frontend/ops/reshape.py index 262dc78b8..ae0eb52a0 100644 --- a/tripy/nvtripy/frontend/ops/reshape.py +++ b/tripy/nvtripy/frontend/ops/reshape.py @@ -23,7 +23,7 @@ from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.reshape import Reshape from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers def infer_dimensions(input: "nvtripy.Tensor", shape: ShapeLike) -> ShapeLike: diff --git a/tripy/nvtripy/frontend/ops/resize.py b/tripy/nvtripy/frontend/ops/resize.py index 477623062..31fba7ad3 100644 --- a/tripy/nvtripy/frontend/ops/resize.py +++ b/tripy/nvtripy/frontend/ops/resize.py @@ -23,7 +23,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.resize import ResizeCubic, ResizeLinear, ResizeNearest from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers SUPPORTED_MODES = ("cubic", "linear", "nearest") diff --git a/tripy/nvtripy/frontend/ops/shape.py b/tripy/nvtripy/frontend/ops/shape.py index 94e3fa19c..08795206d 100644 --- a/tripy/nvtripy/frontend/ops/shape.py +++ b/tripy/nvtripy/frontend/ops/shape.py @@ -25,7 +25,7 @@ from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.shape import GetDimensionSize, Shape from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("shape") diff --git a/tripy/nvtripy/frontend/ops/slice.py b/tripy/nvtripy/frontend/ops/slice.py index 010e04dee..032fc70e5 100644 --- a/tripy/nvtripy/frontend/ops/slice.py +++ b/tripy/nvtripy/frontend/ops/slice.py @@ -23,7 +23,7 @@ from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.slice import Slice from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.utils.types import type_str_from_arg from nvtripy.utils.utils import make_list diff --git a/tripy/nvtripy/frontend/ops/softmax.py b/tripy/nvtripy/frontend/ops/softmax.py index f195eac81..002332319 100644 --- a/tripy/nvtripy/frontend/ops/softmax.py +++ b/tripy/nvtripy/frontend/ops/softmax.py @@ -20,7 +20,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.softmax import Softmax -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/split.py b/tripy/nvtripy/frontend/ops/split.py index 76a3d4a8d..829f74b52 100644 --- a/tripy/nvtripy/frontend/ops/split.py +++ b/tripy/nvtripy/frontend/ops/split.py @@ -21,7 +21,7 @@ from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/squeeze.py b/tripy/nvtripy/frontend/ops/squeeze.py index 5ec260e15..132fd7836 100644 --- a/tripy/nvtripy/frontend/ops/squeeze.py +++ b/tripy/nvtripy/frontend/ops/squeeze.py @@ -17,7 +17,7 @@ from nvtripy import export, utils from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("squeeze") diff --git a/tripy/nvtripy/frontend/ops/stack.py b/tripy/nvtripy/frontend/ops/stack.py index 0f2df667e..e46c8b69f 100644 --- a/tripy/nvtripy/frontend/ops/stack.py +++ b/tripy/nvtripy/frontend/ops/stack.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.common.exception import raise_error -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/transpose.py b/tripy/nvtripy/frontend/ops/transpose.py index 1d4b8cf3d..9119f6937 100644 --- a/tripy/nvtripy/frontend/ops/transpose.py +++ b/tripy/nvtripy/frontend/ops/transpose.py @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.common.exception import raise_error from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("transpose") diff --git a/tripy/nvtripy/frontend/ops/tril.py b/tripy/nvtripy/frontend/ops/tril.py index 563400e16..a90cc8bb4 100644 --- a/tripy/nvtripy/frontend/ops/tril.py +++ b/tripy/nvtripy/frontend/ops/tril.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,7 +18,7 @@ from nvtripy.frontend.ops.iota import iota_like from nvtripy.frontend.ops.zeros import zeros_like from nvtripy.frontend.ops.where import where -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/triu.py b/tripy/nvtripy/frontend/ops/triu.py index 80dc08103..95a3bc02f 100644 --- a/tripy/nvtripy/frontend/ops/triu.py +++ b/tripy/nvtripy/frontend/ops/triu.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,7 +18,7 @@ from nvtripy.frontend.ops.iota import iota_like from nvtripy.frontend.ops.where import where from nvtripy.frontend.ops.zeros import zeros_like -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/unary/abs.py b/tripy/nvtripy/frontend/ops/unary/abs.py index 9447cecbb..6ae52b523 100644 --- a/tripy/nvtripy/frontend/ops/unary/abs.py +++ b/tripy/nvtripy/frontend/ops/unary/abs.py @@ -19,7 +19,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Abs -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__abs__") diff --git a/tripy/nvtripy/frontend/ops/unary/cos.py b/tripy/nvtripy/frontend/ops/unary/cos.py index cabe9f697..fbf9ba35d 100644 --- a/tripy/nvtripy/frontend/ops/unary/cos.py +++ b/tripy/nvtripy/frontend/ops/unary/cos.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Cos -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/exp.py b/tripy/nvtripy/frontend/ops/unary/exp.py index 4895a4556..837f89051 100644 --- a/tripy/nvtripy/frontend/ops/unary/exp.py +++ b/tripy/nvtripy/frontend/ops/unary/exp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Exp -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/gelu.py b/tripy/nvtripy/frontend/ops/unary/gelu.py index a99dd240c..44a9d60bd 100644 --- a/tripy/nvtripy/frontend/ops/unary/gelu.py +++ b/tripy/nvtripy/frontend/ops/unary/gelu.py @@ -17,7 +17,7 @@ from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.trace.ops.unary import GeluErf from nvtripy.frontend.ops import utils as op_utils diff --git a/tripy/nvtripy/frontend/ops/unary/invert.py b/tripy/nvtripy/frontend/ops/unary/invert.py index 343f7525f..8fea29d6a 100644 --- a/tripy/nvtripy/frontend/ops/unary/invert.py +++ b/tripy/nvtripy/frontend/ops/unary/invert.py @@ -15,7 +15,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Not -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__invert__") diff --git a/tripy/nvtripy/frontend/ops/unary/log.py b/tripy/nvtripy/frontend/ops/unary/log.py index 74257a948..8d21af689 100644 --- a/tripy/nvtripy/frontend/ops/unary/log.py +++ b/tripy/nvtripy/frontend/ops/unary/log.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Log -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/neg.py b/tripy/nvtripy/frontend/ops/unary/neg.py index 3849364b4..0e08130a0 100644 --- a/tripy/nvtripy/frontend/ops/unary/neg.py +++ b/tripy/nvtripy/frontend/ops/unary/neg.py @@ -17,7 +17,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Neg -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__neg__") diff --git a/tripy/nvtripy/frontend/ops/unary/relu.py b/tripy/nvtripy/frontend/ops/unary/relu.py index 2db6aac9f..531cff1dc 100644 --- a/tripy/nvtripy/frontend/ops/unary/relu.py +++ b/tripy/nvtripy/frontend/ops/unary/relu.py @@ -18,7 +18,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Relu -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.common import datatype diff --git a/tripy/nvtripy/frontend/ops/unary/rsqrt.py b/tripy/nvtripy/frontend/ops/unary/rsqrt.py index 5cd215073..9041deef6 100644 --- a/tripy/nvtripy/frontend/ops/unary/rsqrt.py +++ b/tripy/nvtripy/frontend/ops/unary/rsqrt.py @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Recip -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/sigmoid.py b/tripy/nvtripy/frontend/ops/unary/sigmoid.py index be7ea05a7..efee06568 100644 --- a/tripy/nvtripy/frontend/ops/unary/sigmoid.py +++ b/tripy/nvtripy/frontend/ops/unary/sigmoid.py @@ -16,7 +16,7 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.trace.ops.unary import Sigmoid from nvtripy.frontend.ops import utils as op_utils diff --git a/tripy/nvtripy/frontend/ops/unary/silu.py b/tripy/nvtripy/frontend/ops/unary/silu.py index 3813f06e3..6e359bec5 100644 --- a/tripy/nvtripy/frontend/ops/unary/silu.py +++ b/tripy/nvtripy/frontend/ops/unary/silu.py @@ -16,7 +16,7 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/sin.py b/tripy/nvtripy/frontend/ops/unary/sin.py index 7078a30b7..bf7ae4e1b 100644 --- a/tripy/nvtripy/frontend/ops/unary/sin.py +++ b/tripy/nvtripy/frontend/ops/unary/sin.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Sin -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/sqrt.py b/tripy/nvtripy/frontend/ops/unary/sqrt.py index 6a67ed9db..b8681357c 100644 --- a/tripy/nvtripy/frontend/ops/unary/sqrt.py +++ b/tripy/nvtripy/frontend/ops/unary/sqrt.py @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Sqrt -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/tanh.py b/tripy/nvtripy/frontend/ops/unary/tanh.py index fac66c675..67f9bc595 100644 --- a/tripy/nvtripy/frontend/ops/unary/tanh.py +++ b/tripy/nvtripy/frontend/ops/unary/tanh.py @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Tanh -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unsqueeze.py b/tripy/nvtripy/frontend/ops/unsqueeze.py index fa3e25045..e14612d49 100644 --- a/tripy/nvtripy/frontend/ops/unsqueeze.py +++ b/tripy/nvtripy/frontend/ops/unsqueeze.py @@ -18,7 +18,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("unsqueeze") diff --git a/tripy/nvtripy/frontend/ops/where.py b/tripy/nvtripy/frontend/ops/where.py index 7c41a4ae0..277a61490 100644 --- a/tripy/nvtripy/frontend/ops/where.py +++ b/tripy/nvtripy/frontend/ops/where.py @@ -20,7 +20,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.where import Where from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/zeros.py b/tripy/nvtripy/frontend/ops/zeros.py index 177d20cf0..4f3028a21 100644 --- a/tripy/nvtripy/frontend/ops/zeros.py +++ b/tripy/nvtripy/frontend/ops/zeros.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.common import datatype from nvtripy.frontend.ops.full import full, full_like -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/utils/wrappers.py b/tripy/nvtripy/frontend/wrappers.py similarity index 96% rename from tripy/nvtripy/utils/wrappers.py rename to tripy/nvtripy/frontend/wrappers.py index 7f84427c3..fa9cdaa17 100644 --- a/tripy/nvtripy/utils/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -15,17 +15,19 @@ # limitations under the License. # +# TODO (pranavm): Move into frontend - only used there. + import functools import inspect -import types from dataclasses import dataclass from textwrap import indent from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union from nvtripy import config, utils +from nvtripy.common.datatype import DATA_TYPES from nvtripy.common.exception import raise_error +from nvtripy.frontend.constraints import Constraints from nvtripy.utils import result -from nvtripy.common.datatype import DATA_TYPES @dataclass @@ -295,6 +297,9 @@ def sorted_types(dtypes): def interface( + # TODO (pranavm): These should be required arguments eventually. + input_requirements: Constraints = None, + output_guarantees: Constraints = None, dtype_constraints: Dict[str, str] = {}, dtype_variables: Dict[str, List[str]] = {}, dtype_exceptions: List[Dict[str, str]] = [], @@ -386,6 +391,15 @@ def wrapper(*args, **kwargs): shape_likes, ) + if config.enable_input_validation: + if input_requirements is not None: + result = input_requirements(func, merged_args) + if not result: + raise_error( + f"Input requirements not met for function: '{func.__qualname__}'.", + result.error_details, + ) + if config.enable_dtype_checking: from nvtripy.common.datatype import dtype from nvtripy.frontend.tensor import Tensor diff --git a/tripy/nvtripy/utils/stack_info.py b/tripy/nvtripy/utils/stack_info.py index 19fab4871..f8557552e 100644 --- a/tripy/nvtripy/utils/stack_info.py +++ b/tripy/nvtripy/utils/stack_info.py @@ -139,6 +139,6 @@ def get_module_names_to_exclude_from_stack_info(): or trying to retrieve column information from code. """ import nvtripy.utils.function_registry as function_registry - import nvtripy.utils.wrappers as wrappers + import nvtripy.frontend.wrappers as wrappers return {mod.__name__ for mod in [function_registry, wrappers]} diff --git a/tripy/tests/common/test_exception.py b/tripy/tests/common/test_exception.py index 0c9fbdd87..b8a8a0aec 100644 --- a/tripy/tests/common/test_exception.py +++ b/tripy/tests/common/test_exception.py @@ -122,7 +122,7 @@ def test_can_determine_column_range(self): ) def test_wrappers_is_excluded(self): - from nvtripy.utils import wrappers + from nvtripy.frontend import wrappers tensor = tp.ones((2, 3)) diff --git a/tripy/tests/utils/wrappers/test_datatype_constraints.py b/tripy/tests/frontend/wrappers/test_datatype_constraints.py similarity index 98% rename from tripy/tests/utils/wrappers/test_datatype_constraints.py rename to tripy/tests/frontend/wrappers/test_datatype_constraints.py index 548023fe2..cccb14590 100644 --- a/tripy/tests/utils/wrappers/test_datatype_constraints.py +++ b/tripy/tests/frontend/wrappers/test_datatype_constraints.py @@ -12,6 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# TODO (pranavm): Move into integration tests + import contextlib import inspect import itertools @@ -22,10 +25,10 @@ import nvtripy as tp import pytest from nvtripy.common.datatype import DATA_TYPES -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.utils.types import str_from_type_annotation from nvtripy.utils.utils import make_list -from nvtripy.utils.wrappers import DATA_TYPE_CONSTRAINTS +from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS from tests import helper from tests.conftest import skip_if_older_than_sm89 diff --git a/tripy/tests/utils/wrappers/test_interface.py b/tripy/tests/frontend/wrappers/test_interface.py similarity index 98% rename from tripy/tests/utils/wrappers/test_interface.py rename to tripy/tests/frontend/wrappers/test_interface.py index 7a9466afe..1ca042bbc 100755 --- a/tripy/tests/utils/wrappers/test_interface.py +++ b/tripy/tests/frontend/wrappers/test_interface.py @@ -22,8 +22,8 @@ import nvtripy as tp import pytest from nvtripy.export import PUBLIC_APIS -from nvtripy.utils import wrappers -from nvtripy.utils.wrappers import DATA_TYPE_CONSTRAINTS +from nvtripy.frontend import wrappers +from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS from tests import helper # Get all functions/methods which have tensors in the type signature diff --git a/tripy/tests/utils/test_utils.py b/tripy/tests/utils/test_utils.py index 54725b4e1..568e6d223 100644 --- a/tripy/tests/utils/test_utils.py +++ b/tripy/tests/utils/test_utils.py @@ -20,6 +20,7 @@ import nvtripy as tp import pytest from nvtripy import utils +from nvtripy.frontend.wrappers import constant_fields from tests import helper @@ -46,7 +47,7 @@ def test_hash_equivalence(self, func): def make_with_constant_field(): - @utils.wrappers.constant_fields("field") + @constant_fields("field") class WithConstField: def __init__(self): self.custom_setter_called_count = defaultdict(int) From 54cc68008a4187db6884d7c0466020a0e3bd339b Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 24 Oct 2025 17:02:52 -0700 Subject: [PATCH 03/32] Enables input validation based on input constraints --- tripy/nvtripy/frontend/constraints/logic.py | 77 ++++++++++++----- tripy/nvtripy/frontend/ops/cast.py | 39 ++++----- tripy/nvtripy/frontend/wrappers.py | 8 +- .../tests/frontend/constraints/test_logic.py | 82 ++++++++++++++++++- 4 files changed, 157 insertions(+), 49 deletions(-) diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py index 0b53599e6..512a031e6 100644 --- a/tripy/nvtripy/frontend/constraints/logic.py +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -21,6 +21,9 @@ from nvtripy.frontend.constraints.fetcher import Fetcher from nvtripy.utils.result import Result +# TODO (pranavm): Maybe error details should say what went wrong, not what was expected? +# TODO (pranavm): Maybe return a list of violated constraints? Then we can construct error messages outside. + class Logic(Constraints): """ @@ -37,62 +40,77 @@ def __and__(self, other: "Logic") -> "Logic": return And(self, *other.constraints) return And(self, other) + def __or__(self, other: "Logic") -> "Logic": + if isinstance(self, Or): + return Or(*self.constraints, other) + elif isinstance(other, Or): + return Or(self, *other.constraints) + return Or(self, other) + def __invert__(self) -> "Logic": if isinstance(self, Equal): - return NotEqual(self.fetcher1, self.fetcher2) + return NotEqual(self.fetcher, self.fetcher_or_value) return Not(self) class OneOf(Logic): def __init__(self, fetcher: Fetcher, options: Sequence[Any]): self.fetcher = fetcher - self.options = options + # Need to convert generator expressions so we can use them more than once + self.options = list(options) def __call__(self, args: List[Tuple[str, Any]]) -> Result: value = self.fetcher(args) if value in self.options: return Result.ok() - return Result.err([f"Expected {self.fetcher} to be one of {self.options}, but got {value}."]) + return Result.err([f"'{self.fetcher}' to be one of {self.options}, but got {value}"]) def __str__(self): return f"{self.fetcher} is one of {self.options}" +def get_val_or_call_fetcher(fetcher_or_value: Any, args: List[Tuple[str, Any]]) -> Any: + if isinstance(fetcher_or_value, Fetcher): + return fetcher_or_value(args) + return fetcher_or_value + + class Equal(Logic): - def __init__(self, fetcher1: Fetcher, fetcher2: Fetcher): - self.fetcher1 = fetcher1 - self.fetcher2 = fetcher2 + def __init__(self, fetcher: Fetcher, fetcher_or_value: Any): + self.fetcher = fetcher + self.fetcher_or_value = fetcher_or_value def __call__(self, args: List[Tuple[str, Any]]) -> Result: - value1 = self.fetcher1(args) - value2 = self.fetcher2(args) + value1 = self.fetcher(args) + value2 = get_val_or_call_fetcher(self.fetcher_or_value, args) if value1 == value2: return Result.ok() - return Result.err([f"Expected {self.fetcher1} to be equal to {self.fetcher2}, but got {value1} and {value2}."]) + return Result.err([f"'{self.fetcher}' to be equal to '{self.fetcher_or_value}'"]) def __str__(self): - return f"{self.fetcher1} == {self.fetcher2}" + return f"{self.fetcher} == {self.fetcher_or_value}" class NotEqual(Logic): - def __init__(self, fetcher1: Fetcher, fetcher2: Fetcher): - self.fetcher1 = fetcher1 - self.fetcher2 = fetcher2 + def __init__(self, fetcher: Fetcher, fetcher_or_value: Fetcher): + self.fetcher = fetcher + self.fetcher_or_value = fetcher_or_value def __call__(self, args: List[Tuple[str, Any]]) -> Result: - value1 = self.fetcher1(args) - value2 = self.fetcher2(args) + value1 = self.fetcher(args) + value2 = get_val_or_call_fetcher(self.fetcher_or_value, args) if value1 != value2: return Result.ok() - return Result.err([f"Expected {self.fetcher1} to be not equal to {self.fetcher2}, but both were {value1}."]) + return Result.err([f"'{self.fetcher}' to be not equal to '{self.fetcher_or_value}'"]) def __str__(self): - return f"{self.fetcher1} != {self.fetcher2}" + return f"{self.fetcher} != {self.fetcher_or_value}" +# TODO (pranavm): Make And and Or combine errors and include a prefix like: "{...} but condition XYZ failed." class And(Logic): def __init__(self, *constraints: Logic): self.constraints = constraints @@ -102,13 +120,30 @@ def __call__(self, args: List[Tuple[str, Any]]) -> Result: for constraint in self.constraints: result = constraint(args) if not result: - errors.extend(result.error_details) + errors.extend(([" and "] if errors else []) + result.error_details) if errors: return Result.err(errors) return Result.ok() def __str__(self): - return " and ".join(str(constraint) for constraint in self.constraints) + return "(" + " and ".join(str(constraint) for constraint in self.constraints) + ")" + + +class Or(Logic): + def __init__(self, *constraints: Logic): + self.constraints = constraints + + def __call__(self, args: List[Tuple[str, Any]]) -> Result: + all_errors = [] + for constraint in self.constraints: + result = constraint(args) + if result: + return Result.ok() + all_errors.extend(([" or "] if all_errors else []) + result.error_details) + return Result.err(all_errors) + + def __str__(self): + return "(" + " or ".join(str(constraint) for constraint in self.constraints) + ")" class Not(Logic): @@ -118,8 +153,8 @@ def __init__(self, constraint: Logic): def __call__(self, args: List[Tuple[str, Any]]) -> Result: result = self.constraint(args) if result: - return Result.err([f"Expected NOT {self.constraint}, but it was satisfied."]) + return Result.err([str(self)]) return Result.ok() def __str__(self): - return f"NOT ({self.constraint})" + return f"not {self.constraint}" diff --git a/tripy/nvtripy/frontend/ops/cast.py b/tripy/nvtripy/frontend/ops/cast.py index b2a50fbf4..c2bab5ba0 100644 --- a/tripy/nvtripy/frontend/ops/cast.py +++ b/tripy/nvtripy/frontend/ops/cast.py @@ -17,35 +17,26 @@ from nvtripy import export -from nvtripy.common.datatype import bool as tp_bool -from nvtripy.common.datatype import float32, int8 +from nvtripy.common import datatype as dt +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.dequantize import dequantize from nvtripy.frontend.ops.quantize import quantize from nvtripy.trace.ops.cast import Cast -from nvtripy.frontend import wrappers - - -# constraints = ( -# OneOf( -# GetInput("input").dtype, -# [tp.float32, tp.float16, tp.bfloat16, tp.float8, tp.int4, tp.int8, tp.int32, tp.int64, tp.bool], -# ) -# & OneOf( -# GetInput("dtype"), -# [tp.float32, tp.float16, tp.bfloat16, tp.float8, tp.int4, tp.int8, tp.int32, tp.int64, tp.bool], -# ) -# & ~(GetInput("input").dtype == tp.float8 & OneOf(GetInput("dtype"), [tp.int4, tp.int8])) -# & ~(GetInput("input").dtype == tp.int8 & GetInput("dtype") == tp.float8) -# & ~(GetInput("input").dtype == tp.int4 & OneOf(GetInput("dtype"), [tp.float8, tp.int8, tp.int64])) -# ) -# output_guarantees = GetReturn(0).dtype == GetInput("dtype") @register_tensor_method("cast") @export.public_api(document_under="operations/functions") @wrappers.interface( + input_requirements=( + ~((GetInput("input").dtype == dt.float8) & OneOf(GetInput("dtype"), [dt.int4, dt.int8])) + & ~((GetInput("input").dtype == dt.int8) & (GetInput("dtype") == dt.float8)) + & ~((GetInput("input").dtype == dt.int4) & OneOf(GetInput("dtype"), [dt.float8, dt.int8, dt.int64])) + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), + # TODO (pranavm): Remove old dtype constraints system: dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], @@ -95,14 +86,14 @@ def cast(input: "nvtripy.Tensor", dtype: "nvtripy.dtype") -> "nvtripy.Tensor": # If given a quantized input, dequantize before converting. If bool is the target dtype, # we do still need to quantize int8s because it compiles into an MLIR-TRT *comparison* op - if op_utils.is_quantized_dtype(input.dtype) and (input.dtype != int8 or dtype == tp_bool): - dequant_dtype = float32 + if op_utils.is_quantized_dtype(input.dtype) and (input.dtype != dt.int8 or dtype == dt.bool): + dequant_dtype = dt.float32 input = dequantize(input, 1.0, dequant_dtype) if dtype == dequant_dtype: return input - if op_utils.is_quantized_dtype(dtype) and dtype != int8: - if input.dtype != float32: - input = op_utils.create_op(Cast, [input], float32) + if op_utils.is_quantized_dtype(dtype) and dtype != dt.int8: + if input.dtype != dt.float32: + input = op_utils.create_op(Cast, [input], dt.float32) return quantize(input, 1.0, dtype) return op_utils.create_op(Cast, [input], dtype) diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index fa9cdaa17..1906a384d 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -393,11 +393,13 @@ def wrapper(*args, **kwargs): if config.enable_input_validation: if input_requirements is not None: - result = input_requirements(func, merged_args) + result = input_requirements(merged_args) if not result: raise_error( - f"Input requirements not met for function: '{func.__qualname__}'.", - result.error_details, + f"Invalid inputs for function: '{func.__qualname__}'.", + ["Expected: "] + + result.error_details + + [f"\n\nNote: Requirements are:\n {input_requirements}."], ) if config.enable_dtype_checking: diff --git a/tripy/tests/frontend/constraints/test_logic.py b/tripy/tests/frontend/constraints/test_logic.py index f898048a6..58bd9a1bf 100644 --- a/tripy/tests/frontend/constraints/test_logic.py +++ b/tripy/tests/frontend/constraints/test_logic.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from nvtripy.frontend.constraints import And, Equal, GetInput, Not, NotEqual, OneOf +from nvtripy.frontend.constraints import And, Equal, GetInput, Not, NotEqual, OneOf, Or class TestLogic: @@ -34,6 +34,22 @@ def test_operator_and_chaining(self): assert len(combined.constraints) == 3 assert combined([("param1", 2), ("param2", "b"), ("param3", True)]) + def test_operator_or_basic(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + combined = constraint1 | constraint2 + assert isinstance(combined, Or) + assert combined([("param1", 5), ("param2", "b")]) + + def test_operator_or_chaining(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + constraint3 = OneOf(GetInput("param3"), [True, False]) + combined = constraint1 | constraint2 | constraint3 + assert isinstance(combined, Or) + assert len(combined.constraints) == 3 + assert combined([("param1", 5), ("param2", "z"), ("param3", True)]) + def test_operator_not_basic(self): constraint = OneOf(GetInput("param"), [1, 2, 3]) negated = ~constraint @@ -84,6 +100,42 @@ def test_str(self): assert str(and_constraint) == "param1 is one of [1, 2, 3] and param2 is one of ['a', 'b']" +class TestOr: + def test_call_first_passes(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert or_constraint([("param1", 2), ("param2", "z")]) + + def test_call_second_passes(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert or_constraint([("param1", 5), ("param2", "b")]) + + def test_call_all_pass(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert or_constraint([("param1", 2), ("param2", "b")]) + + def test_call_all_fail(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + result = or_constraint([("param1", 5), ("param2", "z")]) + assert not result + errors = result.error_details + assert len(errors) == 2 + assert any("Expected param1 to be one of [1, 2, 3], but got 5" in err for err in errors) + assert any("Expected param2 to be one of ['a', 'b', 'c'], but got z" in err for err in errors) + + def test_str(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + assert str(or_constraint) == "param1 is one of [1, 2, 3] or param2 is one of ['a', 'b']" + + def test_call_multiple_constraints(self): + or_constraint = Or( + OneOf(GetInput("param1"), [1, 2, 3]), + OneOf(GetInput("param2"), ["a", "b", "c"]), + OneOf(GetInput("param3"), [True, False]), + ) + assert or_constraint([("param1", 5), ("param2", "z"), ("param3", True)]) + assert not or_constraint([("param1", 5), ("param2", "z"), ("param3", None)]) + + class TestEqual: def test_call_success(self): constraint = Equal(GetInput("param1"), GetInput("param2")) @@ -103,6 +155,20 @@ def test_operator_on_fetcher(self): constraint = GetInput("param1") == GetInput("param2") assert isinstance(constraint, Equal) + def test_call_success_with_constant(self): + constraint = Equal(GetInput("param1"), 5) + assert constraint([("param1", 5)]) + + def test_call_failure_with_constant(self): + constraint = Equal(GetInput("param1"), 5) + result = constraint([("param1", 10)]) + assert not result + assert "Expected param1 to be equal to 5, but got 10 and 5." in result.error_details + + def test_str_with_constant(self): + constraint = Equal(GetInput("param1"), 5) + assert str(constraint) == "param1 == 5" + class TestNotEqual: def test_call_success(self): @@ -123,6 +189,20 @@ def test_operator_on_fetcher(self): constraint = GetInput("param1") != GetInput("param2") assert isinstance(constraint, NotEqual) + def test_call_success_with_constant(self): + constraint = NotEqual(GetInput("param1"), 5) + assert constraint([("param1", 10)]) + + def test_call_failure_with_constant(self): + constraint = NotEqual(GetInput("param1"), 5) + result = constraint([("param1", 5)]) + assert not result + assert "Expected param1 to be not equal to 5, but both were 5." in result.error_details + + def test_str_with_constant(self): + constraint = NotEqual(GetInput("param1"), 5) + assert str(constraint) == "param1 != 5" + class TestNot: def test_call_success(self): From f36f75560df54a98b4e52686926329f203f0aee1 Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 28 Oct 2025 11:23:29 -0700 Subject: [PATCH 04/32] Improves constraint error messages --- .../nvtripy/frontend/constraints/__init__.py | 2 +- tripy/nvtripy/frontend/constraints/fetcher.py | 17 ++++--- tripy/nvtripy/frontend/constraints/logic.py | 48 +++++++++---------- tripy/nvtripy/frontend/wrappers.py | 1 - .../frontend/constraints/test_fetcher.py | 18 +++---- .../tests/frontend/constraints/test_logic.py | 36 +++++++------- 6 files changed, 60 insertions(+), 62 deletions(-) diff --git a/tripy/nvtripy/frontend/constraints/__init__.py b/tripy/nvtripy/frontend/constraints/__init__.py index f1f037e98..98a0ddd2f 100644 --- a/tripy/nvtripy/frontend/constraints/__init__.py +++ b/tripy/nvtripy/frontend/constraints/__init__.py @@ -16,4 +16,4 @@ # from nvtripy.frontend.constraints.base import Constraints from nvtripy.frontend.constraints.fetcher import Fetcher, GetDataType, GetInput, GetReturn, ValueFetcher -from nvtripy.frontend.constraints.logic import And, Equal, Logic, Not, NotEqual, OneOf +from nvtripy.frontend.constraints.logic import And, Equal, Logic, Not, NotEqual, OneOf, Or diff --git a/tripy/nvtripy/frontend/constraints/fetcher.py b/tripy/nvtripy/frontend/constraints/fetcher.py index fa1638fd7..4987426c8 100644 --- a/tripy/nvtripy/frontend/constraints/fetcher.py +++ b/tripy/nvtripy/frontend/constraints/fetcher.py @@ -15,7 +15,7 @@ # limitations under the License. # from abc import abstractmethod -from typing import Any, List, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple from nvtripy.common.datatype import dtype as tp_dtype from nvtripy.common.exception import raise_error @@ -28,7 +28,7 @@ class Fetcher(Constraints): """ @abstractmethod - def __call__(self, args: List[Tuple[str, Any]]) -> Any: ... + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: ... def __eq__(self, other: "Fetcher") -> "Equal": from nvtripy.frontend.constraints.logic import Equal @@ -51,7 +51,7 @@ class GetInput(ValueFetcher): def __init__(self, name: str): self.name = name - def __call__(self, args: List[Tuple[str, Any]]) -> Any: + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: for name, value in args: if name == self.name: return value @@ -65,10 +65,9 @@ class GetReturn(ValueFetcher): def __init__(self, index: int): self.index = index - def __call__(self, args: List[Tuple[str, Any]]) -> Any: - raise NotImplementedError( - "GetReturn is only used to describe output guarantees and must not be called for input validation purposes." - ) + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: + assert returns is not None, "No return values available." + return returns[self.index] def __str__(self): return f"return[{self.index}]" @@ -78,7 +77,7 @@ class GetDataType(Fetcher): def __init__(self, value_fetcher: ValueFetcher): self.value_fetcher = value_fetcher - def __call__(self, args: List[Tuple[str, Any]]) -> Any: + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: from nvtripy.frontend.tensor import Tensor def get_arg_dtype(arg: Any) -> tp_dtype: @@ -104,7 +103,7 @@ def get_arg_dtype(arg: Any) -> tp_dtype: ) return arg_dtype - tensor = self.value_fetcher(args) + tensor = self.value_fetcher(args, returns) return get_arg_dtype(tensor) def __str__(self): diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py index 512a031e6..a71a47685 100644 --- a/tripy/nvtripy/frontend/constraints/logic.py +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -15,15 +15,12 @@ # limitations under the License. # from abc import abstractmethod -from typing import Any, List, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple from nvtripy.frontend.constraints.base import Constraints from nvtripy.frontend.constraints.fetcher import Fetcher from nvtripy.utils.result import Result -# TODO (pranavm): Maybe error details should say what went wrong, not what was expected? -# TODO (pranavm): Maybe return a list of violated constraints? Then we can construct error messages outside. - class Logic(Constraints): """ @@ -31,7 +28,7 @@ class Logic(Constraints): """ @abstractmethod - def __call__(self, args: List[Tuple[str, Any]]) -> Result: ... + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: ... def __and__(self, other: "Logic") -> "Logic": if isinstance(self, And): @@ -59,20 +56,22 @@ def __init__(self, fetcher: Fetcher, options: Sequence[Any]): # Need to convert generator expressions so we can use them more than once self.options = list(options) - def __call__(self, args: List[Tuple[str, Any]]) -> Result: - value = self.fetcher(args) + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value = self.fetcher(args, returns) if value in self.options: return Result.ok() - return Result.err([f"'{self.fetcher}' to be one of {self.options}, but got {value}"]) + return Result.err([f"'{self.fetcher}' to be one of {self.options} (but it was '{value}')"]) def __str__(self): return f"{self.fetcher} is one of {self.options}" -def get_val_or_call_fetcher(fetcher_or_value: Any, args: List[Tuple[str, Any]]) -> Any: +def get_val_or_call_fetcher( + fetcher_or_value: Any, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None +) -> Any: if isinstance(fetcher_or_value, Fetcher): - return fetcher_or_value(args) + return fetcher_or_value(args, returns) return fetcher_or_value @@ -81,13 +80,13 @@ def __init__(self, fetcher: Fetcher, fetcher_or_value: Any): self.fetcher = fetcher self.fetcher_or_value = fetcher_or_value - def __call__(self, args: List[Tuple[str, Any]]) -> Result: - value1 = self.fetcher(args) - value2 = get_val_or_call_fetcher(self.fetcher_or_value, args) + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value1 = self.fetcher(args, returns) + value2 = get_val_or_call_fetcher(self.fetcher_or_value, args, returns) if value1 == value2: return Result.ok() - return Result.err([f"'{self.fetcher}' to be equal to '{self.fetcher_or_value}'"]) + return Result.err([f"'{self.fetcher}' to be equal to '{self.fetcher_or_value}' (but it was '{value1}')"]) def __str__(self): return f"{self.fetcher} == {self.fetcher_or_value}" @@ -98,27 +97,26 @@ def __init__(self, fetcher: Fetcher, fetcher_or_value: Fetcher): self.fetcher = fetcher self.fetcher_or_value = fetcher_or_value - def __call__(self, args: List[Tuple[str, Any]]) -> Result: - value1 = self.fetcher(args) - value2 = get_val_or_call_fetcher(self.fetcher_or_value, args) + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value1 = self.fetcher(args, returns) + value2 = get_val_or_call_fetcher(self.fetcher_or_value, args, returns) if value1 != value2: return Result.ok() - return Result.err([f"'{self.fetcher}' to be not equal to '{self.fetcher_or_value}'"]) + return Result.err([f"'{self.fetcher}' to be not equal to '{self.fetcher_or_value}' (but it was '{value1}')"]) def __str__(self): return f"{self.fetcher} != {self.fetcher_or_value}" -# TODO (pranavm): Make And and Or combine errors and include a prefix like: "{...} but condition XYZ failed." class And(Logic): def __init__(self, *constraints: Logic): self.constraints = constraints - def __call__(self, args: List[Tuple[str, Any]]) -> Result: + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: errors = [] for constraint in self.constraints: - result = constraint(args) + result = constraint(args, returns) if not result: errors.extend(([" and "] if errors else []) + result.error_details) if errors: @@ -133,10 +131,10 @@ class Or(Logic): def __init__(self, *constraints: Logic): self.constraints = constraints - def __call__(self, args: List[Tuple[str, Any]]) -> Result: + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: all_errors = [] for constraint in self.constraints: - result = constraint(args) + result = constraint(args, returns) if result: return Result.ok() all_errors.extend(([" or "] if all_errors else []) + result.error_details) @@ -150,8 +148,8 @@ class Not(Logic): def __init__(self, constraint: Logic): self.constraint = constraint - def __call__(self, args: List[Tuple[str, Any]]) -> Result: - result = self.constraint(args) + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + result = self.constraint(args, returns) if result: return Result.err([str(self)]) return Result.ok() diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index 1906a384d..eb734c5e9 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -15,7 +15,6 @@ # limitations under the License. # -# TODO (pranavm): Move into frontend - only used there. import functools import inspect diff --git a/tripy/tests/frontend/constraints/test_fetcher.py b/tripy/tests/frontend/constraints/test_fetcher.py index d0f1c111a..dbfcc3b40 100644 --- a/tripy/tests/frontend/constraints/test_fetcher.py +++ b/tripy/tests/frontend/constraints/test_fetcher.py @@ -26,16 +26,16 @@ def test_eq_operator_returns_equal(self): fetcher2 = GetInput("param2") constraint = fetcher1 == fetcher2 assert isinstance(constraint, Equal) - assert constraint.fetcher1 == fetcher1 - assert constraint.fetcher2 == fetcher2 + assert constraint.fetcher == fetcher1 + assert constraint.fetcher_or_value == fetcher2 def test_ne_operator_returns_not_equal(self): fetcher1 = GetInput("param1") fetcher2 = GetInput("param2") constraint = fetcher1 != fetcher2 assert isinstance(constraint, NotEqual) - assert constraint.fetcher1 == fetcher1 - assert constraint.fetcher2 == fetcher2 + assert constraint.fetcher == fetcher1 + assert constraint.fetcher_or_value == fetcher2 class TestValueFetcher: @@ -62,11 +62,13 @@ def test_init(self): fetcher = GetReturn(0) assert fetcher.index == 0 - def test_call_raises_not_implemented(self): + def test_call(self): fetcher = GetReturn(0) - args = [("input", 42)] - with helper.raises(NotImplementedError, match="GetReturn is only used to describe output guarantees"): - fetcher(args) + returns = (42, "hello", 3.14) + assert fetcher([], returns) == 42 + + fetcher2 = GetReturn(2) + assert fetcher2([], returns) == 3.14 def test_str(self): fetcher = GetReturn(0) diff --git a/tripy/tests/frontend/constraints/test_logic.py b/tripy/tests/frontend/constraints/test_logic.py index 58bd9a1bf..a565a7709 100644 --- a/tripy/tests/frontend/constraints/test_logic.py +++ b/tripy/tests/frontend/constraints/test_logic.py @@ -68,7 +68,7 @@ def test_call_failure(self): constraint = OneOf(GetInput("param"), [1, 2, 3]) result = constraint([("param", 5)]) assert not result - assert "Expected param to be one of [1, 2, 3], but got 5." in result.error_details + assert "'param' to be one of [1, 2, 3] (but it was '5')" in result.error_details def test_str(self): constraint = OneOf(GetInput("param"), [1, 2, 3]) @@ -84,20 +84,20 @@ def test_call_one_fails(self): and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) result = and_constraint([("param1", 5), ("param2", "b")]) assert not result - assert "Expected param1 to be one of [1, 2, 3], but got 5." in result.error_details + assert "'param1' to be one of [1, 2, 3] (but it was '5')" in result.error_details def test_call_all_fail(self): and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) result = and_constraint([("param1", 5), ("param2", "z")]) assert not result - errors = result.error_details - assert len(errors) == 2 - assert any("Expected param1 to be one of [1, 2, 3], but got 5" in err for err in errors) - assert any("Expected param2 to be one of ['a', 'b', 'c'], but got z" in err for err in errors) + assert ( + "".join(result.error_details) + == "'param1' to be one of [1, 2, 3] (but it was '5') and 'param2' to be one of ['a', 'b', 'c'] (but it was 'z')" + ) def test_str(self): and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) - assert str(and_constraint) == "param1 is one of [1, 2, 3] and param2 is one of ['a', 'b']" + assert str(and_constraint) == "(param1 is one of [1, 2, 3] and param2 is one of ['a', 'b'])" class TestOr: @@ -117,14 +117,14 @@ def test_call_all_fail(self): or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) result = or_constraint([("param1", 5), ("param2", "z")]) assert not result - errors = result.error_details - assert len(errors) == 2 - assert any("Expected param1 to be one of [1, 2, 3], but got 5" in err for err in errors) - assert any("Expected param2 to be one of ['a', 'b', 'c'], but got z" in err for err in errors) + assert ( + "".join(result.error_details) + == "'param1' to be one of [1, 2, 3] (but it was '5') or 'param2' to be one of ['a', 'b', 'c'] (but it was 'z')" + ) def test_str(self): or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) - assert str(or_constraint) == "param1 is one of [1, 2, 3] or param2 is one of ['a', 'b']" + assert str(or_constraint) == "(param1 is one of [1, 2, 3] or param2 is one of ['a', 'b'])" def test_call_multiple_constraints(self): or_constraint = Or( @@ -145,7 +145,7 @@ def test_call_failure(self): constraint = Equal(GetInput("param1"), GetInput("param2")) result = constraint([("param1", 5), ("param2", 10)]) assert not result - assert "Expected param1 to be equal to param2, but got 5 and 10." in result.error_details + assert "'param1' to be equal to 'param2' (but it was '5')" in result.error_details def test_str(self): constraint = Equal(GetInput("param1"), GetInput("param2")) @@ -163,7 +163,7 @@ def test_call_failure_with_constant(self): constraint = Equal(GetInput("param1"), 5) result = constraint([("param1", 10)]) assert not result - assert "Expected param1 to be equal to 5, but got 10 and 5." in result.error_details + assert "'param1' to be equal to '5' (but it was '10')" in result.error_details def test_str_with_constant(self): constraint = Equal(GetInput("param1"), 5) @@ -179,7 +179,7 @@ def test_call_failure(self): constraint = GetInput("param1") != GetInput("param2") result = constraint([("param1", 5), ("param2", 5)]) assert not result - assert "Expected param1 to be not equal to param2, but both were 5." in result.error_details + assert "'param1' to be not equal to 'param2' (but it was '5')" in result.error_details def test_str(self): constraint = GetInput("param1") != GetInput("param2") @@ -197,7 +197,7 @@ def test_call_failure_with_constant(self): constraint = NotEqual(GetInput("param1"), 5) result = constraint([("param1", 5)]) assert not result - assert "Expected param1 to be not equal to 5, but both were 5." in result.error_details + assert "'param1' to be not equal to '5' (but it was '5')" in result.error_details def test_str_with_constant(self): constraint = NotEqual(GetInput("param1"), 5) @@ -213,11 +213,11 @@ def test_call_failure(self): constraint = Not(OneOf(GetInput("param"), [1, 2, 3])) result = constraint([("param", 2)]) assert not result - assert "Expected NOT param is one of [1, 2, 3], but it was satisfied." in result.error_details + assert "not param is one of [1, 2, 3]" == result.error_details[0] def test_str(self): constraint = Not(OneOf(GetInput("param"), [1, 2, 3])) - assert str(constraint) == "NOT (param is one of [1, 2, 3])" + assert str(constraint) == "not param is one of [1, 2, 3]" def test_double_negation(self): constraint = Not(Not(OneOf(GetInput("param"), [1, 2, 3]))) From 0cf6dcb8a7b68be4883c0ee4783f5d90ffe4244f Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 28 Oct 2025 13:15:38 -0700 Subject: [PATCH 05/32] Removes `Not` constraint --- .../nvtripy/frontend/constraints/__init__.py | 2 +- tripy/nvtripy/frontend/constraints/logic.py | 59 +++++--- tripy/nvtripy/frontend/ops/cast.py | 6 +- tripy/nvtripy/frontend/wrappers.py | 2 +- .../tests/frontend/constraints/test_logic.py | 128 +++++++++--------- 5 files changed, 109 insertions(+), 88 deletions(-) diff --git a/tripy/nvtripy/frontend/constraints/__init__.py b/tripy/nvtripy/frontend/constraints/__init__.py index 98a0ddd2f..2889b36bb 100644 --- a/tripy/nvtripy/frontend/constraints/__init__.py +++ b/tripy/nvtripy/frontend/constraints/__init__.py @@ -16,4 +16,4 @@ # from nvtripy.frontend.constraints.base import Constraints from nvtripy.frontend.constraints.fetcher import Fetcher, GetDataType, GetInput, GetReturn, ValueFetcher -from nvtripy.frontend.constraints.logic import And, Equal, Logic, Not, NotEqual, OneOf, Or +from nvtripy.frontend.constraints.logic import And, Equal, Logic, NotEqual, NotOneOf, OneOf, Or diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py index a71a47685..36969fbed 100644 --- a/tripy/nvtripy/frontend/constraints/logic.py +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -30,6 +30,13 @@ class Logic(Constraints): @abstractmethod def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: ... + @abstractmethod + def inverse(self) -> "Logic": + """ + Returns the logical inverse of this constraint. + """ + ... + def __and__(self, other: "Logic") -> "Logic": if isinstance(self, And): return And(*self.constraints, other) @@ -45,9 +52,7 @@ def __or__(self, other: "Logic") -> "Logic": return Or(self, other) def __invert__(self) -> "Logic": - if isinstance(self, Equal): - return NotEqual(self.fetcher, self.fetcher_or_value) - return Not(self) + return self.inverse() class OneOf(Logic): @@ -66,6 +71,28 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return f"{self.fetcher} is one of {self.options}" + def inverse(self) -> "Logic": + return NotOneOf(self.fetcher, self.options) + + +class NotOneOf(Logic): + def __init__(self, fetcher: Fetcher, options: Sequence[Any]): + self.fetcher = fetcher + self.options = list(options) + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value = self.fetcher(args, returns) + if value not in self.options: + return Result.ok() + + return Result.err([f"'{self.fetcher}' to not be one of {self.options} (but it was '{value}')"]) + + def __str__(self): + return f"{self.fetcher} is not one of {self.options}" + + def inverse(self) -> "Logic": + return OneOf(self.fetcher, self.options) + def get_val_or_call_fetcher( fetcher_or_value: Any, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None @@ -91,6 +118,9 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return f"{self.fetcher} == {self.fetcher_or_value}" + def inverse(self) -> "Logic": + return NotEqual(self.fetcher, self.fetcher_or_value) + class NotEqual(Logic): def __init__(self, fetcher: Fetcher, fetcher_or_value: Fetcher): @@ -108,6 +138,9 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return f"{self.fetcher} != {self.fetcher_or_value}" + def inverse(self) -> "Logic": + return Equal(self.fetcher, self.fetcher_or_value) + class And(Logic): def __init__(self, *constraints: Logic): @@ -126,6 +159,10 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return "(" + " and ".join(str(constraint) for constraint in self.constraints) + ")" + def inverse(self) -> "Logic": + # Apply De Morgan's law: not (A and B) = (not A) or (not B) + return Or(*(constraint.inverse() for constraint in self.constraints)) + class Or(Logic): def __init__(self, *constraints: Logic): @@ -143,16 +180,6 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return "(" + " or ".join(str(constraint) for constraint in self.constraints) + ")" - -class Not(Logic): - def __init__(self, constraint: Logic): - self.constraint = constraint - - def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: - result = self.constraint(args, returns) - if result: - return Result.err([str(self)]) - return Result.ok() - - def __str__(self): - return f"not {self.constraint}" + def inverse(self) -> "Logic": + # Apply De Morgan's law: not (A or B) = (not A) and (not B) + return And(*(constraint.inverse() for constraint in self.constraints)) diff --git a/tripy/nvtripy/frontend/ops/cast.py b/tripy/nvtripy/frontend/ops/cast.py index c2bab5ba0..ecc4e44a1 100644 --- a/tripy/nvtripy/frontend/ops/cast.py +++ b/tripy/nvtripy/frontend/ops/cast.py @@ -31,9 +31,9 @@ @export.public_api(document_under="operations/functions") @wrappers.interface( input_requirements=( - ~((GetInput("input").dtype == dt.float8) & OneOf(GetInput("dtype"), [dt.int4, dt.int8])) - & ~((GetInput("input").dtype == dt.int8) & (GetInput("dtype") == dt.float8)) - & ~((GetInput("input").dtype == dt.int4) & OneOf(GetInput("dtype"), [dt.float8, dt.int8, dt.int64])) + ((GetInput("input").dtype != dt.float8) | ~OneOf(GetInput("dtype"), [dt.int4, dt.int8])) + & ((GetInput("input").dtype != dt.int8) | (GetInput("dtype") != dt.float8)) + & ((GetInput("input").dtype != dt.int4) | ~OneOf(GetInput("dtype"), [dt.float8, dt.int8, dt.int64])) ), output_guarantees=GetReturn(0).dtype == GetInput("dtype"), # TODO (pranavm): Remove old dtype constraints system: diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index eb734c5e9..f4fee7316 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -398,7 +398,7 @@ def wrapper(*args, **kwargs): f"Invalid inputs for function: '{func.__qualname__}'.", ["Expected: "] + result.error_details - + [f"\n\nNote: Requirements are:\n {input_requirements}."], + + [f".\n\nNote: Requirements are:\n {input_requirements}."], ) if config.enable_dtype_checking: diff --git a/tripy/tests/frontend/constraints/test_logic.py b/tripy/tests/frontend/constraints/test_logic.py index a565a7709..c23d500b5 100644 --- a/tripy/tests/frontend/constraints/test_logic.py +++ b/tripy/tests/frontend/constraints/test_logic.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from nvtripy.frontend.constraints import And, Equal, GetInput, Not, NotEqual, OneOf, Or +from nvtripy.frontend.constraints import And, Equal, GetInput, NotEqual, NotOneOf, OneOf, Or class TestLogic: @@ -53,26 +53,47 @@ def test_operator_or_chaining(self): def test_operator_not_basic(self): constraint = OneOf(GetInput("param"), [1, 2, 3]) negated = ~constraint - assert isinstance(negated, Not) + assert isinstance(negated, NotOneOf) assert negated([("param", 5)]) assert not negated([("param", 2)]) class TestOneOf: - def test_call_success(self): - constraint = OneOf(GetInput("param"), [1, 2, 3]) - result = constraint([("param", 2)]) - assert result - - def test_call_failure(self): + def test_call(self): constraint = OneOf(GetInput("param"), [1, 2, 3]) + assert constraint([("param", 2)]) result = constraint([("param", 5)]) assert not result assert "'param' to be one of [1, 2, 3] (but it was '5')" in result.error_details def test_str(self): + assert str(OneOf(GetInput("param"), [1, 2, 3])) == "param is one of [1, 2, 3]" + + def test_inverse(self): constraint = OneOf(GetInput("param"), [1, 2, 3]) - assert str(constraint) == "param is one of [1, 2, 3]" + inverse = constraint.inverse() + assert isinstance(inverse, NotOneOf) + assert inverse([("param", 5)]) + assert not inverse([("param", 2)]) + + +class TestNotOneOf: + def test_call(self): + constraint = NotOneOf(GetInput("param"), [1, 2, 3]) + assert constraint([("param", 5)]) + result = constraint([("param", 2)]) + assert not result + assert "'param' to not be one of [1, 2, 3] (but it was '2')" in result.error_details + + def test_str(self): + assert str(NotOneOf(GetInput("param"), [1, 2, 3])) == "param is not one of [1, 2, 3]" + + def test_inverse(self): + constraint = NotOneOf(GetInput("param"), [1, 2, 3]) + inverse = constraint.inverse() + assert isinstance(inverse, OneOf) + assert inverse([("param", 2)]) + assert not inverse([("param", 5)]) class TestAnd: @@ -99,6 +120,13 @@ def test_str(self): and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) assert str(and_constraint) == "(param1 is one of [1, 2, 3] and param2 is one of ['a', 'b'])" + def test_inverse(self): + # De Morgan's law: not (A and B) = (not A) or (not B) + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + inverse = and_constraint.inverse() + assert isinstance(inverse, Or) + assert str(inverse) == "(param1 is not one of [1, 2, 3] or param2 is not one of ['a', 'b'])" + class TestOr: def test_call_first_passes(self): @@ -135,91 +163,57 @@ def test_call_multiple_constraints(self): assert or_constraint([("param1", 5), ("param2", "z"), ("param3", True)]) assert not or_constraint([("param1", 5), ("param2", "z"), ("param3", None)]) + def test_inverse(self): + # De Morgan's law: not (A or B) = (not A) and (not B) + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + inverse = or_constraint.inverse() + assert isinstance(inverse, And) + assert str(inverse) == "(param1 is not one of [1, 2, 3] and param2 is not one of ['a', 'b'])" + class TestEqual: - def test_call_success(self): + def test_call(self): constraint = Equal(GetInput("param1"), GetInput("param2")) assert constraint([("param1", 5), ("param2", 5)]) - - def test_call_failure(self): - constraint = Equal(GetInput("param1"), GetInput("param2")) result = constraint([("param1", 5), ("param2", 10)]) assert not result assert "'param1' to be equal to 'param2' (but it was '5')" in result.error_details def test_str(self): - constraint = Equal(GetInput("param1"), GetInput("param2")) - assert str(constraint) == "param1 == param2" + assert str(Equal(GetInput("param1"), GetInput("param2"))) == "param1 == param2" + assert str(Equal(GetInput("param1"), 5)) == "param1 == 5" def test_operator_on_fetcher(self): constraint = GetInput("param1") == GetInput("param2") assert isinstance(constraint, Equal) - def test_call_success_with_constant(self): - constraint = Equal(GetInput("param1"), 5) - assert constraint([("param1", 5)]) - - def test_call_failure_with_constant(self): + def test_inverse(self): constraint = Equal(GetInput("param1"), 5) - result = constraint([("param1", 10)]) - assert not result - assert "'param1' to be equal to '5' (but it was '10')" in result.error_details - - def test_str_with_constant(self): - constraint = Equal(GetInput("param1"), 5) - assert str(constraint) == "param1 == 5" + inverse = constraint.inverse() + assert isinstance(inverse, NotEqual) + assert inverse([("param1", 10)]) + assert not inverse([("param1", 5)]) class TestNotEqual: - def test_call_success(self): - constraint = GetInput("param1") != GetInput("param2") + def test_call(self): + constraint = NotEqual(GetInput("param1"), GetInput("param2")) assert constraint([("param1", 5), ("param2", 10)]) - - def test_call_failure(self): - constraint = GetInput("param1") != GetInput("param2") result = constraint([("param1", 5), ("param2", 5)]) assert not result assert "'param1' to be not equal to 'param2' (but it was '5')" in result.error_details def test_str(self): - constraint = GetInput("param1") != GetInput("param2") - assert str(constraint) == "param1 != param2" + assert str(NotEqual(GetInput("param1"), GetInput("param2"))) == "param1 != param2" + assert str(NotEqual(GetInput("param1"), 5)) == "param1 != 5" def test_operator_on_fetcher(self): constraint = GetInput("param1") != GetInput("param2") assert isinstance(constraint, NotEqual) - def test_call_success_with_constant(self): - constraint = NotEqual(GetInput("param1"), 5) - assert constraint([("param1", 10)]) - - def test_call_failure_with_constant(self): - constraint = NotEqual(GetInput("param1"), 5) - result = constraint([("param1", 5)]) - assert not result - assert "'param1' to be not equal to '5' (but it was '5')" in result.error_details - - def test_str_with_constant(self): + def test_inverse(self): constraint = NotEqual(GetInput("param1"), 5) - assert str(constraint) == "param1 != 5" - - -class TestNot: - def test_call_success(self): - constraint = Not(OneOf(GetInput("param"), [1, 2, 3])) - assert constraint([("param", 5)]) - - def test_call_failure(self): - constraint = Not(OneOf(GetInput("param"), [1, 2, 3])) - result = constraint([("param", 2)]) - assert not result - assert "not param is one of [1, 2, 3]" == result.error_details[0] - - def test_str(self): - constraint = Not(OneOf(GetInput("param"), [1, 2, 3])) - assert str(constraint) == "not param is one of [1, 2, 3]" - - def test_double_negation(self): - constraint = Not(Not(OneOf(GetInput("param"), [1, 2, 3]))) - assert constraint([("param", 2)]) - assert not constraint([("param", 5)]) + inverse = constraint.inverse() + assert isinstance(inverse, Equal) + assert inverse([("param1", 5)]) + assert not inverse([("param1", 10)]) From b662350b9ed9a935833ae8a0f265159bb036da0e Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 29 Oct 2025 11:16:58 -0700 Subject: [PATCH 06/32] Enables constraints to be documented in a human-readable way Adds a `_doc_str()` function which will generate human-readable text for a given constraint. --- tripy/nvtripy/frontend/constraints/logic.py | 6 +- tripy/nvtripy/frontend/wrappers.py | 62 +++++++++++++++- .../{test_interface.py => test_wrappers.py} | 71 ++++++++++++++++++- 3 files changed, 134 insertions(+), 5 deletions(-) rename tripy/tests/frontend/wrappers/{test_interface.py => test_wrappers.py} (73%) mode change 100755 => 100644 diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py index 36969fbed..de5e6982d 100644 --- a/tripy/nvtripy/frontend/constraints/logic.py +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -27,6 +27,7 @@ class Logic(Constraints): Represents logical operations on constraints. """ + # When the constraint is not met, the error details should complete the sentence: "Expected ..." @abstractmethod def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: ... @@ -113,6 +114,7 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = if value1 == value2: return Result.ok() + # TODO (pranavm): If fetcher_or_value is a Fetcher, include its value in the error message. return Result.err([f"'{self.fetcher}' to be equal to '{self.fetcher_or_value}' (but it was '{value1}')"]) def __str__(self): @@ -160,7 +162,7 @@ def __str__(self): return "(" + " and ".join(str(constraint) for constraint in self.constraints) + ")" def inverse(self) -> "Logic": - # Apply De Morgan's law: not (A and B) = (not A) or (not B) + # De Morgan's law: not (A and B) = (not A) or (not B) return Or(*(constraint.inverse() for constraint in self.constraints)) @@ -181,5 +183,5 @@ def __str__(self): return "(" + " or ".join(str(constraint) for constraint in self.constraints) + ")" def inverse(self) -> "Logic": - # Apply De Morgan's law: not (A or B) = (not A) and (not B) + # De Morgan's law: not (A or B) = (not A) and (not B) return And(*(constraint.inverse() for constraint in self.constraints)) diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index f4fee7316..a96d3aa7e 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -233,8 +233,63 @@ def add_arg(arg): return new_args, new_kwargs, new_merged_args +def _doc_str(obj: Any) -> str: + """ + Returns a string representation of an object for use in the public documentation. + """ + from nvtripy.common.datatype import dtype as tp_dtype + from nvtripy.frontend.constraints.logic import And, Equal, NotEqual, NotOneOf, OneOf, Or + from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn + + if isinstance(obj, tp_dtype): + return f":class:`{obj.name}`" + + if isinstance(obj, GetInput): + return f"``{obj.name}``" + elif isinstance(obj, GetReturn): + return f"``return[{obj.index}]``" + elif isinstance(obj, GetDataType): + # Intentionally do not use _doc_str() on the value_fetcher so we can wrap it in backticks correctly. + return f"``{obj.value_fetcher}.dtype``" + elif isinstance(obj, OneOf): + return f"{_doc_str(obj.fetcher)} is one of [{', '.join(f'{_doc_str(opt)}' for opt in obj.options)}]" + elif isinstance(obj, NotOneOf): + return f"{_doc_str(obj.fetcher)} is not one of [{', '.join(f'{_doc_str(opt)}' for opt in obj.options)}]" + elif isinstance(obj, Equal): + return f"{_doc_str(obj.fetcher)} == {_doc_str(obj.fetcher_or_value)}" + elif isinstance(obj, NotEqual): + return f"{_doc_str(obj.fetcher)} != {_doc_str(obj.fetcher_or_value)}" + elif isinstance(obj, And): + return ", **and**\n".join("- " + indent(_doc_str(constraint), " ").lstrip() for constraint in obj.constraints) + elif isinstance(obj, Or): + return "(" + " *or* ".join(_doc_str(constraint) for constraint in obj.constraints) + ")" + + assert False, f"Unsupported object type for doc string generation: {type(obj)}. Please add handling here!" + + +# Modify the docstring to include constraints +def _update_docstring(func, input_requirements, output_guarantees): + if not func.__doc__: + return + + indentation = " " * 4 + code_block_index = func.__doc__.find(".. code-block:: python") + assert code_block_index != -1, f"No code example in docstring for {func.__name__}" + + input_requirements_str = f"\nINPUT REQUIREMENTS:\n{indent(_doc_str(input_requirements), indentation)}\n" + output_guarantees_str = f"\nOUTPUT GUARANTEES:\n{indent(_doc_str(output_guarantees), indentation)}\n" + + func.__doc__ = ( + func.__doc__[:code_block_index] + + indent(input_requirements_str + output_guarantees_str, indentation) + + "\n" + + indentation + + func.__doc__[code_block_index:] + ) + + # Modify the docstring to mention data type variables and exceptions -def _update_docstring(func, dtype_constraints, dtype_variables, dtype_exceptions): +def _update_docstring_legacy(func, dtype_constraints, dtype_variables, dtype_exceptions): if not func.__doc__: return @@ -371,7 +426,10 @@ def decorator(func): DataTypeConstraints(func, dtype_constraints, dtype_variables, dtype_exceptions) ) - _update_docstring(func, dtype_constraints, dtype_variables, dtype_exceptions) + if input_requirements is not None: + _update_docstring(func, input_requirements, output_guarantees) + elif dtype_constraints or dtype_variables or dtype_exceptions: + _update_docstring_legacy(func, dtype_constraints, dtype_variables, dtype_exceptions) @functools.wraps(func) def wrapper(*args, **kwargs): diff --git a/tripy/tests/frontend/wrappers/test_interface.py b/tripy/tests/frontend/wrappers/test_wrappers.py old mode 100755 new mode 100644 similarity index 73% rename from tripy/tests/frontend/wrappers/test_interface.py rename to tripy/tests/frontend/wrappers/test_wrappers.py index 1ca042bbc..602f9ba76 --- a/tripy/tests/frontend/wrappers/test_interface.py +++ b/tripy/tests/frontend/wrappers/test_wrappers.py @@ -23,7 +23,9 @@ import pytest from nvtripy.export import PUBLIC_APIS from nvtripy.frontend import wrappers -from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS +from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn +from nvtripy.frontend.constraints.logic import And, Equal, NotEqual, NotOneOf, OneOf, Or +from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS, _doc_str from tests import helper # Get all functions/methods which have tensors in the type signature @@ -59,6 +61,73 @@ def sequence_func(tensors: List[tp.Tensor]): return +class TestDocStr: + def test_basic_types(self): + assert _doc_str(tp.float32) == ":class:`float32`" + assert _doc_str(GetInput("x")) == "``x``" + assert _doc_str(GetReturn(0)) == "``return[0]``" + + def test_get_datatype(self): + assert _doc_str(GetDataType(GetInput("x"))) == "``x.dtype``" + assert _doc_str(GetDataType(GetReturn(0))) == "``return[0].dtype``" + + def test_one_of_and_not_one_of(self): + input_x = GetInput("x") + + assert ( + _doc_str(OneOf(input_x, [tp.float32, tp.float16])) == "``x`` is one of [:class:`float32`, :class:`float16`]" + ) + assert _doc_str(NotOneOf(input_x, [tp.int8, tp.int32])) == "``x`` is not one of [:class:`int8`, :class:`int32`]" + + def test_equal_and_not_equal(self): + input_a = GetInput("a") + input_b = GetInput("b") + + assert _doc_str(Equal(input_a, input_b)) == "``a`` == ``b``" + assert _doc_str(Equal(input_a, tp.float32)) == "``a`` == :class:`float32`" + assert _doc_str(NotEqual(input_a, input_b)) == "``a`` != ``b``" + + def test_and_constraint(self): + constraint1 = OneOf(GetInput("a"), [tp.float32]) + constraint2 = OneOf(GetInput("b"), [tp.int32]) + + assert ( + _doc_str(And(constraint1, constraint2)) + == "- ``a`` is one of [:class:`float32`]\n- ``b`` is one of [:class:`int32`]" + ) + + def test_or_constraint(self): + input_a = GetInput("a") + or_constraint = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) + + assert _doc_str(or_constraint) == "(``a`` == :class:`float32` or ``a`` == :class:`float16`)" + + def test_nested_constraints(self): + input_a = GetInput("a") + input_b = GetInput("b") + + or_part = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) + and_constraint = And(or_part, OneOf(input_b, [tp.int32])) + + assert ( + _doc_str(and_constraint) + == "- (``a`` == :class:`float32` or ``a`` == :class:`float16`)\n- ``b`` is one of [:class:`int32`]" + ) + + def test_complex_real_world_constraint(self): + input_a = GetInput("input") + input_b = GetInput("other") + dtype_a = GetDataType(input_a) + dtype_b = GetDataType(input_b) + + and_constraint = And(Equal(dtype_a, dtype_b), OneOf(dtype_a, [tp.float32, tp.float16])) + + assert ( + _doc_str(and_constraint) + == "- ``input.dtype`` == ``other.dtype``\n- ``input.dtype`` is one of [:class:`float32`, :class:`float16`]" + ) + + class TestDtypes: def test_works_with_sequences(self): sequence_func([tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32)]) From 5576a7d85be8a2f0ddd265bf01f89aded49f8d52 Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 31 Oct 2025 16:12:48 -0700 Subject: [PATCH 07/32] Updates automatic casting logic to use constraints system --- tripy/nvtripy/frontend/wrappers.py | 97 +++++++++++++++++-- .../tests/frontend/wrappers/test_wrappers.py | 87 ++++++++++++++++- 2 files changed, 169 insertions(+), 15 deletions(-) diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index a96d3aa7e..e9c7307e4 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -25,7 +25,7 @@ from nvtripy import config, utils from nvtripy.common.datatype import DATA_TYPES from nvtripy.common.exception import raise_error -from nvtripy.frontend.constraints import Constraints +from nvtripy.frontend.constraints import Constraints, Equal, GetInput, GetDataType, Fetcher from nvtripy.utils import result @@ -137,17 +137,87 @@ def get_arg_dtype(arg, func_name, arg_name) -> result.Result["nvtripy.dtype"]: return result.Result.ok(arg_dtype) +def _find_known_datatypes( + merged_args: List[Tuple[str, Any]], input_requirements: Constraints +) -> Dict[str, "nvtripy.dtype"]: + + # We perform this operation in two steps: + # 1. Identify all arguments that are expected to have equal data types. + # 2. Propagate known data types to all arguments in each equality set. + expected_equal_dtype: List[Set[str]] = [] + + def insert_pair(name1, name2): + for pair_set in expected_equal_dtype: + if name1 in pair_set or name2 in pair_set: + pair_set.update({name1, name2}) + return + expected_equal_dtype.append({name1, name2}) + + known_dtypes: Dict[str, Optional["nvtripy.dtype"]] = {} + for name, _ in merged_args: + + # If this argument already has a known dtype, populate it: + try: + known_dtypes[name] = GetDataType(GetInput(name))(merged_args) + except Exception: + pass + + def process_dtype_equality(matched_constraints, input_is_lhs): + for constraint in matched_constraints: + expected = constraint.fetcher_or_value if input_is_lhs else constraint.fetcher + if isinstance(expected, GetDataType): + # This might be too restrictive, in which case the assertion can be lifted and the logic here updated. + # However, it should generally be the case that input requirements are in terms of the inputs. + assert isinstance( + expected.value_fetcher, GetInput + ), f"Input requirements should only look at inputs" + other_name = expected.value_fetcher.name + + insert_pair(name, other_name) + + try: + known_dtypes[other_name] = expected(merged_args) + except Exception: + # dtype is not yet known (i.e. might be comparing two inputs with unknown dtypes) + pass + + else: + known_dtypes[name] = expected(merged_args) if isinstance(expected, Fetcher) else expected + + # NOTE: Because we check for the input on both sides of the equality, we do not need to do another pass over + # expected_equal_dtype to merge disjoint sets - if we have a transitive equality like: + # `a == c and b == d and b == a`, then `a`, `c`, `b` will be immediately added to the same set, which `d` will + # join when we process `b`. + process_dtype_equality(input_requirements.find(Equal(GetDataType(GetInput(name)), None)), input_is_lhs=True) + process_dtype_equality(input_requirements.find(Equal(None, GetDataType(GetInput(name)))), input_is_lhs=False) + + # We do not need to perform validation, as that will happen during constraints checking. + for dtype_set in expected_equal_dtype: + known_dtype_in_set = None + for name in dtype_set: + if name in known_dtypes: + known_dtype_in_set = known_dtypes[name] + break + + # dtype might be unknown if the arguments are all non-tensor types. + for name in dtype_set: + known_dtypes[name] = known_dtype_in_set + + return known_dtypes + + # Performs type conversions if needed. Returns updated values of args, kwargs, and merged args def convert_input_types( func, args, kwargs, - merged_args, + merged_args: List[Tuple[str, Any]], var_arg_info, conversion_targets, conversion_preprocess_func, dtype_constraints, shape_likes, + input_requirements: Constraints, ): from nvtripy.common.datatype import bool as tp_bool from nvtripy.common.datatype import floating, integer @@ -168,13 +238,16 @@ def convert_input_types( else: merged_args[index] = (name, new_args[name]) - # Materialize type variables from tensors. - type_vars = {} - for name, arg in merged_args: - if name in dtype_constraints: - dtype_result = get_arg_dtype(arg, func.__qualname__, name) - if dtype_result: - type_vars[dtype_constraints[name]] = dtype_result.value + if input_requirements is not None: + known_datatypes = _find_known_datatypes(merged_args, input_requirements) + else: + # Materialize type variables from tensors. + type_vars = {} + for name, arg in merged_args: + if name in dtype_constraints: + dtype_result = get_arg_dtype(arg, func.__qualname__, name) + if dtype_result: + type_vars[dtype_constraints[name]] = dtype_result.value new_args = [] new_kwargs = {} @@ -207,7 +280,10 @@ def add_arg(arg): ) dtype = None - if name in dtype_constraints and dtype_constraints[name] in type_vars: + if input_requirements is not None: + dtype = known_datatypes.get(name) + elif name in dtype_constraints and dtype_constraints[name] in type_vars: + # TODO (pranavm): Remove this deprecated path. dtype = type_vars[dtype_constraints[name]] if dtype is not None: @@ -446,6 +522,7 @@ def wrapper(*args, **kwargs): conversion_preprocess_func, dtype_constraints, shape_likes, + input_requirements, ) if config.enable_input_validation: diff --git a/tripy/tests/frontend/wrappers/test_wrappers.py b/tripy/tests/frontend/wrappers/test_wrappers.py index 602f9ba76..5839842fe 100644 --- a/tripy/tests/frontend/wrappers/test_wrappers.py +++ b/tripy/tests/frontend/wrappers/test_wrappers.py @@ -25,7 +25,7 @@ from nvtripy.frontend import wrappers from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn from nvtripy.frontend.constraints.logic import And, Equal, NotEqual, NotOneOf, OneOf, Or -from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS, _doc_str +from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS, _doc_str, _find_known_datatypes from tests import helper # Get all functions/methods which have tensors in the type signature @@ -93,14 +93,14 @@ def test_and_constraint(self): assert ( _doc_str(And(constraint1, constraint2)) - == "- ``a`` is one of [:class:`float32`]\n- ``b`` is one of [:class:`int32`]" + == "- ``a`` is one of [:class:`float32`], **and**\n- ``b`` is one of [:class:`int32`]" ) def test_or_constraint(self): input_a = GetInput("a") or_constraint = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) - assert _doc_str(or_constraint) == "(``a`` == :class:`float32` or ``a`` == :class:`float16`)" + assert _doc_str(or_constraint) == "(``a`` == :class:`float32` *or* ``a`` == :class:`float16`)" def test_nested_constraints(self): input_a = GetInput("a") @@ -111,7 +111,7 @@ def test_nested_constraints(self): assert ( _doc_str(and_constraint) - == "- (``a`` == :class:`float32` or ``a`` == :class:`float16`)\n- ``b`` is one of [:class:`int32`]" + == "- (``a`` == :class:`float32` *or* ``a`` == :class:`float16`), **and**\n- ``b`` is one of [:class:`int32`]" ) def test_complex_real_world_constraint(self): @@ -124,10 +124,87 @@ def test_complex_real_world_constraint(self): assert ( _doc_str(and_constraint) - == "- ``input.dtype`` == ``other.dtype``\n- ``input.dtype`` is one of [:class:`float32`, :class:`float16`]" + == "- ``input.dtype`` == ``other.dtype``, **and**\n- ``input.dtype`` is one of [:class:`float32`, :class:`float16`]" ) +class TestFindKnownDatatypes: + def test_equal_dtypes_propagation(self): + tensor_a = tp.Tensor([1.0, 2.0]) + merged_args = [("a", tensor_a), ("b", 1.0)] + + input_requirements = And(Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b")))) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] == tp.float32 + assert result["b"] == tp.float32 + + def test_multiple_equal_dtypes_chain(self): + tensor_c = tp.Tensor([1.0, 2.0]) + merged_args = [("a", 0.0), ("b", 1.0), ("c", tensor_c)] + + input_requirements = And( + Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b"))), + Equal(GetDataType(GetInput("b")), GetDataType(GetInput("c"))), + ) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] == tp.float32 + assert result["b"] == tp.float32 + assert result["c"] == tp.float32 + + def test_chain_with_disjoint_sets(self): + # Make sure implementation applies transitive equality correctly when the constraints are such + # that it could form disjoint sets (in an incorrect implementation). + tensor_a = tp.Tensor([1.0, 2.0]) + merged_args = [("a", tensor_a), ("b", 1.0), ("c", 1), ("d", 1)] + + input_requirements = And( + Equal(GetDataType(GetInput("a")), GetDataType(GetInput("c"))), + Equal(GetDataType(GetInput("b")), GetDataType(GetInput("d"))), + Equal(GetDataType(GetInput("b")), GetDataType(GetInput("a"))), + ) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] is tp.float32 + assert result["b"] is tp.float32 + assert result["c"] is tp.float32 + assert result["d"] is tp.float32 + + def test_equal_to_constant_dtype(self): + merged_args = [("a", 1.0)] + + input_requirements = And(Equal(GetDataType(GetInput("a")), tp.float16)) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] == tp.float16 + + def test_multiple_separate_dtype_groups(self): + tensor_a = tp.Tensor([1.0, 2.0]) + tensor_c = tp.Tensor([1, 2]) + merged_args = [("a", tensor_a), ("b", 1.0), ("c", tensor_c), ("d", 1)] + + input_requirements = And( + Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b"))), + Equal(GetDataType(GetInput("c")), GetDataType(GetInput("d"))), + ) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] == tp.float32 + assert result["b"] == tp.float32 + assert result["c"] == tp.int32 + assert result["d"] == tp.int32 + + def test_unknown_dtypes(self): + merged_args = [("a", 1.0), ("b", 2.0)] + + input_requirements = And(Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b")))) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] is None + assert result["b"] is None + + class TestDtypes: def test_works_with_sequences(self): sequence_func([tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32)]) From 5b2b7c5c9b338054d4316dee90a7ba453cd3e070 Mon Sep 17 00:00:00 2001 From: pranavm Date: Mon, 3 Nov 2025 21:39:16 +0000 Subject: [PATCH 08/32] Improves source code retrieval in stack info --- tripy/nvtripy/utils/stack_info.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tripy/nvtripy/utils/stack_info.py b/tripy/nvtripy/utils/stack_info.py index f8557552e..fe033e707 100644 --- a/tripy/nvtripy/utils/stack_info.py +++ b/tripy/nvtripy/utils/stack_info.py @@ -17,6 +17,9 @@ # NOTE: We avoid using `inspect` functions as much as possible because they are much slower than # working directly with the frame. + +import linecache +import os import sys from dataclasses import dataclass from typing import Optional, Tuple @@ -53,14 +56,19 @@ def fetch_source_code(self): if self.code is not None: return - # Note that in some cases, e.g. when code is being provided via the interactive shell, we may not be able to retrieve it. - # In that case we just leave it empty. - try: - lines = open(self.file, "r").readlines() - except: - self.code = "" + # Handle cases where bytecode was compiled with a different path than the current filesystem: + file_path = self.file + module = sys.modules.get(self.module) + if module and hasattr(module, "__file__") and module.__file__: + file_path = module.__file__ + + line = linecache.getline(file_path, self.line) + if line: + self.code = line.rstrip() else: - self.code = lines[self.line - 1].rstrip() + # Note that in some cases, e.g. when code is being provided via the interactive shell, we may not be able to retrieve it. + # In that case we just leave it empty. + self.code = "" class StackInfo(list): @@ -138,7 +146,7 @@ def get_module_names_to_exclude_from_stack_info(): Returns a set of module names to exclude from stack information when displaying exceptions or trying to retrieve column information from code. """ - import nvtripy.utils.function_registry as function_registry import nvtripy.frontend.wrappers as wrappers + import nvtripy.utils.function_registry as function_registry return {mod.__name__ for mod in [function_registry, wrappers]} From bb8f8a7a59c66f44b0a9c4d86376a55a21056865 Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 4 Nov 2025 19:53:03 +0000 Subject: [PATCH 09/32] Adds a .devcontainer configuration for VS Code --- tripy/.devcontainer/devcontainer.json | 48 +++++++++++++++++++++++++++ tripy/CONTRIBUTING.md | 2 ++ tripy/Dockerfile | 13 ++++---- 3 files changed, 57 insertions(+), 6 deletions(-) create mode 100644 tripy/.devcontainer/devcontainer.json diff --git a/tripy/.devcontainer/devcontainer.json b/tripy/.devcontainer/devcontainer.json new file mode 100644 index 000000000..38ca2f581 --- /dev/null +++ b/tripy/.devcontainer/devcontainer.json @@ -0,0 +1,48 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/python +{ + "name": "Tripy", + "build": { + "context": "${localWorkspaceFolder}", + "dockerfile": "${localWorkspaceFolder}/Dockerfile", + "args": { + "username": "${localEnv:USER}" + } + }, + "workspaceMount": "source=${localWorkspaceFolder}/..,target=/workspaces/TensorRT-Incubator,type=bind,consistency=cached", + "workspaceFolder": "/workspaces/TensorRT-Incubator/tripy", + "runArgs": [ + "--gpus", + "all", + "-it", + "--cap-add=SYS_PTRACE" + ], + "mounts": [ + "source=${localEnv:HOME}${localEnv:USERPROFILE},target=/home/${localEnv:USER},type=bind,consistency=cached" + ], + "remoteEnv": { + "SHELL": "${localEnv:SHELL:/bin/bash}", + "ZSH": "/home/${localEnv:USER}/.oh-my-zsh" + }, + "remoteUser": "${localEnv:USER}", + "forwardPorts": [ + 8080 + ], + "customizations": { + "vscode": { + "extensions": [ + "ms-python.black-formatter", + "lextudio.restructuredtext", + "github.vscode-github-actions", + "ms-python.isort", + "ms-toolsai.jupyter", + "ms-toolsai.vscode-jupyter-cell-tags", + "ms-toolsai.jupyter-renderers", + "llvm-vs-code-extensions.vscode-mlir", + "ms-python.python", + "ms-python.vscode-pylance", + "eamodio.gitlens" + ] + } + } +} diff --git a/tripy/CONTRIBUTING.md b/tripy/CONTRIBUTING.md index 85490044f..790a1fdbe 100644 --- a/tripy/CONTRIBUTING.md +++ b/tripy/CONTRIBUTING.md @@ -37,6 +37,8 @@ Thanks for your interest in contributing to Tripy! docker run --gpus all -it --cap-add=SYS_PTRACE -p 8080:8080 -v $(pwd):/tripy/ --rm tripy:latest ``` + - If you are using Visual Studio Code, you can alternatively use the included `.devcontainer` configuration. + 3. **[Optional]** Run a sanity check in the container: ```bash diff --git a/tripy/Dockerfile b/tripy/Dockerfile index cd9a796a4..8151fbbd4 100644 --- a/tripy/Dockerfile +++ b/tripy/Dockerfile @@ -9,16 +9,17 @@ ENTRYPOINT ["/bin/bash"] # Setup user account ARG uid=1000 ARG gid=1000 +ARG username=trtuser ENV DEBIAN_FRONTEND=noninteractive -RUN groupadd -r -f -g ${gid} trtuser && \ - useradd -o -r -l -u ${uid} -g ${gid} -ms /bin/bash trtuser && \ - usermod -aG sudo trtuser && \ - echo 'trtuser:nvidia' | chpasswd && \ - mkdir -p /workspace && chown trtuser /workspace +RUN groupadd -r -f -g ${gid} ${username} && \ + useradd -o -r -l -u ${uid} -g ${gid} -ms /bin/bash ${username} && \ + usermod -aG sudo ${username} && \ + echo "${username}:nvidia" | chpasswd && \ + mkdir -p /workspace && chown ${username} /workspace RUN apt-get update && \ - apt-get install -y sudo python3 python3-pip gdb git wget curl && \ + apt-get install -y sudo python3 python3-pip gdb git wget curl zsh && \ apt-get clean && \ python3 -m pip install --upgrade pip From 48c5754f7b842caf31a93086c45909e472b34bd4 Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 5 Nov 2025 17:47:53 +0000 Subject: [PATCH 10/32] Adds automated operator constraints tests for data type constraints --- tripy/nvtripy/frontend/wrappers.py | 25 +- .../tests/frontend/wrappers/test_wrappers.py | 11 +- .../integration/test_operator_constraints.py | 247 ++++++++++++++++++ 3 files changed, 277 insertions(+), 6 deletions(-) create mode 100644 tripy/tests/integration/test_operator_constraints.py diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index e9c7307e4..2a695e2fb 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -41,6 +41,17 @@ class DataTypeConstraints: RETURN_VALUE = "__RETURN_VALUE" +@dataclass +class OperatorConstraints: + func: Callable + input_requirements: Constraints + output_guarantees: Constraints + + +# A list of tuples of (input_requirements, output_guarantees) for operators. +OPERATOR_CONSTRAINTS: List[OperatorConstraints] = [] + + # Try to include correct column offsets for non-tensor arguments. def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name): from nvtripy.frontend.tensor import Tensor @@ -320,6 +331,7 @@ def _doc_str(obj: Any) -> str: if isinstance(obj, tp_dtype): return f":class:`{obj.name}`" + # TODO (pranavm): Move these into their respective classes, make doc_str an export of the constraints module. if isinstance(obj, GetInput): return f"``{obj.name}``" elif isinstance(obj, GetReturn): @@ -428,6 +440,7 @@ def sorted_types(dtypes): def interface( # TODO (pranavm): These should be required arguments eventually. + # TODO (pranavm): Document requirements/guarantees. input_requirements: Constraints = None, output_guarantees: Constraints = None, dtype_constraints: Dict[str, str] = {}, @@ -496,15 +509,17 @@ def decorator(func): ) shape_likes = {name for name, param in signature.parameters.items() if param.annotation is ShapeLike} - # if no dtype constraints have been specified at all, do not add to the table so we don't generate invalid tests - if dtype_constraints or dtype_variables or dtype_exceptions: + # TODO (pranavm): Constraints should never be None eventually. + if input_requirements is not None: + OPERATOR_CONSTRAINTS.append(OperatorConstraints(func, input_requirements, output_guarantees)) + + _update_docstring(func, input_requirements, output_guarantees) + elif dtype_constraints or dtype_variables or dtype_exceptions: + # if no dtype constraints have been specified at all, do not add to the table so we don't generate invalid tests DATA_TYPE_CONSTRAINTS.append( DataTypeConstraints(func, dtype_constraints, dtype_variables, dtype_exceptions) ) - if input_requirements is not None: - _update_docstring(func, input_requirements, output_guarantees) - elif dtype_constraints or dtype_variables or dtype_exceptions: _update_docstring_legacy(func, dtype_constraints, dtype_variables, dtype_exceptions) @functools.wraps(func) diff --git a/tripy/tests/frontend/wrappers/test_wrappers.py b/tripy/tests/frontend/wrappers/test_wrappers.py index 5839842fe..811e2f83a 100644 --- a/tripy/tests/frontend/wrappers/test_wrappers.py +++ b/tripy/tests/frontend/wrappers/test_wrappers.py @@ -25,7 +25,7 @@ from nvtripy.frontend import wrappers from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn from nvtripy.frontend.constraints.logic import And, Equal, NotEqual, NotOneOf, OneOf, Or -from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS, _doc_str, _find_known_datatypes +from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS, _doc_str, _find_known_datatypes, OPERATOR_CONSTRAINTS from tests import helper # Get all functions/methods which have tensors in the type signature @@ -48,6 +48,7 @@ api.qualname + f".{func.__name__}" if func.__name__ not in api.qualname else "" ) +# TODO (pranavm): Remove old dtype constraints system: DATA_TYPE_CONSTRAINTS_FUNC_NAMES = {dtc.func.__qualname__ for dtc in DATA_TYPE_CONSTRAINTS} @@ -56,6 +57,14 @@ def test_all_public_apis_verified(api): assert api.__qualname__ in DATA_TYPE_CONSTRAINTS_FUNC_NAMES, f"Missing datatype constraints for: {api.__qualname__}" +OPERATOR_CONSTRAINTS_FUNC_NAMES = {oc.func.__qualname__ for oc in OPERATOR_CONSTRAINTS} + + +@pytest.mark.parametrize("api", PUBLIC_API_TENSOR_FUNCTIONS, ids=PUBLIC_API_TENSOR_FUNCTION_NAMES) +def test_all_public_apis_verified(api): + assert api.__qualname__ in OPERATOR_CONSTRAINTS_FUNC_NAMES, f"Missing operator constraints for: {api.__qualname__}" + + @wrappers.interface(dtype_constraints={"tensors": "T1"}, dtype_variables={"T1": ["float32"]}) def sequence_func(tensors: List[tp.Tensor]): return diff --git a/tripy/tests/integration/test_operator_constraints.py b/tripy/tests/integration/test_operator_constraints.py new file mode 100644 index 000000000..271549766 --- /dev/null +++ b/tripy/tests/integration/test_operator_constraints.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for operator constraints. + +This test suite validates that Tripy's operator constraints correctly predict what the underlying +software stack (MLIR-TRT/TensorRT) will accept or reject. For each operator with constraints, we: + +1. Generate all possible combinations of data types for parameters that accept Tensors or dtypes +2. Use Tripy's input_requirements to predict whether each combination should be valid +3. Call the operator with Tripy's validation disabled, so errors come from the underlying stack +4. For valid combinations: verify the outputs match the output_guarantees +5. For invalid combinations: ensure the underlying stack raises an exception during evaluation + +This ensures Tripy's constraint system stays in sync with the underlying implementation. +""" + +import contextlib +import inspect +import itertools +from dataclasses import dataclass +from typing import Any, Callable, Dict, List + +import numpy as np +import pytest + +import nvtripy as tp +from nvtripy.common.datatype import DATA_TYPES +from nvtripy.frontend.wrappers import OPERATOR_CONSTRAINTS +from nvtripy.utils.types import str_from_type_annotation +from nvtripy.utils.utils import make_list +from tests import helper +from tests.conftest import skip_if_older_than_sm89 + + +@dataclass +class OperatorConstraintCase: + func: Callable + dtype_values: Dict[str, tp.dtype] # The specific dtype combination for this test + + def __str__(self): + param_str = "_".join(f"{key}-{val}" for key, val in self.dtype_values.items()) + return f"{self.func.__name__}-{param_str}" + + +def get_dtype_constrained_params(func: Callable) -> List[str]: + sig = inspect.signature(func) + return [ + param_name + for param_name, param in sig.parameters.items() + if param.annotation is not inspect.Parameter.empty + and (type_str := str_from_type_annotation(param.annotation)) + and ("Tensor" in type_str or "nvtripy.dtype" in type_str) + ] + + +def generate_test_cases() -> List[OperatorConstraintCase]: + cases = [] + + for op_constraint in OPERATOR_CONSTRAINTS: + func = op_constraint.func + dtype_params = get_dtype_constrained_params(func) + + if not dtype_params: + continue + + all_dtypes = list(DATA_TYPES.values()) + + for dtype_combination in itertools.product(all_dtypes, repeat=len(dtype_params)): + dtype_values = dict(zip(dtype_params, dtype_combination)) + + # Skip FP8 on older hardware + marks = [skip_if_older_than_sm89()] if any(dtype == tp.float8 for dtype in dtype_values.values()) else [] + + cases.append(pytest.param(OperatorConstraintCase(func, dtype_values), marks=marks)) + + # Sort for deterministic ordering + cases.sort(key=lambda case: str(case.values[0])) + return cases + + +CUSTOM_VALUES = { + "__getitem__": {"index": 0}, + "arange": {"start": 0, "stop": 10, "step": 1}, + "avgpool": {"kernel_dims": [2, 2]}, + "convolution": { + "weight": tp.Tensor(np.ones((2, 2, 3, 3), dtype=np.float32)), + "bias": tp.Tensor([1.0, 2.0]), + "padding": ((0, 0), (0, 0)), + "stride": [1, 1], + "groups": 1, + "dilation": [1, 1], + }, + "copy": { + "input": tp.ones((2, 2)), + "device": tp.device("cpu"), + }, + "deconvolution": { + "weight": tp.Tensor(np.ones((2, 2, 3, 3), dtype=np.float32)), + "bias": tp.Tensor([1.0, 2.0]), + "padding": ((0, 0), (0, 0)), + "stride": [1, 1], + "groups": 1, + "dilation": [1, 1], + }, + "dequantize": {"scale": tp.Tensor([1.0, 2.0]), "dim": 1}, + "expand": {"sizes": tp.Tensor((2, 2, 5, 5))}, + "full": {"shape": tp.Tensor([2, 2]), "value": tp.Tensor(1.0)}, + "full_like": {"value": tp.Tensor(1.0)}, + "gather": {"index": tp.Tensor([1])}, + "instancenorm": { + "num_channels": 2, + "weight": tp.Tensor(np.ones((2,), dtype=np.float32)), + "bias": tp.Tensor(np.zeros((2,), dtype=np.float32)), + }, + "iota": {"shape": tp.Tensor([2, 2])}, + "maxpool": {"kernel_dims": [2, 2]}, + "ones": {"shape": [2, 2]}, + "outer": {"vec1": tp.Tensor([1, 2, 3]), "vec2": tp.Tensor([1, 2, 3])}, + "pad": {"pad": [(0, 1), (1, 0), (1, 1), (0, 0)]}, + "permute": {"perm": [1, 0, 3, 2]}, + "quantize": {"scale": tp.Tensor([1.0, 2.0]), "dim": 1}, + "repeat": {"repeats": 2, "dim": 0}, + "reshape": {"shape": tp.Tensor([2, 25])}, + "resize": { + "mode": "nearest", + "output_shape": tp.Tensor((1, 2, 10, 10)), + "scales": [1, 1, 2, 2], + }, + "squeeze": {"dims": 0}, + "transpose": {"dim0": 0, "dim1": 1}, + "zeros": {"shape": [2, 2]}, +} + +# Arguments that must be constants on CPU +REQUIRES_CPU_CONST = { + "dequantize": {"scale"}, + "quantize": {"scale"}, +} + +# Some operations require input shapes to be known +REQUIRES_KNOWN_SHAPES = { + "convolution": {"input", "weight", "bias"}, + "deconvolution": {"input", "weight", "bias"}, +} + + +def _apply_tensor_adjustments(tensor: tp.Tensor, func_name: str, param_name: str) -> tp.Tensor: + if func_name in REQUIRES_CPU_CONST and param_name in REQUIRES_CPU_CONST[func_name]: + if tensor.device.kind != "cpu": + tensor = tp.copy(tensor, device=tp.device("cpu")) + + if func_name in REQUIRES_KNOWN_SHAPES and param_name in REQUIRES_KNOWN_SHAPES[func_name]: + if any(dim == tp.constants.DYNAMIC_DIM for dim in tensor.trace_tensor.shape): + tensor.trace_tensor.shape = tuple(map(int, tensor.shape)) + + return tensor + + +def generate_input_values(case: OperatorConstraintCase) -> Dict[str, Any]: + if tp.int4 in case.dtype_values.values(): + pytest.skip(f"#579: Cannot generate INT4 inputs") + + inputs = {} + sig = inspect.signature(case.func) + func_name = case.func.__name__ + + for param_name, param in sig.parameters.items(): + dtype = case.dtype_values.get(param_name) + param_type = str_from_type_annotation(param.annotation) + + # Handle custom values first + if func_name in CUSTOM_VALUES and param_name in CUSTOM_VALUES[func_name]: + inputs[param_name] = CUSTOM_VALUES[func_name][param_name] + if isinstance(inputs[param_name], tp.Tensor) and dtype is not None: + inputs[param_name] = tp.cast(inputs[param_name], dtype=dtype) + inputs[param_name] = _apply_tensor_adjustments(inputs[param_name], func_name, param_name) + continue + + # Skip optional parameters unless they need a specific dtype + if param.default is not inspect.Parameter.empty and dtype is None: + continue + + # Generate values based on parameter type + if "Tensor" in param_type: + assert dtype is not None, f"Tensor parameter '{param_name}' must have a dtype constraint" + base_tensor = tp.cast(tp.Tensor(np.ones((1, 2, 5, 5), dtype=np.float32)), dtype=dtype) + + if "Sequence" in param_type or "List" in param_type: + inputs[param_name] = [_apply_tensor_adjustments(base_tensor, func_name, param_name) for _ in range(2)] + else: + inputs[param_name] = _apply_tensor_adjustments(base_tensor, func_name, param_name) + elif "nvtripy.dtype" in param_type: + assert dtype is not None, f"dtype parameter '{param_name}' must have a dtype constraint" + inputs[param_name] = dtype + elif "numbers.Number" in param_type or "int" in param_type or "float" in param_type: + inputs[param_name] = 1 + + return inputs + + +OPERATOR_CONSTRAINT_CASES = generate_test_cases() + + +@pytest.mark.parametrize("case", OPERATOR_CONSTRAINT_CASES, ids=lambda case: str(case)) +def test_operator_constraints(case: OperatorConstraintCase): + op_constraint = next((oc for oc in OPERATOR_CONSTRAINTS if oc.func == case.func), None) + assert op_constraint is not None, f"Could not find constraints for {case.func.__name__}" + + # If input validation is enabled, negative tests will trivially pass (we will throw an + # error before even trying to call the underlying implementation). + with helper.config("enable_input_validation", False): + inputs = generate_input_values(case) + merged_args = list(inputs.items()) + + is_valid = bool(op_constraint.input_requirements(merged_args)) + + with contextlib.ExitStack() as stack: + if not is_valid: + stack.enter_context(helper.raises(Exception)) + + outputs = make_list(case.func(**inputs)) + + for out in outputs: + if isinstance(out, tp.Tensor): + # Avoid evaluating CPU constants since some types (e.g. FP8) don't allow them to be used outside of + # certain operations. + out._eval_for_internal_methods() + + if is_valid: + output_result = op_constraint.output_guarantees(merged_args, tuple(outputs)) + assert output_result, f"Output guarantees not met for {case.func.__name__}: " + " ".join( + output_result.error_details + ) From b36aadbd12450118c0585fee7c18f480e51731bb Mon Sep 17 00:00:00 2001 From: pranavm Date: Thu, 6 Nov 2025 19:27:33 +0000 Subject: [PATCH 11/32] Refactors `doc_str`, adds field for additional constraint information - Refactors `doc_str` such that the Constraints classes provide implementations. The free function is primarily used for handling non-Constraints types and performing the dispatch to `Constraints.doc_str()`. - Adds an `info()` method to the `Constraints` class that allows setting additional human-readable information about a constraint. --- .../nvtripy/frontend/constraints/__init__.py | 1 + tripy/nvtripy/frontend/constraints/base.py | 22 ++++- tripy/nvtripy/frontend/constraints/doc_str.py | 35 ++++++++ tripy/nvtripy/frontend/constraints/fetcher.py | 13 +++ tripy/nvtripy/frontend/constraints/logic.py | 26 ++++++ tripy/nvtripy/frontend/wrappers.py | 41 +-------- tripy/tests/frontend/constraints/test_base.py | 9 ++ .../frontend/constraints/test_doc_str.py | 87 +++++++++++++++++++ .../tests/frontend/wrappers/test_wrappers.py | 69 +-------------- 9 files changed, 195 insertions(+), 108 deletions(-) create mode 100644 tripy/nvtripy/frontend/constraints/doc_str.py create mode 100644 tripy/tests/frontend/constraints/test_doc_str.py diff --git a/tripy/nvtripy/frontend/constraints/__init__.py b/tripy/nvtripy/frontend/constraints/__init__.py index 2889b36bb..e41d1483e 100644 --- a/tripy/nvtripy/frontend/constraints/__init__.py +++ b/tripy/nvtripy/frontend/constraints/__init__.py @@ -15,5 +15,6 @@ # limitations under the License. # from nvtripy.frontend.constraints.base import Constraints +from nvtripy.frontend.constraints.doc_str import doc_str from nvtripy.frontend.constraints.fetcher import Fetcher, GetDataType, GetInput, GetReturn, ValueFetcher from nvtripy.frontend.constraints.logic import And, Equal, Logic, NotEqual, NotOneOf, OneOf, Or diff --git a/tripy/nvtripy/frontend/constraints/base.py b/tripy/nvtripy/frontend/constraints/base.py index ffecad43b..360c8fbea 100644 --- a/tripy/nvtripy/frontend/constraints/base.py +++ b/tripy/nvtripy/frontend/constraints/base.py @@ -45,8 +45,8 @@ ``` """ -from abc import ABC -from typing import List +from abc import ABC, abstractmethod +from typing import List, Optional class Constraints(ABC): @@ -54,6 +54,16 @@ class Constraints(ABC): Base class for the entire constraints system. """ + def __init__(self): + self._info: Optional[str] = None + + @abstractmethod + def doc_str(self) -> str: + """ + Returns a string representation for use in documentation. + """ + ... + def get_children(self) -> List["Constraints"]: children = [] for attr_value in vars(self).values(): @@ -128,3 +138,11 @@ def matches_pattern(pattern: Constraints, constraint: Constraints) -> bool: matches.extend(child.find(pattern)) return matches + + def info(self, message: str) -> "Constraints": + """ + Sets additional information about this constraint. + For example, this might express the constraint in natural language or explain why it exists. + """ + self._info = message + return self diff --git a/tripy/nvtripy/frontend/constraints/doc_str.py b/tripy/nvtripy/frontend/constraints/doc_str.py new file mode 100644 index 000000000..d2a253b17 --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/doc_str.py @@ -0,0 +1,35 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any + +from nvtripy.frontend.constraints.base import Constraints + + +def doc_str(obj: Any) -> str: + """ + Returns a string representation of an object for use in the public documentation. + """ + from nvtripy.common.datatype import dtype as tp_dtype + + if isinstance(obj, tp_dtype): + return f":class:`{obj.name}`" + + if isinstance(obj, Constraints): + return obj.doc_str() + + assert False, f"Unsupported object type for doc string generation: {type(obj)}. Please add handling here!" diff --git a/tripy/nvtripy/frontend/constraints/fetcher.py b/tripy/nvtripy/frontend/constraints/fetcher.py index 4987426c8..713bd0649 100644 --- a/tripy/nvtripy/frontend/constraints/fetcher.py +++ b/tripy/nvtripy/frontend/constraints/fetcher.py @@ -49,6 +49,7 @@ def dtype(self) -> "GetDataType": class GetInput(ValueFetcher): def __init__(self, name: str): + super().__init__() self.name = name def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: @@ -60,9 +61,13 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return self.name + def doc_str(self) -> str: + return f"``{self.name}``" + class GetReturn(ValueFetcher): def __init__(self, index: int): + super().__init__() self.index = index def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: @@ -72,9 +77,13 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return f"return[{self.index}]" + def doc_str(self) -> str: + return f"``return[{self.index}]``" + class GetDataType(Fetcher): def __init__(self, value_fetcher: ValueFetcher): + super().__init__() self.value_fetcher = value_fetcher def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: @@ -108,3 +117,7 @@ def get_arg_dtype(arg: Any) -> tp_dtype: def __str__(self): return f"{self.value_fetcher}.dtype" + + def doc_str(self) -> str: + # Intentionally do not use doc_str() on the value_fetcher so we can wrap it in backticks correctly. + return f"``{self.value_fetcher}.dtype``" diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py index de5e6982d..7eac7cdac 100644 --- a/tripy/nvtripy/frontend/constraints/logic.py +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -15,9 +15,11 @@ # limitations under the License. # from abc import abstractmethod +from textwrap import indent from typing import Any, List, Optional, Sequence, Tuple from nvtripy.frontend.constraints.base import Constraints +from nvtripy.frontend.constraints.doc_str import doc_str from nvtripy.frontend.constraints.fetcher import Fetcher from nvtripy.utils.result import Result @@ -58,6 +60,7 @@ def __invert__(self) -> "Logic": class OneOf(Logic): def __init__(self, fetcher: Fetcher, options: Sequence[Any]): + super().__init__() self.fetcher = fetcher # Need to convert generator expressions so we can use them more than once self.options = list(options) @@ -72,12 +75,16 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return f"{self.fetcher} is one of {self.options}" + def doc_str(self) -> str: + return f"{doc_str(self.fetcher)} is one of [{', '.join(f'{doc_str(opt)}' for opt in self.options)}]" + def inverse(self) -> "Logic": return NotOneOf(self.fetcher, self.options) class NotOneOf(Logic): def __init__(self, fetcher: Fetcher, options: Sequence[Any]): + super().__init__() self.fetcher = fetcher self.options = list(options) @@ -91,6 +98,9 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return f"{self.fetcher} is not one of {self.options}" + def doc_str(self) -> str: + return f"{doc_str(self.fetcher)} is not one of [{', '.join(f'{doc_str(opt)}' for opt in self.options)}]" + def inverse(self) -> "Logic": return OneOf(self.fetcher, self.options) @@ -105,6 +115,7 @@ def get_val_or_call_fetcher( class Equal(Logic): def __init__(self, fetcher: Fetcher, fetcher_or_value: Any): + super().__init__() self.fetcher = fetcher self.fetcher_or_value = fetcher_or_value @@ -120,12 +131,16 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return f"{self.fetcher} == {self.fetcher_or_value}" + def doc_str(self) -> str: + return f"{doc_str(self.fetcher)} == {doc_str(self.fetcher_or_value)}" + def inverse(self) -> "Logic": return NotEqual(self.fetcher, self.fetcher_or_value) class NotEqual(Logic): def __init__(self, fetcher: Fetcher, fetcher_or_value: Fetcher): + super().__init__() self.fetcher = fetcher self.fetcher_or_value = fetcher_or_value @@ -140,12 +155,16 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return f"{self.fetcher} != {self.fetcher_or_value}" + def doc_str(self) -> str: + return f"{doc_str(self.fetcher)} != {doc_str(self.fetcher_or_value)}" + def inverse(self) -> "Logic": return Equal(self.fetcher, self.fetcher_or_value) class And(Logic): def __init__(self, *constraints: Logic): + super().__init__() self.constraints = constraints def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: @@ -161,6 +180,9 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return "(" + " and ".join(str(constraint) for constraint in self.constraints) + ")" + def doc_str(self) -> str: + return ", **and**\n".join("- " + indent(doc_str(constraint), " ").lstrip() for constraint in self.constraints) + def inverse(self) -> "Logic": # De Morgan's law: not (A and B) = (not A) or (not B) return Or(*(constraint.inverse() for constraint in self.constraints)) @@ -168,6 +190,7 @@ def inverse(self) -> "Logic": class Or(Logic): def __init__(self, *constraints: Logic): + super().__init__() self.constraints = constraints def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: @@ -182,6 +205,9 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = def __str__(self): return "(" + " or ".join(str(constraint) for constraint in self.constraints) + ")" + def doc_str(self) -> str: + return "(" + " *or* ".join(doc_str(constraint) for constraint in self.constraints) + ")" + def inverse(self) -> "Logic": # De Morgan's law: not (A or B) = (not A) and (not B) return And(*(constraint.inverse() for constraint in self.constraints)) diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index 2a695e2fb..b491b3c88 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -25,7 +25,7 @@ from nvtripy import config, utils from nvtripy.common.datatype import DATA_TYPES from nvtripy.common.exception import raise_error -from nvtripy.frontend.constraints import Constraints, Equal, GetInput, GetDataType, Fetcher +from nvtripy.frontend.constraints import Constraints, Equal, GetInput, GetDataType, Fetcher, doc_str from nvtripy.utils import result @@ -320,41 +320,6 @@ def add_arg(arg): return new_args, new_kwargs, new_merged_args -def _doc_str(obj: Any) -> str: - """ - Returns a string representation of an object for use in the public documentation. - """ - from nvtripy.common.datatype import dtype as tp_dtype - from nvtripy.frontend.constraints.logic import And, Equal, NotEqual, NotOneOf, OneOf, Or - from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn - - if isinstance(obj, tp_dtype): - return f":class:`{obj.name}`" - - # TODO (pranavm): Move these into their respective classes, make doc_str an export of the constraints module. - if isinstance(obj, GetInput): - return f"``{obj.name}``" - elif isinstance(obj, GetReturn): - return f"``return[{obj.index}]``" - elif isinstance(obj, GetDataType): - # Intentionally do not use _doc_str() on the value_fetcher so we can wrap it in backticks correctly. - return f"``{obj.value_fetcher}.dtype``" - elif isinstance(obj, OneOf): - return f"{_doc_str(obj.fetcher)} is one of [{', '.join(f'{_doc_str(opt)}' for opt in obj.options)}]" - elif isinstance(obj, NotOneOf): - return f"{_doc_str(obj.fetcher)} is not one of [{', '.join(f'{_doc_str(opt)}' for opt in obj.options)}]" - elif isinstance(obj, Equal): - return f"{_doc_str(obj.fetcher)} == {_doc_str(obj.fetcher_or_value)}" - elif isinstance(obj, NotEqual): - return f"{_doc_str(obj.fetcher)} != {_doc_str(obj.fetcher_or_value)}" - elif isinstance(obj, And): - return ", **and**\n".join("- " + indent(_doc_str(constraint), " ").lstrip() for constraint in obj.constraints) - elif isinstance(obj, Or): - return "(" + " *or* ".join(_doc_str(constraint) for constraint in obj.constraints) + ")" - - assert False, f"Unsupported object type for doc string generation: {type(obj)}. Please add handling here!" - - # Modify the docstring to include constraints def _update_docstring(func, input_requirements, output_guarantees): if not func.__doc__: @@ -364,8 +329,8 @@ def _update_docstring(func, input_requirements, output_guarantees): code_block_index = func.__doc__.find(".. code-block:: python") assert code_block_index != -1, f"No code example in docstring for {func.__name__}" - input_requirements_str = f"\nINPUT REQUIREMENTS:\n{indent(_doc_str(input_requirements), indentation)}\n" - output_guarantees_str = f"\nOUTPUT GUARANTEES:\n{indent(_doc_str(output_guarantees), indentation)}\n" + input_requirements_str = f"\nINPUT REQUIREMENTS:\n{indent(doc_str(input_requirements), indentation)}\n" + output_guarantees_str = f"\nOUTPUT GUARANTEES:\n{indent(doc_str(output_guarantees), indentation)}\n" func.__doc__ = ( func.__doc__[:code_block_index] diff --git a/tripy/tests/frontend/constraints/test_base.py b/tripy/tests/frontend/constraints/test_base.py index ac3b03cb1..3bf12ffce 100644 --- a/tripy/tests/frontend/constraints/test_base.py +++ b/tripy/tests/frontend/constraints/test_base.py @@ -114,3 +114,12 @@ def test_find_with_none_wildcard_matches_different_types(self): matches = constraint.find(pattern) assert len(matches) == 6 assert constraint in matches + + def test_info_method(self): + constraint = Equal(GetInput("a"), GetInput("b")) + assert constraint._info is None + + result = constraint.info("This checks that a equals b") + + assert constraint._info == "This checks that a equals b" + assert result is constraint # Test method chaining diff --git a/tripy/tests/frontend/constraints/test_doc_str.py b/tripy/tests/frontend/constraints/test_doc_str.py new file mode 100644 index 000000000..907ef6268 --- /dev/null +++ b/tripy/tests/frontend/constraints/test_doc_str.py @@ -0,0 +1,87 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import nvtripy as tp +from nvtripy.frontend.constraints import And, Equal, NotEqual, NotOneOf, OneOf, Or, doc_str +from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn + + +class TestDocStr: + def test_basic_types(self): + assert doc_str(tp.float32) == ":class:`float32`" + assert doc_str(GetInput("x")) == "``x``" + assert doc_str(GetReturn(0)) == "``return[0]``" + + def test_get_datatype(self): + assert doc_str(GetDataType(GetInput("x"))) == "``x.dtype``" + assert doc_str(GetDataType(GetReturn(0))) == "``return[0].dtype``" + + def test_one_of_and_not_one_of(self): + input_x = GetInput("x") + + assert ( + doc_str(OneOf(input_x, [tp.float32, tp.float16])) == "``x`` is one of [:class:`float32`, :class:`float16`]" + ) + assert doc_str(NotOneOf(input_x, [tp.int8, tp.int32])) == "``x`` is not one of [:class:`int8`, :class:`int32`]" + + def test_equal_and_not_equal(self): + input_a = GetInput("a") + input_b = GetInput("b") + + assert doc_str(Equal(input_a, input_b)) == "``a`` == ``b``" + assert doc_str(Equal(input_a, tp.float32)) == "``a`` == :class:`float32`" + assert doc_str(NotEqual(input_a, input_b)) == "``a`` != ``b``" + + def test_and_constraint(self): + constraint1 = OneOf(GetInput("a"), [tp.float32]) + constraint2 = OneOf(GetInput("b"), [tp.int32]) + + assert ( + doc_str(And(constraint1, constraint2)) + == "- ``a`` is one of [:class:`float32`], **and**\n- ``b`` is one of [:class:`int32`]" + ) + + def test_or_constraint(self): + input_a = GetInput("a") + or_constraint = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) + + assert doc_str(or_constraint) == "(``a`` == :class:`float32` *or* ``a`` == :class:`float16`)" + + def test_nested_constraints(self): + input_a = GetInput("a") + input_b = GetInput("b") + + or_part = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) + and_constraint = And(or_part, OneOf(input_b, [tp.int32])) + + assert ( + doc_str(and_constraint) + == "- (``a`` == :class:`float32` *or* ``a`` == :class:`float16`), **and**\n- ``b`` is one of [:class:`int32`]" + ) + + def test_complex_real_world_constraint(self): + input_a = GetInput("input") + input_b = GetInput("other") + dtype_a = GetDataType(input_a) + dtype_b = GetDataType(input_b) + + and_constraint = And(Equal(dtype_a, dtype_b), OneOf(dtype_a, [tp.float32, tp.float16])) + + assert ( + doc_str(and_constraint) + == "- ``input.dtype`` == ``other.dtype``, **and**\n- ``input.dtype`` is one of [:class:`float32`, :class:`float16`]" + ) diff --git a/tripy/tests/frontend/wrappers/test_wrappers.py b/tripy/tests/frontend/wrappers/test_wrappers.py index 811e2f83a..1aa26282f 100644 --- a/tripy/tests/frontend/wrappers/test_wrappers.py +++ b/tripy/tests/frontend/wrappers/test_wrappers.py @@ -25,7 +25,7 @@ from nvtripy.frontend import wrappers from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn from nvtripy.frontend.constraints.logic import And, Equal, NotEqual, NotOneOf, OneOf, Or -from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS, _doc_str, _find_known_datatypes, OPERATOR_CONSTRAINTS +from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS, _find_known_datatypes, OPERATOR_CONSTRAINTS from tests import helper # Get all functions/methods which have tensors in the type signature @@ -70,73 +70,6 @@ def sequence_func(tensors: List[tp.Tensor]): return -class TestDocStr: - def test_basic_types(self): - assert _doc_str(tp.float32) == ":class:`float32`" - assert _doc_str(GetInput("x")) == "``x``" - assert _doc_str(GetReturn(0)) == "``return[0]``" - - def test_get_datatype(self): - assert _doc_str(GetDataType(GetInput("x"))) == "``x.dtype``" - assert _doc_str(GetDataType(GetReturn(0))) == "``return[0].dtype``" - - def test_one_of_and_not_one_of(self): - input_x = GetInput("x") - - assert ( - _doc_str(OneOf(input_x, [tp.float32, tp.float16])) == "``x`` is one of [:class:`float32`, :class:`float16`]" - ) - assert _doc_str(NotOneOf(input_x, [tp.int8, tp.int32])) == "``x`` is not one of [:class:`int8`, :class:`int32`]" - - def test_equal_and_not_equal(self): - input_a = GetInput("a") - input_b = GetInput("b") - - assert _doc_str(Equal(input_a, input_b)) == "``a`` == ``b``" - assert _doc_str(Equal(input_a, tp.float32)) == "``a`` == :class:`float32`" - assert _doc_str(NotEqual(input_a, input_b)) == "``a`` != ``b``" - - def test_and_constraint(self): - constraint1 = OneOf(GetInput("a"), [tp.float32]) - constraint2 = OneOf(GetInput("b"), [tp.int32]) - - assert ( - _doc_str(And(constraint1, constraint2)) - == "- ``a`` is one of [:class:`float32`], **and**\n- ``b`` is one of [:class:`int32`]" - ) - - def test_or_constraint(self): - input_a = GetInput("a") - or_constraint = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) - - assert _doc_str(or_constraint) == "(``a`` == :class:`float32` *or* ``a`` == :class:`float16`)" - - def test_nested_constraints(self): - input_a = GetInput("a") - input_b = GetInput("b") - - or_part = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) - and_constraint = And(or_part, OneOf(input_b, [tp.int32])) - - assert ( - _doc_str(and_constraint) - == "- (``a`` == :class:`float32` *or* ``a`` == :class:`float16`), **and**\n- ``b`` is one of [:class:`int32`]" - ) - - def test_complex_real_world_constraint(self): - input_a = GetInput("input") - input_b = GetInput("other") - dtype_a = GetDataType(input_a) - dtype_b = GetDataType(input_b) - - and_constraint = And(Equal(dtype_a, dtype_b), OneOf(dtype_a, [tp.float32, tp.float16])) - - assert ( - _doc_str(and_constraint) - == "- ``input.dtype`` == ``other.dtype``, **and**\n- ``input.dtype`` is one of [:class:`float32`, :class:`float16`]" - ) - - class TestFindKnownDatatypes: def test_equal_dtypes_propagation(self): tensor_a = tp.Tensor([1.0, 2.0]) From 6d39f6bf961c2a9f344ebb5d38735ef8d8fe8cad Mon Sep 17 00:00:00 2001 From: pranavm Date: Thu, 13 Nov 2025 18:39:03 +0000 Subject: [PATCH 12/32] Adds a section on documentation philosophy, fixes path in devcontainer --- tripy/.devcontainer/devcontainer.json | 4 +- tripy/docs/README.md | 69 +++++++++++++++++++++++++++ tripy/tests/helper.py | 5 +- 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/tripy/.devcontainer/devcontainer.json b/tripy/.devcontainer/devcontainer.json index 38ca2f581..a847d8c5e 100644 --- a/tripy/.devcontainer/devcontainer.json +++ b/tripy/.devcontainer/devcontainer.json @@ -22,7 +22,9 @@ ], "remoteEnv": { "SHELL": "${localEnv:SHELL:/bin/bash}", - "ZSH": "/home/${localEnv:USER}/.oh-my-zsh" + "ZSH": "/home/${localEnv:USER}/.oh-my-zsh", + "PYTHONPATH": "/workspaces/TensorRT-Incubator/tripy:${localEnv:PYTHONPATH}", + "PATH": "/usr/local/bin/:${localEnv:PATH}" }, "remoteUser": "${localEnv:USER}", "forwardPorts": [ diff --git a/tripy/docs/README.md b/tripy/docs/README.md index 612f0f326..b68e07794 100644 --- a/tripy/docs/README.md +++ b/tripy/docs/README.md @@ -167,3 +167,72 @@ Code blocks in docstrings/guides are **preprocessed**: - **Include** only specific variables: `# doc: print-locals ...` - **Exclude** *specific* variables: `# doc: no-print-locals ...` - **Exclude** *all* variables: `# doc: no-print-locals` (with no arguments). + + +## Documentation Philosophy: Write Less Documentation + +> “I didn’t have time to write a short letter, so I wrote a long one instead.” - Mark Twain + +How much documentation do you want to read? The answer is **none**! + +- **Best Case:** Make docs unnecessary with an intuitive API and clear errors. + +This is not always possible; sometimes, we need to write docs. + +- **Problem:** We don’t think enough about *what* *precisely* we want to convey. + +- **Suggestions**: Write discoverable, concise, but complete documentation. + + - **Highlight key points** but make it easy to find details. + + - Use bullets and **bold** to break up monotony. + Paragraphs are *so* 2024. + + - Leverage the medium - pictures, charts, emojis, markup. + We are not using printing presses! + + - Forget the rules: use contractions, don’t spell out numbers, etc. + + - **Tip:** Write like we're paying for every syllable! + If it’s too wordy to say, it’s too wordy to write. + +Writing *concisely* forces you to think about what’s **signal** and what’s **noise**. + +Below are examples from previous versions of Tripy documentation that was improved. + +### Example 1 + +> One important point is that Tripy uses a lazy evaluation model; that is, no computation is performed until a value is actually needed. + +* **Tip:** Ask: “What is this really saying?” + +> Tensors are evaluated only when they’re used. + +### Example 2 + + +> ### **Eager Mode: How Does It Work?** +> +> If you’ve used TensorRT before, you may know that it does not support an eager mode. +> In order to provide eager mode support in Tripy, we actually need to compile the graph under the hood. +> +> Although we employ several tricks to make compile times faster when using eager mode, we do still need to compile, +> and so eager mode will likely be slower than other comparable frameworks. +> +> Consequently, we suggest that you use eager mode primarily for debugging and compiled mode for deployments. + +**Problem**: We must sift through filler to find key points. + +**Ask**: + +- **“What is the ONE most important takeaway?”** + *Eager mode is only for debugging.* + +- **“What questions does this raise?” - Why?** + *Tripy always compiles since TensorRT doesn’t have eager mode.* + +Make the **one key point** stand out so skimmers can spot it: + +> **Best Practice:** Use **eager mode** only for **debugging**; compile for deployment. +> +> **Why:** Eager mode internally compiles the graph (slow\!) since TensorRT doesn’t have eager execution. diff --git a/tripy/tests/helper.py b/tripy/tests/helper.py index 73733021a..a789e9394 100644 --- a/tripy/tests/helper.py +++ b/tripy/tests/helper.py @@ -275,6 +275,7 @@ def from_name(name: str) -> "Marker": # Marks an entire block as being expected to fail. "test: xfail": Marker.from_name("TEST: XFAIL"), # Marks that a block contains the expected output from the immediate previous block. + # Used by example tests to verify correctness. "test: expected_stdout": Marker.from_name("TEST: EXPECTED_STDOUT"), # Marks that a block should be run under pytest. "test: use_pytest": Marker.from_name("TEST: USE_PYTEST"), @@ -282,7 +283,9 @@ def from_name(name: str) -> "Marker": "example: if_fp8": Marker.from_name("EXAMPLE: IF_FP8"), # Indicates that a block should be omitted from the rendered documentation. Such blocks may still be evaluated. "doc: omit": Marker.from_name("DOC: OMIT"), - # Indicates that a block should not be evaluated for the documentation. + # Indicates that a block of Python code should not be evaluated OR formatted for the documentation. + # If you would like to format but not evaluate, add a `# doc: no-eval` tag to the code block instead. + # For non-Python code, formatting is not performed anyway, so you should always use this marker instead of `# doc: no-eval`. "doc: no_eval_or_format": Marker.from_name("DOC: NO_EVAL_OR_FORMAT"), # Indicates that local variables should not be displayed for a code block in the documentation. # Useful when the raw code block is also publicly visible and we don't want inline markers (e.g. in the main README.md). From 493491275007400c109bb44424442e8ab0a905fb Mon Sep 17 00:00:00 2001 From: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com> Date: Wed, 19 Nov 2025 09:45:36 -0800 Subject: [PATCH 13/32] Update tripy/docs/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com> --- tripy/docs/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tripy/docs/README.md b/tripy/docs/README.md index b68e07794..e56882580 100644 --- a/tripy/docs/README.md +++ b/tripy/docs/README.md @@ -235,4 +235,4 @@ Make the **one key point** stand out so skimmers can spot it: > **Best Practice:** Use **eager mode** only for **debugging**; compile for deployment. > -> **Why:** Eager mode internally compiles the graph (slow\!) since TensorRT doesn’t have eager execution. +> **Why:** Eager mode internally compiles the graph (slow!) since TensorRT doesn’t have eager execution. From def6e0e1c941214696c5c58b92a1754168b1dc2b Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 18 Nov 2025 19:21:13 +0000 Subject: [PATCH 14/32] Ports concatenate to new constraints sytem --- tripy/nvtripy/frontend/ops/cast.py | 14 -------------- tripy/nvtripy/frontend/ops/concatenate.py | 17 ++++++++--------- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/tripy/nvtripy/frontend/ops/cast.py b/tripy/nvtripy/frontend/ops/cast.py index ecc4e44a1..32a4a9994 100644 --- a/tripy/nvtripy/frontend/ops/cast.py +++ b/tripy/nvtripy/frontend/ops/cast.py @@ -36,20 +36,6 @@ & ((GetInput("input").dtype != dt.int4) | ~OneOf(GetInput("dtype"), [dt.float8, dt.int8, dt.int64])) ), output_guarantees=GetReturn(0).dtype == GetInput("dtype"), - # TODO (pranavm): Remove old dtype constraints system: - dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, - dtype_exceptions=[ - {"T1": "float8", "T2": "int4"}, - {"T1": "float8", "T2": "int8"}, - {"T1": "int8", "T2": "float8"}, - {"T1": "int4", "T2": "float8"}, - {"T1": "int4", "T2": "int8"}, - {"T1": "int4", "T2": "int64"}, - ], ) def cast(input: "nvtripy.Tensor", dtype: "nvtripy.dtype") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/concatenate.py b/tripy/nvtripy/frontend/ops/concatenate.py index 4f1826db6..730784745 100644 --- a/tripy/nvtripy/frontend/ops/concatenate.py +++ b/tripy/nvtripy/frontend/ops/concatenate.py @@ -18,22 +18,21 @@ from typing import Sequence from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.concatenate import Concatenate -from nvtripy.frontend import wrappers - - -# constraints = OneOf(GetInput("tensors").dtype, [tp.float32, tp.float16, tp.bfloat16, tp.float8, tp.int4, tp.int8, tp.int32, tp.int64, tp.bool]) -# output_guarantees = GetReturn(0).dtype == GetInput("tensors").dtype @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"tensors": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("tensors").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("tensors").dtype, ) def concatenate(tensors: Sequence["nvtripy.Tensor"], dim: int) -> "nvtripy.Tensor": r""" From 0c73514789d81f96a352582f2282f1b4e47d45ef Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 18 Nov 2025 19:22:20 +0000 Subject: [PATCH 15/32] Removes non-ASCII character in README --- tripy/docs/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tripy/docs/README.md b/tripy/docs/README.md index e56882580..b383eda0b 100644 --- a/tripy/docs/README.md +++ b/tripy/docs/README.md @@ -171,7 +171,7 @@ Code blocks in docstrings/guides are **preprocessed**: ## Documentation Philosophy: Write Less Documentation -> “I didn’t have time to write a short letter, so I wrote a long one instead.” - Mark Twain +> "I didn’t have time to write a short letter, so I wrote a long one instead." - Mark Twain How much documentation do you want to read? The answer is **none**! From 5d252bcf872c4d2b35bcfd4cc6c2371c9625c46e Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 19 Nov 2025 18:32:58 +0000 Subject: [PATCH 16/32] Adds an If constraint to gate other constraints --- .../nvtripy/frontend/constraints/__init__.py | 2 +- tripy/nvtripy/frontend/constraints/doc_str.py | 3 + tripy/nvtripy/frontend/constraints/logic.py | 32 ++++- .../frontend/constraints/test_doc_str.py | 53 ++------ .../frontend/constraints/test_fetcher.py | 12 +- .../tests/frontend/constraints/test_logic.py | 120 +++++++++++++++++- 6 files changed, 176 insertions(+), 46 deletions(-) diff --git a/tripy/nvtripy/frontend/constraints/__init__.py b/tripy/nvtripy/frontend/constraints/__init__.py index e41d1483e..8105d1c72 100644 --- a/tripy/nvtripy/frontend/constraints/__init__.py +++ b/tripy/nvtripy/frontend/constraints/__init__.py @@ -17,4 +17,4 @@ from nvtripy.frontend.constraints.base import Constraints from nvtripy.frontend.constraints.doc_str import doc_str from nvtripy.frontend.constraints.fetcher import Fetcher, GetDataType, GetInput, GetReturn, ValueFetcher -from nvtripy.frontend.constraints.logic import And, Equal, Logic, NotEqual, NotOneOf, OneOf, Or +from nvtripy.frontend.constraints.logic import And, Equal, If, Logic, NotEqual, NotOneOf, OneOf, Or diff --git a/tripy/nvtripy/frontend/constraints/doc_str.py b/tripy/nvtripy/frontend/constraints/doc_str.py index d2a253b17..fee352771 100644 --- a/tripy/nvtripy/frontend/constraints/doc_str.py +++ b/tripy/nvtripy/frontend/constraints/doc_str.py @@ -26,6 +26,9 @@ def doc_str(obj: Any) -> str: """ from nvtripy.common.datatype import dtype as tp_dtype + if obj is None: + return "``None``" + if isinstance(obj, tp_dtype): return f":class:`{obj.name}`" diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py index 7eac7cdac..8755a10d8 100644 --- a/tripy/nvtripy/frontend/constraints/logic.py +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -132,7 +132,7 @@ def __str__(self): return f"{self.fetcher} == {self.fetcher_or_value}" def doc_str(self) -> str: - return f"{doc_str(self.fetcher)} == {doc_str(self.fetcher_or_value)}" + return f"{doc_str(self.fetcher)} is {doc_str(self.fetcher_or_value)}" def inverse(self) -> "Logic": return NotEqual(self.fetcher, self.fetcher_or_value) @@ -156,7 +156,7 @@ def __str__(self): return f"{self.fetcher} != {self.fetcher_or_value}" def doc_str(self) -> str: - return f"{doc_str(self.fetcher)} != {doc_str(self.fetcher_or_value)}" + return f"{doc_str(self.fetcher)} is not {doc_str(self.fetcher_or_value)}" def inverse(self) -> "Logic": return Equal(self.fetcher, self.fetcher_or_value) @@ -211,3 +211,31 @@ def doc_str(self) -> str: def inverse(self) -> "Logic": # De Morgan's law: not (A or B) = (not A) and (not B) return And(*(constraint.inverse() for constraint in self.constraints)) + + +class If(Logic): + def __init__(self, condition: Logic, then_branch: Logic, else_branch: Optional[Logic] = None): + super().__init__() + self.condition = condition + self.then_branch = then_branch + self.else_branch = else_branch + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + condition_result = self.condition(args, returns) + if condition_result: + return self.then_branch(args, returns) + else: + return self.else_branch(args, returns) if self.else_branch else Result.ok() + + def __str__(self): + if self.else_branch: + return f"if ({self.condition}) then ({self.then_branch}) else ({self.else_branch})" + return f"if ({self.condition}) then ({self.then_branch})" + + def doc_str(self) -> str: + if self.else_branch: + return f"{doc_str(self.then_branch)} **if** {doc_str(self.condition)}, **otherwise** {doc_str(self.else_branch)}" + return f"if {doc_str(self.condition)}, then {doc_str(self.then_branch)}" + + def inverse(self) -> "Logic": + return If(self.condition, self.then_branch.inverse(), self.else_branch.inverse() if self.else_branch else None) diff --git a/tripy/tests/frontend/constraints/test_doc_str.py b/tripy/tests/frontend/constraints/test_doc_str.py index 907ef6268..d199ae0f7 100644 --- a/tripy/tests/frontend/constraints/test_doc_str.py +++ b/tripy/tests/frontend/constraints/test_doc_str.py @@ -16,50 +16,21 @@ # import nvtripy as tp -from nvtripy.frontend.constraints import And, Equal, NotEqual, NotOneOf, OneOf, Or, doc_str -from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn +import pytest +from nvtripy.frontend.constraints import And, Equal, OneOf, Or, doc_str +from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput class TestDocStr: - def test_basic_types(self): - assert doc_str(tp.float32) == ":class:`float32`" - assert doc_str(GetInput("x")) == "``x``" - assert doc_str(GetReturn(0)) == "``return[0]``" - - def test_get_datatype(self): - assert doc_str(GetDataType(GetInput("x"))) == "``x.dtype``" - assert doc_str(GetDataType(GetReturn(0))) == "``return[0].dtype``" - - def test_one_of_and_not_one_of(self): - input_x = GetInput("x") - - assert ( - doc_str(OneOf(input_x, [tp.float32, tp.float16])) == "``x`` is one of [:class:`float32`, :class:`float16`]" - ) - assert doc_str(NotOneOf(input_x, [tp.int8, tp.int32])) == "``x`` is not one of [:class:`int8`, :class:`int32`]" - - def test_equal_and_not_equal(self): - input_a = GetInput("a") - input_b = GetInput("b") - - assert doc_str(Equal(input_a, input_b)) == "``a`` == ``b``" - assert doc_str(Equal(input_a, tp.float32)) == "``a`` == :class:`float32`" - assert doc_str(NotEqual(input_a, input_b)) == "``a`` != ``b``" - - def test_and_constraint(self): - constraint1 = OneOf(GetInput("a"), [tp.float32]) - constraint2 = OneOf(GetInput("b"), [tp.int32]) - - assert ( - doc_str(And(constraint1, constraint2)) - == "- ``a`` is one of [:class:`float32`], **and**\n- ``b`` is one of [:class:`int32`]" - ) - - def test_or_constraint(self): - input_a = GetInput("a") - or_constraint = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) - - assert doc_str(or_constraint) == "(``a`` == :class:`float32` *or* ``a`` == :class:`float16`)" + @pytest.mark.parametrize( + "obj, expected", + [ + (tp.float32, ":class:`float32`"), + (None, "``None``"), + ], + ) + def test_basic_types(self, obj, expected): + assert doc_str(obj) == expected def test_nested_constraints(self): input_a = GetInput("a") diff --git a/tripy/tests/frontend/constraints/test_fetcher.py b/tripy/tests/frontend/constraints/test_fetcher.py index dbfcc3b40..ac67660f1 100644 --- a/tripy/tests/frontend/constraints/test_fetcher.py +++ b/tripy/tests/frontend/constraints/test_fetcher.py @@ -16,7 +16,7 @@ # import nvtripy as tp from nvtripy.common.exception import TripyException -from nvtripy.frontend.constraints import Equal, GetDataType, GetInput, GetReturn, NotEqual +from nvtripy.frontend.constraints import Equal, GetDataType, GetInput, GetReturn, NotEqual, doc_str from tests import helper @@ -56,6 +56,9 @@ def test_str(self): fetcher = GetInput("data") assert str(fetcher) == "data" + def test_doc_str(self): + assert doc_str(GetInput("x")) == "``x``" + class TestGetReturn: def test_init(self): @@ -77,6 +80,9 @@ def test_str(self): fetcher2 = GetReturn(2) assert str(fetcher2) == "return[2]" + def test_doc_str(self): + assert doc_str(GetReturn(0)) == "``return[0]``" + class TestGetDataType: def test_call(self): @@ -104,3 +110,7 @@ def test_call_with_nested_sequence_error(self): fetcher = GetDataType(GetInput("input_data")) with helper.raises(TripyException, match="Could not determine data type"): fetcher([("input_data", [tp.ones((2, 3), dtype=tp.float32), [42]])]) + + def test_doc_str(self): + assert doc_str(GetDataType(GetInput("x"))) == "``x.dtype``" + assert doc_str(GetDataType(GetReturn(0))) == "``return[0].dtype``" diff --git a/tripy/tests/frontend/constraints/test_logic.py b/tripy/tests/frontend/constraints/test_logic.py index c23d500b5..c0ea769f4 100644 --- a/tripy/tests/frontend/constraints/test_logic.py +++ b/tripy/tests/frontend/constraints/test_logic.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from nvtripy.frontend.constraints import And, Equal, GetInput, NotEqual, NotOneOf, OneOf, Or +import nvtripy as tp +from nvtripy.frontend.constraints import And, Equal, GetInput, If, NotEqual, NotOneOf, OneOf, Or, doc_str +from nvtripy.frontend.constraints.fetcher import GetDataType class TestLogic: @@ -76,6 +78,10 @@ def test_inverse(self): assert inverse([("param", 5)]) assert not inverse([("param", 2)]) + def test_doc_str(self): + constraint = OneOf(GetInput("x"), [tp.float32, tp.float16]) + assert doc_str(constraint) == "``x`` is one of [:class:`float32`, :class:`float16`]" + class TestNotOneOf: def test_call(self): @@ -95,6 +101,10 @@ def test_inverse(self): assert inverse([("param", 2)]) assert not inverse([("param", 5)]) + def test_doc_str(self): + constraint = NotOneOf(GetInput("x"), [tp.int8, tp.int32]) + assert doc_str(constraint) == "``x`` is not one of [:class:`int8`, :class:`int32`]" + class TestAnd: def test_call_all_pass(self): @@ -127,6 +137,13 @@ def test_inverse(self): assert isinstance(inverse, Or) assert str(inverse) == "(param1 is not one of [1, 2, 3] or param2 is not one of ['a', 'b'])" + def test_doc_str(self): + and_constraint = And(OneOf(GetInput("a"), [tp.float32]), OneOf(GetInput("b"), [tp.int32])) + assert ( + doc_str(and_constraint) + == "- ``a`` is one of [:class:`float32`], **and**\n- ``b`` is one of [:class:`int32`]" + ) + class TestOr: def test_call_first_passes(self): @@ -170,6 +187,10 @@ def test_inverse(self): assert isinstance(inverse, And) assert str(inverse) == "(param1 is not one of [1, 2, 3] and param2 is not one of ['a', 'b'])" + def test_doc_str(self): + or_constraint = Or(Equal(GetInput("a"), tp.float32), Equal(GetInput("a"), tp.float16)) + assert doc_str(or_constraint) == "(``a`` is :class:`float32` *or* ``a`` is :class:`float16`)" + class TestEqual: def test_call(self): @@ -194,6 +215,10 @@ def test_inverse(self): assert inverse([("param1", 10)]) assert not inverse([("param1", 5)]) + def test_doc_str(self): + assert doc_str(Equal(GetInput("a"), GetInput("b"))) == "``a`` is ``b``" + assert doc_str(Equal(GetInput("a"), tp.float32)) == "``a`` is :class:`float32`" + class TestNotEqual: def test_call(self): @@ -217,3 +242,96 @@ def test_inverse(self): assert isinstance(inverse, Equal) assert inverse([("param1", 5)]) assert not inverse([("param1", 10)]) + + def test_doc_str(self): + assert doc_str(NotEqual(GetInput("a"), GetInput("b"))) == "``a`` is not ``b``" + + +class TestIf: + def test_call(self): + if_constraint = If( + Equal(GetInput("param1"), 5), + OneOf(GetInput("param2"), [1, 2, 3]), + OneOf(GetInput("param2"), [10, 20, 30]), + ) + # Condition true, then branch passes + assert if_constraint([("param1", 5), ("param2", 2)]) + # Condition true, then branch fails + result = if_constraint([("param1", 5), ("param2", 10)]) + assert not result + assert "'param2' to be one of [1, 2, 3] (but it was '10')" in result.error_details + # Condition false, else branch passes + assert if_constraint([("param1", 10), ("param2", 20)]) + # Condition false, else branch fails + result = if_constraint([("param1", 10), ("param2", 100)]) + assert not result + assert "'param2' to be one of [10, 20, 30] (but it was '100')" in result.error_details + + def test_str(self): + if_constraint = If( + Equal(GetInput("param1"), 5), + OneOf(GetInput("param2"), [1, 2, 3]), + OneOf(GetInput("param2"), [10, 20]), + ) + assert ( + str(if_constraint) == "if (param1 == 5) then (param2 is one of [1, 2, 3]) else (param2 is one of [10, 20])" + ) + + def test_inverse(self): + if_constraint = If( + Equal(GetInput("param1"), 5), + OneOf(GetInput("param2"), [1, 2, 3]), + OneOf(GetInput("param2"), [10, 20]), + ) + inverse = ~if_constraint + assert isinstance(inverse.then_branch, NotOneOf) + assert isinstance(inverse.else_branch, NotOneOf) + # When param1 == 5, param2 should NOT be in [1, 2, 3] + assert inverse([("param1", 5), ("param2", 10)]) + assert not inverse([("param1", 5), ("param2", 2)]) + + def test_doc_str(self): + if_constraint = If( + Equal(GetDataType(GetInput("a")), tp.float32), + OneOf(GetInput("b"), [tp.float32, tp.float16]), + OneOf(GetInput("b"), [tp.int32, tp.int64]), + ) + assert ( + doc_str(if_constraint) + == "``b`` is one of [:class:`float32`, :class:`float16`] **if** ``a.dtype`` is :class:`float32`, **otherwise** ``b`` is one of [:class:`int32`, :class:`int64`]" + ) + + def test_call_without_else_branch(self): + if_constraint = If(Equal(GetInput("param1"), 5), OneOf(GetInput("param2"), [1, 2, 3])) + # Condition true, then branch passes + assert if_constraint([("param1", 5), ("param2", 2)]) + # Condition true, then branch fails + result = if_constraint([("param1", 5), ("param2", 10)]) + assert not result + assert "'param2' to be one of [1, 2, 3] (but it was '10')" in result.error_details + # Condition false, no else branch - should always pass + assert if_constraint([("param1", 10), ("param2", 999)]) + + def test_str_without_else_branch(self): + if_constraint = If(Equal(GetInput("param1"), 5), OneOf(GetInput("param2"), [1, 2, 3])) + assert str(if_constraint) == "if (param1 == 5) then (param2 is one of [1, 2, 3])" + + def test_doc_str_without_else_branch(self): + if_constraint = If( + Equal(GetDataType(GetInput("a")), tp.float32), OneOf(GetInput("b"), [tp.float32, tp.float16]) + ) + assert ( + doc_str(if_constraint) + == "if ``a.dtype`` is :class:`float32`, then ``b`` is one of [:class:`float32`, :class:`float16`]" + ) + + def test_inverse_without_else_branch(self): + if_constraint = If(Equal(GetInput("param1"), 5), OneOf(GetInput("param2"), [1, 2, 3])) + inverse = ~if_constraint + assert isinstance(inverse.then_branch, NotOneOf) + assert inverse.else_branch is None + # When param1 == 5, param2 should NOT be in [1, 2, 3] + assert inverse([("param1", 5), ("param2", 10)]) + assert not inverse([("param1", 5), ("param2", 2)]) + # When param1 != 5, should always pass + assert inverse([("param1", 10), ("param2", 2)]) From d52222d8bb16d5e2fa50858723e130c175482abe Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 19 Nov 2025 18:33:36 +0000 Subject: [PATCH 17/32] Ports ones to new constraint system, updates `merge_function_arguments` - Ports `ones` and `ones_like` to use the new constraint system. - Updates `merge_function_arguments` to return default values for omitted arguments. --- tripy/nvtripy/backend/api/compile.py | 2 +- tripy/nvtripy/frontend/ops/ones.py | 38 ++++++++++++++++++---------- tripy/nvtripy/frontend/wrappers.py | 8 ++++-- tripy/nvtripy/utils/utils.py | 27 +++++++++++++++----- tripy/tests/utils/test_utils.py | 32 +++++++++++++++++++++++ 5 files changed, 85 insertions(+), 22 deletions(-) diff --git a/tripy/nvtripy/backend/api/compile.py b/tripy/nvtripy/backend/api/compile.py index d9296f53a..befc2386f 100644 --- a/tripy/nvtripy/backend/api/compile.py +++ b/tripy/nvtripy/backend/api/compile.py @@ -248,7 +248,7 @@ def process_arg(name, arg): compiled_arg_names = [] new_args = [] - positional_arg_info, variadic_info = utils.utils.get_positional_arg_names(func, *args) + positional_arg_info, variadic_info = utils.utils.get_positional_args_with_names(func, *args) varargs_name = None varargs_index = None diff --git a/tripy/nvtripy/frontend/ops/ones.py b/tripy/nvtripy/frontend/ops/ones.py index 154d5863c..bc5ad5ba5 100644 --- a/tripy/nvtripy/frontend/ops/ones.py +++ b/tripy/nvtripy/frontend/ops/ones.py @@ -15,21 +15,22 @@ from typing import Optional from nvtripy import export -from nvtripy.common import datatype -from nvtripy.frontend.ops.full import full, full_like +from nvtripy.common import datatype as dt from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, If, OneOf +from nvtripy.frontend.ops.full import full, full_like @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), ) def ones( shape: "nvtripy.types.ShapeLike", - dtype: datatype.dtype = datatype.float32, + dtype: dt.dtype = dt.float32, ) -> "nvtripy.Tensor": """ Creates a Tensor of the specified shape and dtype with all elements set to 1. @@ -55,13 +56,24 @@ def ones( @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int8, dt.int32, dt.int64, dt.bool], + ) + & If( + GetInput("dtype") != None, + OneOf( + GetInput("dtype"), + [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool], + ), + ), + output_guarantees=If( + GetInput("dtype") != None, + GetReturn(0).dtype == GetInput("dtype"), + GetReturn(0).dtype == GetInput("input").dtype, + ), ) -def ones_like(input: "nvtripy.Tensor", dtype: Optional[datatype.dtype] = None) -> "nvtripy.Tensor": +def ones_like(input: "nvtripy.Tensor", dtype: Optional[dt.dtype] = None) -> "nvtripy.Tensor": """ Creates a tensor with all elements set to 1 of the same shape as the input tensor. diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index b491b3c88..208507fb0 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -116,6 +116,7 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name): source_info.column_range = candidates[0] +# TODO (pranavm): Remove this: def get_arg_dtype(arg, func_name, arg_name) -> result.Result["nvtripy.dtype"]: from nvtripy.common.datatype import dtype from nvtripy.frontend.tensor import Tensor @@ -489,7 +490,9 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - merged_args, var_arg_info = utils.utils.merge_function_arguments(func, *args, **kwargs) + merged_args, omitted_default_args, var_arg_info = utils.utils.merge_function_arguments( + func, *args, **kwargs + ) if convert_to_tensors: args, kwargs, merged_args = convert_input_types( @@ -507,7 +510,8 @@ def wrapper(*args, **kwargs): if config.enable_input_validation: if input_requirements is not None: - result = input_requirements(merged_args) + # Input validation needs to know values for arguments that were not provided but have default values: + result = input_requirements(merged_args + omitted_default_args) if not result: raise_error( f"Invalid inputs for function: '{func.__qualname__}'.", diff --git a/tripy/nvtripy/utils/utils.py b/tripy/nvtripy/utils/utils.py index 572cfb13d..4d83f99f6 100644 --- a/tripy/nvtripy/utils/utils.py +++ b/tripy/nvtripy/utils/utils.py @@ -333,7 +333,7 @@ def gen_uid(inputs=None, outputs=None): ## ## Functions ## -def get_positional_arg_names(func, *args) -> Tuple[List[Tuple[str, Any]], Optional[Tuple[str, int]]]: +def get_positional_args_with_names(func, *args) -> Tuple[List[Tuple[str, Any]], Optional[Tuple[str, int]]]: # Returns the names of positional arguments by inspecting the function signature. # In the case of variadic positional arguments, we cannot determine names, so we use # None instead. To assist in further processing, this function also returns the name @@ -357,9 +357,24 @@ def get_positional_arg_names(func, *args) -> Tuple[List[Tuple[str, Any]], Option return list(zip(arg_names, args)), (varargs_name, variadic_start_idx) if num_variadic_args > 0 else None -def merge_function_arguments(func, *args, **kwargs) -> Tuple[List[Tuple[str, Any]], Optional[Tuple[str, int]]]: - # Merge positional and keyword arguments, trying to determine names where possible. - # Also returns a pair containing the variadic arg name and start index if present (None otherwise). - all_args, var_arg_info = get_positional_arg_names(func, *args) +def merge_function_arguments( + func, *args, **kwargs +) -> Tuple[List[Tuple[str, Any]], Tuple[str, Any], Optional[Tuple[str, int]]]: + # Returns 3 things: + # 1. A list of all arguments (positional and keyword) as (name, value) pairs. + # 2. A list of omitted arguments with default values filled in. + # 3. A pair containing the variadic arg name and start index if present (None otherwise). + all_args, var_arg_info = get_positional_args_with_names(func, *args) all_args.extend(kwargs.items()) - return all_args, var_arg_info + + signature = inspect.signature(func) + provided_arg_names = {name for name, _ in all_args} + + omitted_args = [] + for name, param in signature.parameters.items(): + if name in provided_arg_names: + continue + if param.default is not inspect.Parameter.empty: + omitted_args.append((name, param.default)) + + return all_args, omitted_args, var_arg_info diff --git a/tripy/tests/utils/test_utils.py b/tripy/tests/utils/test_utils.py index 568e6d223..298e4f166 100644 --- a/tripy/tests/utils/test_utils.py +++ b/tripy/tests/utils/test_utils.py @@ -109,3 +109,35 @@ def test_gen_uid(self, inputs, outputs, expected_prefix): def test_uniqueness(self): uids = [utils.utils.UniqueNameGen.gen_uid() for _ in range(100)] assert len(set(uids)) == 100 + + +class TestMergeFunctionArguments: + def test_defaults_filled_in_with_mixed_args(self): + # Tests that defaults are filled in and positioned correctly before kwargs + def func(a, b=2, c=3, d=4): + pass + + all_args, omitted_args, var_arg_info = utils.utils.merge_function_arguments(func, 1, d=99) + assert all_args == [("a", 1), ("d", 99)] + assert omitted_args == [("b", 2), ("c", 3)] + assert var_arg_info is None + + def test_variadic_positional_args(self): + # Tests variadic args tracking and defaults after variadic args + def func(a, *args, b=10): + pass + + all_args, omitted_args, var_arg_info = utils.utils.merge_function_arguments(func, 1, 2, 3) + assert all_args == [("a", 1), ("args", 2), ("args", 3)] + assert omitted_args == [("b", 10)] + assert var_arg_info == ("args", 1) + + def test_no_variadic_args_provided(self): + # Tests that var_arg_info is None when no variadic args are provided + def func(a, *args, b=10): + pass + + all_args, omitted_args, var_arg_info = utils.utils.merge_function_arguments(func, 1) + assert all_args == [("a", 1)] + assert omitted_args == [("b", 10)] + assert var_arg_info is None From de02edf1fd7b986603c02a275a9bb923f5a6f06e Mon Sep 17 00:00:00 2001 From: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com> Date: Wed, 19 Nov 2025 10:43:59 -0800 Subject: [PATCH 18/32] Update tripy/nvtripy/utils/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com> --- tripy/nvtripy/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tripy/nvtripy/utils/utils.py b/tripy/nvtripy/utils/utils.py index 4d83f99f6..6a4ed8cf4 100644 --- a/tripy/nvtripy/utils/utils.py +++ b/tripy/nvtripy/utils/utils.py @@ -359,7 +359,7 @@ def get_positional_args_with_names(func, *args) -> Tuple[List[Tuple[str, Any]], def merge_function_arguments( func, *args, **kwargs -) -> Tuple[List[Tuple[str, Any]], Tuple[str, Any], Optional[Tuple[str, int]]]: +) -> Tuple[List[Tuple[str, Any]], List[Tuple[str, Any]], Optional[Tuple[str, int]]]: # Returns 3 things: # 1. A list of all arguments (positional and keyword) as (name, value) pairs. # 2. A list of omitted arguments with default values filled in. From 635729b56ad47e22e8e55d836429c28d7e0de96a Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 19 Nov 2025 19:26:45 +0000 Subject: [PATCH 19/32] Migrate unary operators to new constraint system --- ...teOperatorsToNewConstraintSystem.prompt.md | 22 +++++++++++++++++++ tripy/nvtripy/frontend/ops/unary/abs.py | 8 +++++-- tripy/nvtripy/frontend/ops/unary/cos.py | 6 +++-- tripy/nvtripy/frontend/ops/unary/exp.py | 6 +++-- tripy/nvtripy/frontend/ops/unary/gelu.py | 8 +++---- tripy/nvtripy/frontend/ops/unary/invert.py | 6 +++-- tripy/nvtripy/frontend/ops/unary/log.py | 6 +++-- tripy/nvtripy/frontend/ops/unary/neg.py | 8 +++++-- tripy/nvtripy/frontend/ops/unary/relu.py | 10 +++++---- tripy/nvtripy/frontend/ops/unary/rsqrt.py | 6 +++-- tripy/nvtripy/frontend/ops/unary/sigmoid.py | 8 +++---- tripy/nvtripy/frontend/ops/unary/silu.py | 8 +++---- tripy/nvtripy/frontend/ops/unary/sin.py | 6 +++-- tripy/nvtripy/frontend/ops/unary/sqrt.py | 6 +++-- tripy/nvtripy/frontend/ops/unary/tanh.py | 6 +++-- 15 files changed, 84 insertions(+), 36 deletions(-) create mode 100644 tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md diff --git a/tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md b/tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md new file mode 100644 index 000000000..dc206ebe8 --- /dev/null +++ b/tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md @@ -0,0 +1,22 @@ +## Plan: Migrate Operators to New Constraint System + +This plan will systematically port 86 operators across 66 files from the legacy `dtype_constraints`/`dtype_variables`/`dtype_exceptions` system to the new `input_requirements`/`output_guarantees` constraint system. The migration follows proven patterns from [`cast.py`](nvtripy/frontend/ops/cast.py), [`concatenate.py`](nvtripy/frontend/ops/concatenate.py), and [`ones.py`](nvtripy/frontend/ops/ones.py). + +### Steps + +1. **Migrate simple type-preserving unary operators** (~16 ops in [`nvtripy/frontend/ops/unary/`](nvtripy/frontend/ops/unary/)): Convert operators like `exp`, `log`, `sqrt`, `sin`, `cos`, `tanh`, `sigmoid`, `relu`, `gelu`, `silu`, `abs`, `neg`, `invert` using pattern: `input_requirements=OneOf(GetInput("input").dtype, [dtypes])` and `output_guarantees=GetReturn(0).dtype == GetInput("input").dtype`. Run tests with `pytest tests/integration/test_operator_constraints.py -k "exp|log|sqrt" -s` and corresponding sanity tests. **Commit changes** with message like "Migrate unary operators to new constraint system". + +2. **Migrate shape manipulation operators** (~12 ops in [`nvtripy/frontend/ops/`](nvtripy/frontend/ops/)): Convert `reshape`, `transpose`, `permute`, `flatten`, `squeeze`, `unsqueeze`, `expand`, `repeat`, `stack`, `slice`, `flip` using the same type-preserving pattern. Test with `pytest tests/integration/test_operator_constraints.py -k "reshape|transpose|permute" -s`. **Commit changes** with message like "Migrate shape manipulation operators to new constraint system". + +3. **Migrate binary arithmetic and comparison operators** (~16 ops in [`nvtripy/frontend/ops/binary/`](nvtripy/frontend/ops/binary/)): Convert `add`, `sub`, `mul`, `div`, `floor_div`, `mod`, `pow`, `maximum`, `minimum`, `equal`, `not_equal`, `greater`, `greater_equal`, `less`, `less_equal`, `logical_or` using pattern: `OneOf(GetInput("self").dtype, [...]) & (GetInput("other").dtype == GetInput("self").dtype)`. Comparison ops return `dt.bool`. Test with `pytest tests/integration/test_operator_constraints.py -k "add|mul|equal" -s`. **Commit changes** with message like "Migrate binary operators to new constraint system". + +4. **Migrate reduction operators** (~11 ops in [`nvtripy/frontend/ops/reduce/`](nvtripy/frontend/ops/reduce/)): Convert `sum`, `prod`, `mean`, `max`, `min`, `var`, `all`, `any`, plus multi-type ops `argmax`, `argmin`, `topk` (note: `topk` has multiple returns). Use `If` constraints for multi-type variables. Test with `pytest tests/integration/test_operator_constraints.py -k "sum|mean|argmax" -s`. **Commit changes** with message like "Migrate reduction operators to new constraint system". + +5. **Migrate initializers with optional dtype parameters** (~8 ops in [`nvtripy/frontend/ops/`](nvtripy/frontend/ops/)): Convert `zeros`, `zeros_like`, `full`, `full_like`, `iota`, `arange` using `If(GetInput("dtype") != None, ...)` pattern for conditional output guarantees. Test with `pytest tests/integration/test_operator_constraints.py -k "zeros|full|iota|arange" -s`. **Commit changes** with message like "Migrate initializer operators to new constraint system". + +6. **Migrate advanced operations and special cases** (~15 ops): Convert `matmul`, `gather`, `where`, `masked_fill`, `outer`, `pad`, `softmax`, `cumsum`, `copy`, `resize`, `triu`, `tril`, `avgpool`, `maxpool` in their respective files. Handle multi-type variables and complex logic. Migrate `quantize`/`dequantize` in [`plugin_qdp.py`](nvtripy/frontend/ops/plugin_qdp.py) and [`quantize.py`](nvtripy/frontend/ops/quantize.py)/[`dequantize.py`](nvtripy/frontend/ops/dequantize.py) with coordinated type pairs. Test each with `-k` filtering. **Commit changes** with message like "Migrate advanced operators to new constraint system". + +### Implementation Notes + +- Translate `dtype_exceptions` directly to boolean logic using `&`, `|`, `~` operators without creating helper functions +- No modifications to `CUSTOM_VALUES` in test file needed - it's already complete for dtype constraints diff --git a/tripy/nvtripy/frontend/ops/unary/abs.py b/tripy/nvtripy/frontend/ops/unary/abs.py index 6ae52b523..86d45672f 100644 --- a/tripy/nvtripy/frontend/ops/unary/abs.py +++ b/tripy/nvtripy/frontend/ops/unary/abs.py @@ -16,16 +16,20 @@ # +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Abs from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("__abs__") @wrappers.interface( - dtype_constraints={"self": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64] + ), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, ) def __abs__(self: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/cos.py b/tripy/nvtripy/frontend/ops/unary/cos.py index fbf9ba35d..9871e6d2f 100644 --- a/tripy/nvtripy/frontend/ops/unary/cos.py +++ b/tripy/nvtripy/frontend/ops/unary/cos.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Cos from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def cos(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/exp.py b/tripy/nvtripy/frontend/ops/unary/exp.py index 837f89051..22112a1f1 100644 --- a/tripy/nvtripy/frontend/ops/unary/exp.py +++ b/tripy/nvtripy/frontend/ops/unary/exp.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Exp from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def exp(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/gelu.py b/tripy/nvtripy/frontend/ops/unary/gelu.py index 44a9d60bd..5416282b8 100644 --- a/tripy/nvtripy/frontend/ops/unary/gelu.py +++ b/tripy/nvtripy/frontend/ops/unary/gelu.py @@ -17,17 +17,17 @@ from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend import wrappers from nvtripy.trace.ops.unary import GeluErf from nvtripy.frontend.ops import utils as op_utils +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8"], - }, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def gelu(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/invert.py b/tripy/nvtripy/frontend/ops/unary/invert.py index 8fea29d6a..b191cfe7f 100644 --- a/tripy/nvtripy/frontend/ops/unary/invert.py +++ b/tripy/nvtripy/frontend/ops/unary/invert.py @@ -12,16 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Not from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("__invert__") @wrappers.interface( - dtype_constraints={"self": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["bool"]}, + input_requirements=OneOf(GetInput("self").dtype, [dt.bool]), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, ) def __invert__(self: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/log.py b/tripy/nvtripy/frontend/ops/unary/log.py index 8d21af689..1e1fd87b9 100644 --- a/tripy/nvtripy/frontend/ops/unary/log.py +++ b/tripy/nvtripy/frontend/ops/unary/log.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Log from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def log(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/neg.py b/tripy/nvtripy/frontend/ops/unary/neg.py index 0e08130a0..a446a9372 100644 --- a/tripy/nvtripy/frontend/ops/unary/neg.py +++ b/tripy/nvtripy/frontend/ops/unary/neg.py @@ -14,16 +14,20 @@ # limitations under the License. +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Neg from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("__neg__") @wrappers.interface( - dtype_constraints={"self": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64] + ), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, ) def __neg__(self: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/relu.py b/tripy/nvtripy/frontend/ops/unary/relu.py index 531cff1dc..61526001f 100644 --- a/tripy/nvtripy/frontend/ops/unary/relu.py +++ b/tripy/nvtripy/frontend/ops/unary/relu.py @@ -20,14 +20,16 @@ from nvtripy.trace.ops.unary import Relu from nvtripy.frontend import wrappers from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int32", "int64", "int8"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int32, dt.int64, dt.int8] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def relu(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/rsqrt.py b/tripy/nvtripy/frontend/ops/unary/rsqrt.py index 9041deef6..5fe2e095f 100644 --- a/tripy/nvtripy/frontend/ops/unary/rsqrt.py +++ b/tripy/nvtripy/frontend/ops/unary/rsqrt.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Recip from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def rsqrt(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/sigmoid.py b/tripy/nvtripy/frontend/ops/unary/sigmoid.py index efee06568..c50ed7d2c 100644 --- a/tripy/nvtripy/frontend/ops/unary/sigmoid.py +++ b/tripy/nvtripy/frontend/ops/unary/sigmoid.py @@ -16,17 +16,17 @@ # from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend import wrappers from nvtripy.trace.ops.unary import Sigmoid from nvtripy.frontend.ops import utils as op_utils +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8"], - }, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def sigmoid(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/silu.py b/tripy/nvtripy/frontend/ops/unary/silu.py index 6e359bec5..14a35c612 100644 --- a/tripy/nvtripy/frontend/ops/unary/silu.py +++ b/tripy/nvtripy/frontend/ops/unary/silu.py @@ -16,15 +16,15 @@ # from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8"], - }, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def silu(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/sin.py b/tripy/nvtripy/frontend/ops/unary/sin.py index bf7ae4e1b..eaa905cf9 100644 --- a/tripy/nvtripy/frontend/ops/unary/sin.py +++ b/tripy/nvtripy/frontend/ops/unary/sin.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Sin from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def sin(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/sqrt.py b/tripy/nvtripy/frontend/ops/unary/sqrt.py index b8681357c..c95fafdc6 100644 --- a/tripy/nvtripy/frontend/ops/unary/sqrt.py +++ b/tripy/nvtripy/frontend/ops/unary/sqrt.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Sqrt from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def sqrt(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/tanh.py b/tripy/nvtripy/frontend/ops/unary/tanh.py index 67f9bc595..85c187aeb 100644 --- a/tripy/nvtripy/frontend/ops/unary/tanh.py +++ b/tripy/nvtripy/frontend/ops/unary/tanh.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Tanh from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def tanh(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ From a57e59f12387f1c4c6abab1828094d1bdd8604f5 Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 19 Nov 2025 19:29:15 +0000 Subject: [PATCH 20/32] Migrate shape manipulation operators to new constraint system --- tripy/nvtripy/frontend/ops/expand.py | 10 ++++++---- tripy/nvtripy/frontend/ops/flatten.py | 8 ++++++-- tripy/nvtripy/frontend/ops/flip.py | 10 ++++++---- tripy/nvtripy/frontend/ops/permute.py | 8 ++++++-- tripy/nvtripy/frontend/ops/repeat.py | 10 ++++++---- tripy/nvtripy/frontend/ops/reshape.py | 8 ++++++-- tripy/nvtripy/frontend/ops/slice.py | 8 ++++++-- tripy/nvtripy/frontend/ops/squeeze.py | 8 ++++++-- tripy/nvtripy/frontend/ops/stack.py | 10 ++++++---- tripy/nvtripy/frontend/ops/transpose.py | 8 ++++++-- tripy/nvtripy/frontend/ops/unsqueeze.py | 8 ++++++-- 11 files changed, 66 insertions(+), 30 deletions(-) diff --git a/tripy/nvtripy/frontend/ops/expand.py b/tripy/nvtripy/frontend/ops/expand.py index 41b975bd3..14fde14ad 100644 --- a/tripy/nvtripy/frontend/ops/expand.py +++ b/tripy/nvtripy/frontend/ops/expand.py @@ -17,11 +17,13 @@ from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.broadcast import Broadcast from nvtripy.types import ShapeLike from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf def process_sizes(input: "nvtripy.Tensor", sizes: ShapeLike): @@ -53,10 +55,10 @@ def process_sizes(input: "nvtripy.Tensor", sizes: ShapeLike): @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, convert_to_tensors=True, conversion_preprocess_func=process_sizes, ) diff --git a/tripy/nvtripy/frontend/ops/flatten.py b/tripy/nvtripy/frontend/ops/flatten.py index 85b3b175c..e4db23128 100644 --- a/tripy/nvtripy/frontend/ops/flatten.py +++ b/tripy/nvtripy/frontend/ops/flatten.py @@ -15,17 +15,21 @@ import math from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("flatten") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def flatten(input: "nvtripy.Tensor", start_dim: int = 0, end_dim: int = -1) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/flip.py b/tripy/nvtripy/frontend/ops/flip.py index c2e3d71d1..beb771851 100644 --- a/tripy/nvtripy/frontend/ops/flip.py +++ b/tripy/nvtripy/frontend/ops/flip.py @@ -18,16 +18,18 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def flip(input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/permute.py b/tripy/nvtripy/frontend/ops/permute.py index 8f798ec15..cf6c56fe4 100644 --- a/tripy/nvtripy/frontend/ops/permute.py +++ b/tripy/nvtripy/frontend/ops/permute.py @@ -18,18 +18,22 @@ from typing import Sequence from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.permute import Permute from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("permute") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def permute(input: "nvtripy.Tensor", perm: Sequence[int]) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/repeat.py b/tripy/nvtripy/frontend/ops/repeat.py index 9dc498e4c..5c3f8aea9 100644 --- a/tripy/nvtripy/frontend/ops/repeat.py +++ b/tripy/nvtripy/frontend/ops/repeat.py @@ -15,18 +15,20 @@ # from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.types import IntLike from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def repeat(input: "nvtripy.Tensor", repeats: IntLike, dim: int) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/reshape.py b/tripy/nvtripy/frontend/ops/reshape.py index ae0eb52a0..7843f5074 100644 --- a/tripy/nvtripy/frontend/ops/reshape.py +++ b/tripy/nvtripy/frontend/ops/reshape.py @@ -18,12 +18,14 @@ import math from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.reshape import Reshape from nvtripy.types import ShapeLike from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf def infer_dimensions(input: "nvtripy.Tensor", shape: ShapeLike) -> ShapeLike: @@ -46,8 +48,10 @@ def infer_dimensions(input: "nvtripy.Tensor", shape: ShapeLike) -> ShapeLike: @register_tensor_method("reshape") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, convert_to_tensors=True, conversion_preprocess_func=infer_dimensions, ) diff --git a/tripy/nvtripy/frontend/ops/slice.py b/tripy/nvtripy/frontend/ops/slice.py index 032fc70e5..09e1b72b7 100644 --- a/tripy/nvtripy/frontend/ops/slice.py +++ b/tripy/nvtripy/frontend/ops/slice.py @@ -18,12 +18,14 @@ from typing import Sequence, Union from nvtripy import utils +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.slice import Slice from nvtripy.types import IntLike from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.utils.types import type_str_from_arg from nvtripy.utils.utils import make_list @@ -32,8 +34,10 @@ @register_tensor_method("__getitem__") @wrappers.interface( - dtype_constraints={"self": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, ) def __getitem__( self: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/ops/squeeze.py b/tripy/nvtripy/frontend/ops/squeeze.py index 132fd7836..2aa22911c 100644 --- a/tripy/nvtripy/frontend/ops/squeeze.py +++ b/tripy/nvtripy/frontend/ops/squeeze.py @@ -15,16 +15,20 @@ from typing import Sequence, Union from nvtripy import export, utils +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("squeeze") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def squeeze(input: "nvtripy.Tensor", dims: Union[Sequence[int], int]) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/stack.py b/tripy/nvtripy/frontend/ops/stack.py index e46c8b69f..959923d5c 100644 --- a/tripy/nvtripy/frontend/ops/stack.py +++ b/tripy/nvtripy/frontend/ops/stack.py @@ -16,16 +16,18 @@ from typing import Sequence from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"tensors": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("tensors").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("tensors").dtype, ) def stack(tensors: Sequence["nvtripy.Tensor"], dim: int = 0) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/transpose.py b/tripy/nvtripy/frontend/ops/transpose.py index 9119f6937..7c019bd2d 100644 --- a/tripy/nvtripy/frontend/ops/transpose.py +++ b/tripy/nvtripy/frontend/ops/transpose.py @@ -13,16 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("transpose") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def transpose(input: "nvtripy.Tensor", dim0: int, dim1: int) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unsqueeze.py b/tripy/nvtripy/frontend/ops/unsqueeze.py index e14612d49..3a6c3bf69 100644 --- a/tripy/nvtripy/frontend/ops/unsqueeze.py +++ b/tripy/nvtripy/frontend/ops/unsqueeze.py @@ -16,16 +16,20 @@ # from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("unsqueeze") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def unsqueeze(input: "nvtripy.Tensor", dim: int) -> "nvtripy.Tensor": """ From 4a9c8b4aff36aa0c71f7f2bb0d971315b4e5461f Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 19 Nov 2025 20:24:54 +0000 Subject: [PATCH 21/32] Migrate binary operators to new constraint system --- tripy/nvtripy/frontend/ops/binary/add.py | 9 +++++++-- tripy/nvtripy/frontend/ops/binary/div.py | 16 ++++++++++++---- tripy/nvtripy/frontend/ops/binary/equal.py | 12 +++++++----- tripy/nvtripy/frontend/ops/binary/floor_div.py | 16 ++++++++++++---- tripy/nvtripy/frontend/ops/binary/greater.py | 12 +++++++----- .../nvtripy/frontend/ops/binary/greater_equal.py | 12 +++++++----- tripy/nvtripy/frontend/ops/binary/less.py | 12 +++++++----- tripy/nvtripy/frontend/ops/binary/less_equal.py | 12 +++++++----- tripy/nvtripy/frontend/ops/binary/logical_or.py | 6 ++++-- tripy/nvtripy/frontend/ops/binary/maximum.py | 9 +++++++-- tripy/nvtripy/frontend/ops/binary/minimum.py | 9 +++++++-- tripy/nvtripy/frontend/ops/binary/mod.py | 16 ++++++++++++---- tripy/nvtripy/frontend/ops/binary/mul.py | 9 +++++++-- tripy/nvtripy/frontend/ops/binary/not_equal.py | 12 +++++++----- tripy/nvtripy/frontend/ops/binary/pow.py | 12 ++++++++---- tripy/nvtripy/frontend/ops/binary/sub.py | 16 ++++++++++++---- 16 files changed, 130 insertions(+), 60 deletions(-) diff --git a/tripy/nvtripy/frontend/ops/binary/add.py b/tripy/nvtripy/frontend/ops/binary/add.py index e936129cd..f6e7dd008 100644 --- a/tripy/nvtripy/frontend/ops/binary/add.py +++ b/tripy/nvtripy/frontend/ops/binary/add.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Add @@ -22,8 +24,11 @@ @register_tensor_method("__add__") @register_tensor_method("__radd__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __add__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/div.py b/tripy/nvtripy/frontend/ops/binary/div.py index 482a06a59..812e3b523 100644 --- a/tripy/nvtripy/frontend/ops/binary/div.py +++ b/tripy/nvtripy/frontend/ops/binary/div.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Div @@ -21,8 +23,11 @@ @register_tensor_method("__truediv__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __truediv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @@ -51,8 +56,11 @@ def __truediv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @register_tensor_method("__rtruediv__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __rtruediv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/equal.py b/tripy/nvtripy/frontend/ops/binary/equal.py index 699882a74..32e38fe5a 100644 --- a/tripy/nvtripy/frontend/ops/binary/equal.py +++ b/tripy/nvtripy/frontend/ops/binary/equal.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Equal @@ -21,11 +23,11 @@ @register_tensor_method("__eq__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __eq__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/floor_div.py b/tripy/nvtripy/frontend/ops/binary/floor_div.py index e221e0fb8..0e71e2a1b 100644 --- a/tripy/nvtripy/frontend/ops/binary/floor_div.py +++ b/tripy/nvtripy/frontend/ops/binary/floor_div.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import FloorDiv @@ -21,8 +23,11 @@ @register_tensor_method("__floordiv__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "bool", "int8", "int32", "int64"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.bool, dt.int8, dt.int32, dt.int64] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __floordiv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @@ -51,8 +56,11 @@ def __floordiv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @register_tensor_method("__rfloordiv__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "bool", "int4", "int8", "int32", "int64"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.bool, dt.int4, dt.int8, dt.int32, dt.int64] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __rfloordiv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/greater.py b/tripy/nvtripy/frontend/ops/binary/greater.py index cf5504d7a..e9d95b1e8 100644 --- a/tripy/nvtripy/frontend/ops/binary/greater.py +++ b/tripy/nvtripy/frontend/ops/binary/greater.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Greater @@ -21,11 +23,11 @@ @register_tensor_method("__gt__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __gt__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/greater_equal.py b/tripy/nvtripy/frontend/ops/binary/greater_equal.py index c75c3cb12..e4db549b6 100644 --- a/tripy/nvtripy/frontend/ops/binary/greater_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/greater_equal.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike from nvtripy.frontend import wrappers @@ -19,11 +21,11 @@ @register_tensor_method("__ge__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __ge__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/less.py b/tripy/nvtripy/frontend/ops/binary/less.py index 6b8431f39..d882bf105 100644 --- a/tripy/nvtripy/frontend/ops/binary/less.py +++ b/tripy/nvtripy/frontend/ops/binary/less.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Less @@ -21,11 +23,11 @@ @register_tensor_method("__lt__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __lt__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/less_equal.py b/tripy/nvtripy/frontend/ops/binary/less_equal.py index 9694fa4e7..d3b4678bc 100644 --- a/tripy/nvtripy/frontend/ops/binary/less_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/less_equal.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike from nvtripy.frontend import wrappers @@ -19,11 +21,11 @@ @register_tensor_method("__le__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __le__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/logical_or.py b/tripy/nvtripy/frontend/ops/binary/logical_or.py index 268f42b73..00e9dbd7a 100644 --- a/tripy/nvtripy/frontend/ops/binary/logical_or.py +++ b/tripy/nvtripy/frontend/ops/binary/logical_or.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import LogicalOr @@ -20,8 +22,8 @@ @register_tensor_method("__or__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["bool"]}, + input_requirements=(GetInput("self").dtype == dt.bool) & (GetInput("other").dtype == dt.bool), + output_guarantees=GetReturn(0).dtype == dt.bool, ) def __or__(self: "nvtripy.Tensor", other: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/binary/maximum.py b/tripy/nvtripy/frontend/ops/binary/maximum.py index c25d8152a..91a4a9fae 100644 --- a/tripy/nvtripy/frontend/ops/binary/maximum.py +++ b/tripy/nvtripy/frontend/ops/binary/maximum.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Max from nvtripy.types import TensorLike @@ -21,8 +23,11 @@ @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"lhs": "T1", "rhs": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("lhs").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("rhs").dtype == GetInput("lhs").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("lhs").dtype, convert_to_tensors=True, ) def maximum(lhs: TensorLike, rhs: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/minimum.py b/tripy/nvtripy/frontend/ops/binary/minimum.py index 19a7e5a74..454c3016f 100644 --- a/tripy/nvtripy/frontend/ops/binary/minimum.py +++ b/tripy/nvtripy/frontend/ops/binary/minimum.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Min from nvtripy.types import TensorLike @@ -21,8 +23,11 @@ @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"lhs": "T1", "rhs": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("lhs").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("rhs").dtype == GetInput("lhs").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("lhs").dtype, convert_to_tensors=True, ) def minimum(lhs: TensorLike, rhs: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/mod.py b/tripy/nvtripy/frontend/ops/binary/mod.py index f22ee1295..10142484e 100644 --- a/tripy/nvtripy/frontend/ops/binary/mod.py +++ b/tripy/nvtripy/frontend/ops/binary/mod.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike from nvtripy.frontend import wrappers @@ -23,8 +25,11 @@ def mod_impl(lhs, rhs): @register_tensor_method("__mod__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __mod__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @@ -53,8 +58,11 @@ def __mod__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @register_tensor_method("__rmod__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __rmod__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/mul.py b/tripy/nvtripy/frontend/ops/binary/mul.py index 0fc2c9811..77a4b101d 100644 --- a/tripy/nvtripy/frontend/ops/binary/mul.py +++ b/tripy/nvtripy/frontend/ops/binary/mul.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Mul @@ -22,8 +24,11 @@ @register_tensor_method("__mul__") @register_tensor_method("__rmul__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __mul__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/not_equal.py b/tripy/nvtripy/frontend/ops/binary/not_equal.py index d4a8b4fbb..e48e454be 100644 --- a/tripy/nvtripy/frontend/ops/binary/not_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/not_equal.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike from nvtripy.frontend import wrappers @@ -19,11 +21,11 @@ @register_tensor_method("__ne__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __ne__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/pow.py b/tripy/nvtripy/frontend/ops/binary/pow.py index 17b55e1d3..cbea9bd4b 100644 --- a/tripy/nvtripy/frontend/ops/binary/pow.py +++ b/tripy/nvtripy/frontend/ops/binary/pow.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Pow @@ -21,8 +23,9 @@ @register_tensor_method("__pow__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int64"]}, + input_requirements=OneOf(GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int64]) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __pow__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @@ -51,8 +54,9 @@ def __pow__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @register_tensor_method("__rpow__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int64"]}, + input_requirements=OneOf(GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int64]) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __rpow__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/sub.py b/tripy/nvtripy/frontend/ops/binary/sub.py index 7b4bc4d6e..24cf69135 100644 --- a/tripy/nvtripy/frontend/ops/binary/sub.py +++ b/tripy/nvtripy/frontend/ops/binary/sub.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Sub @@ -21,8 +23,11 @@ @register_tensor_method("__sub__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __sub__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @@ -51,8 +56,11 @@ def __sub__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @register_tensor_method("__rsub__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __rsub__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": From f1d90d37299f94a4ae8d4c9f7fdf88e8053b9ed4 Mon Sep 17 00:00:00 2001 From: pranavm Date: Mon, 22 Dec 2025 17:43:38 +0000 Subject: [PATCH 22/32] Migrate reduction operators to new constraint system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Migrated 11 reduction operators: all, any, sum, prod, mean, max, min, var, argmax, argmin, topk - Type-preserving operators (sum, prod, mean, max, min): maintain input dtype for numeric types - Bool-only operators (all, any): only accept bool dtype with explicit runtime validation - Type-changing operators (argmax, argmin): multiple input types → int32 - Tuple return operator (topk): returns (values, indices) with coordinated types - all/any now properly reject non-bool inputs to avoid incorrect casting behavior - 216 tests passed for all reduction operators --- tripy/nvtripy/frontend/ops/reduce/all.py | 23 +++++++++++++++++---- tripy/nvtripy/frontend/ops/reduce/any.py | 23 +++++++++++++++++---- tripy/nvtripy/frontend/ops/reduce/argmax.py | 6 ++++-- tripy/nvtripy/frontend/ops/reduce/argmin.py | 6 ++++-- tripy/nvtripy/frontend/ops/reduce/max.py | 6 ++++-- tripy/nvtripy/frontend/ops/reduce/mean.py | 6 ++++-- tripy/nvtripy/frontend/ops/reduce/min.py | 6 ++++-- tripy/nvtripy/frontend/ops/reduce/prod.py | 6 ++++-- tripy/nvtripy/frontend/ops/reduce/sum.py | 6 ++++-- tripy/nvtripy/frontend/ops/reduce/topk.py | 10 ++++----- tripy/nvtripy/frontend/ops/reduce/var.py | 6 ++++-- 11 files changed, 74 insertions(+), 30 deletions(-) diff --git a/tripy/nvtripy/frontend/ops/reduce/all.py b/tripy/nvtripy/frontend/ops/reduce/all.py index d7dc507e4..a8cb8f0f8 100644 --- a/tripy/nvtripy/frontend/ops/reduce/all.py +++ b/tripy/nvtripy/frontend/ops/reduce/all.py @@ -15,14 +15,15 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn from nvtripy.frontend import wrappers -from nvtripy.common import datatype @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["bool"]}, + input_requirements=GetInput("input").dtype == dt.bool, + output_guarantees=GetReturn(0).dtype == dt.bool, ) def all( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False @@ -50,5 +51,19 @@ def all( """ from nvtripy.frontend.ops.reduce.prod import prod from nvtripy.frontend.ops.cast import cast + from nvtripy.common.exception import raise_error - return cast(prod(cast(input, dtype=datatype.int32), dim, keepdim), dtype=datatype.bool) + # Validate that input is bool - constraint system has already checked this + # but we need to enforce it at runtime when validation is disabled + if input.dtype != dt.bool: + raise_error( + f"Input must have bool dtype for all(), but got {input.dtype}.", + [ + "This function only accepts bool tensors. ", + "Note: If you need to check if all elements are non-zero, first compare with zero: ", + "tp.all(input != 0)", + ], + ) + + # Cast to int32 since prod doesn't accept bool, then cast back to bool + return cast(prod(cast(input, dtype=dt.int32), dim, keepdim), dtype=dt.bool) diff --git a/tripy/nvtripy/frontend/ops/reduce/any.py b/tripy/nvtripy/frontend/ops/reduce/any.py index c9f4af47b..ddbff8f59 100644 --- a/tripy/nvtripy/frontend/ops/reduce/any.py +++ b/tripy/nvtripy/frontend/ops/reduce/any.py @@ -15,14 +15,15 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn from nvtripy.frontend import wrappers -from nvtripy.common import datatype @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["bool"]}, + input_requirements=GetInput("input").dtype == dt.bool, + output_guarantees=GetReturn(0).dtype == dt.bool, ) def any( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False @@ -50,5 +51,19 @@ def any( """ from nvtripy.frontend.ops.reduce.sum import sum from nvtripy.frontend.ops.cast import cast + from nvtripy.common.exception import raise_error - return cast(sum(cast(input, dtype=datatype.int32), dim, keepdim), dtype=datatype.bool) + # Validate that input is bool - constraint system has already checked this + # but we need to enforce it at runtime when validation is disabled + if input.dtype != dt.bool: + raise_error( + f"Input must have bool dtype for any(), but got {input.dtype}.", + [ + "This function only accepts bool tensors. ", + "Note: If you need to check if any elements are non-zero, first compare with zero: ", + "tp.any(input != 0)", + ], + ) + + # Cast to int32 since sum doesn't accept bool, then cast back to bool + return cast(sum(cast(input, dtype=dt.int32), dim, keepdim), dtype=dt.bool) diff --git a/tripy/nvtripy/frontend/ops/reduce/argmax.py b/tripy/nvtripy/frontend/ops/reduce/argmax.py index 7d5ee84ab..d2adcdf85 100644 --- a/tripy/nvtripy/frontend/ops/reduce/argmax.py +++ b/tripy/nvtripy/frontend/ops/reduce/argmax.py @@ -15,6 +15,8 @@ from typing import Optional from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import arg_min_max_impl from nvtripy.trace.ops.topk import TopKMax from nvtripy.frontend import wrappers @@ -22,8 +24,8 @@ @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32", "int64"], "T2": ["int32"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64]), + output_guarantees=GetReturn(0).dtype == dt.int32, ) def argmax(input: "nvtripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/reduce/argmin.py b/tripy/nvtripy/frontend/ops/reduce/argmin.py index bc09187a5..84c9c175c 100644 --- a/tripy/nvtripy/frontend/ops/reduce/argmin.py +++ b/tripy/nvtripy/frontend/ops/reduce/argmin.py @@ -15,6 +15,8 @@ from typing import Optional from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import arg_min_max_impl from nvtripy.trace.ops.topk import TopKMin from nvtripy.frontend import wrappers @@ -22,8 +24,8 @@ @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32", "int64"], "T2": ["int32"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64]), + output_guarantees=GetReturn(0).dtype == dt.int32, ) def argmin(input: "nvtripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/reduce/max.py b/tripy/nvtripy/frontend/ops/reduce/max.py index b6b569b2d..dc2e68052 100644 --- a/tripy/nvtripy/frontend/ops/reduce/max.py +++ b/tripy/nvtripy/frontend/ops/reduce/max.py @@ -15,6 +15,8 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Max from nvtripy.frontend import wrappers @@ -22,8 +24,8 @@ @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.int32, dt.int64, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def max( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False diff --git a/tripy/nvtripy/frontend/ops/reduce/mean.py b/tripy/nvtripy/frontend/ops/reduce/mean.py index e9dd4d5a1..fdca8a967 100644 --- a/tripy/nvtripy/frontend/ops/reduce/mean.py +++ b/tripy/nvtripy/frontend/ops/reduce/mean.py @@ -15,6 +15,8 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Avg from nvtripy.frontend import wrappers @@ -22,8 +24,8 @@ @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.int32, dt.int64, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def mean( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False diff --git a/tripy/nvtripy/frontend/ops/reduce/min.py b/tripy/nvtripy/frontend/ops/reduce/min.py index b08c5a494..8a43c9c18 100644 --- a/tripy/nvtripy/frontend/ops/reduce/min.py +++ b/tripy/nvtripy/frontend/ops/reduce/min.py @@ -15,6 +15,8 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Min from nvtripy.frontend import wrappers @@ -22,8 +24,8 @@ @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.int32, dt.int64, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def min( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False diff --git a/tripy/nvtripy/frontend/ops/reduce/prod.py b/tripy/nvtripy/frontend/ops/reduce/prod.py index 016bf20e8..d70d9ce06 100644 --- a/tripy/nvtripy/frontend/ops/reduce/prod.py +++ b/tripy/nvtripy/frontend/ops/reduce/prod.py @@ -15,6 +15,8 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Prod from nvtripy.frontend import wrappers @@ -22,8 +24,8 @@ @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.int32, dt.int64, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def prod( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False diff --git a/tripy/nvtripy/frontend/ops/reduce/sum.py b/tripy/nvtripy/frontend/ops/reduce/sum.py index 9e85e4197..efaa47ae7 100644 --- a/tripy/nvtripy/frontend/ops/reduce/sum.py +++ b/tripy/nvtripy/frontend/ops/reduce/sum.py @@ -15,6 +15,8 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Sum from nvtripy.frontend import wrappers @@ -22,8 +24,8 @@ @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.int32, dt.int64, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def sum( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False diff --git a/tripy/nvtripy/frontend/ops/reduce/topk.py b/tripy/nvtripy/frontend/ops/reduce/topk.py index e5713c82a..511e7f156 100644 --- a/tripy/nvtripy/frontend/ops/reduce/topk.py +++ b/tripy/nvtripy/frontend/ops/reduce/topk.py @@ -15,19 +15,17 @@ from typing import Tuple from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import topk_impl from nvtripy.trace.ops.topk import TopKMax from nvtripy.frontend import wrappers -# constraints = OneOf(GetInput("input").dtype, [tp.float32, tp.float16, tp.bfloat16, tp.int32, tp.int64]) -# output_guarantees = (GetReturn(0).dtype == GetInput("input").dtype) & (GetReturn(1).dtype == tp.int32)) - - @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: ["T1", "T2"]}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32", "int64"], "T2": ["int32"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64]), + output_guarantees=(GetReturn(0).dtype == GetInput("input").dtype) & (GetReturn(1).dtype == dt.int32), ) def topk(input: "nvtripy.Tensor", k: int, dim: int) -> Tuple["nvtripy.Tensor", "nvtripy.Tensor"]: """ diff --git a/tripy/nvtripy/frontend/ops/reduce/var.py b/tripy/nvtripy/frontend/ops/reduce/var.py index 585600034..cf17b3d9d 100644 --- a/tripy/nvtripy/frontend/ops/reduce/var.py +++ b/tripy/nvtripy/frontend/ops/reduce/var.py @@ -16,14 +16,16 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def var( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False, correction: int = 1 From 5fb7414f25ae7ee1eb72c4ba9dbfbe7ec67a0f81 Mon Sep 17 00:00:00 2001 From: pranavm Date: Mon, 22 Dec 2025 19:15:28 +0000 Subject: [PATCH 23/32] Migrate initializer operators to new constraint system - Migrated zeros(), zeros_like(), full(), full_like(), iota(), iota_like(), arange() (2 overloads) - Applied input_requirements and output_guarantees patterns - Pattern: Optional dtype uses If(GetInput('dtype') != None, ...) - Pattern: Non-tensor params accessed with GetInput('dtype') not .value - Pattern: Value dtype constraints exclude float8 and int8+bool combinations - Fixed constraint system: Modified Constraints.find() to accept skip_within parameter - Fixed auto-casting: _find_known_datatypes now skips If conditions to avoid treating conditional checks (e.g., If(GetInput('value').dtype == dt.int8, ...)) as type requirements - Documented limitation: Auto-casting won't work for conditionally-dependent datatypes --- tripy/nvtripy/frontend/constraints/base.py | 15 ++++++- tripy/nvtripy/frontend/ops/arange.py | 25 ++++++------ tripy/nvtripy/frontend/ops/full.py | 46 ++++++++++++++++------ tripy/nvtripy/frontend/ops/iota.py | 39 ++++++++++++------ tripy/nvtripy/frontend/ops/zeros.py | 32 +++++++++------ tripy/nvtripy/frontend/wrappers.py | 22 ++++++++++- 6 files changed, 126 insertions(+), 53 deletions(-) diff --git a/tripy/nvtripy/frontend/constraints/base.py b/tripy/nvtripy/frontend/constraints/base.py index 360c8fbea..c254ea177 100644 --- a/tripy/nvtripy/frontend/constraints/base.py +++ b/tripy/nvtripy/frontend/constraints/base.py @@ -73,7 +73,7 @@ def get_children(self) -> List["Constraints"]: children.extend(v for v in attr_value if isinstance(v, Constraints)) return children - def find(self, pattern: "Constraints") -> List["Constraints"]: + def find(self, pattern: "Constraints", skip_within: Optional[type] = None) -> List["Constraints"]: """ Find all constraints in the tree that match the given pattern. @@ -84,6 +84,10 @@ def find(self, pattern: "Constraints") -> List["Constraints"]: Args: pattern: The pattern to search for (e.g., Equal(GetDataType, GetDataType)). Use None as a wildcard to match anything. + skip_within: Optional constraint type. If provided, the search will not recurse + into the children of constraints of this type. This allows skipping + specific nested structures (e.g., skip_within=If to avoid searching + inside If condition branches). Returns: A list of all matching constraints found in the tree. @@ -91,6 +95,9 @@ def find(self, pattern: "Constraints") -> List["Constraints"]: Example: pattern = Equal(GetDataType(TensorFetcher), None) # None matches any second argument matches = constraint_tree.find(pattern) + + # Skip searching inside If constraints: + matches = constraint_tree.find(pattern, skip_within=If) """ def matches_pattern(pattern: Constraints, constraint: Constraints) -> bool: @@ -133,9 +140,13 @@ def matches_pattern(pattern: Constraints, constraint: Constraints) -> bool: if matches_pattern(pattern, self): matches.append(self) + # Skip recursing into children if this constraint matches skip_within type + if skip_within is not None and isinstance(self, skip_within): + return matches + # Recursively search children for child in self.get_children(): - matches.extend(child.find(pattern)) + matches.extend(child.find(pattern, skip_within=skip_within)) return matches diff --git a/tripy/nvtripy/frontend/ops/arange.py b/tripy/nvtripy/frontend/ops/arange.py index 6aa07ca8f..24317d557 100644 --- a/tripy/nvtripy/frontend/ops/arange.py +++ b/tripy/nvtripy/frontend/ops/arange.py @@ -16,8 +16,9 @@ from typing import Union from nvtripy import export -from nvtripy.common import datatype +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops.cast import cast from nvtripy.frontend.ops.reshape import reshape @@ -27,16 +28,17 @@ @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("dtype"), + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), ) def arange( start: Union[numbers.Number, "nvtripy.DimensionSize"], stop: Union[numbers.Number, "nvtripy.DimensionSize"], step: Union[numbers.Number, "nvtripy.DimensionSize"] = 1, - dtype: "nvtripy.dtype" = datatype.float32, + dtype: "nvtripy.dtype" = dt.float32, ) -> "nvtripy.Tensor": r""" Returns a 1D tensor containing a sequence of numbers in the half-open interval @@ -96,13 +98,14 @@ def arange( @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("dtype"), + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), ) def arange( - stop: Union[numbers.Number, "nvtripy.DimensionSize"], dtype: "nvtripy.dtype" = datatype.float32 + stop: Union[numbers.Number, "nvtripy.DimensionSize"], dtype: "nvtripy.dtype" = dt.float32 ) -> "nvtripy.Tensor": r""" Returns a 1D tensor containing a sequence of numbers in the half-open interval diff --git a/tripy/nvtripy/frontend/ops/full.py b/tripy/nvtripy/frontend/ops/full.py index e8882250a..6eb409952 100644 --- a/tripy/nvtripy/frontend/ops/full.py +++ b/tripy/nvtripy/frontend/ops/full.py @@ -18,7 +18,8 @@ from typing import Optional from nvtripy import export, utils -from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf, If from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.broadcast import Broadcast from nvtripy.types import ShapeLike, TensorLike @@ -27,13 +28,16 @@ @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=(GetInput("value").dtype != dt.float8) + & If( + GetInput("value").dtype == dt.int8, + GetInput("dtype") != dt.bool, + ) + & OneOf(GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool]), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), convert_to_tensors=True, ) -def full(shape: ShapeLike, value: TensorLike, dtype: "nvtripy.dtype" = datatype.float32) -> "nvtripy.Tensor": +def full(shape: ShapeLike, value: TensorLike, dtype: "nvtripy.dtype" = dt.float32) -> "nvtripy.Tensor": """ Returns a tensor of the desired shape with all values set to the specified value. @@ -55,9 +59,9 @@ def full(shape: ShapeLike, value: TensorLike, dtype: "nvtripy.dtype" = datatype. from nvtripy.frontend.ops.cast import cast value_dtype = dtype - if dtype == datatype.int8: + if dtype == dt.int8: # TODO (#580): Remove this workaround for broadcasting INT8 inputs: - value_dtype = datatype.int32 + value_dtype = dt.int32 # We avoid using the `expand` API since it does extra things that we don't need. out = op_utils.create_op(Broadcast, [cast(value, dtype=value_dtype), shape]) @@ -67,11 +71,27 @@ def full(shape: ShapeLike, value: TensorLike, dtype: "nvtripy.dtype" = datatype. @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("value").dtype != dt.float8) + & If( + GetInput("value").dtype == dt.int8, + If( + GetInput("dtype") != None, + GetInput("dtype") != dt.bool, + GetInput("input").dtype != dt.bool, + ), + ) + & If( + GetInput("dtype") != None, + OneOf(GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool]), + ), + output_guarantees=If( + GetInput("dtype") != None, + GetReturn(0).dtype == GetInput("dtype"), + GetReturn(0).dtype == GetInput("input").dtype, + ), ) def full_like(input: "nvtripy.Tensor", value: TensorLike, dtype: Optional["nvtripy.dtype"] = None) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/iota.py b/tripy/nvtripy/frontend/ops/iota.py index 6df4d2668..8b03b77c7 100644 --- a/tripy/nvtripy/frontend/ops/iota.py +++ b/tripy/nvtripy/frontend/ops/iota.py @@ -18,14 +18,15 @@ from typing import Optional from nvtripy import export, utils -from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf, If from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.linspace import Linspace from nvtripy.types import ShapeLike from nvtripy.frontend import wrappers -def iota_impl(shape: "nvtripy.Tensor", dim: int, dtype: datatype.dtype) -> "nvtripy.Tensor": +def iota_impl(shape: "nvtripy.Tensor", dim: int, dtype: dt.dtype) -> "nvtripy.Tensor": from nvtripy.frontend.ops.cast import cast from nvtripy.frontend.tensor import Tensor @@ -42,13 +43,14 @@ def iota_impl(shape: "nvtripy.Tensor", dim: int, dtype: datatype.dtype) -> "nvtr @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("dtype"), + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), convert_to_tensors=True, ) -def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "nvtripy.Tensor": +def iota(shape: ShapeLike, dim: int = 0, dtype: dt.dtype = dt.float32) -> "nvtripy.Tensor": """ Fills an output tensor with consecutive values starting from zero along the given dimension. @@ -75,13 +77,24 @@ def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float3 @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ) + & If( + GetInput("dtype") != None, + OneOf( + GetInput("dtype"), + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + ), + output_guarantees=If( + GetInput("dtype") != None, + GetReturn(0).dtype == GetInput("dtype"), + GetReturn(0).dtype == GetInput("input").dtype, + ), ) -def iota_like(input: "nvtripy.Tensor", dim: int = 0, dtype: Optional[datatype.dtype] = None) -> "nvtripy.Tensor": +def iota_like(input: "nvtripy.Tensor", dim: int = 0, dtype: Optional[dt.dtype] = None) -> "nvtripy.Tensor": """ Returns a tensor of the same shape and data type as the input tensor, with consecutive values starting from zero along the given dimension. diff --git a/tripy/nvtripy/frontend/ops/zeros.py b/tripy/nvtripy/frontend/ops/zeros.py index 4f3028a21..72de6341b 100644 --- a/tripy/nvtripy/frontend/ops/zeros.py +++ b/tripy/nvtripy/frontend/ops/zeros.py @@ -15,21 +15,22 @@ from typing import Optional from nvtripy import export -from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf, If from nvtripy.frontend.ops.full import full, full_like from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), ) def zeros( shape: "nvtripy.types.ShapeLike", - dtype: datatype.dtype = datatype.float32, + dtype: dt.dtype = dt.float32, ) -> "nvtripy.Tensor": """ Creates a Tensor of the specified shape and dtype with all elements set to 0. @@ -55,13 +56,20 @@ def zeros( @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & If( + GetInput("dtype") != None, + OneOf(GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool]), + ), + output_guarantees=If( + GetInput("dtype") != None, + GetReturn(0).dtype == GetInput("dtype"), + GetReturn(0).dtype == GetInput("input").dtype, + ), ) -def zeros_like(input: "nvtripy.Tensor", dtype: Optional[datatype.dtype] = None) -> "nvtripy.Tensor": +def zeros_like(input: "nvtripy.Tensor", dtype: Optional[dt.dtype] = None) -> "nvtripy.Tensor": """ Creates a Tensor with all elements set to 0 of the same shape as the input tensor. diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index 208507fb0..89138ee52 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -152,6 +152,19 @@ def get_arg_dtype(arg, func_name, arg_name) -> result.Result["nvtripy.dtype"]: def _find_known_datatypes( merged_args: List[Tuple[str, Any]], input_requirements: Constraints ) -> Dict[str, "nvtripy.dtype"]: + """ + Identify known datatypes from input requirements to enable automatic type casting. + + This function searches for Equal constraints in the input requirements to determine + which arguments should have matching datatypes. It skips Equal constraints that appear + inside If statement conditions, as those represent conditional checks rather than + type requirements. + + Limitation: Automatic type casting will not work for arguments whose datatypes are + conditionally dependent on other values (i.e., when the datatype requirement appears + only in the then_branch or else_branch of an If constraint). + """ + from nvtripy.frontend.constraints.logic import If # We perform this operation in two steps: # 1. Identify all arguments that are expected to have equal data types. @@ -200,8 +213,13 @@ def process_dtype_equality(matched_constraints, input_is_lhs): # expected_equal_dtype to merge disjoint sets - if we have a transitive equality like: # `a == c and b == d and b == a`, then `a`, `c`, `b` will be immediately added to the same set, which `d` will # join when we process `b`. - process_dtype_equality(input_requirements.find(Equal(GetDataType(GetInput(name)), None)), input_is_lhs=True) - process_dtype_equality(input_requirements.find(Equal(None, GetDataType(GetInput(name)))), input_is_lhs=False) + # Skip searching inside If constraints to avoid treating conditional checks as type requirements: + process_dtype_equality( + input_requirements.find(Equal(GetDataType(GetInput(name)), None), skip_within=If), input_is_lhs=True + ) + process_dtype_equality( + input_requirements.find(Equal(None, GetDataType(GetInput(name))), skip_within=If), input_is_lhs=False + ) # We do not need to perform validation, as that will happen during constraints checking. for dtype_set in expected_equal_dtype: From 13599c0bbe5b5182cadb915b32e8411aafcf321d Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 24 Dec 2025 20:51:03 +0000 Subject: [PATCH 24/32] Step 6: Migrate type-preserving operators to new constraint system Migrated 11 operators that preserve input dtype in outputs: - matmul: Combined constraints for both inputs using & operator - cumsum: Float types only - copy: All dtypes supported - pad: Specific dtype subset - softmax: Float types only - tril/triu: Excluded bool (where() backend limitation - legacy was incorrect) - avgpool/maxpool: Pooling-specific dtypes - resize: Both overloads with int8/float support Key findings: - Legacy dtype_constraints claimed tril/triu supported bool, but this was never enforced - tril/triu use where() internally, which doesn't support bool - New constraint system is enforced and must accurately reflect backend capabilities - full_like constraints correctly reject dtype=float8 (broadcast doesn't support it) Test results: 136 passed, 26 skipped for type-preserving operators All full/full_like tests passing: 576 passed, 234 skipped --- tripy/nvtripy/frontend/ops/copy.py | 6 ++++-- tripy/nvtripy/frontend/ops/cumsum.py | 10 +++++++--- tripy/nvtripy/frontend/ops/matmul.py | 9 +++++++-- tripy/nvtripy/frontend/ops/pad.py | 8 ++++++-- tripy/nvtripy/frontend/ops/pooling/avgpool.py | 8 ++++++-- tripy/nvtripy/frontend/ops/pooling/maxpool.py | 8 ++++++-- tripy/nvtripy/frontend/ops/resize.py | 12 ++++++++---- tripy/nvtripy/frontend/ops/softmax.py | 10 ++++++---- tripy/nvtripy/frontend/ops/tril.py | 12 ++++++++---- tripy/nvtripy/frontend/ops/triu.py | 12 ++++++++---- 10 files changed, 66 insertions(+), 29 deletions(-) diff --git a/tripy/nvtripy/frontend/ops/copy.py b/tripy/nvtripy/frontend/ops/copy.py index 488cde55a..1dcde27b8 100644 --- a/tripy/nvtripy/frontend/ops/copy.py +++ b/tripy/nvtripy/frontend/ops/copy.py @@ -25,11 +25,13 @@ from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn + + @register_tensor_method("copy") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": list(DATA_TYPES.keys())}, + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def copy(input: "nvtripy.Tensor", device: tp_device) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/cumsum.py b/tripy/nvtripy/frontend/ops/cumsum.py index 711d717ae..9a31be388 100644 --- a/tripy/nvtripy/frontend/ops/cumsum.py +++ b/tripy/nvtripy/frontend/ops/cumsum.py @@ -17,12 +17,16 @@ from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) -def cumsum(input: "nvtripy.Tensor", dim: int) -> "nvtripy.Tensor": +def cumsum(input: "nvtripy.Tensor", dim: int, exclusive: bool = False) -> "nvtripy.Tensor": """ Computes the cumulative sum of elements in the input along the dimension ``dim``. diff --git a/tripy/nvtripy/frontend/ops/matmul.py b/tripy/nvtripy/frontend/ops/matmul.py index 43210a79a..2532f9ac9 100644 --- a/tripy/nvtripy/frontend/ops/matmul.py +++ b/tripy/nvtripy/frontend/ops/matmul.py @@ -22,10 +22,15 @@ from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + @register_tensor_method("__matmul__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, ) def __matmul__(self: "nvtripy.Tensor", other: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/pad.py b/tripy/nvtripy/frontend/ops/pad.py index 42512cf49..0d544118e 100644 --- a/tripy/nvtripy/frontend/ops/pad.py +++ b/tripy/nvtripy/frontend/ops/pad.py @@ -26,10 +26,14 @@ from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bool", "int32", "int64"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bool, dt.int32, dt.int64]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def pad( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/ops/pooling/avgpool.py b/tripy/nvtripy/frontend/ops/pooling/avgpool.py index 68038cbd5..1ec6db284 100644 --- a/tripy/nvtripy/frontend/ops/pooling/avgpool.py +++ b/tripy/nvtripy/frontend/ops/pooling/avgpool.py @@ -26,10 +26,14 @@ from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "bfloat16", "float16", "int8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.bfloat16, dt.float16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def avgpool( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/ops/pooling/maxpool.py b/tripy/nvtripy/frontend/ops/pooling/maxpool.py index dd177e780..d9230be33 100644 --- a/tripy/nvtripy/frontend/ops/pooling/maxpool.py +++ b/tripy/nvtripy/frontend/ops/pooling/maxpool.py @@ -24,10 +24,14 @@ from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "bfloat16", "float16", "int8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.bfloat16, dt.float16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def maxpool( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/ops/resize.py b/tripy/nvtripy/frontend/ops/resize.py index 31fba7ad3..86905d86c 100644 --- a/tripy/nvtripy/frontend/ops/resize.py +++ b/tripy/nvtripy/frontend/ops/resize.py @@ -48,10 +48,14 @@ def _create_resize(mode, inputs, scales, align_corners): return op_utils.create_op(ResizeCubic, inputs, scales=scales, align_corners=align_corners) +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "int8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, convert_to_tensors=True, ) def resize( @@ -108,8 +112,8 @@ def resize( @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "int8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def resize( input: "nvtripy.Tensor", scales: Sequence[numbers.Number], mode: str = "linear", align_corners: bool = False diff --git a/tripy/nvtripy/frontend/ops/softmax.py b/tripy/nvtripy/frontend/ops/softmax.py index 002332319..afaa836a5 100644 --- a/tripy/nvtripy/frontend/ops/softmax.py +++ b/tripy/nvtripy/frontend/ops/softmax.py @@ -23,12 +23,14 @@ from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16"], - }, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def softmax(input: "nvtripy.Tensor", dim: Optional[int] = None) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/tril.py b/tripy/nvtripy/frontend/ops/tril.py index a90cc8bb4..bfb5a95b8 100644 --- a/tripy/nvtripy/frontend/ops/tril.py +++ b/tripy/nvtripy/frontend/ops/tril.py @@ -21,12 +21,16 @@ from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"tensor": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("tensor").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64] + ), + output_guarantees=GetReturn(0).dtype == GetInput("tensor").dtype, ) def tril(tensor: "nvtripy.Tensor", diagonal: int = 0) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/triu.py b/tripy/nvtripy/frontend/ops/triu.py index 95a3bc02f..69f4ca649 100644 --- a/tripy/nvtripy/frontend/ops/triu.py +++ b/tripy/nvtripy/frontend/ops/triu.py @@ -21,12 +21,16 @@ from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"tensor": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("tensor").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64] + ), + output_guarantees=GetReturn(0).dtype == GetInput("tensor").dtype, ) def triu(tensor: "nvtripy.Tensor", diagonal: int = 0) -> "nvtripy.Tensor": r""" From 575eb1a9e9ca74b32d268a46eaf8f05e3b1bda5d Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 24 Dec 2025 20:59:13 +0000 Subject: [PATCH 25/32] Step 7: Migrate operators with complex dtype relationships MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Migrated 7 operators with specialized constraint patterns: - gather: input/output preserve dtype, separate indices constraint (int32/int64) - where: condition=bool, input/other must match, output preserves input dtype - masked_fill: mask=bool, input/output preserve dtype (excluding bool - where() limitation) - quantize: float inputs → quantized outputs (int4/int8/float8), scale matches input - dequantize: quantized inputs → float outputs, scale/dtype must match - equal: inputs match, returns Python bool (no output_guarantees) - allclose: float inputs match, returns Python bool (no output_guarantees) Key changes: - Fixed masked_fill to exclude bool (uses where() which doesn't support bool) - Updated test framework to handle None output_guarantees for non-Tensor returns - Quantize/dequantize have coordinated input/output dtype constraints Test results: 1792 passed, 719 skipped --- tripy/nvtripy/frontend/ops/allclose.py | 8 +++++++- tripy/nvtripy/frontend/ops/dequantize.py | 7 +++++++ tripy/nvtripy/frontend/ops/equal.py | 7 +++++++ tripy/nvtripy/frontend/ops/gather.py | 9 +++++++++ tripy/nvtripy/frontend/ops/masked_fill.py | 9 +++++++++ tripy/nvtripy/frontend/ops/quantize.py | 7 +++++++ tripy/nvtripy/frontend/ops/where.py | 7 +++++++ tripy/tests/integration/test_operator_constraints.py | 2 +- 8 files changed, 54 insertions(+), 2 deletions(-) diff --git a/tripy/nvtripy/frontend/ops/allclose.py b/tripy/nvtripy/frontend/ops/allclose.py index 0b4d756e2..47edff941 100644 --- a/tripy/nvtripy/frontend/ops/allclose.py +++ b/tripy/nvtripy/frontend/ops/allclose.py @@ -18,10 +18,16 @@ from nvtripy import export from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, OneOf + @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", "other": "T1"}, dtype_variables={"T1": ["float32", "float16", "bfloat16"]} + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("other").dtype == GetInput("input").dtype), + dtype_constraints={"input": "T1", "other": "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, ) def allclose(input: "nvtripy.Tensor", other: "nvtripy.Tensor", rtol: float = 1e-05, atol: float = 1e-08) -> bool: r""" diff --git a/tripy/nvtripy/frontend/ops/dequantize.py b/tripy/nvtripy/frontend/ops/dequantize.py index cc0d6a39b..8e1a8c957 100644 --- a/tripy/nvtripy/frontend/ops/dequantize.py +++ b/tripy/nvtripy/frontend/ops/dequantize.py @@ -24,9 +24,16 @@ from nvtripy.trace.ops.dequantize import Dequantize from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @export.public_api(document_under="operations/quantization") @wrappers.interface( + input_requirements=OneOf(GetInput("input").dtype, [dt.int4, dt.int8, dt.float8]) + & OneOf(GetInput("scale").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("dtype") == GetInput("scale").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), dtype_constraints={"input": "T1", "scale": "T2", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, dtype_variables={"T1": ["int4", "int8", "float8"], "T2": ["float32", "float16", "bfloat16"]}, convert_to_tensors={"scale"}, diff --git a/tripy/nvtripy/frontend/ops/equal.py b/tripy/nvtripy/frontend/ops/equal.py index 93ac7256f..8d6b7bc9d 100644 --- a/tripy/nvtripy/frontend/ops/equal.py +++ b/tripy/nvtripy/frontend/ops/equal.py @@ -16,9 +16,16 @@ from nvtripy.common.datatype import DATA_TYPES from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, OneOf + @export.public_api(document_under="operations/functions") @wrappers.interface( + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("input").dtype), dtype_constraints={"input": "T1", "other": "T1"}, dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, ) diff --git a/tripy/nvtripy/frontend/ops/gather.py b/tripy/nvtripy/frontend/ops/gather.py index bced20763..4ca88ffbe 100644 --- a/tripy/nvtripy/frontend/ops/gather.py +++ b/tripy/nvtripy/frontend/ops/gather.py @@ -21,9 +21,18 @@ from nvtripy.trace.ops.gather import Gather from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @export.public_api(document_under="operations/functions") @wrappers.interface( + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ) + & OneOf(GetInput("index").dtype, [dt.int32, dt.int64]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, dtype_constraints={"input": "T1", "index": "T2", wrappers.RETURN_VALUE: "T1"}, dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], diff --git a/tripy/nvtripy/frontend/ops/masked_fill.py b/tripy/nvtripy/frontend/ops/masked_fill.py index 540198fc4..8fd02265c 100644 --- a/tripy/nvtripy/frontend/ops/masked_fill.py +++ b/tripy/nvtripy/frontend/ops/masked_fill.py @@ -17,9 +17,18 @@ from nvtripy import export from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @export.public_api(document_under="operations/functions") @wrappers.interface( + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64], + ) + & (GetInput("mask").dtype == dt.bool), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, dtype_constraints={"input": "T1", "mask": "T2", wrappers.RETURN_VALUE: "T1"}, dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], diff --git a/tripy/nvtripy/frontend/ops/quantize.py b/tripy/nvtripy/frontend/ops/quantize.py index 3d229789d..8125edd7e 100644 --- a/tripy/nvtripy/frontend/ops/quantize.py +++ b/tripy/nvtripy/frontend/ops/quantize.py @@ -24,9 +24,16 @@ from nvtripy.trace.ops.quantize import Quantize from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @export.public_api(document_under="operations/quantization") @wrappers.interface( + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("scale").dtype == GetInput("input").dtype) + & OneOf(GetInput("dtype"), [dt.int4, dt.int8, dt.float8]), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), dtype_constraints={"input": "T1", "scale": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, dtype_variables={"T1": ["float32", "float16", "bfloat16"], "T2": ["int4", "int8", "float8"]}, convert_to_tensors={"scale"}, diff --git a/tripy/nvtripy/frontend/ops/where.py b/tripy/nvtripy/frontend/ops/where.py index 277a61490..cc89c4b33 100644 --- a/tripy/nvtripy/frontend/ops/where.py +++ b/tripy/nvtripy/frontend/ops/where.py @@ -22,9 +22,16 @@ from nvtripy.types import TensorLike from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @export.public_api(document_under="operations/functions") @wrappers.interface( + input_requirements=(GetInput("condition").dtype == dt.bool) + & OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64]) + & (GetInput("other").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, dtype_constraints={"condition": "T2", "input": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"], diff --git a/tripy/tests/integration/test_operator_constraints.py b/tripy/tests/integration/test_operator_constraints.py index 271549766..32f7c07a9 100644 --- a/tripy/tests/integration/test_operator_constraints.py +++ b/tripy/tests/integration/test_operator_constraints.py @@ -240,7 +240,7 @@ def test_operator_constraints(case: OperatorConstraintCase): # certain operations. out._eval_for_internal_methods() - if is_valid: + if is_valid and op_constraint.output_guarantees is not None: output_result = op_constraint.output_guarantees(merged_args, tuple(outputs)) assert output_result, f"Output guarantees not met for {case.func.__name__}: " + " ".join( output_result.error_details From 5de83f5bde9b385478a27efb80537778d7a850dc Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 24 Dec 2025 21:27:08 +0000 Subject: [PATCH 26/32] Step 8: Migrate final remaining operators Migrated last 3 operators to new constraint system: - outer: Both vectors must match dtype (float types only) - shape: Accepts all dtypes, returns tuple of IntLike (no output_guarantees) - split: Preserves input dtype, supports all dtypes All 81 operators with dtype constraints have now been migrated: - Steps 1-5: 60 operators (unary, shape, binary, reduction, initializers) - Step 6: 11 type-preserving operators - Step 7: 7 operators with complex dtype relationships - Step 8: 3 final operators Migration complete! All operators now use the new enforced constraint system with input_requirements and output_guarantees alongside legacy dtype_constraints. Test results: 88 passed, 20 skipped for final operators --- tripy/nvtripy/frontend/ops/outer.py | 6 ++++++ tripy/nvtripy/frontend/ops/shape.py | 9 ++++++++- tripy/nvtripy/frontend/ops/split.py | 8 ++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tripy/nvtripy/frontend/ops/outer.py b/tripy/nvtripy/frontend/ops/outer.py index c25652023..183fb4d0b 100644 --- a/tripy/nvtripy/frontend/ops/outer.py +++ b/tripy/nvtripy/frontend/ops/outer.py @@ -18,9 +18,15 @@ from nvtripy import export from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @export.public_api(document_under="operations/functions") @wrappers.interface( + input_requirements=OneOf(GetInput("vec1").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("vec2").dtype == GetInput("vec1").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("vec1").dtype, dtype_constraints={"vec1": "T1", "vec2": "T1", wrappers.RETURN_VALUE: "T1"}, dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, ) diff --git a/tripy/nvtripy/frontend/ops/shape.py b/tripy/nvtripy/frontend/ops/shape.py index 08795206d..2cde654bb 100644 --- a/tripy/nvtripy/frontend/ops/shape.py +++ b/tripy/nvtripy/frontend/ops/shape.py @@ -27,10 +27,17 @@ from nvtripy.types import IntLike from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, OneOf + @register_tensor_method("shape") @property -@wrappers.interface(dtype_constraints={"self": "T1"}, dtype_variables={"T1": list(DATA_TYPES.keys())}) +@wrappers.interface( + input_requirements=OneOf(GetInput("self").dtype, list(DATA_TYPES.values())), + dtype_constraints={"self": "T1"}, + dtype_variables={"T1": list(DATA_TYPES.keys())}, +) def shape(self: "nvtripy.Tensor") -> Tuple[IntLike]: """ Represents the shape of the tensor. diff --git a/tripy/nvtripy/frontend/ops/split.py b/tripy/nvtripy/frontend/ops/split.py index 829f74b52..588c4007c 100644 --- a/tripy/nvtripy/frontend/ops/split.py +++ b/tripy/nvtripy/frontend/ops/split.py @@ -23,9 +23,17 @@ from nvtripy.types import IntLike from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @export.public_api(document_under="operations/functions") @wrappers.interface( + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], From dc40af07fca433baf0ca5c2e3feb404a2e88ec10 Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 9 Jan 2026 18:58:42 +0000 Subject: [PATCH 27/32] Complete migration to new constraint system and remove legacy dtype system - Remove deprecated dtype_constraints/dtype_variables/dtype_exceptions system - Remove enable_dtype_checking config option (replaced by enable_input_validation) - Remove DATA_TYPE_CONSTRAINTS registry and DataTypeConstraints dataclass - Remove legacy get_arg_dtype and type_var handling from wrappers.py - Migrate remaining operators to input_requirements/output_guarantees: - Module ops: conv, conv_transpose, instancenorm, layernorm - Frontend ops: allclose, dequantize, equal, gather, masked_fill, outer, quantize, shape, split, tril, triu, where - Enhance GetDataType fetcher to support Python scalar types (int, float, bool) - Fix Equal/NotEqual constraint logic for None comparisons - Improve dtype autocasting to prefer trusted sources (Tensors/dtypes) over inferred scalar literal types - Update all tests to use new error message format - Replace legacy test_datatype_constraints.py with integration tests - Update documentation examples to use new constraint syntax --- ...teOperatorsToNewConstraintSystem.prompt.md | 5 +- tripy/docs/README.md | 38 +- .../01-how-to-add-new-ops.md | 6 +- .../docs/pre0_user_guides/02-quantization.md | 4 +- tripy/examples/nanogpt/README.md | 2 +- tripy/nvtripy/config.py | 9 +- tripy/nvtripy/frontend/constraints/fetcher.py | 23 +- tripy/nvtripy/frontend/constraints/logic.py | 20 +- tripy/nvtripy/frontend/module/conv/conv.py | 10 +- .../frontend/module/conv/conv_transpose.py | 10 +- tripy/nvtripy/frontend/module/instancenorm.py | 10 +- tripy/nvtripy/frontend/module/layernorm.py | 10 +- tripy/nvtripy/frontend/ops/allclose.py | 4 +- tripy/nvtripy/frontend/ops/dequantize.py | 4 +- tripy/nvtripy/frontend/ops/equal.py | 4 +- tripy/nvtripy/frontend/ops/gather.py | 7 +- tripy/nvtripy/frontend/ops/masked_fill.py | 9 +- tripy/nvtripy/frontend/ops/outer.py | 4 +- tripy/nvtripy/frontend/ops/quantize.py | 4 +- tripy/nvtripy/frontend/ops/shape.py | 6 +- tripy/nvtripy/frontend/ops/split.py | 6 +- tripy/nvtripy/frontend/ops/tril.py | 4 +- tripy/nvtripy/frontend/ops/triu.py | 4 +- tripy/nvtripy/frontend/ops/where.py | 7 +- tripy/nvtripy/frontend/wrappers.py | 331 +++++------------- .../frontend/constraints/test_fetcher.py | 16 +- .../tests/frontend/constraints/test_logic.py | 14 +- tripy/tests/frontend/module/test_conv.py | 4 +- tripy/tests/frontend/module/test_embedding.py | 4 +- tripy/tests/frontend/ops/test_binary.py | 4 +- tripy/tests/frontend/ops/test_dequantize.py | 6 +- tripy/tests/frontend/ops/test_equal.py | 4 +- tripy/tests/frontend/ops/test_matmul.py | 4 +- tripy/tests/frontend/ops/test_quantize.py | 6 +- tripy/tests/frontend/ops/test_where.py | 8 +- .../wrappers/test_datatype_constraints.py | 232 +----------- .../tests/frontend/wrappers/test_wrappers.py | 36 +- .../integration/test_operator_constraints.py | 8 +- 38 files changed, 261 insertions(+), 626 deletions(-) diff --git a/tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md b/tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md index dc206ebe8..321378556 100644 --- a/tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md +++ b/tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md @@ -1,6 +1,6 @@ ## Plan: Migrate Operators to New Constraint System -This plan will systematically port 86 operators across 66 files from the legacy `dtype_constraints`/`dtype_variables`/`dtype_exceptions` system to the new `input_requirements`/`output_guarantees` constraint system. The migration follows proven patterns from [`cast.py`](nvtripy/frontend/ops/cast.py), [`concatenate.py`](nvtripy/frontend/ops/concatenate.py), and [`ones.py`](nvtripy/frontend/ops/ones.py). +This plan will systematically port operators to use the `input_requirements`/`output_guarantees` constraint system. The migration follows proven patterns from [`cast.py`](nvtripy/frontend/ops/cast.py), [`concatenate.py`](nvtripy/frontend/ops/concatenate.py), and [`ones.py`](nvtripy/frontend/ops/ones.py). ### Steps @@ -18,5 +18,4 @@ This plan will systematically port 86 operators across 66 files from the legacy ### Implementation Notes -- Translate `dtype_exceptions` directly to boolean logic using `&`, `|`, `~` operators without creating helper functions -- No modifications to `CUSTOM_VALUES` in test file needed - it's already complete for dtype constraints +- Translate any existing exception logic directly to boolean logic using `&`, `|`, `~` operators without creating helper functions diff --git a/tripy/docs/README.md b/tripy/docs/README.md index b383eda0b..5b886a30f 100644 --- a/tripy/docs/README.md +++ b/tripy/docs/README.md @@ -50,12 +50,18 @@ which specifies doc metadata for each API (e.g. location). **Example:** ```py +from nvtripy import export +from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int32", "int64", "bool", "int8"], - }, + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def relu(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" @@ -171,7 +177,7 @@ Code blocks in docstrings/guides are **preprocessed**: ## Documentation Philosophy: Write Less Documentation -> "I didn’t have time to write a short letter, so I wrote a long one instead." - Mark Twain +> "I didn't have time to write a short letter, so I wrote a long one instead." - Mark Twain How much documentation do you want to read? The answer is **none**! @@ -179,7 +185,7 @@ How much documentation do you want to read? The answer is **none**! This is not always possible; sometimes, we need to write docs. -- **Problem:** We don’t think enough about *what* *precisely* we want to convey. +- **Problem:** We don't think enough about *what* *precisely* we want to convey. - **Suggestions**: Write discoverable, concise, but complete documentation. @@ -191,12 +197,12 @@ This is not always possible; sometimes, we need to write docs. - Leverage the medium - pictures, charts, emojis, markup. We are not using printing presses! - - Forget the rules: use contractions, don’t spell out numbers, etc. + - Forget the rules: use contractions, don't spell out numbers, etc. - **Tip:** Write like we're paying for every syllable! - If it’s too wordy to say, it’s too wordy to write. + If it's too wordy to say, it's too wordy to write. -Writing *concisely* forces you to think about what’s **signal** and what’s **noise**. +Writing *concisely* forces you to think about what's **signal** and what's **noise**. Below are examples from previous versions of Tripy documentation that was improved. @@ -204,16 +210,16 @@ Below are examples from previous versions of Tripy documentation that was improv > One important point is that Tripy uses a lazy evaluation model; that is, no computation is performed until a value is actually needed. -* **Tip:** Ask: “What is this really saying?” +* **Tip:** Ask: "What is this really saying?" -> Tensors are evaluated only when they’re used. +> Tensors are evaluated only when they're used. ### Example 2 > ### **Eager Mode: How Does It Work?** > -> If you’ve used TensorRT before, you may know that it does not support an eager mode. +> If you've used TensorRT before, you may know that it does not support an eager mode. > In order to provide eager mode support in Tripy, we actually need to compile the graph under the hood. > > Although we employ several tricks to make compile times faster when using eager mode, we do still need to compile, @@ -225,14 +231,14 @@ Below are examples from previous versions of Tripy documentation that was improv **Ask**: -- **“What is the ONE most important takeaway?”** +- **"What is the ONE most important takeaway?"** *Eager mode is only for debugging.* -- **“What questions does this raise?” - Why?** - *Tripy always compiles since TensorRT doesn’t have eager mode.* +- **"What questions does this raise?" - Why?** + *Tripy always compiles since TensorRT doesn't have eager mode.* Make the **one key point** stand out so skimmers can spot it: > **Best Practice:** Use **eager mode** only for **debugging**; compile for deployment. > -> **Why:** Eager mode internally compiles the graph (slow!) since TensorRT doesn’t have eager execution. +> **Why:** Eager mode internally compiles the graph (slow!) since TensorRT doesn't have eager execution. diff --git a/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md b/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md index 8a16af3a6..5a1dc87ea 100644 --- a/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md +++ b/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md @@ -131,11 +131,13 @@ from nvtripy import export from nvtripy.trace.ops.topn import TopN from nvtripy.frontend import wrappers from nvtripy.frontend.ops import utils as op_utils +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: ["T1", "T2"]}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32", "int64"], "T2": ["int32"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64]), + output_guarantees=(GetReturn(0).dtype == GetInput("input").dtype) & (GetReturn(1).dtype == dt.int32), ) def topn(input: "nvtripy.Tensor", n: int, dim: int) -> Tuple["nvtripy.Tensor", "nvtripy.Tensor"]: # See docs/README.md for more information on how to write docstrings diff --git a/tripy/docs/pre0_user_guides/02-quantization.md b/tripy/docs/pre0_user_guides/02-quantization.md index 63496c09b..e33ac29fb 100644 --- a/tripy/docs/pre0_user_guides/02-quantization.md +++ b/tripy/docs/pre0_user_guides/02-quantization.md @@ -17,7 +17,7 @@ explains quantization in more detail. ## Post-Training Quantization With ModelOpt If the model was not trained with quantization-aware training (QAT), we can use -[TensorRT ModelOpt](https://nvidia.github.io/TensorRT-Model-Optimizer/index.html) +[TensorRT ModelOpt](https://nvidia.github.io/Model-Optimizer/index.html) to do **calibration** to determine scaling factors. :::{admonition} Info @@ -82,7 +82,7 @@ Let's calibrate a GPT model: ``` 3. Run calibration to replace linear layers with - [`QuantLinear`](https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.nn.modules.quant_linear.html#modelopt.torch.quantization.nn.modules.quant_linear.QuantLinear), + [`QuantLinear`](https://nvidia.github.io/Model-Optimizer/reference/generated/modelopt.torch.quantization.nn.modules.quant_linear.html#modelopt.torch.quantization.nn.modules.quant_linear.QuantLinear), which contain calibration information: ```py diff --git a/tripy/examples/nanogpt/README.md b/tripy/examples/nanogpt/README.md index 8826a43d7..191497dae 100644 --- a/tripy/examples/nanogpt/README.md +++ b/tripy/examples/nanogpt/README.md @@ -40,7 +40,7 @@ This example implements a [NanoGPT model](https://github.com/karpathy/nanoGPT) u ### Running with Quantization [`quantization.py`](./quantization.py), uses -[NVIDIA TensorRT Model Optimizer](https://nvidia.github.io/TensorRT-Model-Optimizer/getting_started/1_overview.html) +[NVIDIA TensorRT Model Optimizer](https://nvidia.github.io/Model-Optimizer/getting_started/1_overview.html) to quantize the pytorch model. `load_quant_weights_from_hf` in [`weight_loader.py`](./weight_loader.py) converts the quantization diff --git a/tripy/nvtripy/config.py b/tripy/nvtripy/config.py index f274aa124..a6349306e 100644 --- a/tripy/nvtripy/config.py +++ b/tripy/nvtripy/config.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -45,13 +45,6 @@ )(os.path.join(tempfile.gettempdir(), "tripy-cache")) """Path to a timing cache file that can be used to speed up compilation time.""" -enable_dtype_checking: bool = export.public_api( - document_under="config.rst", - module=sys.modules[__name__], - symbol="enable_dtype_checking", -)(True) -"""[DEPRECATED - use enable_input_validation] Whether to enable data type checking in API functions.""" - enable_input_validation: bool = export.public_api( document_under="config.rst", module=sys.modules[__name__], diff --git a/tripy/nvtripy/frontend/constraints/fetcher.py b/tripy/nvtripy/frontend/constraints/fetcher.py index 713bd0649..884cf6e15 100644 --- a/tripy/nvtripy/frontend/constraints/fetcher.py +++ b/tripy/nvtripy/frontend/constraints/fetcher.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,8 +15,10 @@ # limitations under the License. # from abc import abstractmethod +import numbers from typing import Any, List, Optional, Sequence, Tuple +from nvtripy.common import datatype from nvtripy.common.datatype import dtype as tp_dtype from nvtripy.common.exception import raise_error from nvtripy.frontend.constraints.base import Constraints @@ -90,9 +92,18 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = from nvtripy.frontend.tensor import Tensor def get_arg_dtype(arg: Any) -> tp_dtype: - if isinstance(arg, Sequence): + if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)): arg_dtypes = [get_arg_dtype(elem) for elem in arg] + if len(arg_dtypes) == 0: + raise_error( + f"Could not determine data type of {self.value_fetcher}", + [ + "Empty sequence argument.\n", + f"For parameter: '{self.value_fetcher}', the sequence must contain at least one element.", + ], + ) + if len(set(arg_dtypes)) != 1: raise_error( f"Could not determine data type of {self.value_fetcher}", @@ -105,6 +116,14 @@ def get_arg_dtype(arg: Any) -> tp_dtype: arg_dtype = arg_dtypes[0] elif isinstance(arg, Tensor): arg_dtype = arg.dtype + elif isinstance(arg, tp_dtype): + arg_dtype = arg + elif isinstance(arg, bool): + arg_dtype = datatype.bool + elif isinstance(arg, numbers.Integral): + arg_dtype = datatype.int32 if datatype.INT32_MIN <= arg <= datatype.INT32_MAX else datatype.int64 + elif isinstance(arg, numbers.Real): + arg_dtype = datatype.float32 else: raise_error( f"Could not determine data type of {self.value_fetcher}", diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py index 8755a10d8..5664c8b6e 100644 --- a/tripy/nvtripy/frontend/constraints/logic.py +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -122,7 +122,12 @@ def __init__(self, fetcher: Fetcher, fetcher_or_value: Any): def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: value1 = self.fetcher(args, returns) value2 = get_val_or_call_fetcher(self.fetcher_or_value, args, returns) - if value1 == value2: + + # Avoid triggering overloaded equality implementations (e.g., on Tensor) when comparing to None. + if value1 is None or value2 is None: + if value1 is value2: + return Result.ok() + elif value1 == value2: return Result.ok() # TODO (pranavm): If fetcher_or_value is a Fetcher, include its value in the error message. @@ -132,7 +137,7 @@ def __str__(self): return f"{self.fetcher} == {self.fetcher_or_value}" def doc_str(self) -> str: - return f"{doc_str(self.fetcher)} is {doc_str(self.fetcher_or_value)}" + return f"{doc_str(self.fetcher)} == {doc_str(self.fetcher_or_value)}" def inverse(self) -> "Logic": return NotEqual(self.fetcher, self.fetcher_or_value) @@ -147,7 +152,12 @@ def __init__(self, fetcher: Fetcher, fetcher_or_value: Fetcher): def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: value1 = self.fetcher(args, returns) value2 = get_val_or_call_fetcher(self.fetcher_or_value, args, returns) - if value1 != value2: + + # Avoid triggering overloaded inequality implementations (e.g., on Tensor) when comparing to None. + if value1 is None or value2 is None: + if value1 is not value2: + return Result.ok() + elif value1 != value2: return Result.ok() return Result.err([f"'{self.fetcher}' to be not equal to '{self.fetcher_or_value}' (but it was '{value1}')"]) @@ -156,7 +166,7 @@ def __str__(self): return f"{self.fetcher} != {self.fetcher_or_value}" def doc_str(self) -> str: - return f"{doc_str(self.fetcher)} is not {doc_str(self.fetcher_or_value)}" + return f"{doc_str(self.fetcher)} != {doc_str(self.fetcher_or_value)}" def inverse(self) -> "Logic": return Equal(self.fetcher, self.fetcher_or_value) diff --git a/tripy/nvtripy/frontend/module/conv/conv.py b/tripy/nvtripy/frontend/module/conv/conv.py index 978cc8394..e7ce14209 100644 --- a/tripy/nvtripy/frontend/module/conv/conv.py +++ b/tripy/nvtripy/frontend/module/conv/conv.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,11 +28,15 @@ from nvtripy.trace.ops.convolution import Convolution from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, If, OneOf + # This function is added so that we can do dtype checking. @wrappers.interface( - dtype_constraints={"input": "T1", "weight": "T1", "bias": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [datatype.float32, datatype.float16, datatype.bfloat16]) + & (GetInput("weight").dtype == GetInput("input").dtype) + & If(GetInput("bias") != None, GetInput("bias").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def convolution( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/module/conv/conv_transpose.py b/tripy/nvtripy/frontend/module/conv/conv_transpose.py index 4becc1103..050cb3394 100644 --- a/tripy/nvtripy/frontend/module/conv/conv_transpose.py +++ b/tripy/nvtripy/frontend/module/conv/conv_transpose.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,12 +28,16 @@ from nvtripy.trace.ops.deconvolution import Deconvolution from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, If, OneOf + # This function is added so that we can do dtype checking. # TODO (#565): TRT supposedly supports BF16 in deconv, but actually trying to use it results in a bug. @wrappers.interface( - dtype_constraints={"input": "T1", "weight": "T1", "bias": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16"]}, + input_requirements=OneOf(GetInput("input").dtype, [datatype.float32, datatype.float16]) + & (GetInput("weight").dtype == GetInput("input").dtype) + & If(GetInput("bias") != None, GetInput("bias").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def deconvolution( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/module/instancenorm.py b/tripy/nvtripy/frontend/module/instancenorm.py index 5a9a41ee4..3e0c1a690 100644 --- a/tripy/nvtripy/frontend/module/instancenorm.py +++ b/tripy/nvtripy/frontend/module/instancenorm.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,10 +28,14 @@ from nvtripy.frontend.wrappers import constant_fields from nvtripy.trace.ops.instancenorm import InstanceNorm as InstanceNormOp +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @wrappers.interface( - dtype_constraints={"input": "T1", "weight": "T1", "bias": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [datatype.float32, datatype.float16, datatype.bfloat16]) + & (GetInput("weight").dtype == GetInput("input").dtype) + & (GetInput("bias").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def instancenorm( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/module/layernorm.py b/tripy/nvtripy/frontend/module/layernorm.py index 1128d41cf..90845eb6b 100644 --- a/tripy/nvtripy/frontend/module/layernorm.py +++ b/tripy/nvtripy/frontend/module/layernorm.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,10 +29,14 @@ from nvtripy.frontend.wrappers import constant_fields from nvtripy.trace.ops.layernorm import LayerNorm as LayerNormOp +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @wrappers.interface( - dtype_constraints={"input": "T1", "weight": "T1", "bias": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [datatype.float32, datatype.float16, datatype.bfloat16]) + & (GetInput("weight").dtype == GetInput("input").dtype) + & (GetInput("bias").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def layernorm( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/ops/allclose.py b/tripy/nvtripy/frontend/ops/allclose.py index 47edff941..f89ece4b3 100644 --- a/tripy/nvtripy/frontend/ops/allclose.py +++ b/tripy/nvtripy/frontend/ops/allclose.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,8 +26,6 @@ @wrappers.interface( input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]) & (GetInput("other").dtype == GetInput("input").dtype), - dtype_constraints={"input": "T1", "other": "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, ) def allclose(input: "nvtripy.Tensor", other: "nvtripy.Tensor", rtol: float = 1e-05, atol: float = 1e-08) -> bool: r""" diff --git a/tripy/nvtripy/frontend/ops/dequantize.py b/tripy/nvtripy/frontend/ops/dequantize.py index 8e1a8c957..2c1462bc8 100644 --- a/tripy/nvtripy/frontend/ops/dequantize.py +++ b/tripy/nvtripy/frontend/ops/dequantize.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,8 +34,6 @@ & OneOf(GetInput("scale").dtype, [dt.float32, dt.float16, dt.bfloat16]) & (GetInput("dtype") == GetInput("scale").dtype), output_guarantees=GetReturn(0).dtype == GetInput("dtype"), - dtype_constraints={"input": "T1", "scale": "T2", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={"T1": ["int4", "int8", "float8"], "T2": ["float32", "float16", "bfloat16"]}, convert_to_tensors={"scale"}, ) def dequantize( diff --git a/tripy/nvtripy/frontend/ops/equal.py b/tripy/nvtripy/frontend/ops/equal.py index 8d6b7bc9d..6feb3617a 100644 --- a/tripy/nvtripy/frontend/ops/equal.py +++ b/tripy/nvtripy/frontend/ops/equal.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,8 +26,6 @@ GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] ) & (GetInput("other").dtype == GetInput("input").dtype), - dtype_constraints={"input": "T1", "other": "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, ) def equal(input: "nvtripy.Tensor", other: "nvtripy.Tensor") -> bool: r""" diff --git a/tripy/nvtripy/frontend/ops/gather.py b/tripy/nvtripy/frontend/ops/gather.py index 4ca88ffbe..083504844 100644 --- a/tripy/nvtripy/frontend/ops/gather.py +++ b/tripy/nvtripy/frontend/ops/gather.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,11 +33,6 @@ ) & OneOf(GetInput("index").dtype, [dt.int32, dt.int64]), output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, - dtype_constraints={"input": "T1", "index": "T2", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], - "T2": ["int32", "int64"], - }, ) def gather(input: "nvtripy.Tensor", dim: int, index: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/masked_fill.py b/tripy/nvtripy/frontend/ops/masked_fill.py index 8fd02265c..13e436a42 100644 --- a/tripy/nvtripy/frontend/ops/masked_fill.py +++ b/tripy/nvtripy/frontend/ops/masked_fill.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,15 +25,10 @@ @wrappers.interface( input_requirements=OneOf( GetInput("input").dtype, - [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64], + [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], ) & (GetInput("mask").dtype == dt.bool), output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, - dtype_constraints={"input": "T1", "mask": "T2", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, ) def masked_fill(input: "nvtripy.Tensor", mask: "nvtripy.Tensor", value: numbers.Number) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/outer.py b/tripy/nvtripy/frontend/ops/outer.py index 183fb4d0b..a72333463 100644 --- a/tripy/nvtripy/frontend/ops/outer.py +++ b/tripy/nvtripy/frontend/ops/outer.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,8 +27,6 @@ input_requirements=OneOf(GetInput("vec1").dtype, [dt.float32, dt.float16, dt.bfloat16]) & (GetInput("vec2").dtype == GetInput("vec1").dtype), output_guarantees=GetReturn(0).dtype == GetInput("vec1").dtype, - dtype_constraints={"vec1": "T1", "vec2": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, ) def outer(vec1: "nvtripy.Tensor", vec2: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/quantize.py b/tripy/nvtripy/frontend/ops/quantize.py index 8125edd7e..7427b246e 100644 --- a/tripy/nvtripy/frontend/ops/quantize.py +++ b/tripy/nvtripy/frontend/ops/quantize.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,8 +34,6 @@ & (GetInput("scale").dtype == GetInput("input").dtype) & OneOf(GetInput("dtype"), [dt.int4, dt.int8, dt.float8]), output_guarantees=GetReturn(0).dtype == GetInput("dtype"), - dtype_constraints={"input": "T1", "scale": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"], "T2": ["int4", "int8", "float8"]}, convert_to_tensors={"scale"}, ) def quantize( diff --git a/tripy/nvtripy/frontend/ops/shape.py b/tripy/nvtripy/frontend/ops/shape.py index 2cde654bb..d662714c1 100644 --- a/tripy/nvtripy/frontend/ops/shape.py +++ b/tripy/nvtripy/frontend/ops/shape.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,8 +26,6 @@ from nvtripy.trace.ops.shape import GetDimensionSize, Shape from nvtripy.types import IntLike from nvtripy.frontend import wrappers - -from nvtripy.common import datatype as dt from nvtripy.frontend.constraints import GetInput, OneOf @@ -35,8 +33,6 @@ @property @wrappers.interface( input_requirements=OneOf(GetInput("self").dtype, list(DATA_TYPES.values())), - dtype_constraints={"self": "T1"}, - dtype_variables={"T1": list(DATA_TYPES.keys())}, ) def shape(self: "nvtripy.Tensor") -> Tuple[IntLike]: """ diff --git a/tripy/nvtripy/frontend/ops/split.py b/tripy/nvtripy/frontend/ops/split.py index 588c4007c..0fc36787f 100644 --- a/tripy/nvtripy/frontend/ops/split.py +++ b/tripy/nvtripy/frontend/ops/split.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,10 +34,6 @@ [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], ), output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], - }, ) def split( input: "nvtripy.Tensor", num_split_or_sizes: Union[int, Sequence[IntLike]], dim: int = 0 diff --git a/tripy/nvtripy/frontend/ops/tril.py b/tripy/nvtripy/frontend/ops/tril.py index bfb5a95b8..dea9e3e06 100644 --- a/tripy/nvtripy/frontend/ops/tril.py +++ b/tripy/nvtripy/frontend/ops/tril.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,7 +28,7 @@ @export.public_api(document_under="operations/initializers") @wrappers.interface( input_requirements=OneOf( - GetInput("tensor").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64] + GetInput("tensor").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] ), output_guarantees=GetReturn(0).dtype == GetInput("tensor").dtype, ) diff --git a/tripy/nvtripy/frontend/ops/triu.py b/tripy/nvtripy/frontend/ops/triu.py index 69f4ca649..9057caf22 100644 --- a/tripy/nvtripy/frontend/ops/triu.py +++ b/tripy/nvtripy/frontend/ops/triu.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,7 +28,7 @@ @export.public_api(document_under="operations/initializers") @wrappers.interface( input_requirements=OneOf( - GetInput("tensor").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64] + GetInput("tensor").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] ), output_guarantees=GetReturn(0).dtype == GetInput("tensor").dtype, ) diff --git a/tripy/nvtripy/frontend/ops/where.py b/tripy/nvtripy/frontend/ops/where.py index cc89c4b33..01023705a 100644 --- a/tripy/nvtripy/frontend/ops/where.py +++ b/tripy/nvtripy/frontend/ops/where.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,11 +32,6 @@ & OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64]) & (GetInput("other").dtype == GetInput("input").dtype), output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, - dtype_constraints={"condition": "T2", "input": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"], - "T2": ["bool"], - }, convert_to_tensors=True, ) def where(condition: "nvtripy.Tensor", input: TensorLike, other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index 89138ee52..b4d46643d 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,29 +23,16 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union from nvtripy import config, utils -from nvtripy.common.datatype import DATA_TYPES +from nvtripy.common.datatype import dtype as tp_dtype from nvtripy.common.exception import raise_error from nvtripy.frontend.constraints import Constraints, Equal, GetInput, GetDataType, Fetcher, doc_str -from nvtripy.utils import result - - -@dataclass -class DataTypeConstraints: - func: Callable - constraints: Dict[str, str] - variables: Dict[str, List[str]] - exceptions: List[Dict[str, str]] - - -DATA_TYPE_CONSTRAINTS = [] -RETURN_VALUE = "__RETURN_VALUE" @dataclass class OperatorConstraints: func: Callable - input_requirements: Constraints - output_guarantees: Constraints + input_requirements: Optional[Constraints] + output_guarantees: Optional[Constraints] # A list of tuples of (input_requirements, output_guarantees) for operators. @@ -116,42 +103,7 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name): source_info.column_range = candidates[0] -# TODO (pranavm): Remove this: -def get_arg_dtype(arg, func_name, arg_name) -> result.Result["nvtripy.dtype"]: - from nvtripy.common.datatype import dtype - from nvtripy.frontend.tensor import Tensor - - if isinstance(arg, Sequence): - arg_dtypes = [] - for elem in arg: - dtype_result = get_arg_dtype(elem, func_name, arg_name) - if not dtype_result: - return result.Result.err( - [f"Could not determine data type of elements in sequence: {arg_name}"] + dtype_result.error_details - ) - arg_dtypes.append(dtype_result.value) - - if len(set(arg_dtypes)) != 1: - return result.Result.err( - [ - f"Mismatched data types in sequence argument for '{func_name}'.\n", - f"For parameter: '{arg_name}', all arguments must have the same data type, but got: " - f"{arg_dtypes}", - ], - ) - arg_dtype = arg_dtypes[0] - elif isinstance(arg, Tensor): - arg_dtype = arg.dtype - elif isinstance(arg, dtype): - arg_dtype = arg - else: - return result.Result.err([f"Expected a tensor or data type argument for {arg_name}, but got: {arg}"]) - return result.Result.ok(arg_dtype) - - -def _find_known_datatypes( - merged_args: List[Tuple[str, Any]], input_requirements: Constraints -) -> Dict[str, "nvtripy.dtype"]: +def _find_known_datatypes(merged_args: List[Tuple[str, Any]], input_requirements: Constraints) -> Dict[str, tp_dtype]: """ Identify known datatypes from input requirements to enable automatic type casting. @@ -178,14 +130,32 @@ def insert_pair(name1, name2): return expected_equal_dtype.append({name1, name2}) - known_dtypes: Dict[str, Optional["nvtripy.dtype"]] = {} + known_dtypes: Dict[str, Optional[tp_dtype]] = {} + arg_map: Dict[str, Any] = {name: value for name, value in merged_args} + + def _is_trusted_dtype_source(value: Any) -> bool: + """Return True if `value` should be treated as a stable dtype source for autocasting. + + We intentionally treat Python scalar literals (int/float/bool) as *untrusted* sources + so that expressions like `tensor_f16 * 1.0` will cast the scalar to the tensor dtype + rather than forcing the tensor to match the scalar's default dtype. + """ + from nvtripy.frontend.tensor import Tensor + + if isinstance(value, Tensor) or isinstance(value, tp_dtype): + return True + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + return all(_is_trusted_dtype_source(v) for v in value) if len(value) > 0 else False + return False + for name, _ in merged_args: - # If this argument already has a known dtype, populate it: - try: - known_dtypes[name] = GetDataType(GetInput(name))(merged_args) - except Exception: - pass + # If this argument already has a known dtype and is a trusted source, populate it: + if _is_trusted_dtype_source(arg_map.get(name)): + try: + known_dtypes[name] = GetDataType(GetInput(name))(merged_args) + except Exception: + pass def process_dtype_equality(matched_constraints, input_is_lhs): for constraint in matched_constraints: @@ -223,15 +193,33 @@ def process_dtype_equality(matched_constraints, input_is_lhs): # We do not need to perform validation, as that will happen during constraints checking. for dtype_set in expected_equal_dtype: - known_dtype_in_set = None - for name in dtype_set: - if name in known_dtypes: - known_dtype_in_set = known_dtypes[name] - break - - # dtype might be unknown if the arguments are all non-tensor types. + # Prefer dtypes coming from trusted sources (tensors / explicit dtypes). + trusted_names_in_order = [ + n for (n, _) in merged_args if n in dtype_set and _is_trusted_dtype_source(arg_map.get(n)) + ] + candidate_dtypes = [known_dtypes.get(n) for n in trusted_names_in_order if known_dtypes.get(n) is not None] + + # If there are conflicting trusted dtypes, do not guess. + known_dtype_in_set: Optional[tp_dtype] + if len(set(candidate_dtypes)) > 1: + known_dtype_in_set = None + else: + known_dtype_in_set = candidate_dtypes[0] if candidate_dtypes else None + + # dtype might be unknown if the arguments are all non-tensor / untrusted types. + # In that case, we intentionally do *not* treat inferred scalar literal dtypes (e.g. 1.0 -> float32) + # as "known" to avoid unnecessary/incorrect casting behavior. for name in dtype_set: - known_dtypes[name] = known_dtype_in_set + if known_dtype_in_set is None: + known_dtypes[name] = None + continue + + # If this argument is an untrusted dtype source (e.g., Python scalar), prefer the + # trusted dtype chosen for the group even if we previously inferred a dtype for it. + if not _is_trusted_dtype_source(arg_map.get(name)): + known_dtypes[name] = known_dtype_in_set + else: + known_dtypes.setdefault(name, known_dtype_in_set) return known_dtypes @@ -245,7 +233,6 @@ def convert_input_types( var_arg_info, conversion_targets, conversion_preprocess_func, - dtype_constraints, shape_likes, input_requirements: Constraints, ): @@ -268,16 +255,9 @@ def convert_input_types( else: merged_args[index] = (name, new_args[name]) + known_datatypes: Dict[str, Optional[tp_dtype]] = {} if input_requirements is not None: known_datatypes = _find_known_datatypes(merged_args, input_requirements) - else: - # Materialize type variables from tensors. - type_vars = {} - for name, arg in merged_args: - if name in dtype_constraints: - dtype_result = get_arg_dtype(arg, func.__qualname__, name) - if dtype_result: - type_vars[dtype_constraints[name]] = dtype_result.value new_args = [] new_kwargs = {} @@ -312,9 +292,6 @@ def add_arg(arg): dtype = None if input_requirements is not None: dtype = known_datatypes.get(name) - elif name in dtype_constraints and dtype_constraints[name] in type_vars: - # TODO (pranavm): Remove this deprecated path. - dtype = type_vars[dtype_constraints[name]] if dtype is not None: # Refuse to do unsafe casts like float -> int. @@ -344,12 +321,19 @@ def _update_docstring(func, input_requirements, output_guarantees): if not func.__doc__: return + if input_requirements is None and output_guarantees is None: + return + indentation = " " * 4 code_block_index = func.__doc__.find(".. code-block:: python") assert code_block_index != -1, f"No code example in docstring for {func.__name__}" - input_requirements_str = f"\nINPUT REQUIREMENTS:\n{indent(doc_str(input_requirements), indentation)}\n" - output_guarantees_str = f"\nOUTPUT GUARANTEES:\n{indent(doc_str(output_guarantees), indentation)}\n" + input_requirements_str = ( + f"\nINPUT REQUIREMENTS:\n{indent(doc_str(input_requirements), indentation)}\n" if input_requirements else "" + ) + output_guarantees_str = ( + f"\nOUTPUT GUARANTEES:\n{indent(doc_str(output_guarantees), indentation)}\n" if output_guarantees else "" + ) func.__doc__ = ( func.__doc__[:code_block_index] @@ -360,76 +344,11 @@ def _update_docstring(func, input_requirements, output_guarantees): ) -# Modify the docstring to mention data type variables and exceptions -def _update_docstring_legacy(func, dtype_constraints, dtype_variables, dtype_exceptions): - if not func.__doc__: - return - - # Update the docstring to add data type variables after the parameter documentation. - args_index = func.__doc__.find("Args:") - # Args: may be omitted for functions with no inputs - args_index = args_index if args_index != -1 else 0 - for name, var in dtype_constraints.items(): - find_str = f"\n {name}: " if name != RETURN_VALUE else "\n Returns:\n " - - param_index = func.__doc__.find(find_str, args_index) - assert param_index != -1, f"Parameter: {name} is not present or was not documented in {func.__name__}" - func.__doc__ = ( - func.__doc__[:param_index] - + rf"{find_str}[dtype=\ **{var}**\ ] " - + func.__doc__[param_index + len(find_str) :] - ) - - prefix = " " * 8 - - def sorted_types(dtypes): - return sorted( - dtypes, - key=lambda dtype: ( - tuple(typ.__name__ for typ in DATA_TYPES[dtype].__bases__), - DATA_TYPES[dtype].itemsize, - ), - ) - - dtype_info = "DATA TYPE CONSTRAINTS:\n" - dtype_info += indent( - "\n".join( - [ - f"- **{var}**: {', '.join(map(lambda t: f':class:`{t}`', sorted_types(dtypes)))}" - for var, dtypes in dtype_variables.items() - ] - ), - prefix, - ) - - if dtype_exceptions: - dtype_info += "\n\n UNSUPPORTED DATA TYPE COMBINATIONS:\n" - esc_space = r"\ " - dtype_info += indent( - "\n".join( - [ - f"- {', '.join([f'**{k}**{esc_space}={esc_space}:class:`{v}`' for k, v in exception.items()])}" - for exception in dtype_exceptions - ] - ), - prefix, - ) - - dtype_info += "\n\n " - - code_block_index = func.__doc__.find(".. code-block:: python") - assert code_block_index != -1, f"No code example in docstring for {func.__name__}" - func.__doc__ = func.__doc__[:code_block_index] + dtype_info + func.__doc__[code_block_index:] - - def interface( # TODO (pranavm): These should be required arguments eventually. # TODO (pranavm): Document requirements/guarantees. input_requirements: Constraints = None, output_guarantees: Constraints = None, - dtype_constraints: Dict[str, str] = {}, - dtype_variables: Dict[str, List[str]] = {}, - dtype_exceptions: List[Dict[str, str]] = [], convert_to_tensors: Union[bool, Set[str]] = False, conversion_preprocess_func: Optional[Callable] = None, ): @@ -438,39 +357,13 @@ def interface( layer too many decorators, it is preferable to extend this decorator with further functionality than to add and apply further decorators. - The supported constraints are for data type constraints and for converting `TensorLike` and `ShapeLike` - inputs into `Tensor`s or `DimensionSize`s. - - NOTE: When annotating a new API, you should also update `tests/constraints/object_builders.py`. - Args: - dtype_constraints: Maps parameters and return values to data type constraint variables. - Use the special value `wrappers.RETURN_VALUE` to denote return values - this can be - a list for functions that have multiple outputs. If only one return type is specified for - functions with multiple outputs, it will be applied to all outputs. - For example: - {"input": "T1", "other": T2, wrappers.RETURN_VALUE: "T1"} - - dtype_variables: Maps data type constraints variables to their supported data types. - For example: - {"T1": ["float32", "float16"], "T2": ["int32", "int64"]} - - dtype_exceptions: Indicates specific combinations of data types that are not supported by the API. - For example: - [ - {"T1": "float16", "T2": "int32"}, - ] - - aliases: A list of function name aliases. For methods that are exposed as multiple APIs - (e.g. `__add__` and `__radd__`), this will enable type information to be added to the - documentation for the aliases as well. - convert_to_tensors: If False or an empty set, no argument types will be converted. If True, all arguments with the `TensorLike` or `ShapeLike` annotations will be converted into `Tensor`s or, whenever possible, `DimensionSize`. If the argument is a set of argument names, conversions will be done only for those arguments. - The conversions will respect any datatype constraints, casting the `TensorLike` values as necessary, + The conversions will attempt safe casts as needed based on `input_requirements`, but will raise an exception for lossy casts like float to int (but *not* for, e.g., `float32` to `float16`). conversion_preprocess_func: If `convert_to_tensors` is true, this argument is a callback that is @@ -493,18 +386,14 @@ def decorator(func): ) shape_likes = {name for name, param in signature.parameters.items() if param.annotation is ShapeLike} - # TODO (pranavm): Constraints should never be None eventually. - if input_requirements is not None: + # Register constraints for Tripy operators if either side is specified. + # + # NOTE: The interface decorator is also used in unit tests and potentially by user code. + # We only want Tripy's own operators to appear in the global registry that powers + # public-API validation and integration tests. + if (input_requirements is not None or output_guarantees is not None) and func.__module__.startswith("nvtripy"): OPERATOR_CONSTRAINTS.append(OperatorConstraints(func, input_requirements, output_guarantees)) - _update_docstring(func, input_requirements, output_guarantees) - elif dtype_constraints or dtype_variables or dtype_exceptions: - # if no dtype constraints have been specified at all, do not add to the table so we don't generate invalid tests - DATA_TYPE_CONSTRAINTS.append( - DataTypeConstraints(func, dtype_constraints, dtype_variables, dtype_exceptions) - ) - - _update_docstring_legacy(func, dtype_constraints, dtype_variables, dtype_exceptions) @functools.wraps(func) def wrapper(*args, **kwargs): @@ -521,7 +410,6 @@ def wrapper(*args, **kwargs): var_arg_info, conversion_targets, conversion_preprocess_func, - dtype_constraints, shape_likes, input_requirements, ) @@ -531,80 +419,23 @@ def wrapper(*args, **kwargs): # Input validation needs to know values for arguments that were not provided but have default values: result = input_requirements(merged_args + omitted_default_args) if not result: - raise_error( - f"Invalid inputs for function: '{func.__qualname__}'.", + details = ( ["Expected: "] + result.error_details - + [f".\n\nNote: Requirements are:\n {input_requirements}."], + + [f".\n\nNote: Requirements are:\n {input_requirements}."] ) - if config.enable_dtype_checking: - from nvtripy.common.datatype import dtype - from nvtripy.frontend.tensor import Tensor + # Include source locations for relevant tensor inputs to make constraint + # failures actionable. + for name, value in merged_args + omitted_default_args: + if hasattr(value, "stack_info"): + details.extend([f"\n\nArgument '{name}' was defined here:\n", value]) - # The first arguments seen for each type variable. Other arguments with the same variable - # must use the same data types. - type_var_first_args: Dict[str, Tuple[str, dtype, Any]] = {} - - for name, arg in merged_args: - if name not in dtype_constraints: - continue - - if arg is None: - # This is only possible for omitted optional arguments. Otherwise, None will - # be disallowed by the function registry's type checking. - continue - - type_var = dtype_constraints[name] - - arg_dtype = get_arg_dtype(arg, func.__qualname__, name) - if not arg_dtype: - raise_error(f"Could not determine datatype of {name}.", arg_dtype.error_details) - arg_dtype = arg_dtype.value - - # Check if the type is supported at all - supported_dtypes = dtype_variables[type_var] - if arg_dtype.name not in supported_dtypes: raise_error( - f"Unsupported data type in '{func.__qualname__}'.", - [ - f"For parameter: '{name}', got unsupported data type: '{arg_dtype}'.\n" - f"Supported data types are: {supported_dtypes}." - ] - + ( - [ - f"\nNote: '{name}' was: ", - arg, - ] - if isinstance(arg, Tensor) and "all" in config.extra_error_information - else [] - ), + f"Invalid inputs for function: '{func.__qualname__}'.", + details, ) - # Check if the type matches that of other inputs with the same type_var. - if type_var in type_var_first_args: - other_name, other_arg_dtype, other_arg = type_var_first_args[type_var] - if other_arg_dtype != arg_dtype: - raise_error( - f"Mismatched data types in '{func.__qualname__}'.", - [ - f"Parameters: '{other_name}' and '{name}' must have matching data types, but got: " - f"'{other_arg_dtype.name}' and '{arg_dtype.name}' respectively.\n" - ] - + ( - [ - f"Note: '{other_name}' was: ", - other_arg, - f"While '{name}' was: ", - arg, - ] - if isinstance(arg, Tensor) - else [] - ), - ) - - type_var_first_args[type_var] = (name, arg_dtype, arg) - return func(*args, **kwargs) return wrapper diff --git a/tripy/tests/frontend/constraints/test_fetcher.py b/tripy/tests/frontend/constraints/test_fetcher.py index ac67660f1..9565e1c8f 100644 --- a/tripy/tests/frontend/constraints/test_fetcher.py +++ b/tripy/tests/frontend/constraints/test_fetcher.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -104,7 +104,19 @@ def test_call_with_mismatched_dtypes_in_sequence(self): def test_call_with_non_tensor_argument(self): fetcher = GetDataType(GetInput("input_data")) with helper.raises(TripyException, match="Expected a tensor or data type argument"): - fetcher([("input_data", 42)]) + fetcher([("input_data", object())]) + + def test_call_with_python_scalar_int(self): + fetcher = GetDataType(GetInput("value")) + assert fetcher([("value", 42)]) == tp.int32 + + def test_call_with_python_scalar_float(self): + fetcher = GetDataType(GetInput("value")) + assert fetcher([("value", 1.0)]) == tp.float32 + + def test_call_with_python_scalar_bool(self): + fetcher = GetDataType(GetInput("value")) + assert fetcher([("value", True)]) == tp.bool def test_call_with_nested_sequence_error(self): fetcher = GetDataType(GetInput("input_data")) diff --git a/tripy/tests/frontend/constraints/test_logic.py b/tripy/tests/frontend/constraints/test_logic.py index c0ea769f4..d0b27e903 100644 --- a/tripy/tests/frontend/constraints/test_logic.py +++ b/tripy/tests/frontend/constraints/test_logic.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -189,7 +189,7 @@ def test_inverse(self): def test_doc_str(self): or_constraint = Or(Equal(GetInput("a"), tp.float32), Equal(GetInput("a"), tp.float16)) - assert doc_str(or_constraint) == "(``a`` is :class:`float32` *or* ``a`` is :class:`float16`)" + assert doc_str(or_constraint) == "(``a`` == :class:`float32` *or* ``a`` == :class:`float16`)" class TestEqual: @@ -216,8 +216,8 @@ def test_inverse(self): assert not inverse([("param1", 5)]) def test_doc_str(self): - assert doc_str(Equal(GetInput("a"), GetInput("b"))) == "``a`` is ``b``" - assert doc_str(Equal(GetInput("a"), tp.float32)) == "``a`` is :class:`float32`" + assert doc_str(Equal(GetInput("a"), GetInput("b"))) == "``a`` == ``b``" + assert doc_str(Equal(GetInput("a"), tp.float32)) == "``a`` == :class:`float32`" class TestNotEqual: @@ -244,7 +244,7 @@ def test_inverse(self): assert not inverse([("param1", 10)]) def test_doc_str(self): - assert doc_str(NotEqual(GetInput("a"), GetInput("b"))) == "``a`` is not ``b``" + assert doc_str(NotEqual(GetInput("a"), GetInput("b"))) == "``a`` != ``b``" class TestIf: @@ -298,7 +298,7 @@ def test_doc_str(self): ) assert ( doc_str(if_constraint) - == "``b`` is one of [:class:`float32`, :class:`float16`] **if** ``a.dtype`` is :class:`float32`, **otherwise** ``b`` is one of [:class:`int32`, :class:`int64`]" + == "``b`` is one of [:class:`float32`, :class:`float16`] **if** ``a.dtype`` == :class:`float32`, **otherwise** ``b`` is one of [:class:`int32`, :class:`int64`]" ) def test_call_without_else_branch(self): @@ -322,7 +322,7 @@ def test_doc_str_without_else_branch(self): ) assert ( doc_str(if_constraint) - == "if ``a.dtype`` is :class:`float32`, then ``b`` is one of [:class:`float32`, :class:`float16`]" + == "if ``a.dtype`` == :class:`float32`, then ``b`` is one of [:class:`float32`, :class:`float16`]" ) def test_inverse_without_else_branch(self): diff --git a/tripy/tests/frontend/module/test_conv.py b/tripy/tests/frontend/module/test_conv.py index fd0af012f..1c11cccf6 100644 --- a/tripy/tests/frontend/module/test_conv.py +++ b/tripy/tests/frontend/module/test_conv.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -44,7 +44,7 @@ def test_op_func(self, ConvType): def test_mismatched_dtypes_fails(self, ConvType): input = tp.ones((4, 3, 8, 8), dtype=tp.float32) conv_layer = make_conv(ConvType, 3, 16, (5, 5), dtype=tp.float16) - with helper.raises(tp.TripyException, match=r"Mismatched data types in", has_stack_info_for=[input]): + with helper.raises(tp.TripyException, match=r"Invalid inputs for function:", has_stack_info_for=[input]): output = conv_layer(input) def test_mismatched_dim_fails(self, ConvType): diff --git a/tripy/tests/frontend/module/test_embedding.py b/tripy/tests/frontend/module/test_embedding.py index 6f440b70e..0ee3f55a2 100644 --- a/tripy/tests/frontend/module/test_embedding.py +++ b/tripy/tests/frontend/module/test_embedding.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,5 +34,5 @@ def test_incorrect_input_dtype(self): embd = tp.Embedding(4, 16) embd.weight = tp.ones(embd.weight.shape) - with helper.raises(tp.TripyException, match="Unsupported data type in 'gather'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: 'gather'."): out = embd(a) diff --git a/tripy/tests/frontend/ops/test_binary.py b/tripy/tests/frontend/ops/test_binary.py index dc19b0e31..473aa183a 100644 --- a/tripy/tests/frontend/ops/test_binary.py +++ b/tripy/tests/frontend/ops/test_binary.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -63,7 +63,7 @@ def test_mismatched_dtypes_fails(self): with helper.raises( tp.TripyException, # Keep the entire error message here so we'll know if the display becomes horribly corrupted. - match=r"Mismatched data types in '__add__'.", + match=r"Invalid inputs for function: '__add__'.", has_stack_info_for=[a, b], ): c = a + b diff --git a/tripy/tests/frontend/ops/test_dequantize.py b/tripy/tests/frontend/ops/test_dequantize.py index a353db13b..8c80431d7 100644 --- a/tripy/tests/frontend/ops/test_dequantize.py +++ b/tripy/tests/frontend/ops/test_dequantize.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,7 +26,7 @@ def test_invalid_input_dtype(self): a = tp.Tensor([1.0, 2.0]) with helper.raises( tp.TripyException, - match="Unsupported data type in 'dequantize'", + match="Invalid inputs for function: 'dequantize'.", ): a = tp.dequantize(a, 0.9, tp.float32) @@ -34,7 +34,7 @@ def test_invalid_dequant_dtype(self): a = tp.ones([2], dtype=tp.int8) with helper.raises( tp.TripyException, - match="Unsupported data type in 'dequantize'", + match="Invalid inputs for function: 'dequantize'.", ): a = tp.dequantize(a, 1, tp.int32) diff --git a/tripy/tests/frontend/ops/test_equal.py b/tripy/tests/frontend/ops/test_equal.py index 83792a6a9..2b6cb25ca 100644 --- a/tripy/tests/frontend/ops/test_equal.py +++ b/tripy/tests/frontend/ops/test_equal.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,5 +18,5 @@ class TestEqual: def test_mismatched_dtypes_disallowed(self): - with helper.raises(tp.TripyException, match="Mismatched data types in 'equal'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: 'equal'."): tp.equal(tp.ones((2,), dtype=tp.float32), tp.ones((2,), dtype=tp.float16)) diff --git a/tripy/tests/frontend/ops/test_matmul.py b/tripy/tests/frontend/ops/test_matmul.py index f33cb7f44..3d810403d 100644 --- a/tripy/tests/frontend/ops/test_matmul.py +++ b/tripy/tests/frontend/ops/test_matmul.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,7 +33,7 @@ def test_mismatched_dtypes_fails(self): a = tp.ones((2, 3), dtype=tp.float32) b = tp.ones((3, 2), dtype=tp.float16) - with helper.raises(tp.TripyException, match="Mismatched data types in '__matmul__'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: '__matmul__'."): c = a @ b def test_incompatible_1d_shapes_fails(self): diff --git a/tripy/tests/frontend/ops/test_quantize.py b/tripy/tests/frontend/ops/test_quantize.py index f9520a37d..a9055bf18 100644 --- a/tripy/tests/frontend/ops/test_quantize.py +++ b/tripy/tests/frontend/ops/test_quantize.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +25,7 @@ def test_invalid_input_dtype(self): a = tp.Tensor([1, 2]) with helper.raises( tp.TripyException, - match="Unsupported data type in 'quantize'", + match="Invalid inputs for function: 'quantize'.", ): a = tp.quantize(a, 1, tp.int8) @@ -33,7 +33,7 @@ def test_unsupported_quant_dtype(self): a = tp.Tensor([1.0, 2.0]) with helper.raises( tp.TripyException, - match="Unsupported data type in 'quantize'", + match="Invalid inputs for function: 'quantize'.", ): a = tp.quantize(a, 1, tp.float16) diff --git a/tripy/tests/frontend/ops/test_where.py b/tripy/tests/frontend/ops/test_where.py index 85d68f5f5..6921d9175 100644 --- a/tripy/tests/frontend/ops/test_where.py +++ b/tripy/tests/frontend/ops/test_where.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -45,7 +45,7 @@ def test_mismatched_input_dtypes(self): a = tp.ones((2,), dtype=tp.float32) b = tp.ones((2,), dtype=tp.float16) - with helper.raises(tp.TripyException, match="Mismatched data types in 'where'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: 'where'."): c = tp.where(cond, a, b) def test_condition_is_not_bool(self): @@ -53,7 +53,7 @@ def test_condition_is_not_bool(self): a = tp.ones((2,), dtype=tp.float32) b = tp.ones((2,), dtype=tp.float32) - with helper.raises(tp.TripyException, match="Unsupported data type in 'where'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: 'where'."): c = tp.where(cond, a, b) @@ -62,5 +62,5 @@ def test_condition_is_not_bool(self): a = tp.Tensor([0, 1, 0, 1]) mask = tp.Tensor([1.0, 2.0, 3.0, 4.0]) - with helper.raises(tp.TripyException, match="Unsupported data type in 'masked_fill'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: 'masked_fill'."): b = tp.masked_fill(a, mask, -1) diff --git a/tripy/tests/frontend/wrappers/test_datatype_constraints.py b/tripy/tests/frontend/wrappers/test_datatype_constraints.py index cccb14590..b681f87b1 100644 --- a/tripy/tests/frontend/wrappers/test_datatype_constraints.py +++ b/tripy/tests/frontend/wrappers/test_datatype_constraints.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,230 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO (pranavm): Move into integration tests +"""(Removed) Legacy dtype constraint tests. -import contextlib -import inspect -import itertools -from dataclasses import dataclass -from typing import Callable, Dict, List - -import numpy as np -import nvtripy as tp -import pytest -from nvtripy.common.datatype import DATA_TYPES -from nvtripy.frontend import wrappers -from nvtripy.utils.types import str_from_type_annotation -from nvtripy.utils.utils import make_list -from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS -from tests import helper -from tests.conftest import skip_if_older_than_sm89 - - -@dataclass -class DtypeConstraintCase: - func: Callable - constraints: Dict[str, str] - variables: Dict[str, tp.dtype] - negative: bool # Whether this is a negative test case - - def __str__(self): - return f"{self.func.__name__}-{'_'.join(f'{key}-{val}' for key, val in self.variables.items())}" + ( - "-invalid" if self.negative else "-valid" - ) - - -DTYPE_CONSTRAINT_CASES: List[DtypeConstraintCase] = [] - -for dtc in DATA_TYPE_CONSTRAINTS: - keys, values = list(zip(*dtc.variables.items())) - - def add_cases(combinations, negative): - for combination in combinations: - dtype_combination = dict(zip(keys, combination)) - DTYPE_CONSTRAINT_CASES.append( - pytest.param( - DtypeConstraintCase(dtc.func, dtc.constraints, dtype_combination, negative), - marks=( - skip_if_older_than_sm89() - if any(dtype == "float8" for dtype in dtype_combination.values()) - else [] - ), - ) - ) - - # Positive cases: - positive_combinations = list(itertools.product(*values)) - exceptions = [tuple(exception[key] for key in keys) for exception in dtc.exceptions] - positive_combinations = [comb for comb in positive_combinations if comb not in exceptions] - add_cases(positive_combinations, negative=False) - - # Negative cases - we do this by simply generating all possible combinations and removing the positive ones: - total_dtypes = set(map(str, DATA_TYPES.values())) - negative_combinations = list(itertools.product(*(total_dtypes for _ in values))) - negative_combinations = list(comb for comb in negative_combinations if comb not in positive_combinations) - add_cases(negative_combinations, negative=True) - - -DTYPE_CONSTRAINT_CASES.sort(key=lambda case: str(case)) - - -# In some cases, we need to use custom values so that the code is valid. -CUSTOM_VALUES = { - "__getitem__": {"index": 0}, - "arange": {"start": 0, "stop": 10, "step": 1}, - "avgpool": {"kernel_dims": [2, 2]}, - "convolution": { - "weight": tp.Tensor(np.ones((2, 2, 3, 3), dtype=np.float32)), - "bias": tp.Tensor([1.0, 2.0]), - "padding": ((0, 0), (0, 0)), - "stride": [1, 1], - "groups": 1, - "dilation": [1, 1], - }, - "copy": { - "input": tp.ones((2, 2)), - "device": tp.device("cpu"), - }, - "deconvolution": { - "weight": tp.Tensor(np.ones((2, 2, 3, 3), dtype=np.float32)), - "bias": tp.Tensor([1.0, 2.0]), - "padding": ((0, 0), (0, 0)), - "stride": [1, 1], - "groups": 1, - "dilation": [1, 1], - }, - "dequantize": {"scale": tp.Tensor([1.0, 2.0]), "dim": 1}, - "expand": {"sizes": tp.Tensor((2, 2, 5, 5))}, - "full": {"shape": tp.Tensor([2, 2]), "value": tp.Tensor(1.0)}, - "full_like": {"value": tp.Tensor(1.0)}, - "gather": {"index": tp.Tensor([1])}, - "instancenorm": { - "num_channels": 2, - "weight": tp.Tensor(np.ones((2,), dtype=np.float32)), - "bias": tp.Tensor(np.zeros((2,), dtype=np.float32)), - }, - "iota": {"shape": tp.Tensor([2, 2])}, - "maxpool": {"kernel_dims": [2, 2]}, - "ones": {"shape": [2, 2]}, - "outer": {"vec1": tp.Tensor([1, 2, 3]), "vec2": tp.Tensor([1, 2, 3])}, - "pad": {"pad": [(0, 1), (1, 0), (1, 1), (0, 0)]}, - "permute": {"perm": [1, 0, 3, 2]}, - "quantize": {"scale": tp.Tensor([1.0, 2.0]), "dim": 1}, - "repeat": {"repeats": 2, "dim": 0}, - "reshape": {"shape": tp.Tensor([2, 25])}, - "resize": { - "mode": "nearest", - "output_shape": tp.Tensor((1, 2, 10, 10)), - "scales": [1, 1, 2, 2], - }, - "squeeze": {"dims": 0}, - "transpose": {"dim0": 0, "dim1": 1}, - "zeros": {"shape": [2, 2]}, -} - -# Arguments that must be constants on CPU. -REQUIRES_CPU_CONST = { - "dequantize": {"scale"}, - "quantize": {"scale"}, -} - -# Some operations require input shapes to be known -REQUIRES_KNOWN_SHAPES = { - "convolution": {"input", "weight", "bias"}, - "deconvolution": {"input", "weight", "bias"}, -} - - -def generate_input_values(case: DtypeConstraintCase): - inputs = {} - for param_name, param in inspect.signature(case.func).parameters.items(): - requires_cpu_const = ( - case.func.__name__ in REQUIRES_CPU_CONST and param_name in REQUIRES_CPU_CONST[case.func.__name__] - ) - requires_known_shapes = ( - case.func.__name__ in REQUIRES_KNOWN_SHAPES and param_name in REQUIRES_KNOWN_SHAPES[case.func.__name__] - ) - - param_type = str_from_type_annotation(param.annotation) - - dtype = None - if param_name in case.constraints: - dtype = DATA_TYPES[case.variables[case.constraints[param_name]]] - - if dtype == tp.int4: - # TODO (#579): Enable int4 inputs - pytest.skip(f"#579: Cannot generate INT4 inputs") - - def copy_input_to_cpu_and_set_shapes(): - if requires_cpu_const: - assert isinstance(inputs[param_name], tp.Tensor) - if inputs[param_name].device.kind != "cpu": - inputs[param_name] = tp.copy(inputs[param_name], device=tp.device("cpu")) - - if requires_known_shapes: - assert isinstance(inputs[param_name], tp.Tensor) - if any(dim == tp.constants.DYNAMIC_DIM for dim in inputs[param_name].trace_tensor.shape): - inputs[param_name].trace_tensor.shape = tuple(map(int, inputs[param_name].shape)) - - if case.func.__name__ in CUSTOM_VALUES and param_name in CUSTOM_VALUES[case.func.__name__]: - inputs[param_name] = CUSTOM_VALUES[case.func.__name__][param_name] - if isinstance(inputs[param_name], tp.Tensor): - if dtype is not None: - inputs[param_name] = tp.cast(inputs[param_name], dtype=dtype) - copy_input_to_cpu_and_set_shapes() - continue - - if param.default is not inspect.Parameter.empty and dtype is None: - continue # Skip optional parameters unless we explicitly need to set a datatype for them. - - if "nvtripy.Tensor" in param_type: - assert dtype is not None, "Tensors must have type annotations" - # Need to cast here because `ones` does not support all types. - tensor = tp.Tensor(np.ones((1, 2, 5, 5), dtype=np.float32)) - if "Sequence" in param_type: - inputs[param_name] = [tp.cast(tensor, dtype=dtype) for _ in range(2)] - copy_input_to_cpu_and_set_shapes() - else: - inputs[param_name] = tp.cast(tensor, dtype=dtype) - copy_input_to_cpu_and_set_shapes() - elif "nvtripy.dtype" in param_type: - assert dtype is not None, "Data types must have type annotations" - inputs[param_name] = dtype - elif "numbers.Number" in param_type or "int" in param_type or "float" in param_type: - inputs[param_name] = 1 - else: - assert False, f"Unsupported parameter type: {param_type}" - return inputs - - -@pytest.mark.parametrize("case", DTYPE_CONSTRAINT_CASES, ids=lambda case: str(case)) -def test_datatype_constraints(case: DtypeConstraintCase): - - # If data type checking is enabled, negative tests will trivially pass (we will throw an - # error before even trying to call the function). - with helper.config("enable_dtype_checking", False): - inputs = generate_input_values(case) - - with contextlib.ExitStack() as stack: - if case.negative: - stack.enter_context(helper.raises(Exception)) - - outputs = make_list(case.func(**inputs)) - - # Some APIs do not generate Tensor outputs (e.g. `allclose`), so we don't need to evaluate those. - if not any(isinstance(out, tp.Tensor) for out in outputs): - return - - expected_return_types = [ - case.variables[cons] for cons in make_list(case.constraints[wrappers.RETURN_VALUE]) - ] - assert expected_return_types, "Return value must have a constraint" - if len(expected_return_types) < len(outputs): - expected_return_types += [expected_return_types[-1]] * (len(outputs) - len(expected_return_types)) - - for out, expected_type in zip(outputs, expected_return_types): - # Avoid evaluating CPU constants since some types (e.g. FP8) don't allow them to be used outside of - # certain operations. - out._eval_for_internal_methods() - assert out.dtype == DATA_TYPES[expected_type], f"Expected {expected_type}, got {out.dtype}" +The old `dtype_constraints`/`dtype_variables` system has been removed. +Operator constraint behavior is now validated via `tests/integration/test_operator_constraints.py`. +""" diff --git a/tripy/tests/frontend/wrappers/test_wrappers.py b/tripy/tests/frontend/wrappers/test_wrappers.py index 1aa26282f..fdc59c9a1 100644 --- a/tripy/tests/frontend/wrappers/test_wrappers.py +++ b/tripy/tests/frontend/wrappers/test_wrappers.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +25,7 @@ from nvtripy.frontend import wrappers from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn from nvtripy.frontend.constraints.logic import And, Equal, NotEqual, NotOneOf, OneOf, Or -from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS, _find_known_datatypes, OPERATOR_CONSTRAINTS +from nvtripy.frontend.wrappers import _find_known_datatypes, OPERATOR_CONSTRAINTS from tests import helper # Get all functions/methods which have tensors in the type signature @@ -48,24 +48,15 @@ api.qualname + f".{func.__name__}" if func.__name__ not in api.qualname else "" ) -# TODO (pranavm): Remove old dtype constraints system: -DATA_TYPE_CONSTRAINTS_FUNC_NAMES = {dtc.func.__qualname__ for dtc in DATA_TYPE_CONSTRAINTS} - - -@pytest.mark.parametrize("api", PUBLIC_API_TENSOR_FUNCTIONS, ids=PUBLIC_API_TENSOR_FUNCTION_NAMES) -def test_all_public_apis_verified(api): - assert api.__qualname__ in DATA_TYPE_CONSTRAINTS_FUNC_NAMES, f"Missing datatype constraints for: {api.__qualname__}" - - OPERATOR_CONSTRAINTS_FUNC_NAMES = {oc.func.__qualname__ for oc in OPERATOR_CONSTRAINTS} @pytest.mark.parametrize("api", PUBLIC_API_TENSOR_FUNCTIONS, ids=PUBLIC_API_TENSOR_FUNCTION_NAMES) -def test_all_public_apis_verified(api): +def test_all_public_apis_have_operator_constraints(api): assert api.__qualname__ in OPERATOR_CONSTRAINTS_FUNC_NAMES, f"Missing operator constraints for: {api.__qualname__}" -@wrappers.interface(dtype_constraints={"tensors": "T1"}, dtype_variables={"T1": ["float32"]}) +@wrappers.interface(input_requirements=OneOf(GetInput("tensors").dtype, [tp.float32])) def sequence_func(tensors: List[tp.Tensor]): return @@ -152,7 +143,10 @@ def test_works_with_sequences(self): sequence_func([tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32)]) def test_raises_on_mismatched_sequence_dtypes(self): - with helper.raises(tp.TripyException, match="Mismatched data types in sequence argument for 'sequence_func'."): + with helper.raises( + tp.TripyException, + match=r"Mismatched data types in sequence argument", + ): sequence_func([tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.int32)]) @@ -256,8 +250,8 @@ def func(a: tp.types.TensorLike, b: tp.types.TensorLike): def test_cast_dtype(self): # When type constraints are included, the decorator should automatically cast when possible. @wrappers.interface( - dtype_constraints={"a": "T1", "b": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float16"]}, + input_requirements=(GetInput("a").dtype == tp.float16) & (GetInput("b").dtype == GetInput("a").dtype), + output_guarantees=(GetReturn(0).dtype == GetInput("a").dtype) & (GetReturn(1).dtype == GetInput("a").dtype), convert_to_tensors=True, ) def func(a: tp.Tensor, b: tp.types.TensorLike): @@ -273,11 +267,15 @@ def func(a: tp.Tensor, b: tp.types.TensorLike): assert isinstance(b, tp.Tensor) assert b.dtype == tp.float16 - @pytest.mark.parametrize("arg, dtype", [(1.0, tp.int32), (1.0, tp.int64), (2, tp.bool)]) + @pytest.mark.parametrize( + "arg, dtype", + [(1.0, tp.int32), (1.0, tp.int64), (2, tp.bool)], + ) def test_refuse_unsafe_cast(self, arg, dtype): @wrappers.interface( - dtype_constraints={"a": "T1", "b": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["int32", "int64"]}, + input_requirements=OneOf(GetInput("a").dtype, [tp.int32, tp.int64, tp.bool]) + & (GetInput("b").dtype == GetInput("a").dtype), + output_guarantees=(GetReturn(0).dtype == GetInput("a").dtype) & (GetReturn(1).dtype == GetInput("a").dtype), convert_to_tensors=True, ) def func(a: tp.Tensor, b: tp.types.TensorLike): diff --git a/tripy/tests/integration/test_operator_constraints.py b/tripy/tests/integration/test_operator_constraints.py index 32f7c07a9..e793ec378 100644 --- a/tripy/tests/integration/test_operator_constraints.py +++ b/tripy/tests/integration/test_operator_constraints.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -226,7 +226,11 @@ def test_operator_constraints(case: OperatorConstraintCase): inputs = generate_input_values(case) merged_args = list(inputs.items()) - is_valid = bool(op_constraint.input_requirements(merged_args)) + # Some operators may only define output guarantees. + # In that case, we cannot predict input validity via constraints. + is_valid = ( + True if op_constraint.input_requirements is None else bool(op_constraint.input_requirements(merged_args)) + ) with contextlib.ExitStack() as stack: if not is_valid: From 33960b417a0b1c629aabc5e1b59c60a6d7b9f8c8 Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 6 Feb 2026 18:25:48 +0000 Subject: [PATCH 28/32] Removes extra files --- ...teOperatorsToNewConstraintSystem.prompt.md | 21 ------------------- .../wrappers/test_datatype_constraints.py | 20 ------------------ 2 files changed, 41 deletions(-) delete mode 100644 tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md delete mode 100644 tripy/tests/frontend/wrappers/test_datatype_constraints.py diff --git a/tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md b/tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md deleted file mode 100644 index 321378556..000000000 --- a/tripy/.github/prompts/plan-migrateOperatorsToNewConstraintSystem.prompt.md +++ /dev/null @@ -1,21 +0,0 @@ -## Plan: Migrate Operators to New Constraint System - -This plan will systematically port operators to use the `input_requirements`/`output_guarantees` constraint system. The migration follows proven patterns from [`cast.py`](nvtripy/frontend/ops/cast.py), [`concatenate.py`](nvtripy/frontend/ops/concatenate.py), and [`ones.py`](nvtripy/frontend/ops/ones.py). - -### Steps - -1. **Migrate simple type-preserving unary operators** (~16 ops in [`nvtripy/frontend/ops/unary/`](nvtripy/frontend/ops/unary/)): Convert operators like `exp`, `log`, `sqrt`, `sin`, `cos`, `tanh`, `sigmoid`, `relu`, `gelu`, `silu`, `abs`, `neg`, `invert` using pattern: `input_requirements=OneOf(GetInput("input").dtype, [dtypes])` and `output_guarantees=GetReturn(0).dtype == GetInput("input").dtype`. Run tests with `pytest tests/integration/test_operator_constraints.py -k "exp|log|sqrt" -s` and corresponding sanity tests. **Commit changes** with message like "Migrate unary operators to new constraint system". - -2. **Migrate shape manipulation operators** (~12 ops in [`nvtripy/frontend/ops/`](nvtripy/frontend/ops/)): Convert `reshape`, `transpose`, `permute`, `flatten`, `squeeze`, `unsqueeze`, `expand`, `repeat`, `stack`, `slice`, `flip` using the same type-preserving pattern. Test with `pytest tests/integration/test_operator_constraints.py -k "reshape|transpose|permute" -s`. **Commit changes** with message like "Migrate shape manipulation operators to new constraint system". - -3. **Migrate binary arithmetic and comparison operators** (~16 ops in [`nvtripy/frontend/ops/binary/`](nvtripy/frontend/ops/binary/)): Convert `add`, `sub`, `mul`, `div`, `floor_div`, `mod`, `pow`, `maximum`, `minimum`, `equal`, `not_equal`, `greater`, `greater_equal`, `less`, `less_equal`, `logical_or` using pattern: `OneOf(GetInput("self").dtype, [...]) & (GetInput("other").dtype == GetInput("self").dtype)`. Comparison ops return `dt.bool`. Test with `pytest tests/integration/test_operator_constraints.py -k "add|mul|equal" -s`. **Commit changes** with message like "Migrate binary operators to new constraint system". - -4. **Migrate reduction operators** (~11 ops in [`nvtripy/frontend/ops/reduce/`](nvtripy/frontend/ops/reduce/)): Convert `sum`, `prod`, `mean`, `max`, `min`, `var`, `all`, `any`, plus multi-type ops `argmax`, `argmin`, `topk` (note: `topk` has multiple returns). Use `If` constraints for multi-type variables. Test with `pytest tests/integration/test_operator_constraints.py -k "sum|mean|argmax" -s`. **Commit changes** with message like "Migrate reduction operators to new constraint system". - -5. **Migrate initializers with optional dtype parameters** (~8 ops in [`nvtripy/frontend/ops/`](nvtripy/frontend/ops/)): Convert `zeros`, `zeros_like`, `full`, `full_like`, `iota`, `arange` using `If(GetInput("dtype") != None, ...)` pattern for conditional output guarantees. Test with `pytest tests/integration/test_operator_constraints.py -k "zeros|full|iota|arange" -s`. **Commit changes** with message like "Migrate initializer operators to new constraint system". - -6. **Migrate advanced operations and special cases** (~15 ops): Convert `matmul`, `gather`, `where`, `masked_fill`, `outer`, `pad`, `softmax`, `cumsum`, `copy`, `resize`, `triu`, `tril`, `avgpool`, `maxpool` in their respective files. Handle multi-type variables and complex logic. Migrate `quantize`/`dequantize` in [`plugin_qdp.py`](nvtripy/frontend/ops/plugin_qdp.py) and [`quantize.py`](nvtripy/frontend/ops/quantize.py)/[`dequantize.py`](nvtripy/frontend/ops/dequantize.py) with coordinated type pairs. Test each with `-k` filtering. **Commit changes** with message like "Migrate advanced operators to new constraint system". - -### Implementation Notes - -- Translate any existing exception logic directly to boolean logic using `&`, `|`, `~` operators without creating helper functions diff --git a/tripy/tests/frontend/wrappers/test_datatype_constraints.py b/tripy/tests/frontend/wrappers/test_datatype_constraints.py deleted file mode 100644 index b681f87b1..000000000 --- a/tripy/tests/frontend/wrappers/test_datatype_constraints.py +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""(Removed) Legacy dtype constraint tests. - -The old `dtype_constraints`/`dtype_variables` system has been removed. -Operator constraint behavior is now validated via `tests/integration/test_operator_constraints.py`. -""" From 208b992daba3a080729be063c4ea1fa7555230de Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 6 Feb 2026 18:51:26 +0000 Subject: [PATCH 29/32] Removes a hallucinated parameter --- tripy/nvtripy/frontend/ops/cumsum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tripy/nvtripy/frontend/ops/cumsum.py b/tripy/nvtripy/frontend/ops/cumsum.py index 9a31be388..8b08e70c0 100644 --- a/tripy/nvtripy/frontend/ops/cumsum.py +++ b/tripy/nvtripy/frontend/ops/cumsum.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,7 +26,7 @@ input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) -def cumsum(input: "nvtripy.Tensor", dim: int, exclusive: bool = False) -> "nvtripy.Tensor": +def cumsum(input: "nvtripy.Tensor", dim: int) -> "nvtripy.Tensor": """ Computes the cumulative sum of elements in the input along the dimension ``dim``. From 1f693c4716e2ffc8c6a8534705debcb52be62e2e Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 6 Feb 2026 11:24:27 -0800 Subject: [PATCH 30/32] Adds HF token --- .github/workflows/tripy-l0.yml | 7 ++++--- .github/workflows/tripy-l1.yml | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tripy-l0.yml b/.github/workflows/tripy-l0.yml index af7db1628..476a097ba 100644 --- a/.github/workflows/tripy-l0.yml +++ b/.github/workflows/tripy-l0.yml @@ -11,6 +11,7 @@ env: REGISTRY: ghcr.io DEFAULT_IMAGE: ghcr.io/nvidia/tensorrt-incubator/tripy:latest NEW_TEST_IMAGE: test-image:latest + HF_TOKEN: ${{ secrets.HF_TOKEN }} jobs: @@ -58,7 +59,7 @@ jobs: uses: addnab/docker-run-action@v3 with: image: ${{ env.l0_image }} - options: --gpus all -v ${{ github.workspace }}/tripy:/tripy + options: --gpus all -v ${{ github.workspace }}/tripy:/tripy -e HF_TOKEN=${{ env.HF_TOKEN }} run: | python3 docs/generate_rsts.py sphinx-build build/doc_sources build/docs -c docs/ -j 4 -W -n @@ -67,7 +68,7 @@ jobs: uses: addnab/docker-run-action@v3 with: image: ${{ env.l0_image }} - options: --gpus all -v ${{ github.workspace }}/tripy:/tripy + options: --gpus all -v ${{ github.workspace }}/tripy:/tripy -e HF_TOKEN=${{ env.HF_TOKEN }} run: | pytest --cov=nvtripy/ --cov-config=.coveragerc tests/ -v -m "not l1" -n 4 --durations=15 --ignore tests/performance @@ -75,7 +76,7 @@ jobs: uses: addnab/docker-run-action@v3 with: image: ${{ env.l0_image }} - options: --gpus all -v ${{ github.workspace }}/tripy:/tripy + options: --gpus all -v ${{ github.workspace }}/tripy:/tripy -e HF_TOKEN=${{ env.HF_TOKEN }} run: | pytest tests/performance -v -m "not l1" --benchmark-warmup=on --benchmark-json benchmark.json diff --git a/.github/workflows/tripy-l1.yml b/.github/workflows/tripy-l1.yml index 650dd183f..bf90d5d9e 100644 --- a/.github/workflows/tripy-l1.yml +++ b/.github/workflows/tripy-l1.yml @@ -18,6 +18,8 @@ concurrency: jobs: l1-test: runs-on: tripy-self-hosted + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} container: image: ghcr.io/nvidia/tensorrt-incubator/tripy:latest volumes: From 00a4fe372d587feab0b0f7405fcace07aeda9767 Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 10 Feb 2026 18:31:50 +0000 Subject: [PATCH 31/32] Adds an optimizer for the constraints sysytem, optimizes `merge_function_arguments` --- .../nvtripy/frontend/constraints/__init__.py | 15 ++- tripy/nvtripy/frontend/constraints/logic.py | 39 ++++++- .../nvtripy/frontend/constraints/optimizer.py | 108 ++++++++++++++++++ tripy/nvtripy/frontend/wrappers.py | 27 ++++- tripy/nvtripy/utils/utils.py | 23 ++-- .../frontend/constraints/test_optimizer.py | 43 +++++++ 6 files changed, 236 insertions(+), 19 deletions(-) create mode 100644 tripy/nvtripy/frontend/constraints/optimizer.py create mode 100644 tripy/tests/frontend/constraints/test_optimizer.py diff --git a/tripy/nvtripy/frontend/constraints/__init__.py b/tripy/nvtripy/frontend/constraints/__init__.py index 8105d1c72..e2650a51d 100644 --- a/tripy/nvtripy/frontend/constraints/__init__.py +++ b/tripy/nvtripy/frontend/constraints/__init__.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,4 +17,15 @@ from nvtripy.frontend.constraints.base import Constraints from nvtripy.frontend.constraints.doc_str import doc_str from nvtripy.frontend.constraints.fetcher import Fetcher, GetDataType, GetInput, GetReturn, ValueFetcher -from nvtripy.frontend.constraints.logic import And, Equal, If, Logic, NotEqual, NotOneOf, OneOf, Or +from nvtripy.frontend.constraints.logic import ( + AlwaysFalse, + AlwaysTrue, + And, + Equal, + If, + Logic, + NotEqual, + NotOneOf, + OneOf, + Or, +) diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py index 5664c8b6e..23f776a78 100644 --- a/tripy/nvtripy/frontend/constraints/logic.py +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -58,14 +58,45 @@ def __invert__(self) -> "Logic": return self.inverse() +class AlwaysTrue(Logic): + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + return Result.ok() + + def __str__(self): + return "true" + + def doc_str(self) -> str: + return "true" + + def inverse(self) -> "Logic": + return AlwaysFalse() + + +class AlwaysFalse(Logic): + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + return Result.err(["true"]) + + def __str__(self): + return "false" + + def doc_str(self) -> str: + return "false" + + def inverse(self) -> "Logic": + return AlwaysTrue() + + class OneOf(Logic): - def __init__(self, fetcher: Fetcher, options: Sequence[Any]): + def __init__(self, fetcher: Fetcher, options: Optional[Sequence[Any]]): super().__init__() self.fetcher = fetcher - # Need to convert generator expressions so we can use them more than once - self.options = list(options) + # Need to convert generator expressions so we can use them more than once. + # `None` is allowed to support pattern matching wildcards. + self.options = list(options) if options is not None else None def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + if self.options is None: + raise_error("OneOf constraint cannot be evaluated with wildcard options.") value = self.fetcher(args, returns) if value in self.options: return Result.ok() @@ -76,6 +107,8 @@ def __str__(self): return f"{self.fetcher} is one of {self.options}" def doc_str(self) -> str: + if self.options is None: + return f"{doc_str(self.fetcher)} is one of [*]" return f"{doc_str(self.fetcher)} is one of [{', '.join(f'{doc_str(opt)}' for opt in self.options)}]" def inverse(self) -> "Logic": diff --git a/tripy/nvtripy/frontend/constraints/optimizer.py b/tripy/nvtripy/frontend/constraints/optimizer.py new file mode 100644 index 000000000..4aa36d70b --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/optimizer.py @@ -0,0 +1,108 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Iterable, List, Optional + +from nvtripy.common.datatype import DATA_TYPES +from nvtripy.frontend.constraints import AlwaysTrue, Constraints +from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput +from nvtripy.frontend.constraints.logic import OneOf + + +@dataclass(frozen=True) +class ConstraintPass(ABC): + name: str + pattern: Constraints + + def predicate(self, constraint: Constraints) -> bool: + return True + + @abstractmethod + def rewrite(self, constraint: Constraints) -> Constraints: ... + + +def _optimize_children( + constraint: Constraints, + constraint_pass: ConstraintPass, + pass_matches: List[Constraints], +) -> Constraints: + """Recursively rewrite child nodes after local rewrites have been applied.""" + for attr_name, attr_value in vars(constraint).items(): + if isinstance(attr_value, Constraints): + setattr(constraint, attr_name, _optimize_constraints(attr_value, constraint_pass, pass_matches)) + continue + + if isinstance(attr_value, (list, tuple)): + optimized_items = [] + changed = False + for item in attr_value: + if isinstance(item, Constraints): + optimized_item = _optimize_constraints(item, constraint_pass, pass_matches) + optimized_items.append(optimized_item) + changed = changed or optimized_item is not item + else: + optimized_items.append(item) + if changed: + new_value = tuple(optimized_items) if isinstance(attr_value, tuple) else optimized_items + setattr(constraint, attr_name, new_value) + + return constraint + + +def _optimize_constraints( + constraint: Constraints, + constraint_pass: ConstraintPass, + pass_matches: List[Constraints], +) -> Constraints: + """Apply passes to this node, then recurse into its children.""" + rewritten = constraint + if any(rewritten is match for match in pass_matches) and constraint_pass.predicate(rewritten): + rewritten = constraint_pass.rewrite(rewritten) + + return _optimize_children(rewritten, constraint_pass, pass_matches) + + +class DropAllDtypesOneOf(ConstraintPass): + def __init__(self): + super().__init__( + name="drop-all-dtypes-oneof", + pattern=OneOf(GetDataType(GetInput(None)), None), + ) + + def predicate(self, constraint: Constraints) -> bool: + return set(constraint.options).issuperset(set(DATA_TYPES.values())) + + def rewrite(self, constraint: Constraints) -> Constraints: + return AlwaysTrue() + + +def _default_passes() -> Iterable[ConstraintPass]: + return (DropAllDtypesOneOf(),) + + +def optimize_constraints(constraints: Optional[Constraints]) -> Optional[Constraints]: + if constraints is None: + return None + + passes = tuple(_default_passes()) + optimized = constraints + for constraint_pass in passes: + matches = optimized.find(constraint_pass.pattern) + optimized = _optimize_constraints(optimized, constraint_pass, matches) + return optimized diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index b4d46643d..c75c8d3f5 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -25,7 +25,8 @@ from nvtripy import config, utils from nvtripy.common.datatype import dtype as tp_dtype from nvtripy.common.exception import raise_error -from nvtripy.frontend.constraints import Constraints, Equal, GetInput, GetDataType, Fetcher, doc_str +from nvtripy.frontend.constraints import AlwaysTrue, Constraints, Equal, GetInput, GetDataType, Fetcher, doc_str +from nvtripy.frontend.constraints.optimizer import optimize_constraints @dataclass @@ -378,6 +379,10 @@ def interface( def decorator(func): from nvtripy.types import ShapeLike, TensorLike + optimized_input_requirements = optimize_constraints(input_requirements) + if isinstance(optimized_input_requirements, AlwaysTrue): + optimized_input_requirements = None + signature = inspect.signature(func) conversion_targets = ( convert_to_tensors @@ -397,11 +402,20 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - merged_args, omitted_default_args, var_arg_info = utils.utils.merge_function_arguments( - func, *args, **kwargs - ) + merged_args = None + omitted_default_args = None + var_arg_info = None + + def get_merged_args(): + nonlocal merged_args, omitted_default_args, var_arg_info + if merged_args is None: + merged_args, omitted_default_args, var_arg_info = utils.utils.merge_function_arguments( + func, *args, **kwargs + ) + return merged_args, omitted_default_args, var_arg_info if convert_to_tensors: + merged_args, omitted_default_args, var_arg_info = get_merged_args() args, kwargs, merged_args = convert_input_types( func, args, @@ -415,9 +429,10 @@ def wrapper(*args, **kwargs): ) if config.enable_input_validation: - if input_requirements is not None: + if optimized_input_requirements is not None: + merged_args, omitted_default_args, _ = get_merged_args() # Input validation needs to know values for arguments that were not provided but have default values: - result = input_requirements(merged_args + omitted_default_args) + result = optimized_input_requirements(merged_args + omitted_default_args) if not result: details = ( ["Expected: "] diff --git a/tripy/nvtripy/utils/utils.py b/tripy/nvtripy/utils/utils.py index 6a4ed8cf4..98ca78d07 100644 --- a/tripy/nvtripy/utils/utils.py +++ b/tripy/nvtripy/utils/utils.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -333,11 +333,8 @@ def gen_uid(inputs=None, outputs=None): ## ## Functions ## -def get_positional_args_with_names(func, *args) -> Tuple[List[Tuple[str, Any]], Optional[Tuple[str, int]]]: - # Returns the names of positional arguments by inspecting the function signature. - # In the case of variadic positional arguments, we cannot determine names, so we use - # None instead. To assist in further processing, this function also returns the name - # and start index of the variadic args in a pair if present (None if not). +@functools.lru_cache(maxsize=None) +def _get_signature_info(func): signature = inspect.signature(func) arg_names = [] varargs_name = None @@ -347,9 +344,19 @@ def get_positional_args_with_names(func, *args) -> Tuple[List[Tuple[str, Any]], # (they would just be absorbed into the variadic argument). varargs_name = name break - arg_names.append(name) + return signature, tuple(arg_names), varargs_name + + +def get_positional_args_with_names(func, *args) -> Tuple[List[Tuple[str, Any]], Optional[Tuple[str, int]]]: + # Returns the names of positional arguments by inspecting the function signature. + # In the case of variadic positional arguments, we cannot determine names, so we use + # None instead. To assist in further processing, this function also returns the name + # and start index of the variadic args in a pair if present (None if not). + _, arg_names, varargs_name = _get_signature_info(func) + arg_names = list(arg_names) + # For all variadic positional arguments, assign the name of the variadic group. num_variadic_args = len(args) - len(arg_names) variadic_start_idx = len(arg_names) @@ -367,7 +374,7 @@ def merge_function_arguments( all_args, var_arg_info = get_positional_args_with_names(func, *args) all_args.extend(kwargs.items()) - signature = inspect.signature(func) + signature, _, _ = _get_signature_info(func) provided_arg_names = {name for name, _ in all_args} omitted_args = [] diff --git a/tripy/tests/frontend/constraints/test_optimizer.py b/tripy/tests/frontend/constraints/test_optimizer.py new file mode 100644 index 000000000..b90d23597 --- /dev/null +++ b/tripy/tests/frontend/constraints/test_optimizer.py @@ -0,0 +1,43 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from nvtripy.common.datatype import DATA_TYPES +from nvtripy.frontend.constraints import AlwaysTrue, And, GetInput, OneOf +from nvtripy.frontend.constraints.fetcher import GetDataType +from nvtripy.frontend.constraints.optimizer import optimize_constraints + + +class TestOptimizeConstraints: + def test_drops_all_dtype_oneof(self): + constraint = OneOf(GetDataType(GetInput("x")), list(DATA_TYPES.values())) + optimized = optimize_constraints(constraint) + assert isinstance(optimized, AlwaysTrue) + + def test_keeps_non_exhaustive_oneof(self): + dtypes = list(DATA_TYPES.values()) + constraint = OneOf(GetDataType(GetInput("x")), dtypes[:-1]) + optimized = optimize_constraints(constraint) + assert optimized is constraint + + def test_applies_to_nested_constraints(self): + constraint = And( + OneOf(GetDataType(GetInput("x")), list(DATA_TYPES.values())), + OneOf(GetInput("y"), [1, 2, 3]), + ) + optimized = optimize_constraints(constraint) + assert isinstance(optimized, And) + assert any(isinstance(child, AlwaysTrue) for child in optimized.constraints) From 0a177a628c2ef99d16446a740a8afdb145d19ba8 Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 10 Feb 2026 18:55:38 +0000 Subject: [PATCH 32/32] Addresses various TODOs and fixes tests --- tripy/nvtripy/frontend/constraints/logic.py | 9 +++++++- tripy/nvtripy/frontend/wrappers.py | 10 ++++---- tripy/nvtripy/utils/utils.py | 23 +++++++++++++++---- tripy/tests/conftest.py | 17 ++++++++++---- .../tests/frontend/constraints/test_logic.py | 2 +- 5 files changed, 47 insertions(+), 14 deletions(-) diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py index 23f776a78..31f5a7a29 100644 --- a/tripy/nvtripy/frontend/constraints/logic.py +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -163,7 +163,14 @@ def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = elif value1 == value2: return Result.ok() - # TODO (pranavm): If fetcher_or_value is a Fetcher, include its value in the error message. + if isinstance(self.fetcher_or_value, Fetcher): + return Result.err( + [ + f"'{self.fetcher}' to be equal to '{self.fetcher_or_value}' " + f"(but it was '{value1}' while '{self.fetcher_or_value}' was '{value2}')" + ] + ) + return Result.err([f"'{self.fetcher}' to be equal to '{self.fetcher_or_value}' (but it was '{value1}')"]) def __str__(self): diff --git a/tripy/nvtripy/frontend/wrappers.py b/tripy/nvtripy/frontend/wrappers.py index c75c8d3f5..b1ab44336 100644 --- a/tripy/nvtripy/frontend/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -346,10 +346,8 @@ def _update_docstring(func, input_requirements, output_guarantees): def interface( - # TODO (pranavm): These should be required arguments eventually. - # TODO (pranavm): Document requirements/guarantees. - input_requirements: Constraints = None, - output_guarantees: Constraints = None, + input_requirements: Optional[Constraints] = None, + output_guarantees: Optional[Constraints] = None, convert_to_tensors: Union[bool, Set[str]] = False, conversion_preprocess_func: Optional[Callable] = None, ): @@ -359,6 +357,10 @@ def interface( than to add and apply further decorators. Args: + input_requirements: A constraints tree that validates function inputs. + If provided and input validation is enabled, these constraints are checked at runtime. + output_guarantees: A constraints tree describing guarantees about the function output. + If provided, these are used for documentation and tooling. convert_to_tensors: If False or an empty set, no argument types will be converted. If True, all arguments with the `TensorLike` or `ShapeLike` annotations will be converted into `Tensor`s or, whenever possible, `DimensionSize`. If the argument diff --git a/tripy/nvtripy/utils/utils.py b/tripy/nvtripy/utils/utils.py index 98ca78d07..26190b1da 100644 --- a/tripy/nvtripy/utils/utils.py +++ b/tripy/nvtripy/utils/utils.py @@ -333,9 +333,24 @@ def gen_uid(inputs=None, outputs=None): ## ## Functions ## +def _get_signature_cache_key(func): + try: + hash(func) + return func + except TypeError: + call = getattr(func, "__call__", None) + if call is not None: + try: + hash(call) + return call + except TypeError: + pass + return func.__class__ + + @functools.lru_cache(maxsize=None) -def _get_signature_info(func): - signature = inspect.signature(func) +def _get_signature_info(signature_source): + signature = inspect.signature(signature_source) arg_names = [] varargs_name = None for name, param in signature.parameters.items(): @@ -354,7 +369,7 @@ def get_positional_args_with_names(func, *args) -> Tuple[List[Tuple[str, Any]], # In the case of variadic positional arguments, we cannot determine names, so we use # None instead. To assist in further processing, this function also returns the name # and start index of the variadic args in a pair if present (None if not). - _, arg_names, varargs_name = _get_signature_info(func) + _, arg_names, varargs_name = _get_signature_info(_get_signature_cache_key(func)) arg_names = list(arg_names) # For all variadic positional arguments, assign the name of the variadic group. @@ -374,7 +389,7 @@ def merge_function_arguments( all_args, var_arg_info = get_positional_args_with_names(func, *args) all_args.extend(kwargs.items()) - signature, _, _ = _get_signature_info(func) + signature, _, _ = _get_signature_info(_get_signature_cache_key(func)) provided_arg_names = {name for name, _ in all_args} omitted_args = [] diff --git a/tripy/tests/conftest.py b/tripy/tests/conftest.py index 495e15aea..515d6847a 100644 --- a/tripy/tests/conftest.py +++ b/tripy/tests/conftest.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -84,15 +84,24 @@ def add_plugin_aot_impl( ) compiled_kernel = triton.compile(src) + metadata = compiled_kernel.metadata + if isinstance(metadata, dict): + kernel_name = metadata["name"] + num_warps = metadata["num_warps"] + shared_mem = metadata.get("shared", 0) + else: + kernel_name = metadata.name + num_warps = metadata.num_warps + shared_mem = metadata.shared N = inp0.shape_expr.numel() launch_params = trtp.KernelLaunchParams() launch_params.grid_x = trtp.cdiv(N, block_size) - launch_params.block_x = compiled_kernel.metadata.num_warps * 32 - launch_params.shared_mem = compiled_kernel.metadata.shared + launch_params.block_x = num_warps * 32 + launch_params.shared_mem = shared_mem extra_args = trtp.SymIntExprs(1) extra_args[0] = trtp.SymInt32(N) - return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args + return kernel_name, compiled_kernel.asm["ptx"], launch_params, extra_args diff --git a/tripy/tests/frontend/constraints/test_logic.py b/tripy/tests/frontend/constraints/test_logic.py index d0b27e903..e5b2521f9 100644 --- a/tripy/tests/frontend/constraints/test_logic.py +++ b/tripy/tests/frontend/constraints/test_logic.py @@ -198,7 +198,7 @@ def test_call(self): assert constraint([("param1", 5), ("param2", 5)]) result = constraint([("param1", 5), ("param2", 10)]) assert not result - assert "'param1' to be equal to 'param2' (but it was '5')" in result.error_details + assert "'param1' to be equal to 'param2' (but it was '5' while 'param2' was '10')" in result.error_details def test_str(self): assert str(Equal(GetInput("param1"), GetInput("param2"))) == "param1 == param2"