diff --git a/.github/workflows/tripy-l0.yml b/.github/workflows/tripy-l0.yml index af7db1628..476a097ba 100644 --- a/.github/workflows/tripy-l0.yml +++ b/.github/workflows/tripy-l0.yml @@ -11,6 +11,7 @@ env: REGISTRY: ghcr.io DEFAULT_IMAGE: ghcr.io/nvidia/tensorrt-incubator/tripy:latest NEW_TEST_IMAGE: test-image:latest + HF_TOKEN: ${{ secrets.HF_TOKEN }} jobs: @@ -58,7 +59,7 @@ jobs: uses: addnab/docker-run-action@v3 with: image: ${{ env.l0_image }} - options: --gpus all -v ${{ github.workspace }}/tripy:/tripy + options: --gpus all -v ${{ github.workspace }}/tripy:/tripy -e HF_TOKEN=${{ env.HF_TOKEN }} run: | python3 docs/generate_rsts.py sphinx-build build/doc_sources build/docs -c docs/ -j 4 -W -n @@ -67,7 +68,7 @@ jobs: uses: addnab/docker-run-action@v3 with: image: ${{ env.l0_image }} - options: --gpus all -v ${{ github.workspace }}/tripy:/tripy + options: --gpus all -v ${{ github.workspace }}/tripy:/tripy -e HF_TOKEN=${{ env.HF_TOKEN }} run: | pytest --cov=nvtripy/ --cov-config=.coveragerc tests/ -v -m "not l1" -n 4 --durations=15 --ignore tests/performance @@ -75,7 +76,7 @@ jobs: uses: addnab/docker-run-action@v3 with: image: ${{ env.l0_image }} - options: --gpus all -v ${{ github.workspace }}/tripy:/tripy + options: --gpus all -v ${{ github.workspace }}/tripy:/tripy -e HF_TOKEN=${{ env.HF_TOKEN }} run: | pytest tests/performance -v -m "not l1" --benchmark-warmup=on --benchmark-json benchmark.json diff --git a/.github/workflows/tripy-l1.yml b/.github/workflows/tripy-l1.yml index 650dd183f..bf90d5d9e 100644 --- a/.github/workflows/tripy-l1.yml +++ b/.github/workflows/tripy-l1.yml @@ -18,6 +18,8 @@ concurrency: jobs: l1-test: runs-on: tripy-self-hosted + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} container: image: ghcr.io/nvidia/tensorrt-incubator/tripy:latest volumes: diff --git a/tripy/.devcontainer/devcontainer.json b/tripy/.devcontainer/devcontainer.json new file mode 100644 index 000000000..a847d8c5e --- /dev/null +++ b/tripy/.devcontainer/devcontainer.json @@ -0,0 +1,50 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/python +{ + "name": "Tripy", + "build": { + "context": "${localWorkspaceFolder}", + "dockerfile": "${localWorkspaceFolder}/Dockerfile", + "args": { + "username": "${localEnv:USER}" + } + }, + "workspaceMount": "source=${localWorkspaceFolder}/..,target=/workspaces/TensorRT-Incubator,type=bind,consistency=cached", + "workspaceFolder": "/workspaces/TensorRT-Incubator/tripy", + "runArgs": [ + "--gpus", + "all", + "-it", + "--cap-add=SYS_PTRACE" + ], + "mounts": [ + "source=${localEnv:HOME}${localEnv:USERPROFILE},target=/home/${localEnv:USER},type=bind,consistency=cached" + ], + "remoteEnv": { + "SHELL": "${localEnv:SHELL:/bin/bash}", + "ZSH": "/home/${localEnv:USER}/.oh-my-zsh", + "PYTHONPATH": "/workspaces/TensorRT-Incubator/tripy:${localEnv:PYTHONPATH}", + "PATH": "/usr/local/bin/:${localEnv:PATH}" + }, + "remoteUser": "${localEnv:USER}", + "forwardPorts": [ + 8080 + ], + "customizations": { + "vscode": { + "extensions": [ + "ms-python.black-formatter", + "lextudio.restructuredtext", + "github.vscode-github-actions", + "ms-python.isort", + "ms-toolsai.jupyter", + "ms-toolsai.vscode-jupyter-cell-tags", + "ms-toolsai.jupyter-renderers", + "llvm-vs-code-extensions.vscode-mlir", + "ms-python.python", + "ms-python.vscode-pylance", + "eamodio.gitlens" + ] + } + } +} diff --git a/tripy/CONTRIBUTING.md b/tripy/CONTRIBUTING.md index 85490044f..790a1fdbe 100644 --- a/tripy/CONTRIBUTING.md +++ b/tripy/CONTRIBUTING.md @@ -37,6 +37,8 @@ Thanks for your interest in contributing to Tripy! docker run --gpus all -it --cap-add=SYS_PTRACE -p 8080:8080 -v $(pwd):/tripy/ --rm tripy:latest ``` + - If you are using Visual Studio Code, you can alternatively use the included `.devcontainer` configuration. + 3. **[Optional]** Run a sanity check in the container: ```bash diff --git a/tripy/Dockerfile b/tripy/Dockerfile index cd9a796a4..8151fbbd4 100644 --- a/tripy/Dockerfile +++ b/tripy/Dockerfile @@ -9,16 +9,17 @@ ENTRYPOINT ["/bin/bash"] # Setup user account ARG uid=1000 ARG gid=1000 +ARG username=trtuser ENV DEBIAN_FRONTEND=noninteractive -RUN groupadd -r -f -g ${gid} trtuser && \ - useradd -o -r -l -u ${uid} -g ${gid} -ms /bin/bash trtuser && \ - usermod -aG sudo trtuser && \ - echo 'trtuser:nvidia' | chpasswd && \ - mkdir -p /workspace && chown trtuser /workspace +RUN groupadd -r -f -g ${gid} ${username} && \ + useradd -o -r -l -u ${uid} -g ${gid} -ms /bin/bash ${username} && \ + usermod -aG sudo ${username} && \ + echo "${username}:nvidia" | chpasswd && \ + mkdir -p /workspace && chown ${username} /workspace RUN apt-get update && \ - apt-get install -y sudo python3 python3-pip gdb git wget curl && \ + apt-get install -y sudo python3 python3-pip gdb git wget curl zsh && \ apt-get clean && \ python3 -m pip install --upgrade pip diff --git a/tripy/docs/README.md b/tripy/docs/README.md index a107706ed..5b886a30f 100644 --- a/tripy/docs/README.md +++ b/tripy/docs/README.md @@ -45,17 +45,23 @@ which specifies doc metadata for each API (e.g. location). - Docstring must include *at least* **one [code example](#code-examples)**. - If the function accepts `tp.Tensor`s, must indicate **data type constraints** - with the [`wrappers.interface`](../nvtripy/utils/wrappers.py) decorator. + with the [`wrappers.interface`](../nvtripy/frontend/wrappers.py) decorator. **Example:** ```py +from nvtripy import export +from nvtripy.frontend import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int32", "int64", "bool", "int8"], - }, + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def relu(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" @@ -167,3 +173,72 @@ Code blocks in docstrings/guides are **preprocessed**: - **Include** only specific variables: `# doc: print-locals ...` - **Exclude** *specific* variables: `# doc: no-print-locals ...` - **Exclude** *all* variables: `# doc: no-print-locals` (with no arguments). + + +## Documentation Philosophy: Write Less Documentation + +> "I didn't have time to write a short letter, so I wrote a long one instead." - Mark Twain + +How much documentation do you want to read? The answer is **none**! + +- **Best Case:** Make docs unnecessary with an intuitive API and clear errors. + +This is not always possible; sometimes, we need to write docs. + +- **Problem:** We don't think enough about *what* *precisely* we want to convey. + +- **Suggestions**: Write discoverable, concise, but complete documentation. + + - **Highlight key points** but make it easy to find details. + + - Use bullets and **bold** to break up monotony. + Paragraphs are *so* 2024. + + - Leverage the medium - pictures, charts, emojis, markup. + We are not using printing presses! + + - Forget the rules: use contractions, don't spell out numbers, etc. + + - **Tip:** Write like we're paying for every syllable! + If it's too wordy to say, it's too wordy to write. + +Writing *concisely* forces you to think about what's **signal** and what's **noise**. + +Below are examples from previous versions of Tripy documentation that was improved. + +### Example 1 + +> One important point is that Tripy uses a lazy evaluation model; that is, no computation is performed until a value is actually needed. + +* **Tip:** Ask: "What is this really saying?" + +> Tensors are evaluated only when they're used. + +### Example 2 + + +> ### **Eager Mode: How Does It Work?** +> +> If you've used TensorRT before, you may know that it does not support an eager mode. +> In order to provide eager mode support in Tripy, we actually need to compile the graph under the hood. +> +> Although we employ several tricks to make compile times faster when using eager mode, we do still need to compile, +> and so eager mode will likely be slower than other comparable frameworks. +> +> Consequently, we suggest that you use eager mode primarily for debugging and compiled mode for deployments. + +**Problem**: We must sift through filler to find key points. + +**Ask**: + +- **"What is the ONE most important takeaway?"** + *Eager mode is only for debugging.* + +- **"What questions does this raise?" - Why?** + *Tripy always compiles since TensorRT doesn't have eager mode.* + +Make the **one key point** stand out so skimmers can spot it: + +> **Best Practice:** Use **eager mode** only for **debugging**; compile for deployment. +> +> **Why:** Eager mode internally compiles the graph (slow!) since TensorRT doesn't have eager execution. diff --git a/tripy/docs/post0_developer_guides/00-architecture.md b/tripy/docs/post0_developer_guides/00-architecture.md index 055dd0b3b..880cf5441 100644 --- a/tripy/docs/post0_developer_guides/00-architecture.md +++ b/tripy/docs/post0_developer_guides/00-architecture.md @@ -76,7 +76,7 @@ and various operations, e.g. {class}`nvtripy.resize`. :::{admonition} Info Most operations are decorated with: 1. [`@export.public_api`](source:/nvtripy/export.py): Enables documentation, type checking, and overloading. -2. [`@wrappers.interface`](source:/nvtripy/utils/wrappers.py): Enforces (and generates tests for) data type constraints. +2. [`@wrappers.interface`](source:/nvtripy/frontend/wrappers.py): Enforces (and generates tests for) data type constraints. ::: Operations are **lazily evaluated**. diff --git a/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md b/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md index bd354d582..5a1dc87ea 100644 --- a/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md +++ b/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md @@ -129,13 +129,15 @@ from typing import Tuple from nvtripy import export from nvtripy.trace.ops.topn import TopN -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.frontend.ops import utils as op_utils +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: ["T1", "T2"]}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32", "int64"], "T2": ["int32"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64]), + output_guarantees=(GetReturn(0).dtype == GetInput("input").dtype) & (GetReturn(1).dtype == dt.int32), ) def topn(input: "nvtripy.Tensor", n: int, dim: int) -> Tuple["nvtripy.Tensor", "nvtripy.Tensor"]: # See docs/README.md for more information on how to write docstrings diff --git a/tripy/docs/pre0_user_guides/02-quantization.md b/tripy/docs/pre0_user_guides/02-quantization.md index 63496c09b..e33ac29fb 100644 --- a/tripy/docs/pre0_user_guides/02-quantization.md +++ b/tripy/docs/pre0_user_guides/02-quantization.md @@ -17,7 +17,7 @@ explains quantization in more detail. ## Post-Training Quantization With ModelOpt If the model was not trained with quantization-aware training (QAT), we can use -[TensorRT ModelOpt](https://nvidia.github.io/TensorRT-Model-Optimizer/index.html) +[TensorRT ModelOpt](https://nvidia.github.io/Model-Optimizer/index.html) to do **calibration** to determine scaling factors. :::{admonition} Info @@ -82,7 +82,7 @@ Let's calibrate a GPT model: ``` 3. Run calibration to replace linear layers with - [`QuantLinear`](https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.nn.modules.quant_linear.html#modelopt.torch.quantization.nn.modules.quant_linear.QuantLinear), + [`QuantLinear`](https://nvidia.github.io/Model-Optimizer/reference/generated/modelopt.torch.quantization.nn.modules.quant_linear.html#modelopt.torch.quantization.nn.modules.quant_linear.QuantLinear), which contain calibration information: ```py diff --git a/tripy/examples/nanogpt/README.md b/tripy/examples/nanogpt/README.md index 8826a43d7..191497dae 100644 --- a/tripy/examples/nanogpt/README.md +++ b/tripy/examples/nanogpt/README.md @@ -40,7 +40,7 @@ This example implements a [NanoGPT model](https://github.com/karpathy/nanoGPT) u ### Running with Quantization [`quantization.py`](./quantization.py), uses -[NVIDIA TensorRT Model Optimizer](https://nvidia.github.io/TensorRT-Model-Optimizer/getting_started/1_overview.html) +[NVIDIA TensorRT Model Optimizer](https://nvidia.github.io/Model-Optimizer/getting_started/1_overview.html) to quantize the pytorch model. `load_quant_weights_from_hf` in [`weight_loader.py`](./weight_loader.py) converts the quantization diff --git a/tripy/nvtripy/backend/api/compile.py b/tripy/nvtripy/backend/api/compile.py index d9296f53a..befc2386f 100644 --- a/tripy/nvtripy/backend/api/compile.py +++ b/tripy/nvtripy/backend/api/compile.py @@ -248,7 +248,7 @@ def process_arg(name, arg): compiled_arg_names = [] new_args = [] - positional_arg_info, variadic_info = utils.utils.get_positional_arg_names(func, *args) + positional_arg_info, variadic_info = utils.utils.get_positional_args_with_names(func, *args) varargs_name = None varargs_index = None diff --git a/tripy/nvtripy/config.py b/tripy/nvtripy/config.py index ba9d8d10c..a6349306e 100644 --- a/tripy/nvtripy/config.py +++ b/tripy/nvtripy/config.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -45,12 +45,12 @@ )(os.path.join(tempfile.gettempdir(), "tripy-cache")) """Path to a timing cache file that can be used to speed up compilation time.""" -enable_dtype_checking: bool = export.public_api( +enable_input_validation: bool = export.public_api( document_under="config.rst", module=sys.modules[__name__], - symbol="enable_dtype_checking", + symbol="enable_input_validation", )(True) -"""Whether to enable data type checking in API functions.""" +"""Whether to enable input validation in API functions.""" extra_error_information: List[str] = export.public_api( document_under="config.rst", diff --git a/tripy/nvtripy/frontend/constraints/__init__.py b/tripy/nvtripy/frontend/constraints/__init__.py new file mode 100644 index 000000000..e2650a51d --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/__init__.py @@ -0,0 +1,31 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from nvtripy.frontend.constraints.base import Constraints +from nvtripy.frontend.constraints.doc_str import doc_str +from nvtripy.frontend.constraints.fetcher import Fetcher, GetDataType, GetInput, GetReturn, ValueFetcher +from nvtripy.frontend.constraints.logic import ( + AlwaysFalse, + AlwaysTrue, + And, + Equal, + If, + Logic, + NotEqual, + NotOneOf, + OneOf, + Or, +) diff --git a/tripy/nvtripy/frontend/constraints/base.py b/tripy/nvtripy/frontend/constraints/base.py new file mode 100644 index 000000000..c254ea177 --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/base.py @@ -0,0 +1,159 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +The constraints system has two purposes: + +1. Imposing input requirements. +2. Describing output guarantees. + +Constraints are specified by composing one or more `Constraints` subclasses: + +```py +constraint = And( + Equal(GetDataType(GetInput("input0")), GetInput("dtype")), + Equal(GetDataType(GetInput("input1")), GetInput("dtype")), + ) +``` + +We also override several bitwise operators and properties to provide a convenient shorthand. +For example, the above can be written as: + +```py +constraint = (GetInput("input0").dtype == GetInput("dtype")) & (GetInput("input1").dtype == GetInput("dtype")) +``` + +The constraints class also provides a pattern matcher. +For example, we may want to find all constraints that check the data type of an input (`None` is a wildcard). + +```py +matches = constraint.find(Equal(GetDataType(GetInput), None)) +``` +""" + +from abc import ABC, abstractmethod +from typing import List, Optional + + +class Constraints(ABC): + """ + Base class for the entire constraints system. + """ + + def __init__(self): + self._info: Optional[str] = None + + @abstractmethod + def doc_str(self) -> str: + """ + Returns a string representation for use in documentation. + """ + ... + + def get_children(self) -> List["Constraints"]: + children = [] + for attr_value in vars(self).values(): + if isinstance(attr_value, Constraints): + children.append(attr_value) + elif isinstance(attr_value, (list, tuple)): + children.extend(v for v in attr_value if isinstance(v, Constraints)) + return children + + def find(self, pattern: "Constraints", skip_within: Optional[type] = None) -> List["Constraints"]: + """ + Find all constraints in the tree that match the given pattern. + + Performs a depth-first search through the constraint tree to find all + constraints that structurally match the given pattern, using the current + constraint as the root node. + + Args: + pattern: The pattern to search for (e.g., Equal(GetDataType, GetDataType)). + Use None as a wildcard to match anything. + skip_within: Optional constraint type. If provided, the search will not recurse + into the children of constraints of this type. This allows skipping + specific nested structures (e.g., skip_within=If to avoid searching + inside If condition branches). + + Returns: + A list of all matching constraints found in the tree. + + Example: + pattern = Equal(GetDataType(TensorFetcher), None) # None matches any second argument + matches = constraint_tree.find(pattern) + + # Skip searching inside If constraints: + matches = constraint_tree.find(pattern, skip_within=If) + """ + + def matches_pattern(pattern: Constraints, constraint: Constraints) -> bool: + # None is a wildcard that matches anything + if pattern is None: + return True + + if isinstance(pattern, type): + return isinstance(constraint, pattern) + + if type(pattern) != type(constraint): + return False + + # Need to index into sequences rather than comparing directly since there may be patterns in the sequence. + if isinstance(pattern, (list, tuple)) and isinstance(constraint, (list, tuple)): + if len(pattern) != len(constraint): + return False + return all(matches_pattern(p_val, c_val) for p_val, c_val in zip(pattern, constraint)) + + if not isinstance(pattern, Constraints): + return pattern == constraint + + # Compare attributes + pattern_attrs = vars(pattern) + constraint_attrs = vars(constraint) + + for key, pattern_value in pattern_attrs.items(): + if key not in constraint_attrs: + return False + + constraint_value = constraint_attrs[key] + + if not matches_pattern(pattern_value, constraint_value): + return False + + return True + + matches = [] + + if matches_pattern(pattern, self): + matches.append(self) + + # Skip recursing into children if this constraint matches skip_within type + if skip_within is not None and isinstance(self, skip_within): + return matches + + # Recursively search children + for child in self.get_children(): + matches.extend(child.find(pattern, skip_within=skip_within)) + + return matches + + def info(self, message: str) -> "Constraints": + """ + Sets additional information about this constraint. + For example, this might express the constraint in natural language or explain why it exists. + """ + self._info = message + return self diff --git a/tripy/nvtripy/frontend/constraints/doc_str.py b/tripy/nvtripy/frontend/constraints/doc_str.py new file mode 100644 index 000000000..fee352771 --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/doc_str.py @@ -0,0 +1,38 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any + +from nvtripy.frontend.constraints.base import Constraints + + +def doc_str(obj: Any) -> str: + """ + Returns a string representation of an object for use in the public documentation. + """ + from nvtripy.common.datatype import dtype as tp_dtype + + if obj is None: + return "``None``" + + if isinstance(obj, tp_dtype): + return f":class:`{obj.name}`" + + if isinstance(obj, Constraints): + return obj.doc_str() + + assert False, f"Unsupported object type for doc string generation: {type(obj)}. Please add handling here!" diff --git a/tripy/nvtripy/frontend/constraints/fetcher.py b/tripy/nvtripy/frontend/constraints/fetcher.py new file mode 100644 index 000000000..884cf6e15 --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/fetcher.py @@ -0,0 +1,142 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import abstractmethod +import numbers +from typing import Any, List, Optional, Sequence, Tuple + +from nvtripy.common import datatype +from nvtripy.common.datatype import dtype as tp_dtype +from nvtripy.common.exception import raise_error +from nvtripy.frontend.constraints.base import Constraints + + +class Fetcher(Constraints): + """ + Fetches a value based on the function parameters or return value. + """ + + @abstractmethod + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: ... + + def __eq__(self, other: "Fetcher") -> "Equal": + from nvtripy.frontend.constraints.logic import Equal + + return Equal(self, other) + + def __ne__(self, other: "Fetcher") -> "NotEqual": + from nvtripy.frontend.constraints.logic import NotEqual + + return NotEqual(self, other) + + +class ValueFetcher(Fetcher): + @property + def dtype(self) -> "GetDataType": + return GetDataType(self) + + +class GetInput(ValueFetcher): + def __init__(self, name: str): + super().__init__() + self.name = name + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: + for name, value in args: + if name == self.name: + return value + assert False, f"Input '{self.name}' not found in arguments." + + def __str__(self): + return self.name + + def doc_str(self) -> str: + return f"``{self.name}``" + + +class GetReturn(ValueFetcher): + def __init__(self, index: int): + super().__init__() + self.index = index + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: + assert returns is not None, "No return values available." + return returns[self.index] + + def __str__(self): + return f"return[{self.index}]" + + def doc_str(self) -> str: + return f"``return[{self.index}]``" + + +class GetDataType(Fetcher): + def __init__(self, value_fetcher: ValueFetcher): + super().__init__() + self.value_fetcher = value_fetcher + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: + from nvtripy.frontend.tensor import Tensor + + def get_arg_dtype(arg: Any) -> tp_dtype: + if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)): + arg_dtypes = [get_arg_dtype(elem) for elem in arg] + + if len(arg_dtypes) == 0: + raise_error( + f"Could not determine data type of {self.value_fetcher}", + [ + "Empty sequence argument.\n", + f"For parameter: '{self.value_fetcher}', the sequence must contain at least one element.", + ], + ) + + if len(set(arg_dtypes)) != 1: + raise_error( + f"Could not determine data type of {self.value_fetcher}", + [ + f"Mismatched data types in sequence argument.\n", + f"For parameter: '{self.value_fetcher}', all arguments must have the same data type, but got: " + f"{arg_dtypes}", + ], + ) + arg_dtype = arg_dtypes[0] + elif isinstance(arg, Tensor): + arg_dtype = arg.dtype + elif isinstance(arg, tp_dtype): + arg_dtype = arg + elif isinstance(arg, bool): + arg_dtype = datatype.bool + elif isinstance(arg, numbers.Integral): + arg_dtype = datatype.int32 if datatype.INT32_MIN <= arg <= datatype.INT32_MAX else datatype.int64 + elif isinstance(arg, numbers.Real): + arg_dtype = datatype.float32 + else: + raise_error( + f"Could not determine data type of {self.value_fetcher}", + [f"Expected a tensor or data type argument for {self.value_fetcher}, but got: {arg}"], + ) + return arg_dtype + + tensor = self.value_fetcher(args, returns) + return get_arg_dtype(tensor) + + def __str__(self): + return f"{self.value_fetcher}.dtype" + + def doc_str(self) -> str: + # Intentionally do not use doc_str() on the value_fetcher so we can wrap it in backticks correctly. + return f"``{self.value_fetcher}.dtype``" diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py new file mode 100644 index 000000000..31f5a7a29 --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -0,0 +1,291 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import abstractmethod +from textwrap import indent +from typing import Any, List, Optional, Sequence, Tuple + +from nvtripy.frontend.constraints.base import Constraints +from nvtripy.frontend.constraints.doc_str import doc_str +from nvtripy.frontend.constraints.fetcher import Fetcher +from nvtripy.utils.result import Result + + +class Logic(Constraints): + """ + Represents logical operations on constraints. + """ + + # When the constraint is not met, the error details should complete the sentence: "Expected ..." + @abstractmethod + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: ... + + @abstractmethod + def inverse(self) -> "Logic": + """ + Returns the logical inverse of this constraint. + """ + ... + + def __and__(self, other: "Logic") -> "Logic": + if isinstance(self, And): + return And(*self.constraints, other) + elif isinstance(other, And): + return And(self, *other.constraints) + return And(self, other) + + def __or__(self, other: "Logic") -> "Logic": + if isinstance(self, Or): + return Or(*self.constraints, other) + elif isinstance(other, Or): + return Or(self, *other.constraints) + return Or(self, other) + + def __invert__(self) -> "Logic": + return self.inverse() + + +class AlwaysTrue(Logic): + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + return Result.ok() + + def __str__(self): + return "true" + + def doc_str(self) -> str: + return "true" + + def inverse(self) -> "Logic": + return AlwaysFalse() + + +class AlwaysFalse(Logic): + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + return Result.err(["true"]) + + def __str__(self): + return "false" + + def doc_str(self) -> str: + return "false" + + def inverse(self) -> "Logic": + return AlwaysTrue() + + +class OneOf(Logic): + def __init__(self, fetcher: Fetcher, options: Optional[Sequence[Any]]): + super().__init__() + self.fetcher = fetcher + # Need to convert generator expressions so we can use them more than once. + # `None` is allowed to support pattern matching wildcards. + self.options = list(options) if options is not None else None + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + if self.options is None: + raise_error("OneOf constraint cannot be evaluated with wildcard options.") + value = self.fetcher(args, returns) + if value in self.options: + return Result.ok() + + return Result.err([f"'{self.fetcher}' to be one of {self.options} (but it was '{value}')"]) + + def __str__(self): + return f"{self.fetcher} is one of {self.options}" + + def doc_str(self) -> str: + if self.options is None: + return f"{doc_str(self.fetcher)} is one of [*]" + return f"{doc_str(self.fetcher)} is one of [{', '.join(f'{doc_str(opt)}' for opt in self.options)}]" + + def inverse(self) -> "Logic": + return NotOneOf(self.fetcher, self.options) + + +class NotOneOf(Logic): + def __init__(self, fetcher: Fetcher, options: Sequence[Any]): + super().__init__() + self.fetcher = fetcher + self.options = list(options) + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value = self.fetcher(args, returns) + if value not in self.options: + return Result.ok() + + return Result.err([f"'{self.fetcher}' to not be one of {self.options} (but it was '{value}')"]) + + def __str__(self): + return f"{self.fetcher} is not one of {self.options}" + + def doc_str(self) -> str: + return f"{doc_str(self.fetcher)} is not one of [{', '.join(f'{doc_str(opt)}' for opt in self.options)}]" + + def inverse(self) -> "Logic": + return OneOf(self.fetcher, self.options) + + +def get_val_or_call_fetcher( + fetcher_or_value: Any, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None +) -> Any: + if isinstance(fetcher_or_value, Fetcher): + return fetcher_or_value(args, returns) + return fetcher_or_value + + +class Equal(Logic): + def __init__(self, fetcher: Fetcher, fetcher_or_value: Any): + super().__init__() + self.fetcher = fetcher + self.fetcher_or_value = fetcher_or_value + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value1 = self.fetcher(args, returns) + value2 = get_val_or_call_fetcher(self.fetcher_or_value, args, returns) + + # Avoid triggering overloaded equality implementations (e.g., on Tensor) when comparing to None. + if value1 is None or value2 is None: + if value1 is value2: + return Result.ok() + elif value1 == value2: + return Result.ok() + + if isinstance(self.fetcher_or_value, Fetcher): + return Result.err( + [ + f"'{self.fetcher}' to be equal to '{self.fetcher_or_value}' " + f"(but it was '{value1}' while '{self.fetcher_or_value}' was '{value2}')" + ] + ) + + return Result.err([f"'{self.fetcher}' to be equal to '{self.fetcher_or_value}' (but it was '{value1}')"]) + + def __str__(self): + return f"{self.fetcher} == {self.fetcher_or_value}" + + def doc_str(self) -> str: + return f"{doc_str(self.fetcher)} == {doc_str(self.fetcher_or_value)}" + + def inverse(self) -> "Logic": + return NotEqual(self.fetcher, self.fetcher_or_value) + + +class NotEqual(Logic): + def __init__(self, fetcher: Fetcher, fetcher_or_value: Fetcher): + super().__init__() + self.fetcher = fetcher + self.fetcher_or_value = fetcher_or_value + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value1 = self.fetcher(args, returns) + value2 = get_val_or_call_fetcher(self.fetcher_or_value, args, returns) + + # Avoid triggering overloaded inequality implementations (e.g., on Tensor) when comparing to None. + if value1 is None or value2 is None: + if value1 is not value2: + return Result.ok() + elif value1 != value2: + return Result.ok() + + return Result.err([f"'{self.fetcher}' to be not equal to '{self.fetcher_or_value}' (but it was '{value1}')"]) + + def __str__(self): + return f"{self.fetcher} != {self.fetcher_or_value}" + + def doc_str(self) -> str: + return f"{doc_str(self.fetcher)} != {doc_str(self.fetcher_or_value)}" + + def inverse(self) -> "Logic": + return Equal(self.fetcher, self.fetcher_or_value) + + +class And(Logic): + def __init__(self, *constraints: Logic): + super().__init__() + self.constraints = constraints + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + errors = [] + for constraint in self.constraints: + result = constraint(args, returns) + if not result: + errors.extend(([" and "] if errors else []) + result.error_details) + if errors: + return Result.err(errors) + return Result.ok() + + def __str__(self): + return "(" + " and ".join(str(constraint) for constraint in self.constraints) + ")" + + def doc_str(self) -> str: + return ", **and**\n".join("- " + indent(doc_str(constraint), " ").lstrip() for constraint in self.constraints) + + def inverse(self) -> "Logic": + # De Morgan's law: not (A and B) = (not A) or (not B) + return Or(*(constraint.inverse() for constraint in self.constraints)) + + +class Or(Logic): + def __init__(self, *constraints: Logic): + super().__init__() + self.constraints = constraints + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + all_errors = [] + for constraint in self.constraints: + result = constraint(args, returns) + if result: + return Result.ok() + all_errors.extend(([" or "] if all_errors else []) + result.error_details) + return Result.err(all_errors) + + def __str__(self): + return "(" + " or ".join(str(constraint) for constraint in self.constraints) + ")" + + def doc_str(self) -> str: + return "(" + " *or* ".join(doc_str(constraint) for constraint in self.constraints) + ")" + + def inverse(self) -> "Logic": + # De Morgan's law: not (A or B) = (not A) and (not B) + return And(*(constraint.inverse() for constraint in self.constraints)) + + +class If(Logic): + def __init__(self, condition: Logic, then_branch: Logic, else_branch: Optional[Logic] = None): + super().__init__() + self.condition = condition + self.then_branch = then_branch + self.else_branch = else_branch + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + condition_result = self.condition(args, returns) + if condition_result: + return self.then_branch(args, returns) + else: + return self.else_branch(args, returns) if self.else_branch else Result.ok() + + def __str__(self): + if self.else_branch: + return f"if ({self.condition}) then ({self.then_branch}) else ({self.else_branch})" + return f"if ({self.condition}) then ({self.then_branch})" + + def doc_str(self) -> str: + if self.else_branch: + return f"{doc_str(self.then_branch)} **if** {doc_str(self.condition)}, **otherwise** {doc_str(self.else_branch)}" + return f"if {doc_str(self.condition)}, then {doc_str(self.then_branch)}" + + def inverse(self) -> "Logic": + return If(self.condition, self.then_branch.inverse(), self.else_branch.inverse() if self.else_branch else None) diff --git a/tripy/nvtripy/frontend/constraints/optimizer.py b/tripy/nvtripy/frontend/constraints/optimizer.py new file mode 100644 index 000000000..4aa36d70b --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/optimizer.py @@ -0,0 +1,108 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Iterable, List, Optional + +from nvtripy.common.datatype import DATA_TYPES +from nvtripy.frontend.constraints import AlwaysTrue, Constraints +from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput +from nvtripy.frontend.constraints.logic import OneOf + + +@dataclass(frozen=True) +class ConstraintPass(ABC): + name: str + pattern: Constraints + + def predicate(self, constraint: Constraints) -> bool: + return True + + @abstractmethod + def rewrite(self, constraint: Constraints) -> Constraints: ... + + +def _optimize_children( + constraint: Constraints, + constraint_pass: ConstraintPass, + pass_matches: List[Constraints], +) -> Constraints: + """Recursively rewrite child nodes after local rewrites have been applied.""" + for attr_name, attr_value in vars(constraint).items(): + if isinstance(attr_value, Constraints): + setattr(constraint, attr_name, _optimize_constraints(attr_value, constraint_pass, pass_matches)) + continue + + if isinstance(attr_value, (list, tuple)): + optimized_items = [] + changed = False + for item in attr_value: + if isinstance(item, Constraints): + optimized_item = _optimize_constraints(item, constraint_pass, pass_matches) + optimized_items.append(optimized_item) + changed = changed or optimized_item is not item + else: + optimized_items.append(item) + if changed: + new_value = tuple(optimized_items) if isinstance(attr_value, tuple) else optimized_items + setattr(constraint, attr_name, new_value) + + return constraint + + +def _optimize_constraints( + constraint: Constraints, + constraint_pass: ConstraintPass, + pass_matches: List[Constraints], +) -> Constraints: + """Apply passes to this node, then recurse into its children.""" + rewritten = constraint + if any(rewritten is match for match in pass_matches) and constraint_pass.predicate(rewritten): + rewritten = constraint_pass.rewrite(rewritten) + + return _optimize_children(rewritten, constraint_pass, pass_matches) + + +class DropAllDtypesOneOf(ConstraintPass): + def __init__(self): + super().__init__( + name="drop-all-dtypes-oneof", + pattern=OneOf(GetDataType(GetInput(None)), None), + ) + + def predicate(self, constraint: Constraints) -> bool: + return set(constraint.options).issuperset(set(DATA_TYPES.values())) + + def rewrite(self, constraint: Constraints) -> Constraints: + return AlwaysTrue() + + +def _default_passes() -> Iterable[ConstraintPass]: + return (DropAllDtypesOneOf(),) + + +def optimize_constraints(constraints: Optional[Constraints]) -> Optional[Constraints]: + if constraints is None: + return None + + passes = tuple(_default_passes()) + optimized = constraints + for constraint_pass in passes: + matches = optimized.find(constraint_pass.pattern) + optimized = _optimize_constraints(optimized, constraint_pass, matches) + return optimized diff --git a/tripy/nvtripy/frontend/module/batchnorm.py b/tripy/nvtripy/frontend/module/batchnorm.py index f7c1380a2..d08a2ff78 100644 --- a/tripy/nvtripy/frontend/module/batchnorm.py +++ b/tripy/nvtripy/frontend/module/batchnorm.py @@ -22,11 +22,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["num_features"]) +@constant_fields(["num_features"]) class BatchNorm(Module): r""" Applies batch normalization over an N-dimensional input tensor using precomputed statistics: @@ -105,8 +106,8 @@ def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor": Returns: A tensor of the same shape as the input. """ - from nvtripy.frontend.ops.unary.rsqrt import rsqrt from nvtripy.frontend.ops.reshape import reshape + from nvtripy.frontend.ops.unary.rsqrt import rsqrt x_shape = (1, self.num_features, *([1] * (x.rank - 2))) diff --git a/tripy/nvtripy/frontend/module/conv/base.py b/tripy/nvtripy/frontend/module/conv/base.py index 812e08fef..fcebf79f0 100644 --- a/tripy/nvtripy/frontend/module/conv/base.py +++ b/tripy/nvtripy/frontend/module/conv/base.py @@ -23,10 +23,11 @@ from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @dataclass -@utils.wrappers.constant_fields(["dtype", "padding", "stride", "groups", "dilation"]) +@constant_fields(["dtype", "padding", "stride", "groups", "dilation"]) class ConvBase(Module): r"""Base class for sharing common functionality between Conv and ConvTranspose.""" diff --git a/tripy/nvtripy/frontend/module/conv/conv.py b/tripy/nvtripy/frontend/module/conv/conv.py index 2c322fdf8..e7ce14209 100644 --- a/tripy/nvtripy/frontend/module/conv/conv.py +++ b/tripy/nvtripy/frontend/module/conv/conv.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,13 +26,17 @@ from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor from nvtripy.trace.ops.convolution import Convolution -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.frontend.constraints import GetInput, GetReturn, If, OneOf # This function is added so that we can do dtype checking. @wrappers.interface( - dtype_constraints={"input": "T1", "weight": "T1", "bias": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [datatype.float32, datatype.float16, datatype.bfloat16]) + & (GetInput("weight").dtype == GetInput("input").dtype) + & If(GetInput("bias") != None, GetInput("bias").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def convolution( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/module/conv/conv_transpose.py b/tripy/nvtripy/frontend/module/conv/conv_transpose.py index 3c5b5666e..050cb3394 100644 --- a/tripy/nvtripy/frontend/module/conv/conv_transpose.py +++ b/tripy/nvtripy/frontend/module/conv/conv_transpose.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,14 +26,18 @@ from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor from nvtripy.trace.ops.deconvolution import Deconvolution -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.frontend.constraints import GetInput, GetReturn, If, OneOf # This function is added so that we can do dtype checking. # TODO (#565): TRT supposedly supports BF16 in deconv, but actually trying to use it results in a bug. @wrappers.interface( - dtype_constraints={"input": "T1", "weight": "T1", "bias": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16"]}, + input_requirements=OneOf(GetInput("input").dtype, [datatype.float32, datatype.float16]) + & (GetInput("weight").dtype == GetInput("input").dtype) + & If(GetInput("bias") != None, GetInput("bias").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def deconvolution( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/module/embedding.py b/tripy/nvtripy/frontend/module/embedding.py index b8e5b72f0..d72e4134c 100644 --- a/tripy/nvtripy/frontend/module/embedding.py +++ b/tripy/nvtripy/frontend/module/embedding.py @@ -22,11 +22,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["dtype"]) +@constant_fields(["dtype"]) class Embedding(Module): """ A lookup table for embedding vectors of a fixed size. diff --git a/tripy/nvtripy/frontend/module/groupnorm.py b/tripy/nvtripy/frontend/module/groupnorm.py index 5f98807ca..390f07fa8 100644 --- a/tripy/nvtripy/frontend/module/groupnorm.py +++ b/tripy/nvtripy/frontend/module/groupnorm.py @@ -25,11 +25,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["num_groups", "num_channels", "dtype"]) +@constant_fields(["num_groups", "num_channels", "dtype"]) class GroupNorm(Module): r""" Applies group normalization over the input tensor: diff --git a/tripy/nvtripy/frontend/module/instancenorm.py b/tripy/nvtripy/frontend/module/instancenorm.py index 54b29a250..3e0c1a690 100644 --- a/tripy/nvtripy/frontend/module/instancenorm.py +++ b/tripy/nvtripy/frontend/module/instancenorm.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,21 +17,25 @@ from dataclasses import dataclass -from nvtripy import constants, export, utils +from nvtripy import constants, export from nvtripy.common import datatype from nvtripy.common.exception import raise_error +from nvtripy.frontend import wrappers from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter -from nvtripy.frontend.tensor import Tensor - from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields from nvtripy.trace.ops.instancenorm import InstanceNorm as InstanceNormOp +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @wrappers.interface( - dtype_constraints={"input": "T1", "weight": "T1", "bias": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [datatype.float32, datatype.float16, datatype.bfloat16]) + & (GetInput("weight").dtype == GetInput("input").dtype) + & (GetInput("bias").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def instancenorm( input: "nvtripy.Tensor", @@ -81,7 +85,7 @@ def instancenorm( @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["num_channels", "dtype", "eps"]) +@constant_fields(["num_channels", "dtype", "eps"]) class InstanceNorm(Module): r""" Applies Instance Normalization over a mini-batch of inputs: diff --git a/tripy/nvtripy/frontend/module/layernorm.py b/tripy/nvtripy/frontend/module/layernorm.py index be2627ed1..90845eb6b 100644 --- a/tripy/nvtripy/frontend/module/layernorm.py +++ b/tripy/nvtripy/frontend/module/layernorm.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,18 +21,22 @@ from nvtripy import export, utils from nvtripy.common import datatype from nvtripy.common.exception import raise_error +from nvtripy.frontend import wrappers from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter -from nvtripy.frontend.tensor import Tensor - from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields from nvtripy.trace.ops.layernorm import LayerNorm as LayerNormOp +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + @wrappers.interface( - dtype_constraints={"input": "T1", "weight": "T1", "bias": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [datatype.float32, datatype.float16, datatype.bfloat16]) + & (GetInput("weight").dtype == GetInput("input").dtype) + & (GetInput("bias").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def layernorm( input: "nvtripy.Tensor", @@ -70,7 +74,7 @@ def layernorm( @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["dtype", "normalized_shape"]) +@constant_fields(["dtype", "normalized_shape"]) class LayerNorm(Module): r""" Applies layer normalization over the input tensor: diff --git a/tripy/nvtripy/frontend/module/linear.py b/tripy/nvtripy/frontend/module/linear.py index 6c3c06eba..919c9e1f1 100644 --- a/tripy/nvtripy/frontend/module/linear.py +++ b/tripy/nvtripy/frontend/module/linear.py @@ -23,11 +23,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter, OptionalParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["dtype", "quant_dtype"]) +@constant_fields(["dtype", "quant_dtype"]) class Linear(Module): r""" Applies a linear transformation to the input: @@ -117,11 +118,11 @@ def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor": A tensor of shape :math:`[*, \text{out_features}]`. """ from nvtripy.common.exception import raise_error - from nvtripy.frontend.tensor import Tensor - from nvtripy.frontend.ops.transpose import transpose - from nvtripy.frontend.ops.unsqueeze import unsqueeze from nvtripy.frontend.ops.dequantize import dequantize from nvtripy.frontend.ops.quantize import quantize + from nvtripy.frontend.ops.transpose import transpose + from nvtripy.frontend.ops.unsqueeze import unsqueeze + from nvtripy.frontend.tensor import Tensor if self.quant_dtype is not None: if isinstance(self.input_scale, Tensor): diff --git a/tripy/nvtripy/frontend/ops/allclose.py b/tripy/nvtripy/frontend/ops/allclose.py index 28ac163ed..f89ece4b3 100644 --- a/tripy/nvtripy/frontend/ops/allclose.py +++ b/tripy/nvtripy/frontend/ops/allclose.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,12 +16,16 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", "other": "T1"}, dtype_variables={"T1": ["float32", "float16", "bfloat16"]} + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("other").dtype == GetInput("input").dtype), ) def allclose(input: "nvtripy.Tensor", other: "nvtripy.Tensor", rtol: float = 1e-05, atol: float = 1e-08) -> bool: r""" diff --git a/tripy/nvtripy/frontend/ops/arange.py b/tripy/nvtripy/frontend/ops/arange.py index e20be7917..24317d557 100644 --- a/tripy/nvtripy/frontend/ops/arange.py +++ b/tripy/nvtripy/frontend/ops/arange.py @@ -16,27 +16,29 @@ from typing import Union from nvtripy import export -from nvtripy.common import datatype +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops.cast import cast from nvtripy.frontend.ops.reshape import reshape from nvtripy.trace.ops.linspace import Linspace -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("dtype"), + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), ) def arange( start: Union[numbers.Number, "nvtripy.DimensionSize"], stop: Union[numbers.Number, "nvtripy.DimensionSize"], step: Union[numbers.Number, "nvtripy.DimensionSize"] = 1, - dtype: "nvtripy.dtype" = datatype.float32, + dtype: "nvtripy.dtype" = dt.float32, ) -> "nvtripy.Tensor": r""" Returns a 1D tensor containing a sequence of numbers in the half-open interval @@ -96,13 +98,14 @@ def arange( @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("dtype"), + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), ) def arange( - stop: Union[numbers.Number, "nvtripy.DimensionSize"], dtype: "nvtripy.dtype" = datatype.float32 + stop: Union[numbers.Number, "nvtripy.DimensionSize"], dtype: "nvtripy.dtype" = dt.float32 ) -> "nvtripy.Tensor": r""" Returns a 1D tensor containing a sequence of numbers in the half-open interval diff --git a/tripy/nvtripy/frontend/ops/binary/add.py b/tripy/nvtripy/frontend/ops/binary/add.py index f673c452c..f6e7dd008 100644 --- a/tripy/nvtripy/frontend/ops/binary/add.py +++ b/tripy/nvtripy/frontend/ops/binary/add.py @@ -12,18 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Add from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__add__") @register_tensor_method("__radd__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __add__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/div.py b/tripy/nvtripy/frontend/ops/binary/div.py index 125b42b46..812e3b523 100644 --- a/tripy/nvtripy/frontend/ops/binary/div.py +++ b/tripy/nvtripy/frontend/ops/binary/div.py @@ -12,17 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Div from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__truediv__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __truediv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @@ -51,8 +56,11 @@ def __truediv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @register_tensor_method("__rtruediv__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __rtruediv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/equal.py b/tripy/nvtripy/frontend/ops/binary/equal.py index 55c899b36..32e38fe5a 100644 --- a/tripy/nvtripy/frontend/ops/binary/equal.py +++ b/tripy/nvtripy/frontend/ops/binary/equal.py @@ -12,20 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Equal from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__eq__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __eq__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/floor_div.py b/tripy/nvtripy/frontend/ops/binary/floor_div.py index 4ddbe0a5b..0e71e2a1b 100644 --- a/tripy/nvtripy/frontend/ops/binary/floor_div.py +++ b/tripy/nvtripy/frontend/ops/binary/floor_div.py @@ -12,17 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import FloorDiv from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__floordiv__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "bool", "int8", "int32", "int64"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.bool, dt.int8, dt.int32, dt.int64] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __floordiv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @@ -51,8 +56,11 @@ def __floordiv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @register_tensor_method("__rfloordiv__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "bool", "int4", "int8", "int32", "int64"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.bool, dt.int4, dt.int8, dt.int32, dt.int64] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __rfloordiv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/greater.py b/tripy/nvtripy/frontend/ops/binary/greater.py index c82ded594..e9d95b1e8 100644 --- a/tripy/nvtripy/frontend/ops/binary/greater.py +++ b/tripy/nvtripy/frontend/ops/binary/greater.py @@ -12,20 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Greater from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__gt__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __gt__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/greater_equal.py b/tripy/nvtripy/frontend/ops/binary/greater_equal.py index 6be66c9db..e4db549b6 100644 --- a/tripy/nvtripy/frontend/ops/binary/greater_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/greater_equal.py @@ -12,18 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__ge__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __ge__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/less.py b/tripy/nvtripy/frontend/ops/binary/less.py index 6495319fa..d882bf105 100644 --- a/tripy/nvtripy/frontend/ops/binary/less.py +++ b/tripy/nvtripy/frontend/ops/binary/less.py @@ -12,20 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Less from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__lt__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __lt__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/less_equal.py b/tripy/nvtripy/frontend/ops/binary/less_equal.py index 14bf35c9b..d3b4678bc 100644 --- a/tripy/nvtripy/frontend/ops/binary/less_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/less_equal.py @@ -12,18 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__le__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __le__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/logical_or.py b/tripy/nvtripy/frontend/ops/binary/logical_or.py index 2302415ce..00e9dbd7a 100644 --- a/tripy/nvtripy/frontend/ops/binary/logical_or.py +++ b/tripy/nvtripy/frontend/ops/binary/logical_or.py @@ -12,16 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import LogicalOr -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__or__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["bool"]}, + input_requirements=(GetInput("self").dtype == dt.bool) & (GetInput("other").dtype == dt.bool), + output_guarantees=GetReturn(0).dtype == dt.bool, ) def __or__(self: "nvtripy.Tensor", other: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/binary/maximum.py b/tripy/nvtripy/frontend/ops/binary/maximum.py index 05023378d..91a4a9fae 100644 --- a/tripy/nvtripy/frontend/ops/binary/maximum.py +++ b/tripy/nvtripy/frontend/ops/binary/maximum.py @@ -13,16 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Max from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"lhs": "T1", "rhs": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("lhs").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("rhs").dtype == GetInput("lhs").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("lhs").dtype, convert_to_tensors=True, ) def maximum(lhs: TensorLike, rhs: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/minimum.py b/tripy/nvtripy/frontend/ops/binary/minimum.py index 0a1954b1b..454c3016f 100644 --- a/tripy/nvtripy/frontend/ops/binary/minimum.py +++ b/tripy/nvtripy/frontend/ops/binary/minimum.py @@ -13,16 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Min from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"lhs": "T1", "rhs": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("lhs").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("rhs").dtype == GetInput("lhs").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("lhs").dtype, convert_to_tensors=True, ) def minimum(lhs: TensorLike, rhs: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/mod.py b/tripy/nvtripy/frontend/ops/binary/mod.py index f38203f43..10142484e 100644 --- a/tripy/nvtripy/frontend/ops/binary/mod.py +++ b/tripy/nvtripy/frontend/ops/binary/mod.py @@ -12,9 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers def mod_impl(lhs, rhs): @@ -23,8 +25,11 @@ def mod_impl(lhs, rhs): @register_tensor_method("__mod__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __mod__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @@ -53,8 +58,11 @@ def __mod__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @register_tensor_method("__rmod__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __rmod__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/mul.py b/tripy/nvtripy/frontend/ops/binary/mul.py index e6bcd9940..77a4b101d 100644 --- a/tripy/nvtripy/frontend/ops/binary/mul.py +++ b/tripy/nvtripy/frontend/ops/binary/mul.py @@ -12,18 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Mul from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__mul__") @register_tensor_method("__rmul__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __mul__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/not_equal.py b/tripy/nvtripy/frontend/ops/binary/not_equal.py index 6e175dc84..e48e454be 100644 --- a/tripy/nvtripy/frontend/ops/binary/not_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/not_equal.py @@ -12,18 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__ne__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == dt.bool, convert_to_tensors=True, ) def __ne__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/pow.py b/tripy/nvtripy/frontend/ops/binary/pow.py index 285e73141..cbea9bd4b 100644 --- a/tripy/nvtripy/frontend/ops/binary/pow.py +++ b/tripy/nvtripy/frontend/ops/binary/pow.py @@ -12,17 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Pow from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__pow__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int64"]}, + input_requirements=OneOf(GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int64]) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __pow__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @@ -51,8 +54,9 @@ def __pow__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @register_tensor_method("__rpow__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int64"]}, + input_requirements=OneOf(GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int64]) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __rpow__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/binary/sub.py b/tripy/nvtripy/frontend/ops/binary/sub.py index 94b51ec9e..24cf69135 100644 --- a/tripy/nvtripy/frontend/ops/binary/sub.py +++ b/tripy/nvtripy/frontend/ops/binary/sub.py @@ -12,17 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Sub from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__sub__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __sub__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @@ -51,8 +56,11 @@ def __sub__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": @register_tensor_method("__rsub__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, convert_to_tensors=True, ) def __rsub__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/cast.py b/tripy/nvtripy/frontend/ops/cast.py index 9c197af57..32a4a9994 100644 --- a/tripy/nvtripy/frontend/ops/cast.py +++ b/tripy/nvtripy/frontend/ops/cast.py @@ -17,32 +17,25 @@ from nvtripy import export -from nvtripy.common.datatype import bool as tp_bool -from nvtripy.common.datatype import float32, int8 +from nvtripy.common import datatype as dt +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.dequantize import dequantize from nvtripy.frontend.ops.quantize import quantize from nvtripy.trace.ops.cast import Cast -from nvtripy.utils import wrappers @register_tensor_method("cast") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, - dtype_exceptions=[ - {"T1": "float8", "T2": "int4"}, - {"T1": "float8", "T2": "int8"}, - {"T1": "int8", "T2": "float8"}, - {"T1": "int4", "T2": "float8"}, - {"T1": "int4", "T2": "int8"}, - {"T1": "int4", "T2": "int64"}, - ], + input_requirements=( + ((GetInput("input").dtype != dt.float8) | ~OneOf(GetInput("dtype"), [dt.int4, dt.int8])) + & ((GetInput("input").dtype != dt.int8) | (GetInput("dtype") != dt.float8)) + & ((GetInput("input").dtype != dt.int4) | ~OneOf(GetInput("dtype"), [dt.float8, dt.int8, dt.int64])) + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), ) def cast(input: "nvtripy.Tensor", dtype: "nvtripy.dtype") -> "nvtripy.Tensor": r""" @@ -79,14 +72,14 @@ def cast(input: "nvtripy.Tensor", dtype: "nvtripy.dtype") -> "nvtripy.Tensor": # If given a quantized input, dequantize before converting. If bool is the target dtype, # we do still need to quantize int8s because it compiles into an MLIR-TRT *comparison* op - if op_utils.is_quantized_dtype(input.dtype) and (input.dtype != int8 or dtype == tp_bool): - dequant_dtype = float32 + if op_utils.is_quantized_dtype(input.dtype) and (input.dtype != dt.int8 or dtype == dt.bool): + dequant_dtype = dt.float32 input = dequantize(input, 1.0, dequant_dtype) if dtype == dequant_dtype: return input - if op_utils.is_quantized_dtype(dtype) and dtype != int8: - if input.dtype != float32: - input = op_utils.create_op(Cast, [input], float32) + if op_utils.is_quantized_dtype(dtype) and dtype != dt.int8: + if input.dtype != dt.float32: + input = op_utils.create_op(Cast, [input], dt.float32) return quantize(input, 1.0, dtype) return op_utils.create_op(Cast, [input], dtype) diff --git a/tripy/nvtripy/frontend/ops/concatenate.py b/tripy/nvtripy/frontend/ops/concatenate.py index de9d95de7..730784745 100644 --- a/tripy/nvtripy/frontend/ops/concatenate.py +++ b/tripy/nvtripy/frontend/ops/concatenate.py @@ -18,18 +18,21 @@ from typing import Sequence from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.concatenate import Concatenate -from nvtripy.utils import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"tensors": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("tensors").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("tensors").dtype, ) def concatenate(tensors: Sequence["nvtripy.Tensor"], dim: int) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/copy.py b/tripy/nvtripy/frontend/ops/copy.py index 4e5c10d6a..1dcde27b8 100644 --- a/tripy/nvtripy/frontend/ops/copy.py +++ b/tripy/nvtripy/frontend/ops/copy.py @@ -22,14 +22,16 @@ from nvtripy.common.datatype import DATA_TYPES from nvtripy.common.exception import raise_error from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +from nvtripy.frontend.constraints import GetInput, GetReturn @register_tensor_method("copy") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": list(DATA_TYPES.keys())}, + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def copy(input: "nvtripy.Tensor", device: tp_device) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/cumsum.py b/tripy/nvtripy/frontend/ops/cumsum.py index c0dd76902..8b08e70c0 100644 --- a/tripy/nvtripy/frontend/ops/cumsum.py +++ b/tripy/nvtripy/frontend/ops/cumsum.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,13 +14,17 @@ # limitations under the License. from nvtripy import export from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def cumsum(input: "nvtripy.Tensor", dim: int) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/dequantize.py b/tripy/nvtripy/frontend/ops/dequantize.py index 816efa28c..2c1462bc8 100644 --- a/tripy/nvtripy/frontend/ops/dequantize.py +++ b/tripy/nvtripy/frontend/ops/dequantize.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,13 +22,18 @@ from nvtripy.common import datatype from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.dequantize import Dequantize -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/quantization") @wrappers.interface( - dtype_constraints={"input": "T1", "scale": "T2", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={"T1": ["int4", "int8", "float8"], "T2": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.int4, dt.int8, dt.float8]) + & OneOf(GetInput("scale").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("dtype") == GetInput("scale").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), convert_to_tensors={"scale"}, ) def dequantize( diff --git a/tripy/nvtripy/frontend/ops/equal.py b/tripy/nvtripy/frontend/ops/equal.py index 7d5fc3a7b..6feb3617a 100644 --- a/tripy/nvtripy/frontend/ops/equal.py +++ b/tripy/nvtripy/frontend/ops/equal.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,13 +14,18 @@ # limitations under the License. from nvtripy import export from nvtripy.common.datatype import DATA_TYPES -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", "other": "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("other").dtype == GetInput("input").dtype), ) def equal(input: "nvtripy.Tensor", other: "nvtripy.Tensor") -> bool: r""" diff --git a/tripy/nvtripy/frontend/ops/expand.py b/tripy/nvtripy/frontend/ops/expand.py index 6cced90bc..14fde14ad 100644 --- a/tripy/nvtripy/frontend/ops/expand.py +++ b/tripy/nvtripy/frontend/ops/expand.py @@ -17,11 +17,13 @@ from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.broadcast import Broadcast from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf def process_sizes(input: "nvtripy.Tensor", sizes: ShapeLike): @@ -53,10 +55,10 @@ def process_sizes(input: "nvtripy.Tensor", sizes: ShapeLike): @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, convert_to_tensors=True, conversion_preprocess_func=process_sizes, ) diff --git a/tripy/nvtripy/frontend/ops/flatten.py b/tripy/nvtripy/frontend/ops/flatten.py index a356bf123..e4db23128 100644 --- a/tripy/nvtripy/frontend/ops/flatten.py +++ b/tripy/nvtripy/frontend/ops/flatten.py @@ -15,17 +15,21 @@ import math from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("flatten") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def flatten(input: "nvtripy.Tensor", start_dim: int = 0, end_dim: int = -1) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/flip.py b/tripy/nvtripy/frontend/ops/flip.py index 16912f97b..beb771851 100644 --- a/tripy/nvtripy/frontend/ops/flip.py +++ b/tripy/nvtripy/frontend/ops/flip.py @@ -18,16 +18,18 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def flip(input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/full.py b/tripy/nvtripy/frontend/ops/full.py index 044cfe1e5..6eb409952 100644 --- a/tripy/nvtripy/frontend/ops/full.py +++ b/tripy/nvtripy/frontend/ops/full.py @@ -18,22 +18,26 @@ from typing import Optional from nvtripy import export, utils -from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf, If from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.broadcast import Broadcast from nvtripy.types import ShapeLike, TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=(GetInput("value").dtype != dt.float8) + & If( + GetInput("value").dtype == dt.int8, + GetInput("dtype") != dt.bool, + ) + & OneOf(GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool]), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), convert_to_tensors=True, ) -def full(shape: ShapeLike, value: TensorLike, dtype: "nvtripy.dtype" = datatype.float32) -> "nvtripy.Tensor": +def full(shape: ShapeLike, value: TensorLike, dtype: "nvtripy.dtype" = dt.float32) -> "nvtripy.Tensor": """ Returns a tensor of the desired shape with all values set to the specified value. @@ -55,9 +59,9 @@ def full(shape: ShapeLike, value: TensorLike, dtype: "nvtripy.dtype" = datatype. from nvtripy.frontend.ops.cast import cast value_dtype = dtype - if dtype == datatype.int8: + if dtype == dt.int8: # TODO (#580): Remove this workaround for broadcasting INT8 inputs: - value_dtype = datatype.int32 + value_dtype = dt.int32 # We avoid using the `expand` API since it does extra things that we don't need. out = op_utils.create_op(Broadcast, [cast(value, dtype=value_dtype), shape]) @@ -67,11 +71,27 @@ def full(shape: ShapeLike, value: TensorLike, dtype: "nvtripy.dtype" = datatype. @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & (GetInput("value").dtype != dt.float8) + & If( + GetInput("value").dtype == dt.int8, + If( + GetInput("dtype") != None, + GetInput("dtype") != dt.bool, + GetInput("input").dtype != dt.bool, + ), + ) + & If( + GetInput("dtype") != None, + OneOf(GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool]), + ), + output_guarantees=If( + GetInput("dtype") != None, + GetReturn(0).dtype == GetInput("dtype"), + GetReturn(0).dtype == GetInput("input").dtype, + ), ) def full_like(input: "nvtripy.Tensor", value: TensorLike, dtype: Optional["nvtripy.dtype"] = None) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/gather.py b/tripy/nvtripy/frontend/ops/gather.py index 7af68d070..083504844 100644 --- a/tripy/nvtripy/frontend/ops/gather.py +++ b/tripy/nvtripy/frontend/ops/gather.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,16 +19,20 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.gather import Gather -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", "index": "T2", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], - "T2": ["int32", "int64"], - }, + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ) + & OneOf(GetInput("index").dtype, [dt.int32, dt.int64]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def gather(input: "nvtripy.Tensor", dim: int, index: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/iota.py b/tripy/nvtripy/frontend/ops/iota.py index fa5b247f1..8b03b77c7 100644 --- a/tripy/nvtripy/frontend/ops/iota.py +++ b/tripy/nvtripy/frontend/ops/iota.py @@ -18,14 +18,15 @@ from typing import Optional from nvtripy import export, utils -from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf, If from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.linspace import Linspace from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers -def iota_impl(shape: "nvtripy.Tensor", dim: int, dtype: datatype.dtype) -> "nvtripy.Tensor": +def iota_impl(shape: "nvtripy.Tensor", dim: int, dtype: dt.dtype) -> "nvtripy.Tensor": from nvtripy.frontend.ops.cast import cast from nvtripy.frontend.tensor import Tensor @@ -42,13 +43,14 @@ def iota_impl(shape: "nvtripy.Tensor", dim: int, dtype: datatype.dtype) -> "nvtr @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("dtype"), + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), convert_to_tensors=True, ) -def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "nvtripy.Tensor": +def iota(shape: ShapeLike, dim: int = 0, dtype: dt.dtype = dt.float32) -> "nvtripy.Tensor": """ Fills an output tensor with consecutive values starting from zero along the given dimension. @@ -75,13 +77,24 @@ def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float3 @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ) + & If( + GetInput("dtype") != None, + OneOf( + GetInput("dtype"), + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + ), + output_guarantees=If( + GetInput("dtype") != None, + GetReturn(0).dtype == GetInput("dtype"), + GetReturn(0).dtype == GetInput("input").dtype, + ), ) -def iota_like(input: "nvtripy.Tensor", dim: int = 0, dtype: Optional[datatype.dtype] = None) -> "nvtripy.Tensor": +def iota_like(input: "nvtripy.Tensor", dim: int = 0, dtype: Optional[dt.dtype] = None) -> "nvtripy.Tensor": """ Returns a tensor of the same shape and data type as the input tensor, with consecutive values starting from zero along the given dimension. diff --git a/tripy/nvtripy/frontend/ops/masked_fill.py b/tripy/nvtripy/frontend/ops/masked_fill.py index 3bfb54226..13e436a42 100644 --- a/tripy/nvtripy/frontend/ops/masked_fill.py +++ b/tripy/nvtripy/frontend/ops/masked_fill.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,16 +15,20 @@ import numbers from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", "mask": "T2", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], - "T2": ["bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ) + & (GetInput("mask").dtype == dt.bool), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def masked_fill(input: "nvtripy.Tensor", mask: "nvtripy.Tensor", value: numbers.Number) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/matmul.py b/tripy/nvtripy/frontend/ops/matmul.py index 1f323987f..2532f9ac9 100644 --- a/tripy/nvtripy/frontend/ops/matmul.py +++ b/tripy/nvtripy/frontend/ops/matmul.py @@ -19,13 +19,18 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.matmul import MatrixMultiply -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("__matmul__") @wrappers.interface( - dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("other").dtype == GetInput("self").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, ) def __matmul__(self: "nvtripy.Tensor", other: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/ones.py b/tripy/nvtripy/frontend/ops/ones.py index 2abcd81cc..bc5ad5ba5 100644 --- a/tripy/nvtripy/frontend/ops/ones.py +++ b/tripy/nvtripy/frontend/ops/ones.py @@ -15,21 +15,22 @@ from typing import Optional from nvtripy import export -from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, If, OneOf from nvtripy.frontend.ops.full import full, full_like -from nvtripy.utils import wrappers @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), ) def ones( shape: "nvtripy.types.ShapeLike", - dtype: datatype.dtype = datatype.float32, + dtype: dt.dtype = dt.float32, ) -> "nvtripy.Tensor": """ Creates a Tensor of the specified shape and dtype with all elements set to 1. @@ -55,13 +56,24 @@ def ones( @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int8, dt.int32, dt.int64, dt.bool], + ) + & If( + GetInput("dtype") != None, + OneOf( + GetInput("dtype"), + [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool], + ), + ), + output_guarantees=If( + GetInput("dtype") != None, + GetReturn(0).dtype == GetInput("dtype"), + GetReturn(0).dtype == GetInput("input").dtype, + ), ) -def ones_like(input: "nvtripy.Tensor", dtype: Optional[datatype.dtype] = None) -> "nvtripy.Tensor": +def ones_like(input: "nvtripy.Tensor", dtype: Optional[dt.dtype] = None) -> "nvtripy.Tensor": """ Creates a tensor with all elements set to 1 of the same shape as the input tensor. diff --git a/tripy/nvtripy/frontend/ops/outer.py b/tripy/nvtripy/frontend/ops/outer.py index fae200134..a72333463 100644 --- a/tripy/nvtripy/frontend/ops/outer.py +++ b/tripy/nvtripy/frontend/ops/outer.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,13 +16,17 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"vec1": "T1", "vec2": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("vec1").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("vec2").dtype == GetInput("vec1").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("vec1").dtype, ) def outer(vec1: "nvtripy.Tensor", vec2: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/pad.py b/tripy/nvtripy/frontend/ops/pad.py index f58afdfd7..0d544118e 100644 --- a/tripy/nvtripy/frontend/ops/pad.py +++ b/tripy/nvtripy/frontend/ops/pad.py @@ -23,13 +23,17 @@ from nvtripy.trace.ops.shape import Shape from nvtripy.trace.ops.slice import SliceFill, SliceReflect from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bool", "int32", "int64"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bool, dt.int32, dt.int64]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def pad( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/ops/permute.py b/tripy/nvtripy/frontend/ops/permute.py index a367ce939..cf6c56fe4 100644 --- a/tripy/nvtripy/frontend/ops/permute.py +++ b/tripy/nvtripy/frontend/ops/permute.py @@ -18,18 +18,22 @@ from typing import Sequence from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.permute import Permute -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("permute") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def permute(input: "nvtripy.Tensor", perm: Sequence[int]) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/pooling/avgpool.py b/tripy/nvtripy/frontend/ops/pooling/avgpool.py index 50eb3a18d..1ec6db284 100644 --- a/tripy/nvtripy/frontend/ops/pooling/avgpool.py +++ b/tripy/nvtripy/frontend/ops/pooling/avgpool.py @@ -23,13 +23,17 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops.pooling import utils as pooling_utils from nvtripy.trace.ops.pooling import AvgPooling -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "bfloat16", "float16", "int8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.bfloat16, dt.float16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def avgpool( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/ops/pooling/maxpool.py b/tripy/nvtripy/frontend/ops/pooling/maxpool.py index c38ab13f6..d9230be33 100644 --- a/tripy/nvtripy/frontend/ops/pooling/maxpool.py +++ b/tripy/nvtripy/frontend/ops/pooling/maxpool.py @@ -21,13 +21,17 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops.pooling import utils as pooling_utils from nvtripy.trace.ops.pooling import MaxPooling -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "bfloat16", "float16", "int8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.bfloat16, dt.float16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def maxpool( input: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/ops/quantize.py b/tripy/nvtripy/frontend/ops/quantize.py index 69c062a0b..7427b246e 100644 --- a/tripy/nvtripy/frontend/ops/quantize.py +++ b/tripy/nvtripy/frontend/ops/quantize.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,13 +22,18 @@ from nvtripy.common import datatype from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.quantize import Quantize -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/quantization") @wrappers.interface( - dtype_constraints={"input": "T1", "scale": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"], "T2": ["int4", "int8", "float8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("scale").dtype == GetInput("input").dtype) + & OneOf(GetInput("dtype"), [dt.int4, dt.int8, dt.float8]), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), convert_to_tensors={"scale"}, ) def quantize( diff --git a/tripy/nvtripy/frontend/ops/reduce/all.py b/tripy/nvtripy/frontend/ops/reduce/all.py index fc3116287..a8cb8f0f8 100644 --- a/tripy/nvtripy/frontend/ops/reduce/all.py +++ b/tripy/nvtripy/frontend/ops/reduce/all.py @@ -15,14 +15,15 @@ from typing import Optional, Sequence, Union from nvtripy import export -from nvtripy.utils import wrappers -from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["bool"]}, + input_requirements=GetInput("input").dtype == dt.bool, + output_guarantees=GetReturn(0).dtype == dt.bool, ) def all( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False @@ -50,5 +51,19 @@ def all( """ from nvtripy.frontend.ops.reduce.prod import prod from nvtripy.frontend.ops.cast import cast + from nvtripy.common.exception import raise_error - return cast(prod(cast(input, dtype=datatype.int32), dim, keepdim), dtype=datatype.bool) + # Validate that input is bool - constraint system has already checked this + # but we need to enforce it at runtime when validation is disabled + if input.dtype != dt.bool: + raise_error( + f"Input must have bool dtype for all(), but got {input.dtype}.", + [ + "This function only accepts bool tensors. ", + "Note: If you need to check if all elements are non-zero, first compare with zero: ", + "tp.all(input != 0)", + ], + ) + + # Cast to int32 since prod doesn't accept bool, then cast back to bool + return cast(prod(cast(input, dtype=dt.int32), dim, keepdim), dtype=dt.bool) diff --git a/tripy/nvtripy/frontend/ops/reduce/any.py b/tripy/nvtripy/frontend/ops/reduce/any.py index 01630dd60..ddbff8f59 100644 --- a/tripy/nvtripy/frontend/ops/reduce/any.py +++ b/tripy/nvtripy/frontend/ops/reduce/any.py @@ -15,14 +15,15 @@ from typing import Optional, Sequence, Union from nvtripy import export -from nvtripy.utils import wrappers -from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["bool"]}, + input_requirements=GetInput("input").dtype == dt.bool, + output_guarantees=GetReturn(0).dtype == dt.bool, ) def any( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False @@ -50,5 +51,19 @@ def any( """ from nvtripy.frontend.ops.reduce.sum import sum from nvtripy.frontend.ops.cast import cast + from nvtripy.common.exception import raise_error - return cast(sum(cast(input, dtype=datatype.int32), dim, keepdim), dtype=datatype.bool) + # Validate that input is bool - constraint system has already checked this + # but we need to enforce it at runtime when validation is disabled + if input.dtype != dt.bool: + raise_error( + f"Input must have bool dtype for any(), but got {input.dtype}.", + [ + "This function only accepts bool tensors. ", + "Note: If you need to check if any elements are non-zero, first compare with zero: ", + "tp.any(input != 0)", + ], + ) + + # Cast to int32 since sum doesn't accept bool, then cast back to bool + return cast(sum(cast(input, dtype=dt.int32), dim, keepdim), dtype=dt.bool) diff --git a/tripy/nvtripy/frontend/ops/reduce/argmax.py b/tripy/nvtripy/frontend/ops/reduce/argmax.py index e3abdb4c0..d2adcdf85 100644 --- a/tripy/nvtripy/frontend/ops/reduce/argmax.py +++ b/tripy/nvtripy/frontend/ops/reduce/argmax.py @@ -15,15 +15,17 @@ from typing import Optional from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import arg_min_max_impl from nvtripy.trace.ops.topk import TopKMax -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32", "int64"], "T2": ["int32"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64]), + output_guarantees=GetReturn(0).dtype == dt.int32, ) def argmax(input: "nvtripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/reduce/argmin.py b/tripy/nvtripy/frontend/ops/reduce/argmin.py index 69d616162..84c9c175c 100644 --- a/tripy/nvtripy/frontend/ops/reduce/argmin.py +++ b/tripy/nvtripy/frontend/ops/reduce/argmin.py @@ -15,15 +15,17 @@ from typing import Optional from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import arg_min_max_impl from nvtripy.trace.ops.topk import TopKMin -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32", "int64"], "T2": ["int32"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64]), + output_guarantees=GetReturn(0).dtype == dt.int32, ) def argmin(input: "nvtripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/reduce/max.py b/tripy/nvtripy/frontend/ops/reduce/max.py index 5b28f99a4..dc2e68052 100644 --- a/tripy/nvtripy/frontend/ops/reduce/max.py +++ b/tripy/nvtripy/frontend/ops/reduce/max.py @@ -15,15 +15,17 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Max -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.int32, dt.int64, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def max( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False diff --git a/tripy/nvtripy/frontend/ops/reduce/mean.py b/tripy/nvtripy/frontend/ops/reduce/mean.py index 5a85e6357..fdca8a967 100644 --- a/tripy/nvtripy/frontend/ops/reduce/mean.py +++ b/tripy/nvtripy/frontend/ops/reduce/mean.py @@ -15,15 +15,17 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Avg -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.int32, dt.int64, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def mean( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False diff --git a/tripy/nvtripy/frontend/ops/reduce/min.py b/tripy/nvtripy/frontend/ops/reduce/min.py index 1bd15785b..8a43c9c18 100644 --- a/tripy/nvtripy/frontend/ops/reduce/min.py +++ b/tripy/nvtripy/frontend/ops/reduce/min.py @@ -15,15 +15,17 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Min -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.int32, dt.int64, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def min( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False diff --git a/tripy/nvtripy/frontend/ops/reduce/prod.py b/tripy/nvtripy/frontend/ops/reduce/prod.py index 75ab67f21..d70d9ce06 100644 --- a/tripy/nvtripy/frontend/ops/reduce/prod.py +++ b/tripy/nvtripy/frontend/ops/reduce/prod.py @@ -15,15 +15,17 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Prod -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.int32, dt.int64, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def prod( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False diff --git a/tripy/nvtripy/frontend/ops/reduce/sum.py b/tripy/nvtripy/frontend/ops/reduce/sum.py index e118eb397..efaa47ae7 100644 --- a/tripy/nvtripy/frontend/ops/reduce/sum.py +++ b/tripy/nvtripy/frontend/ops/reduce/sum.py @@ -15,15 +15,17 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Sum -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.int32, dt.int64, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def sum( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False diff --git a/tripy/nvtripy/frontend/ops/reduce/topk.py b/tripy/nvtripy/frontend/ops/reduce/topk.py index 57e1e2dfe..511e7f156 100644 --- a/tripy/nvtripy/frontend/ops/reduce/topk.py +++ b/tripy/nvtripy/frontend/ops/reduce/topk.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,15 +15,17 @@ from typing import Tuple from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops.reduce.utils import topk_impl from nvtripy.trace.ops.topk import TopKMax -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: ["T1", "T2"]}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32", "int64"], "T2": ["int32"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64]), + output_guarantees=(GetReturn(0).dtype == GetInput("input").dtype) & (GetReturn(1).dtype == dt.int32), ) def topk(input: "nvtripy.Tensor", k: int, dim: int) -> Tuple["nvtripy.Tensor", "nvtripy.Tensor"]: """ diff --git a/tripy/nvtripy/frontend/ops/reduce/var.py b/tripy/nvtripy/frontend/ops/reduce/var.py index ad8cd4f8d..cf17b3d9d 100644 --- a/tripy/nvtripy/frontend/ops/reduce/var.py +++ b/tripy/nvtripy/frontend/ops/reduce/var.py @@ -16,14 +16,16 @@ from typing import Optional, Sequence, Union from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def var( input: "nvtripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False, correction: int = 1 diff --git a/tripy/nvtripy/frontend/ops/repeat.py b/tripy/nvtripy/frontend/ops/repeat.py index d7e531b41..5c3f8aea9 100644 --- a/tripy/nvtripy/frontend/ops/repeat.py +++ b/tripy/nvtripy/frontend/ops/repeat.py @@ -15,18 +15,20 @@ # from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def repeat(input: "nvtripy.Tensor", repeats: IntLike, dim: int) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/reshape.py b/tripy/nvtripy/frontend/ops/reshape.py index 262dc78b8..7843f5074 100644 --- a/tripy/nvtripy/frontend/ops/reshape.py +++ b/tripy/nvtripy/frontend/ops/reshape.py @@ -18,12 +18,14 @@ import math from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.reshape import Reshape from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf def infer_dimensions(input: "nvtripy.Tensor", shape: ShapeLike) -> ShapeLike: @@ -46,8 +48,10 @@ def infer_dimensions(input: "nvtripy.Tensor", shape: ShapeLike) -> ShapeLike: @register_tensor_method("reshape") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, convert_to_tensors=True, conversion_preprocess_func=infer_dimensions, ) diff --git a/tripy/nvtripy/frontend/ops/resize.py b/tripy/nvtripy/frontend/ops/resize.py index 477623062..86905d86c 100644 --- a/tripy/nvtripy/frontend/ops/resize.py +++ b/tripy/nvtripy/frontend/ops/resize.py @@ -23,7 +23,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.resize import ResizeCubic, ResizeLinear, ResizeNearest from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers SUPPORTED_MODES = ("cubic", "linear", "nearest") @@ -48,10 +48,14 @@ def _create_resize(mode, inputs, scales, align_corners): return op_utils.create_op(ResizeCubic, inputs, scales=scales, align_corners=align_corners) +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "int8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, convert_to_tensors=True, ) def resize( @@ -108,8 +112,8 @@ def resize( @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "int8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def resize( input: "nvtripy.Tensor", scales: Sequence[numbers.Number], mode: str = "linear", align_corners: bool = False diff --git a/tripy/nvtripy/frontend/ops/shape.py b/tripy/nvtripy/frontend/ops/shape.py index 94e3fa19c..d662714c1 100644 --- a/tripy/nvtripy/frontend/ops/shape.py +++ b/tripy/nvtripy/frontend/ops/shape.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,12 +25,15 @@ from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.shape import GetDimensionSize, Shape from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, OneOf @register_tensor_method("shape") @property -@wrappers.interface(dtype_constraints={"self": "T1"}, dtype_variables={"T1": list(DATA_TYPES.keys())}) +@wrappers.interface( + input_requirements=OneOf(GetInput("self").dtype, list(DATA_TYPES.values())), +) def shape(self: "nvtripy.Tensor") -> Tuple[IntLike]: """ Represents the shape of the tensor. diff --git a/tripy/nvtripy/frontend/ops/slice.py b/tripy/nvtripy/frontend/ops/slice.py index 010e04dee..09e1b72b7 100644 --- a/tripy/nvtripy/frontend/ops/slice.py +++ b/tripy/nvtripy/frontend/ops/slice.py @@ -18,12 +18,14 @@ from typing import Sequence, Union from nvtripy import utils +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.slice import Slice from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.utils.types import type_str_from_arg from nvtripy.utils.utils import make_list @@ -32,8 +34,10 @@ @register_tensor_method("__getitem__") @wrappers.interface( - dtype_constraints={"self": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, ) def __getitem__( self: "nvtripy.Tensor", diff --git a/tripy/nvtripy/frontend/ops/softmax.py b/tripy/nvtripy/frontend/ops/softmax.py index f195eac81..afaa836a5 100644 --- a/tripy/nvtripy/frontend/ops/softmax.py +++ b/tripy/nvtripy/frontend/ops/softmax.py @@ -20,15 +20,17 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.softmax import Softmax -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16"], - }, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def softmax(input: "nvtripy.Tensor", dim: Optional[int] = None) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/split.py b/tripy/nvtripy/frontend/ops/split.py index 76a3d4a8d..0fc36787f 100644 --- a/tripy/nvtripy/frontend/ops/split.py +++ b/tripy/nvtripy/frontend/ops/split.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,15 +21,19 @@ from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool], + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def split( input: "nvtripy.Tensor", num_split_or_sizes: Union[int, Sequence[IntLike]], dim: int = 0 diff --git a/tripy/nvtripy/frontend/ops/squeeze.py b/tripy/nvtripy/frontend/ops/squeeze.py index 5ec260e15..2aa22911c 100644 --- a/tripy/nvtripy/frontend/ops/squeeze.py +++ b/tripy/nvtripy/frontend/ops/squeeze.py @@ -15,16 +15,20 @@ from typing import Sequence, Union from nvtripy import export, utils +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("squeeze") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def squeeze(input: "nvtripy.Tensor", dims: Union[Sequence[int], int]) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/stack.py b/tripy/nvtripy/frontend/ops/stack.py index 0f2df667e..959923d5c 100644 --- a/tripy/nvtripy/frontend/ops/stack.py +++ b/tripy/nvtripy/frontend/ops/stack.py @@ -16,16 +16,18 @@ from typing import Sequence from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"tensors": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("tensors").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("tensors").dtype, ) def stack(tensors: Sequence["nvtripy.Tensor"], dim: int = 0) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/transpose.py b/tripy/nvtripy/frontend/ops/transpose.py index 1d4b8cf3d..7c019bd2d 100644 --- a/tripy/nvtripy/frontend/ops/transpose.py +++ b/tripy/nvtripy/frontend/ops/transpose.py @@ -13,16 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.common.exception import raise_error from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("transpose") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def transpose(input: "nvtripy.Tensor", dim0: int, dim1: int) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/tril.py b/tripy/nvtripy/frontend/ops/tril.py index 563400e16..dea9e3e06 100644 --- a/tripy/nvtripy/frontend/ops/tril.py +++ b/tripy/nvtripy/frontend/ops/tril.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,15 +18,19 @@ from nvtripy.frontend.ops.iota import iota_like from nvtripy.frontend.ops.zeros import zeros_like from nvtripy.frontend.ops.where import where -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"tensor": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("tensor").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("tensor").dtype, ) def tril(tensor: "nvtripy.Tensor", diagonal: int = 0) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/triu.py b/tripy/nvtripy/frontend/ops/triu.py index 80dc08103..9057caf22 100644 --- a/tripy/nvtripy/frontend/ops/triu.py +++ b/tripy/nvtripy/frontend/ops/triu.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,15 +18,19 @@ from nvtripy.frontend.ops.iota import iota_like from nvtripy.frontend.ops.where import where from nvtripy.frontend.ops.zeros import zeros_like -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"tensor": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("tensor").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("tensor").dtype, ) def triu(tensor: "nvtripy.Tensor", diagonal: int = 0) -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/abs.py b/tripy/nvtripy/frontend/ops/unary/abs.py index 9447cecbb..86d45672f 100644 --- a/tripy/nvtripy/frontend/ops/unary/abs.py +++ b/tripy/nvtripy/frontend/ops/unary/abs.py @@ -16,16 +16,20 @@ # +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Abs -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("__abs__") @wrappers.interface( - dtype_constraints={"self": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64] + ), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, ) def __abs__(self: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/cos.py b/tripy/nvtripy/frontend/ops/unary/cos.py index cabe9f697..9871e6d2f 100644 --- a/tripy/nvtripy/frontend/ops/unary/cos.py +++ b/tripy/nvtripy/frontend/ops/unary/cos.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Cos -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def cos(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/exp.py b/tripy/nvtripy/frontend/ops/unary/exp.py index 4895a4556..22112a1f1 100644 --- a/tripy/nvtripy/frontend/ops/unary/exp.py +++ b/tripy/nvtripy/frontend/ops/unary/exp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Exp -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def exp(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/gelu.py b/tripy/nvtripy/frontend/ops/unary/gelu.py index a99dd240c..5416282b8 100644 --- a/tripy/nvtripy/frontend/ops/unary/gelu.py +++ b/tripy/nvtripy/frontend/ops/unary/gelu.py @@ -17,17 +17,17 @@ from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend import wrappers from nvtripy.trace.ops.unary import GeluErf from nvtripy.frontend.ops import utils as op_utils +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8"], - }, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def gelu(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/invert.py b/tripy/nvtripy/frontend/ops/unary/invert.py index 343f7525f..b191cfe7f 100644 --- a/tripy/nvtripy/frontend/ops/unary/invert.py +++ b/tripy/nvtripy/frontend/ops/unary/invert.py @@ -12,16 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Not -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("__invert__") @wrappers.interface( - dtype_constraints={"self": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["bool"]}, + input_requirements=OneOf(GetInput("self").dtype, [dt.bool]), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, ) def __invert__(self: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/log.py b/tripy/nvtripy/frontend/ops/unary/log.py index 74257a948..1e1fd87b9 100644 --- a/tripy/nvtripy/frontend/ops/unary/log.py +++ b/tripy/nvtripy/frontend/ops/unary/log.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Log -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def log(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/neg.py b/tripy/nvtripy/frontend/ops/unary/neg.py index 3849364b4..a446a9372 100644 --- a/tripy/nvtripy/frontend/ops/unary/neg.py +++ b/tripy/nvtripy/frontend/ops/unary/neg.py @@ -14,16 +14,20 @@ # limitations under the License. +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Neg -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("__neg__") @wrappers.interface( - dtype_constraints={"self": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, + input_requirements=OneOf( + GetInput("self").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64] + ), + output_guarantees=GetReturn(0).dtype == GetInput("self").dtype, ) def __neg__(self: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/relu.py b/tripy/nvtripy/frontend/ops/unary/relu.py index 2db6aac9f..61526001f 100644 --- a/tripy/nvtripy/frontend/ops/unary/relu.py +++ b/tripy/nvtripy/frontend/ops/unary/relu.py @@ -18,16 +18,18 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Relu -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int4", "int32", "int64", "int8"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int4, dt.int32, dt.int64, dt.int8] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def relu(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/rsqrt.py b/tripy/nvtripy/frontend/ops/unary/rsqrt.py index 5cd215073..5fe2e095f 100644 --- a/tripy/nvtripy/frontend/ops/unary/rsqrt.py +++ b/tripy/nvtripy/frontend/ops/unary/rsqrt.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Recip -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def rsqrt(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/sigmoid.py b/tripy/nvtripy/frontend/ops/unary/sigmoid.py index be7ea05a7..c50ed7d2c 100644 --- a/tripy/nvtripy/frontend/ops/unary/sigmoid.py +++ b/tripy/nvtripy/frontend/ops/unary/sigmoid.py @@ -16,17 +16,17 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend import wrappers from nvtripy.trace.ops.unary import Sigmoid from nvtripy.frontend.ops import utils as op_utils +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8"], - }, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def sigmoid(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/silu.py b/tripy/nvtripy/frontend/ops/unary/silu.py index 3813f06e3..14a35c612 100644 --- a/tripy/nvtripy/frontend/ops/unary/silu.py +++ b/tripy/nvtripy/frontend/ops/unary/silu.py @@ -16,15 +16,15 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.common import datatype as dt +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8"], - }, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def silu(input: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" diff --git a/tripy/nvtripy/frontend/ops/unary/sin.py b/tripy/nvtripy/frontend/ops/unary/sin.py index 7078a30b7..eaa905cf9 100644 --- a/tripy/nvtripy/frontend/ops/unary/sin.py +++ b/tripy/nvtripy/frontend/ops/unary/sin.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Sin -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def sin(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/sqrt.py b/tripy/nvtripy/frontend/ops/unary/sqrt.py index 6a67ed9db..c95fafdc6 100644 --- a/tripy/nvtripy/frontend/ops/unary/sqrt.py +++ b/tripy/nvtripy/frontend/ops/unary/sqrt.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Sqrt -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def sqrt(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unary/tanh.py b/tripy/nvtripy/frontend/ops/unary/tanh.py index fac66c675..85c187aeb 100644 --- a/tripy/nvtripy/frontend/ops/unary/tanh.py +++ b/tripy/nvtripy/frontend/ops/unary/tanh.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Tanh -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8"]}, + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def tanh(input: "nvtripy.Tensor") -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/unsqueeze.py b/tripy/nvtripy/frontend/ops/unsqueeze.py index fa3e25045..3a6c3bf69 100644 --- a/tripy/nvtripy/frontend/ops/unsqueeze.py +++ b/tripy/nvtripy/frontend/ops/unsqueeze.py @@ -16,16 +16,20 @@ # from nvtripy import export +from nvtripy.common import datatype as dt from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @register_tensor_method("unsqueeze") @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, ) def unsqueeze(input: "nvtripy.Tensor", dim: int) -> "nvtripy.Tensor": """ diff --git a/tripy/nvtripy/frontend/ops/where.py b/tripy/nvtripy/frontend/ops/where.py index 7c41a4ae0..01023705a 100644 --- a/tripy/nvtripy/frontend/ops/where.py +++ b/tripy/nvtripy/frontend/ops/where.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,16 +20,18 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.where import Where from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf @export.public_api(document_under="operations/functions") @wrappers.interface( - dtype_constraints={"condition": "T2", "input": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"], - "T2": ["bool"], - }, + input_requirements=(GetInput("condition").dtype == dt.bool) + & OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64]) + & (GetInput("other").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, convert_to_tensors=True, ) def where(condition: "nvtripy.Tensor", input: TensorLike, other: TensorLike) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/zeros.py b/tripy/nvtripy/frontend/ops/zeros.py index 177d20cf0..72de6341b 100644 --- a/tripy/nvtripy/frontend/ops/zeros.py +++ b/tripy/nvtripy/frontend/ops/zeros.py @@ -15,21 +15,22 @@ from typing import Optional from nvtripy import export -from nvtripy.common import datatype +from nvtripy.common import datatype as dt +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf, If from nvtripy.frontend.ops.full import full, full_like -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), ) def zeros( shape: "nvtripy.types.ShapeLike", - dtype: datatype.dtype = datatype.float32, + dtype: dt.dtype = dt.float32, ) -> "nvtripy.Tensor": """ Creates a Tensor of the specified shape and dtype with all elements set to 0. @@ -55,13 +56,20 @@ def zeros( @export.public_api(document_under="operations/initializers") @wrappers.interface( - dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], - }, + input_requirements=OneOf( + GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int8, dt.int32, dt.int64, dt.bool] + ) + & If( + GetInput("dtype") != None, + OneOf(GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool]), + ), + output_guarantees=If( + GetInput("dtype") != None, + GetReturn(0).dtype == GetInput("dtype"), + GetReturn(0).dtype == GetInput("input").dtype, + ), ) -def zeros_like(input: "nvtripy.Tensor", dtype: Optional[datatype.dtype] = None) -> "nvtripy.Tensor": +def zeros_like(input: "nvtripy.Tensor", dtype: Optional[dt.dtype] = None) -> "nvtripy.Tensor": """ Creates a Tensor with all elements set to 0 of the same shape as the input tensor. diff --git a/tripy/nvtripy/utils/wrappers.py b/tripy/nvtripy/frontend/wrappers.py similarity index 50% rename from tripy/nvtripy/utils/wrappers.py rename to tripy/nvtripy/frontend/wrappers.py index 7f84427c3..b1ab44336 100644 --- a/tripy/nvtripy/utils/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,29 +15,29 @@ # limitations under the License. # + import functools import inspect -import types from dataclasses import dataclass from textwrap import indent from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union from nvtripy import config, utils +from nvtripy.common.datatype import dtype as tp_dtype from nvtripy.common.exception import raise_error -from nvtripy.utils import result -from nvtripy.common.datatype import DATA_TYPES +from nvtripy.frontend.constraints import AlwaysTrue, Constraints, Equal, GetInput, GetDataType, Fetcher, doc_str +from nvtripy.frontend.constraints.optimizer import optimize_constraints @dataclass -class DataTypeConstraints: +class OperatorConstraints: func: Callable - constraints: Dict[str, str] - variables: Dict[str, List[str]] - exceptions: List[Dict[str, str]] + input_requirements: Optional[Constraints] + output_guarantees: Optional[Constraints] -DATA_TYPE_CONSTRAINTS = [] -RETURN_VALUE = "__RETURN_VALUE" +# A list of tuples of (input_requirements, output_guarantees) for operators. +OPERATOR_CONSTRAINTS: List[OperatorConstraints] = [] # Try to include correct column offsets for non-tensor arguments. @@ -104,36 +104,125 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name): source_info.column_range = candidates[0] -def get_arg_dtype(arg, func_name, arg_name) -> result.Result["nvtripy.dtype"]: - from nvtripy.common.datatype import dtype - from nvtripy.frontend.tensor import Tensor +def _find_known_datatypes(merged_args: List[Tuple[str, Any]], input_requirements: Constraints) -> Dict[str, tp_dtype]: + """ + Identify known datatypes from input requirements to enable automatic type casting. - if isinstance(arg, Sequence): - arg_dtypes = [] - for elem in arg: - dtype_result = get_arg_dtype(elem, func_name, arg_name) - if not dtype_result: - return result.Result.err( - [f"Could not determine data type of elements in sequence: {arg_name}"] + dtype_result.error_details - ) - arg_dtypes.append(dtype_result.value) - - if len(set(arg_dtypes)) != 1: - return result.Result.err( - [ - f"Mismatched data types in sequence argument for '{func_name}'.\n", - f"For parameter: '{arg_name}', all arguments must have the same data type, but got: " - f"{arg_dtypes}", - ], - ) - arg_dtype = arg_dtypes[0] - elif isinstance(arg, Tensor): - arg_dtype = arg.dtype - elif isinstance(arg, dtype): - arg_dtype = arg - else: - return result.Result.err([f"Expected a tensor or data type argument for {arg_name}, but got: {arg}"]) - return result.Result.ok(arg_dtype) + This function searches for Equal constraints in the input requirements to determine + which arguments should have matching datatypes. It skips Equal constraints that appear + inside If statement conditions, as those represent conditional checks rather than + type requirements. + + Limitation: Automatic type casting will not work for arguments whose datatypes are + conditionally dependent on other values (i.e., when the datatype requirement appears + only in the then_branch or else_branch of an If constraint). + """ + from nvtripy.frontend.constraints.logic import If + + # We perform this operation in two steps: + # 1. Identify all arguments that are expected to have equal data types. + # 2. Propagate known data types to all arguments in each equality set. + expected_equal_dtype: List[Set[str]] = [] + + def insert_pair(name1, name2): + for pair_set in expected_equal_dtype: + if name1 in pair_set or name2 in pair_set: + pair_set.update({name1, name2}) + return + expected_equal_dtype.append({name1, name2}) + + known_dtypes: Dict[str, Optional[tp_dtype]] = {} + arg_map: Dict[str, Any] = {name: value for name, value in merged_args} + + def _is_trusted_dtype_source(value: Any) -> bool: + """Return True if `value` should be treated as a stable dtype source for autocasting. + + We intentionally treat Python scalar literals (int/float/bool) as *untrusted* sources + so that expressions like `tensor_f16 * 1.0` will cast the scalar to the tensor dtype + rather than forcing the tensor to match the scalar's default dtype. + """ + from nvtripy.frontend.tensor import Tensor + + if isinstance(value, Tensor) or isinstance(value, tp_dtype): + return True + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + return all(_is_trusted_dtype_source(v) for v in value) if len(value) > 0 else False + return False + + for name, _ in merged_args: + + # If this argument already has a known dtype and is a trusted source, populate it: + if _is_trusted_dtype_source(arg_map.get(name)): + try: + known_dtypes[name] = GetDataType(GetInput(name))(merged_args) + except Exception: + pass + + def process_dtype_equality(matched_constraints, input_is_lhs): + for constraint in matched_constraints: + expected = constraint.fetcher_or_value if input_is_lhs else constraint.fetcher + if isinstance(expected, GetDataType): + # This might be too restrictive, in which case the assertion can be lifted and the logic here updated. + # However, it should generally be the case that input requirements are in terms of the inputs. + assert isinstance( + expected.value_fetcher, GetInput + ), f"Input requirements should only look at inputs" + other_name = expected.value_fetcher.name + + insert_pair(name, other_name) + + try: + known_dtypes[other_name] = expected(merged_args) + except Exception: + # dtype is not yet known (i.e. might be comparing two inputs with unknown dtypes) + pass + + else: + known_dtypes[name] = expected(merged_args) if isinstance(expected, Fetcher) else expected + + # NOTE: Because we check for the input on both sides of the equality, we do not need to do another pass over + # expected_equal_dtype to merge disjoint sets - if we have a transitive equality like: + # `a == c and b == d and b == a`, then `a`, `c`, `b` will be immediately added to the same set, which `d` will + # join when we process `b`. + # Skip searching inside If constraints to avoid treating conditional checks as type requirements: + process_dtype_equality( + input_requirements.find(Equal(GetDataType(GetInput(name)), None), skip_within=If), input_is_lhs=True + ) + process_dtype_equality( + input_requirements.find(Equal(None, GetDataType(GetInput(name))), skip_within=If), input_is_lhs=False + ) + + # We do not need to perform validation, as that will happen during constraints checking. + for dtype_set in expected_equal_dtype: + # Prefer dtypes coming from trusted sources (tensors / explicit dtypes). + trusted_names_in_order = [ + n for (n, _) in merged_args if n in dtype_set and _is_trusted_dtype_source(arg_map.get(n)) + ] + candidate_dtypes = [known_dtypes.get(n) for n in trusted_names_in_order if known_dtypes.get(n) is not None] + + # If there are conflicting trusted dtypes, do not guess. + known_dtype_in_set: Optional[tp_dtype] + if len(set(candidate_dtypes)) > 1: + known_dtype_in_set = None + else: + known_dtype_in_set = candidate_dtypes[0] if candidate_dtypes else None + + # dtype might be unknown if the arguments are all non-tensor / untrusted types. + # In that case, we intentionally do *not* treat inferred scalar literal dtypes (e.g. 1.0 -> float32) + # as "known" to avoid unnecessary/incorrect casting behavior. + for name in dtype_set: + if known_dtype_in_set is None: + known_dtypes[name] = None + continue + + # If this argument is an untrusted dtype source (e.g., Python scalar), prefer the + # trusted dtype chosen for the group even if we previously inferred a dtype for it. + if not _is_trusted_dtype_source(arg_map.get(name)): + known_dtypes[name] = known_dtype_in_set + else: + known_dtypes.setdefault(name, known_dtype_in_set) + + return known_dtypes # Performs type conversions if needed. Returns updated values of args, kwargs, and merged args @@ -141,12 +230,12 @@ def convert_input_types( func, args, kwargs, - merged_args, + merged_args: List[Tuple[str, Any]], var_arg_info, conversion_targets, conversion_preprocess_func, - dtype_constraints, shape_likes, + input_requirements: Constraints, ): from nvtripy.common.datatype import bool as tp_bool from nvtripy.common.datatype import floating, integer @@ -167,13 +256,9 @@ def convert_input_types( else: merged_args[index] = (name, new_args[name]) - # Materialize type variables from tensors. - type_vars = {} - for name, arg in merged_args: - if name in dtype_constraints: - dtype_result = get_arg_dtype(arg, func.__qualname__, name) - if dtype_result: - type_vars[dtype_constraints[name]] = dtype_result.value + known_datatypes: Dict[str, Optional[tp_dtype]] = {} + if input_requirements is not None: + known_datatypes = _find_known_datatypes(merged_args, input_requirements) new_args = [] new_kwargs = {} @@ -206,8 +291,8 @@ def add_arg(arg): ) dtype = None - if name in dtype_constraints and dtype_constraints[name] in type_vars: - dtype = type_vars[dtype_constraints[name]] + if input_requirements is not None: + dtype = known_datatypes.get(name) if dtype is not None: # Refuse to do unsafe casts like float -> int. @@ -232,72 +317,37 @@ def add_arg(arg): return new_args, new_kwargs, new_merged_args -# Modify the docstring to mention data type variables and exceptions -def _update_docstring(func, dtype_constraints, dtype_variables, dtype_exceptions): +# Modify the docstring to include constraints +def _update_docstring(func, input_requirements, output_guarantees): if not func.__doc__: return - # Update the docstring to add data type variables after the parameter documentation. - args_index = func.__doc__.find("Args:") - # Args: may be omitted for functions with no inputs - args_index = args_index if args_index != -1 else 0 - for name, var in dtype_constraints.items(): - find_str = f"\n {name}: " if name != RETURN_VALUE else "\n Returns:\n " - - param_index = func.__doc__.find(find_str, args_index) - assert param_index != -1, f"Parameter: {name} is not present or was not documented in {func.__name__}" - func.__doc__ = ( - func.__doc__[:param_index] - + rf"{find_str}[dtype=\ **{var}**\ ] " - + func.__doc__[param_index + len(find_str) :] - ) - - prefix = " " * 8 + if input_requirements is None and output_guarantees is None: + return - def sorted_types(dtypes): - return sorted( - dtypes, - key=lambda dtype: ( - tuple(typ.__name__ for typ in DATA_TYPES[dtype].__bases__), - DATA_TYPES[dtype].itemsize, - ), - ) + indentation = " " * 4 + code_block_index = func.__doc__.find(".. code-block:: python") + assert code_block_index != -1, f"No code example in docstring for {func.__name__}" - dtype_info = "DATA TYPE CONSTRAINTS:\n" - dtype_info += indent( - "\n".join( - [ - f"- **{var}**: {', '.join(map(lambda t: f':class:`{t}`', sorted_types(dtypes)))}" - for var, dtypes in dtype_variables.items() - ] - ), - prefix, + input_requirements_str = ( + f"\nINPUT REQUIREMENTS:\n{indent(doc_str(input_requirements), indentation)}\n" if input_requirements else "" + ) + output_guarantees_str = ( + f"\nOUTPUT GUARANTEES:\n{indent(doc_str(output_guarantees), indentation)}\n" if output_guarantees else "" ) - if dtype_exceptions: - dtype_info += "\n\n UNSUPPORTED DATA TYPE COMBINATIONS:\n" - esc_space = r"\ " - dtype_info += indent( - "\n".join( - [ - f"- {', '.join([f'**{k}**{esc_space}={esc_space}:class:`{v}`' for k, v in exception.items()])}" - for exception in dtype_exceptions - ] - ), - prefix, - ) - - dtype_info += "\n\n " - - code_block_index = func.__doc__.find(".. code-block:: python") - assert code_block_index != -1, f"No code example in docstring for {func.__name__}" - func.__doc__ = func.__doc__[:code_block_index] + dtype_info + func.__doc__[code_block_index:] + func.__doc__ = ( + func.__doc__[:code_block_index] + + indent(input_requirements_str + output_guarantees_str, indentation) + + "\n" + + indentation + + func.__doc__[code_block_index:] + ) def interface( - dtype_constraints: Dict[str, str] = {}, - dtype_variables: Dict[str, List[str]] = {}, - dtype_exceptions: List[Dict[str, str]] = [], + input_requirements: Optional[Constraints] = None, + output_guarantees: Optional[Constraints] = None, convert_to_tensors: Union[bool, Set[str]] = False, conversion_preprocess_func: Optional[Callable] = None, ): @@ -306,39 +356,17 @@ def interface( layer too many decorators, it is preferable to extend this decorator with further functionality than to add and apply further decorators. - The supported constraints are for data type constraints and for converting `TensorLike` and `ShapeLike` - inputs into `Tensor`s or `DimensionSize`s. - - NOTE: When annotating a new API, you should also update `tests/constraints/object_builders.py`. - Args: - dtype_constraints: Maps parameters and return values to data type constraint variables. - Use the special value `wrappers.RETURN_VALUE` to denote return values - this can be - a list for functions that have multiple outputs. If only one return type is specified for - functions with multiple outputs, it will be applied to all outputs. - For example: - {"input": "T1", "other": T2, wrappers.RETURN_VALUE: "T1"} - - dtype_variables: Maps data type constraints variables to their supported data types. - For example: - {"T1": ["float32", "float16"], "T2": ["int32", "int64"]} - - dtype_exceptions: Indicates specific combinations of data types that are not supported by the API. - For example: - [ - {"T1": "float16", "T2": "int32"}, - ] - - aliases: A list of function name aliases. For methods that are exposed as multiple APIs - (e.g. `__add__` and `__radd__`), this will enable type information to be added to the - documentation for the aliases as well. - + input_requirements: A constraints tree that validates function inputs. + If provided and input validation is enabled, these constraints are checked at runtime. + output_guarantees: A constraints tree describing guarantees about the function output. + If provided, these are used for documentation and tooling. convert_to_tensors: If False or an empty set, no argument types will be converted. If True, all arguments with the `TensorLike` or `ShapeLike` annotations will be converted into `Tensor`s or, whenever possible, `DimensionSize`. If the argument is a set of argument names, conversions will be done only for those arguments. - The conversions will respect any datatype constraints, casting the `TensorLike` values as necessary, + The conversions will attempt safe casts as needed based on `input_requirements`, but will raise an exception for lossy casts like float to int (but *not* for, e.g., `float32` to `float16`). conversion_preprocess_func: If `convert_to_tensors` is true, this argument is a callback that is @@ -353,6 +381,10 @@ def interface( def decorator(func): from nvtripy.types import ShapeLike, TensorLike + optimized_input_requirements = optimize_constraints(input_requirements) + if isinstance(optimized_input_requirements, AlwaysTrue): + optimized_input_requirements = None + signature = inspect.signature(func) conversion_targets = ( convert_to_tensors @@ -361,19 +393,31 @@ def decorator(func): ) shape_likes = {name for name, param in signature.parameters.items() if param.annotation is ShapeLike} - # if no dtype constraints have been specified at all, do not add to the table so we don't generate invalid tests - if dtype_constraints or dtype_variables or dtype_exceptions: - DATA_TYPE_CONSTRAINTS.append( - DataTypeConstraints(func, dtype_constraints, dtype_variables, dtype_exceptions) - ) - - _update_docstring(func, dtype_constraints, dtype_variables, dtype_exceptions) + # Register constraints for Tripy operators if either side is specified. + # + # NOTE: The interface decorator is also used in unit tests and potentially by user code. + # We only want Tripy's own operators to appear in the global registry that powers + # public-API validation and integration tests. + if (input_requirements is not None or output_guarantees is not None) and func.__module__.startswith("nvtripy"): + OPERATOR_CONSTRAINTS.append(OperatorConstraints(func, input_requirements, output_guarantees)) + _update_docstring(func, input_requirements, output_guarantees) @functools.wraps(func) def wrapper(*args, **kwargs): - merged_args, var_arg_info = utils.utils.merge_function_arguments(func, *args, **kwargs) + merged_args = None + omitted_default_args = None + var_arg_info = None + + def get_merged_args(): + nonlocal merged_args, omitted_default_args, var_arg_info + if merged_args is None: + merged_args, omitted_default_args, var_arg_info = utils.utils.merge_function_arguments( + func, *args, **kwargs + ) + return merged_args, omitted_default_args, var_arg_info if convert_to_tensors: + merged_args, omitted_default_args, var_arg_info = get_merged_args() args, kwargs, merged_args = convert_input_types( func, args, @@ -382,77 +426,33 @@ def wrapper(*args, **kwargs): var_arg_info, conversion_targets, conversion_preprocess_func, - dtype_constraints, shape_likes, + input_requirements, ) - if config.enable_dtype_checking: - from nvtripy.common.datatype import dtype - from nvtripy.frontend.tensor import Tensor - - # The first arguments seen for each type variable. Other arguments with the same variable - # must use the same data types. - type_var_first_args: Dict[str, Tuple[str, dtype, Any]] = {} - - for name, arg in merged_args: - if name not in dtype_constraints: - continue - - if arg is None: - # This is only possible for omitted optional arguments. Otherwise, None will - # be disallowed by the function registry's type checking. - continue - - type_var = dtype_constraints[name] + if config.enable_input_validation: + if optimized_input_requirements is not None: + merged_args, omitted_default_args, _ = get_merged_args() + # Input validation needs to know values for arguments that were not provided but have default values: + result = optimized_input_requirements(merged_args + omitted_default_args) + if not result: + details = ( + ["Expected: "] + + result.error_details + + [f".\n\nNote: Requirements are:\n {input_requirements}."] + ) - arg_dtype = get_arg_dtype(arg, func.__qualname__, name) - if not arg_dtype: - raise_error(f"Could not determine datatype of {name}.", arg_dtype.error_details) - arg_dtype = arg_dtype.value + # Include source locations for relevant tensor inputs to make constraint + # failures actionable. + for name, value in merged_args + omitted_default_args: + if hasattr(value, "stack_info"): + details.extend([f"\n\nArgument '{name}' was defined here:\n", value]) - # Check if the type is supported at all - supported_dtypes = dtype_variables[type_var] - if arg_dtype.name not in supported_dtypes: raise_error( - f"Unsupported data type in '{func.__qualname__}'.", - [ - f"For parameter: '{name}', got unsupported data type: '{arg_dtype}'.\n" - f"Supported data types are: {supported_dtypes}." - ] - + ( - [ - f"\nNote: '{name}' was: ", - arg, - ] - if isinstance(arg, Tensor) and "all" in config.extra_error_information - else [] - ), + f"Invalid inputs for function: '{func.__qualname__}'.", + details, ) - # Check if the type matches that of other inputs with the same type_var. - if type_var in type_var_first_args: - other_name, other_arg_dtype, other_arg = type_var_first_args[type_var] - if other_arg_dtype != arg_dtype: - raise_error( - f"Mismatched data types in '{func.__qualname__}'.", - [ - f"Parameters: '{other_name}' and '{name}' must have matching data types, but got: " - f"'{other_arg_dtype.name}' and '{arg_dtype.name}' respectively.\n" - ] - + ( - [ - f"Note: '{other_name}' was: ", - other_arg, - f"While '{name}' was: ", - arg, - ] - if isinstance(arg, Tensor) - else [] - ), - ) - - type_var_first_args[type_var] = (name, arg_dtype, arg) - return func(*args, **kwargs) return wrapper diff --git a/tripy/nvtripy/utils/stack_info.py b/tripy/nvtripy/utils/stack_info.py index 19fab4871..fe033e707 100644 --- a/tripy/nvtripy/utils/stack_info.py +++ b/tripy/nvtripy/utils/stack_info.py @@ -17,6 +17,9 @@ # NOTE: We avoid using `inspect` functions as much as possible because they are much slower than # working directly with the frame. + +import linecache +import os import sys from dataclasses import dataclass from typing import Optional, Tuple @@ -53,14 +56,19 @@ def fetch_source_code(self): if self.code is not None: return - # Note that in some cases, e.g. when code is being provided via the interactive shell, we may not be able to retrieve it. - # In that case we just leave it empty. - try: - lines = open(self.file, "r").readlines() - except: - self.code = "" + # Handle cases where bytecode was compiled with a different path than the current filesystem: + file_path = self.file + module = sys.modules.get(self.module) + if module and hasattr(module, "__file__") and module.__file__: + file_path = module.__file__ + + line = linecache.getline(file_path, self.line) + if line: + self.code = line.rstrip() else: - self.code = lines[self.line - 1].rstrip() + # Note that in some cases, e.g. when code is being provided via the interactive shell, we may not be able to retrieve it. + # In that case we just leave it empty. + self.code = "" class StackInfo(list): @@ -138,7 +146,7 @@ def get_module_names_to_exclude_from_stack_info(): Returns a set of module names to exclude from stack information when displaying exceptions or trying to retrieve column information from code. """ + import nvtripy.frontend.wrappers as wrappers import nvtripy.utils.function_registry as function_registry - import nvtripy.utils.wrappers as wrappers return {mod.__name__ for mod in [function_registry, wrappers]} diff --git a/tripy/nvtripy/utils/utils.py b/tripy/nvtripy/utils/utils.py index 572cfb13d..26190b1da 100644 --- a/tripy/nvtripy/utils/utils.py +++ b/tripy/nvtripy/utils/utils.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -333,12 +333,24 @@ def gen_uid(inputs=None, outputs=None): ## ## Functions ## -def get_positional_arg_names(func, *args) -> Tuple[List[Tuple[str, Any]], Optional[Tuple[str, int]]]: - # Returns the names of positional arguments by inspecting the function signature. - # In the case of variadic positional arguments, we cannot determine names, so we use - # None instead. To assist in further processing, this function also returns the name - # and start index of the variadic args in a pair if present (None if not). - signature = inspect.signature(func) +def _get_signature_cache_key(func): + try: + hash(func) + return func + except TypeError: + call = getattr(func, "__call__", None) + if call is not None: + try: + hash(call) + return call + except TypeError: + pass + return func.__class__ + + +@functools.lru_cache(maxsize=None) +def _get_signature_info(signature_source): + signature = inspect.signature(signature_source) arg_names = [] varargs_name = None for name, param in signature.parameters.items(): @@ -347,9 +359,19 @@ def get_positional_arg_names(func, *args) -> Tuple[List[Tuple[str, Any]], Option # (they would just be absorbed into the variadic argument). varargs_name = name break - arg_names.append(name) + return signature, tuple(arg_names), varargs_name + + +def get_positional_args_with_names(func, *args) -> Tuple[List[Tuple[str, Any]], Optional[Tuple[str, int]]]: + # Returns the names of positional arguments by inspecting the function signature. + # In the case of variadic positional arguments, we cannot determine names, so we use + # None instead. To assist in further processing, this function also returns the name + # and start index of the variadic args in a pair if present (None if not). + _, arg_names, varargs_name = _get_signature_info(_get_signature_cache_key(func)) + arg_names = list(arg_names) + # For all variadic positional arguments, assign the name of the variadic group. num_variadic_args = len(args) - len(arg_names) variadic_start_idx = len(arg_names) @@ -357,9 +379,24 @@ def get_positional_arg_names(func, *args) -> Tuple[List[Tuple[str, Any]], Option return list(zip(arg_names, args)), (varargs_name, variadic_start_idx) if num_variadic_args > 0 else None -def merge_function_arguments(func, *args, **kwargs) -> Tuple[List[Tuple[str, Any]], Optional[Tuple[str, int]]]: - # Merge positional and keyword arguments, trying to determine names where possible. - # Also returns a pair containing the variadic arg name and start index if present (None otherwise). - all_args, var_arg_info = get_positional_arg_names(func, *args) +def merge_function_arguments( + func, *args, **kwargs +) -> Tuple[List[Tuple[str, Any]], List[Tuple[str, Any]], Optional[Tuple[str, int]]]: + # Returns 3 things: + # 1. A list of all arguments (positional and keyword) as (name, value) pairs. + # 2. A list of omitted arguments with default values filled in. + # 3. A pair containing the variadic arg name and start index if present (None otherwise). + all_args, var_arg_info = get_positional_args_with_names(func, *args) all_args.extend(kwargs.items()) - return all_args, var_arg_info + + signature, _, _ = _get_signature_info(_get_signature_cache_key(func)) + provided_arg_names = {name for name, _ in all_args} + + omitted_args = [] + for name, param in signature.parameters.items(): + if name in provided_arg_names: + continue + if param.default is not inspect.Parameter.empty: + omitted_args.append((name, param.default)) + + return all_args, omitted_args, var_arg_info diff --git a/tripy/tests/common/test_exception.py b/tripy/tests/common/test_exception.py index 0c9fbdd87..b8a8a0aec 100644 --- a/tripy/tests/common/test_exception.py +++ b/tripy/tests/common/test_exception.py @@ -122,7 +122,7 @@ def test_can_determine_column_range(self): ) def test_wrappers_is_excluded(self): - from nvtripy.utils import wrappers + from nvtripy.frontend import wrappers tensor = tp.ones((2, 3)) diff --git a/tripy/tests/conftest.py b/tripy/tests/conftest.py index 495e15aea..515d6847a 100644 --- a/tripy/tests/conftest.py +++ b/tripy/tests/conftest.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -84,15 +84,24 @@ def add_plugin_aot_impl( ) compiled_kernel = triton.compile(src) + metadata = compiled_kernel.metadata + if isinstance(metadata, dict): + kernel_name = metadata["name"] + num_warps = metadata["num_warps"] + shared_mem = metadata.get("shared", 0) + else: + kernel_name = metadata.name + num_warps = metadata.num_warps + shared_mem = metadata.shared N = inp0.shape_expr.numel() launch_params = trtp.KernelLaunchParams() launch_params.grid_x = trtp.cdiv(N, block_size) - launch_params.block_x = compiled_kernel.metadata.num_warps * 32 - launch_params.shared_mem = compiled_kernel.metadata.shared + launch_params.block_x = num_warps * 32 + launch_params.shared_mem = shared_mem extra_args = trtp.SymIntExprs(1) extra_args[0] = trtp.SymInt32(N) - return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args + return kernel_name, compiled_kernel.asm["ptx"], launch_params, extra_args diff --git a/tripy/tests/frontend/constraints/__init__.py b/tripy/tests/frontend/constraints/__init__.py new file mode 100644 index 000000000..8bb95d5cb --- /dev/null +++ b/tripy/tests/frontend/constraints/__init__.py @@ -0,0 +1,16 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tripy/tests/frontend/constraints/test_base.py b/tripy/tests/frontend/constraints/test_base.py new file mode 100644 index 000000000..3bf12ffce --- /dev/null +++ b/tripy/tests/frontend/constraints/test_base.py @@ -0,0 +1,125 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from nvtripy.frontend.constraints import And, Equal, GetDataType, GetInput, OneOf + + +class TestConstraints: + def test_find_exact_match(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = Equal(GetInput("a"), GetInput("b")) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_no_match(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = OneOf(GetInput("a"), [1, 2, 3]) + assert len(constraint.find(pattern)) == 0 + + def test_find_in_nested_and(self): + inner_constraint = Equal(GetInput("a"), GetInput("b")) + constraint = And(inner_constraint, OneOf(GetInput("c"), [1, 2, 3])) + pattern = Equal(GetInput, GetInput) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is inner_constraint + + def test_find_multiple_matches(self): + equal1 = Equal(GetInput("a"), GetInput("b")) + equal2 = Equal(GetInput("c"), GetInput("d")) + constraint = And(equal1, equal2, OneOf(GetInput("e"), [1, 2, 3])) + matches = constraint.find(Equal(GetInput, GetInput)) + assert len(matches) == 2 and matches[0] is equal1 and matches[1] is equal2 + + def test_find_with_dtype_pattern(self): + constraint = Equal(GetDataType(GetInput("tensor1")), GetDataType(GetInput("tensor2"))) + pattern = Equal(GetDataType(GetInput), GetDataType(GetInput)) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_deeply_nested_matches(self): + equal1 = Equal(GetInput("a"), GetInput("b")) + equal2 = Equal(GetInput("d"), GetInput("e")) + constraint = And( + And(equal1, OneOf(GetInput("c"), [1, 2, 3])), + And(equal2, OneOf(GetInput("f"), [4, 5, 6])), + ) + matches = constraint.find(Equal(GetInput, GetInput)) + assert len(matches) == 2 and matches[0] is equal1 and matches[1] is equal2 + + def test_find_with_specific_names(self): + match_constraint = Equal(GetInput("a"), GetInput("b")) + constraint = And(match_constraint, Equal(GetInput("c"), GetInput("d"))) + matches = constraint.find(Equal(GetInput("a"), GetInput("b"))) + assert len(matches) == 1 and matches[0] is match_constraint + + def test_find_with_multiple_children(self): + equal1 = Equal(GetInput("a"), GetInput("b")) + equal2 = Equal(GetInput("c"), GetInput("d")) + oneof1 = OneOf(GetInput("e"), [1, 2, 3]) + equal3 = Equal(GetInput("f"), GetInput("g")) + constraint = And(equal1, equal2, oneof1, equal3) + matches = constraint.find(Equal(GetInput, GetInput)) + assert len(matches) == 3 + assert equal1 in matches + assert equal2 in matches + assert equal3 in matches + + def test_find_and_constraint(self): + and1 = And(Equal(GetInput("a"), GetInput("b")), OneOf(GetInput("c"), [1, 2, 3])) + and2 = And(Equal(GetInput("d"), GetInput("e")), OneOf(GetInput("f"), [4, 5, 6])) + constraint = And(and1, and2) + matches = constraint.find(And(Equal, OneOf)) + assert len(matches) == 2 + assert and1 in matches + assert and2 in matches + + def test_find_with_none_wildcard_second_arg(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = Equal(GetInput("a"), None) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_with_none_wildcard_first_arg(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = Equal(None, GetInput("b")) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_with_none_wildcard_in_nested(self): + equal1 = Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b"))) + equal2 = Equal(GetInput("c"), GetInput("d")) + constraint = And(equal1, equal2) + pattern = Equal(GetDataType(GetInput), None) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is equal1 + + def test_find_with_none_wildcard_matches_different_types(self): + equal = Equal(GetInput("a"), GetInput("b")) + oneof = OneOf(GetInput("c"), [1, 2, 3]) + constraint = And(equal, oneof) + pattern = None + matches = constraint.find(pattern) + assert len(matches) == 6 + assert constraint in matches + + def test_info_method(self): + constraint = Equal(GetInput("a"), GetInput("b")) + assert constraint._info is None + + result = constraint.info("This checks that a equals b") + + assert constraint._info == "This checks that a equals b" + assert result is constraint # Test method chaining diff --git a/tripy/tests/frontend/constraints/test_doc_str.py b/tripy/tests/frontend/constraints/test_doc_str.py new file mode 100644 index 000000000..d199ae0f7 --- /dev/null +++ b/tripy/tests/frontend/constraints/test_doc_str.py @@ -0,0 +1,58 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import nvtripy as tp +import pytest +from nvtripy.frontend.constraints import And, Equal, OneOf, Or, doc_str +from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput + + +class TestDocStr: + @pytest.mark.parametrize( + "obj, expected", + [ + (tp.float32, ":class:`float32`"), + (None, "``None``"), + ], + ) + def test_basic_types(self, obj, expected): + assert doc_str(obj) == expected + + def test_nested_constraints(self): + input_a = GetInput("a") + input_b = GetInput("b") + + or_part = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) + and_constraint = And(or_part, OneOf(input_b, [tp.int32])) + + assert ( + doc_str(and_constraint) + == "- (``a`` == :class:`float32` *or* ``a`` == :class:`float16`), **and**\n- ``b`` is one of [:class:`int32`]" + ) + + def test_complex_real_world_constraint(self): + input_a = GetInput("input") + input_b = GetInput("other") + dtype_a = GetDataType(input_a) + dtype_b = GetDataType(input_b) + + and_constraint = And(Equal(dtype_a, dtype_b), OneOf(dtype_a, [tp.float32, tp.float16])) + + assert ( + doc_str(and_constraint) + == "- ``input.dtype`` == ``other.dtype``, **and**\n- ``input.dtype`` is one of [:class:`float32`, :class:`float16`]" + ) diff --git a/tripy/tests/frontend/constraints/test_fetcher.py b/tripy/tests/frontend/constraints/test_fetcher.py new file mode 100644 index 000000000..9565e1c8f --- /dev/null +++ b/tripy/tests/frontend/constraints/test_fetcher.py @@ -0,0 +1,128 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import nvtripy as tp +from nvtripy.common.exception import TripyException +from nvtripy.frontend.constraints import Equal, GetDataType, GetInput, GetReturn, NotEqual, doc_str +from tests import helper + + +class TestFetcher: + def test_eq_operator_returns_equal(self): + fetcher1 = GetInput("param1") + fetcher2 = GetInput("param2") + constraint = fetcher1 == fetcher2 + assert isinstance(constraint, Equal) + assert constraint.fetcher == fetcher1 + assert constraint.fetcher_or_value == fetcher2 + + def test_ne_operator_returns_not_equal(self): + fetcher1 = GetInput("param1") + fetcher2 = GetInput("param2") + constraint = fetcher1 != fetcher2 + assert isinstance(constraint, NotEqual) + assert constraint.fetcher == fetcher1 + assert constraint.fetcher_or_value == fetcher2 + + +class TestValueFetcher: + def test_dtype_property(self): + fetcher = GetInput("tensor") + dtype_fetcher = fetcher.dtype + assert isinstance(dtype_fetcher, GetDataType) + assert dtype_fetcher.value_fetcher == fetcher + + +class TestGetInput: + def test_call(self): + fetcher = GetInput("data") + args = [("data", 42), ("other", "hello")] + assert fetcher(args) == 42 + + def test_str(self): + fetcher = GetInput("data") + assert str(fetcher) == "data" + + def test_doc_str(self): + assert doc_str(GetInput("x")) == "``x``" + + +class TestGetReturn: + def test_init(self): + fetcher = GetReturn(0) + assert fetcher.index == 0 + + def test_call(self): + fetcher = GetReturn(0) + returns = (42, "hello", 3.14) + assert fetcher([], returns) == 42 + + fetcher2 = GetReturn(2) + assert fetcher2([], returns) == 3.14 + + def test_str(self): + fetcher = GetReturn(0) + assert str(fetcher) == "return[0]" + + fetcher2 = GetReturn(2) + assert str(fetcher2) == "return[2]" + + def test_doc_str(self): + assert doc_str(GetReturn(0)) == "``return[0]``" + + +class TestGetDataType: + def test_call(self): + tensor = tp.ones((2, 3), dtype=tp.float32) + fetcher = GetDataType(GetInput("input_tensor")) + assert fetcher([("input_tensor", tensor)]) == tp.float32 + + def test_call_with_sequence(self): + tensors = [tp.ones((2, 3), dtype=tp.float32)] * 2 + fetcher = GetDataType(GetInput("input_tensors")) + assert fetcher([("input_tensors", tensors)]) == tp.float32 + + def test_call_with_mismatched_dtypes_in_sequence(self): + tensors = [tp.ones((2, 3), dtype=tp.float32), tp.ones((2, 3), dtype=tp.int32)] + fetcher = GetDataType(GetInput("input_tensors")) + with helper.raises(TripyException, match="Could not determine data type"): + fetcher([("input_tensors", tensors)]) + + def test_call_with_non_tensor_argument(self): + fetcher = GetDataType(GetInput("input_data")) + with helper.raises(TripyException, match="Expected a tensor or data type argument"): + fetcher([("input_data", object())]) + + def test_call_with_python_scalar_int(self): + fetcher = GetDataType(GetInput("value")) + assert fetcher([("value", 42)]) == tp.int32 + + def test_call_with_python_scalar_float(self): + fetcher = GetDataType(GetInput("value")) + assert fetcher([("value", 1.0)]) == tp.float32 + + def test_call_with_python_scalar_bool(self): + fetcher = GetDataType(GetInput("value")) + assert fetcher([("value", True)]) == tp.bool + + def test_call_with_nested_sequence_error(self): + fetcher = GetDataType(GetInput("input_data")) + with helper.raises(TripyException, match="Could not determine data type"): + fetcher([("input_data", [tp.ones((2, 3), dtype=tp.float32), [42]])]) + + def test_doc_str(self): + assert doc_str(GetDataType(GetInput("x"))) == "``x.dtype``" + assert doc_str(GetDataType(GetReturn(0))) == "``return[0].dtype``" diff --git a/tripy/tests/frontend/constraints/test_logic.py b/tripy/tests/frontend/constraints/test_logic.py new file mode 100644 index 000000000..e5b2521f9 --- /dev/null +++ b/tripy/tests/frontend/constraints/test_logic.py @@ -0,0 +1,337 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import nvtripy as tp +from nvtripy.frontend.constraints import And, Equal, GetInput, If, NotEqual, NotOneOf, OneOf, Or, doc_str +from nvtripy.frontend.constraints.fetcher import GetDataType + + +class TestLogic: + def test_operator_and_basic(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + combined = constraint1 & constraint2 + assert isinstance(combined, And) + assert combined([("param1", 2), ("param2", "b")]) + + def test_operator_and_chaining(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + constraint3 = OneOf(GetInput("param3"), [True, False]) + combined = constraint1 & constraint2 & constraint3 + assert isinstance(combined, And) + assert len(combined.constraints) == 3 + assert combined([("param1", 2), ("param2", "b"), ("param3", True)]) + + def test_operator_or_basic(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + combined = constraint1 | constraint2 + assert isinstance(combined, Or) + assert combined([("param1", 5), ("param2", "b")]) + + def test_operator_or_chaining(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + constraint3 = OneOf(GetInput("param3"), [True, False]) + combined = constraint1 | constraint2 | constraint3 + assert isinstance(combined, Or) + assert len(combined.constraints) == 3 + assert combined([("param1", 5), ("param2", "z"), ("param3", True)]) + + def test_operator_not_basic(self): + constraint = OneOf(GetInput("param"), [1, 2, 3]) + negated = ~constraint + assert isinstance(negated, NotOneOf) + assert negated([("param", 5)]) + assert not negated([("param", 2)]) + + +class TestOneOf: + def test_call(self): + constraint = OneOf(GetInput("param"), [1, 2, 3]) + assert constraint([("param", 2)]) + result = constraint([("param", 5)]) + assert not result + assert "'param' to be one of [1, 2, 3] (but it was '5')" in result.error_details + + def test_str(self): + assert str(OneOf(GetInput("param"), [1, 2, 3])) == "param is one of [1, 2, 3]" + + def test_inverse(self): + constraint = OneOf(GetInput("param"), [1, 2, 3]) + inverse = constraint.inverse() + assert isinstance(inverse, NotOneOf) + assert inverse([("param", 5)]) + assert not inverse([("param", 2)]) + + def test_doc_str(self): + constraint = OneOf(GetInput("x"), [tp.float32, tp.float16]) + assert doc_str(constraint) == "``x`` is one of [:class:`float32`, :class:`float16`]" + + +class TestNotOneOf: + def test_call(self): + constraint = NotOneOf(GetInput("param"), [1, 2, 3]) + assert constraint([("param", 5)]) + result = constraint([("param", 2)]) + assert not result + assert "'param' to not be one of [1, 2, 3] (but it was '2')" in result.error_details + + def test_str(self): + assert str(NotOneOf(GetInput("param"), [1, 2, 3])) == "param is not one of [1, 2, 3]" + + def test_inverse(self): + constraint = NotOneOf(GetInput("param"), [1, 2, 3]) + inverse = constraint.inverse() + assert isinstance(inverse, OneOf) + assert inverse([("param", 2)]) + assert not inverse([("param", 5)]) + + def test_doc_str(self): + constraint = NotOneOf(GetInput("x"), [tp.int8, tp.int32]) + assert doc_str(constraint) == "``x`` is not one of [:class:`int8`, :class:`int32`]" + + +class TestAnd: + def test_call_all_pass(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert and_constraint([("param1", 2), ("param2", "b")]) + + def test_call_one_fails(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + result = and_constraint([("param1", 5), ("param2", "b")]) + assert not result + assert "'param1' to be one of [1, 2, 3] (but it was '5')" in result.error_details + + def test_call_all_fail(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + result = and_constraint([("param1", 5), ("param2", "z")]) + assert not result + assert ( + "".join(result.error_details) + == "'param1' to be one of [1, 2, 3] (but it was '5') and 'param2' to be one of ['a', 'b', 'c'] (but it was 'z')" + ) + + def test_str(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + assert str(and_constraint) == "(param1 is one of [1, 2, 3] and param2 is one of ['a', 'b'])" + + def test_inverse(self): + # De Morgan's law: not (A and B) = (not A) or (not B) + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + inverse = and_constraint.inverse() + assert isinstance(inverse, Or) + assert str(inverse) == "(param1 is not one of [1, 2, 3] or param2 is not one of ['a', 'b'])" + + def test_doc_str(self): + and_constraint = And(OneOf(GetInput("a"), [tp.float32]), OneOf(GetInput("b"), [tp.int32])) + assert ( + doc_str(and_constraint) + == "- ``a`` is one of [:class:`float32`], **and**\n- ``b`` is one of [:class:`int32`]" + ) + + +class TestOr: + def test_call_first_passes(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert or_constraint([("param1", 2), ("param2", "z")]) + + def test_call_second_passes(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert or_constraint([("param1", 5), ("param2", "b")]) + + def test_call_all_pass(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert or_constraint([("param1", 2), ("param2", "b")]) + + def test_call_all_fail(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + result = or_constraint([("param1", 5), ("param2", "z")]) + assert not result + assert ( + "".join(result.error_details) + == "'param1' to be one of [1, 2, 3] (but it was '5') or 'param2' to be one of ['a', 'b', 'c'] (but it was 'z')" + ) + + def test_str(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + assert str(or_constraint) == "(param1 is one of [1, 2, 3] or param2 is one of ['a', 'b'])" + + def test_call_multiple_constraints(self): + or_constraint = Or( + OneOf(GetInput("param1"), [1, 2, 3]), + OneOf(GetInput("param2"), ["a", "b", "c"]), + OneOf(GetInput("param3"), [True, False]), + ) + assert or_constraint([("param1", 5), ("param2", "z"), ("param3", True)]) + assert not or_constraint([("param1", 5), ("param2", "z"), ("param3", None)]) + + def test_inverse(self): + # De Morgan's law: not (A or B) = (not A) and (not B) + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + inverse = or_constraint.inverse() + assert isinstance(inverse, And) + assert str(inverse) == "(param1 is not one of [1, 2, 3] and param2 is not one of ['a', 'b'])" + + def test_doc_str(self): + or_constraint = Or(Equal(GetInput("a"), tp.float32), Equal(GetInput("a"), tp.float16)) + assert doc_str(or_constraint) == "(``a`` == :class:`float32` *or* ``a`` == :class:`float16`)" + + +class TestEqual: + def test_call(self): + constraint = Equal(GetInput("param1"), GetInput("param2")) + assert constraint([("param1", 5), ("param2", 5)]) + result = constraint([("param1", 5), ("param2", 10)]) + assert not result + assert "'param1' to be equal to 'param2' (but it was '5' while 'param2' was '10')" in result.error_details + + def test_str(self): + assert str(Equal(GetInput("param1"), GetInput("param2"))) == "param1 == param2" + assert str(Equal(GetInput("param1"), 5)) == "param1 == 5" + + def test_operator_on_fetcher(self): + constraint = GetInput("param1") == GetInput("param2") + assert isinstance(constraint, Equal) + + def test_inverse(self): + constraint = Equal(GetInput("param1"), 5) + inverse = constraint.inverse() + assert isinstance(inverse, NotEqual) + assert inverse([("param1", 10)]) + assert not inverse([("param1", 5)]) + + def test_doc_str(self): + assert doc_str(Equal(GetInput("a"), GetInput("b"))) == "``a`` == ``b``" + assert doc_str(Equal(GetInput("a"), tp.float32)) == "``a`` == :class:`float32`" + + +class TestNotEqual: + def test_call(self): + constraint = NotEqual(GetInput("param1"), GetInput("param2")) + assert constraint([("param1", 5), ("param2", 10)]) + result = constraint([("param1", 5), ("param2", 5)]) + assert not result + assert "'param1' to be not equal to 'param2' (but it was '5')" in result.error_details + + def test_str(self): + assert str(NotEqual(GetInput("param1"), GetInput("param2"))) == "param1 != param2" + assert str(NotEqual(GetInput("param1"), 5)) == "param1 != 5" + + def test_operator_on_fetcher(self): + constraint = GetInput("param1") != GetInput("param2") + assert isinstance(constraint, NotEqual) + + def test_inverse(self): + constraint = NotEqual(GetInput("param1"), 5) + inverse = constraint.inverse() + assert isinstance(inverse, Equal) + assert inverse([("param1", 5)]) + assert not inverse([("param1", 10)]) + + def test_doc_str(self): + assert doc_str(NotEqual(GetInput("a"), GetInput("b"))) == "``a`` != ``b``" + + +class TestIf: + def test_call(self): + if_constraint = If( + Equal(GetInput("param1"), 5), + OneOf(GetInput("param2"), [1, 2, 3]), + OneOf(GetInput("param2"), [10, 20, 30]), + ) + # Condition true, then branch passes + assert if_constraint([("param1", 5), ("param2", 2)]) + # Condition true, then branch fails + result = if_constraint([("param1", 5), ("param2", 10)]) + assert not result + assert "'param2' to be one of [1, 2, 3] (but it was '10')" in result.error_details + # Condition false, else branch passes + assert if_constraint([("param1", 10), ("param2", 20)]) + # Condition false, else branch fails + result = if_constraint([("param1", 10), ("param2", 100)]) + assert not result + assert "'param2' to be one of [10, 20, 30] (but it was '100')" in result.error_details + + def test_str(self): + if_constraint = If( + Equal(GetInput("param1"), 5), + OneOf(GetInput("param2"), [1, 2, 3]), + OneOf(GetInput("param2"), [10, 20]), + ) + assert ( + str(if_constraint) == "if (param1 == 5) then (param2 is one of [1, 2, 3]) else (param2 is one of [10, 20])" + ) + + def test_inverse(self): + if_constraint = If( + Equal(GetInput("param1"), 5), + OneOf(GetInput("param2"), [1, 2, 3]), + OneOf(GetInput("param2"), [10, 20]), + ) + inverse = ~if_constraint + assert isinstance(inverse.then_branch, NotOneOf) + assert isinstance(inverse.else_branch, NotOneOf) + # When param1 == 5, param2 should NOT be in [1, 2, 3] + assert inverse([("param1", 5), ("param2", 10)]) + assert not inverse([("param1", 5), ("param2", 2)]) + + def test_doc_str(self): + if_constraint = If( + Equal(GetDataType(GetInput("a")), tp.float32), + OneOf(GetInput("b"), [tp.float32, tp.float16]), + OneOf(GetInput("b"), [tp.int32, tp.int64]), + ) + assert ( + doc_str(if_constraint) + == "``b`` is one of [:class:`float32`, :class:`float16`] **if** ``a.dtype`` == :class:`float32`, **otherwise** ``b`` is one of [:class:`int32`, :class:`int64`]" + ) + + def test_call_without_else_branch(self): + if_constraint = If(Equal(GetInput("param1"), 5), OneOf(GetInput("param2"), [1, 2, 3])) + # Condition true, then branch passes + assert if_constraint([("param1", 5), ("param2", 2)]) + # Condition true, then branch fails + result = if_constraint([("param1", 5), ("param2", 10)]) + assert not result + assert "'param2' to be one of [1, 2, 3] (but it was '10')" in result.error_details + # Condition false, no else branch - should always pass + assert if_constraint([("param1", 10), ("param2", 999)]) + + def test_str_without_else_branch(self): + if_constraint = If(Equal(GetInput("param1"), 5), OneOf(GetInput("param2"), [1, 2, 3])) + assert str(if_constraint) == "if (param1 == 5) then (param2 is one of [1, 2, 3])" + + def test_doc_str_without_else_branch(self): + if_constraint = If( + Equal(GetDataType(GetInput("a")), tp.float32), OneOf(GetInput("b"), [tp.float32, tp.float16]) + ) + assert ( + doc_str(if_constraint) + == "if ``a.dtype`` == :class:`float32`, then ``b`` is one of [:class:`float32`, :class:`float16`]" + ) + + def test_inverse_without_else_branch(self): + if_constraint = If(Equal(GetInput("param1"), 5), OneOf(GetInput("param2"), [1, 2, 3])) + inverse = ~if_constraint + assert isinstance(inverse.then_branch, NotOneOf) + assert inverse.else_branch is None + # When param1 == 5, param2 should NOT be in [1, 2, 3] + assert inverse([("param1", 5), ("param2", 10)]) + assert not inverse([("param1", 5), ("param2", 2)]) + # When param1 != 5, should always pass + assert inverse([("param1", 10), ("param2", 2)]) diff --git a/tripy/tests/frontend/constraints/test_optimizer.py b/tripy/tests/frontend/constraints/test_optimizer.py new file mode 100644 index 000000000..b90d23597 --- /dev/null +++ b/tripy/tests/frontend/constraints/test_optimizer.py @@ -0,0 +1,43 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from nvtripy.common.datatype import DATA_TYPES +from nvtripy.frontend.constraints import AlwaysTrue, And, GetInput, OneOf +from nvtripy.frontend.constraints.fetcher import GetDataType +from nvtripy.frontend.constraints.optimizer import optimize_constraints + + +class TestOptimizeConstraints: + def test_drops_all_dtype_oneof(self): + constraint = OneOf(GetDataType(GetInput("x")), list(DATA_TYPES.values())) + optimized = optimize_constraints(constraint) + assert isinstance(optimized, AlwaysTrue) + + def test_keeps_non_exhaustive_oneof(self): + dtypes = list(DATA_TYPES.values()) + constraint = OneOf(GetDataType(GetInput("x")), dtypes[:-1]) + optimized = optimize_constraints(constraint) + assert optimized is constraint + + def test_applies_to_nested_constraints(self): + constraint = And( + OneOf(GetDataType(GetInput("x")), list(DATA_TYPES.values())), + OneOf(GetInput("y"), [1, 2, 3]), + ) + optimized = optimize_constraints(constraint) + assert isinstance(optimized, And) + assert any(isinstance(child, AlwaysTrue) for child in optimized.constraints) diff --git a/tripy/tests/frontend/module/test_conv.py b/tripy/tests/frontend/module/test_conv.py index fd0af012f..1c11cccf6 100644 --- a/tripy/tests/frontend/module/test_conv.py +++ b/tripy/tests/frontend/module/test_conv.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -44,7 +44,7 @@ def test_op_func(self, ConvType): def test_mismatched_dtypes_fails(self, ConvType): input = tp.ones((4, 3, 8, 8), dtype=tp.float32) conv_layer = make_conv(ConvType, 3, 16, (5, 5), dtype=tp.float16) - with helper.raises(tp.TripyException, match=r"Mismatched data types in", has_stack_info_for=[input]): + with helper.raises(tp.TripyException, match=r"Invalid inputs for function:", has_stack_info_for=[input]): output = conv_layer(input) def test_mismatched_dim_fails(self, ConvType): diff --git a/tripy/tests/frontend/module/test_embedding.py b/tripy/tests/frontend/module/test_embedding.py index 6f440b70e..0ee3f55a2 100644 --- a/tripy/tests/frontend/module/test_embedding.py +++ b/tripy/tests/frontend/module/test_embedding.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,5 +34,5 @@ def test_incorrect_input_dtype(self): embd = tp.Embedding(4, 16) embd.weight = tp.ones(embd.weight.shape) - with helper.raises(tp.TripyException, match="Unsupported data type in 'gather'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: 'gather'."): out = embd(a) diff --git a/tripy/tests/frontend/ops/test_binary.py b/tripy/tests/frontend/ops/test_binary.py index dc19b0e31..473aa183a 100644 --- a/tripy/tests/frontend/ops/test_binary.py +++ b/tripy/tests/frontend/ops/test_binary.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -63,7 +63,7 @@ def test_mismatched_dtypes_fails(self): with helper.raises( tp.TripyException, # Keep the entire error message here so we'll know if the display becomes horribly corrupted. - match=r"Mismatched data types in '__add__'.", + match=r"Invalid inputs for function: '__add__'.", has_stack_info_for=[a, b], ): c = a + b diff --git a/tripy/tests/frontend/ops/test_dequantize.py b/tripy/tests/frontend/ops/test_dequantize.py index a353db13b..8c80431d7 100644 --- a/tripy/tests/frontend/ops/test_dequantize.py +++ b/tripy/tests/frontend/ops/test_dequantize.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,7 +26,7 @@ def test_invalid_input_dtype(self): a = tp.Tensor([1.0, 2.0]) with helper.raises( tp.TripyException, - match="Unsupported data type in 'dequantize'", + match="Invalid inputs for function: 'dequantize'.", ): a = tp.dequantize(a, 0.9, tp.float32) @@ -34,7 +34,7 @@ def test_invalid_dequant_dtype(self): a = tp.ones([2], dtype=tp.int8) with helper.raises( tp.TripyException, - match="Unsupported data type in 'dequantize'", + match="Invalid inputs for function: 'dequantize'.", ): a = tp.dequantize(a, 1, tp.int32) diff --git a/tripy/tests/frontend/ops/test_equal.py b/tripy/tests/frontend/ops/test_equal.py index 83792a6a9..2b6cb25ca 100644 --- a/tripy/tests/frontend/ops/test_equal.py +++ b/tripy/tests/frontend/ops/test_equal.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,5 +18,5 @@ class TestEqual: def test_mismatched_dtypes_disallowed(self): - with helper.raises(tp.TripyException, match="Mismatched data types in 'equal'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: 'equal'."): tp.equal(tp.ones((2,), dtype=tp.float32), tp.ones((2,), dtype=tp.float16)) diff --git a/tripy/tests/frontend/ops/test_matmul.py b/tripy/tests/frontend/ops/test_matmul.py index f33cb7f44..3d810403d 100644 --- a/tripy/tests/frontend/ops/test_matmul.py +++ b/tripy/tests/frontend/ops/test_matmul.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,7 +33,7 @@ def test_mismatched_dtypes_fails(self): a = tp.ones((2, 3), dtype=tp.float32) b = tp.ones((3, 2), dtype=tp.float16) - with helper.raises(tp.TripyException, match="Mismatched data types in '__matmul__'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: '__matmul__'."): c = a @ b def test_incompatible_1d_shapes_fails(self): diff --git a/tripy/tests/frontend/ops/test_quantize.py b/tripy/tests/frontend/ops/test_quantize.py index f9520a37d..a9055bf18 100644 --- a/tripy/tests/frontend/ops/test_quantize.py +++ b/tripy/tests/frontend/ops/test_quantize.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +25,7 @@ def test_invalid_input_dtype(self): a = tp.Tensor([1, 2]) with helper.raises( tp.TripyException, - match="Unsupported data type in 'quantize'", + match="Invalid inputs for function: 'quantize'.", ): a = tp.quantize(a, 1, tp.int8) @@ -33,7 +33,7 @@ def test_unsupported_quant_dtype(self): a = tp.Tensor([1.0, 2.0]) with helper.raises( tp.TripyException, - match="Unsupported data type in 'quantize'", + match="Invalid inputs for function: 'quantize'.", ): a = tp.quantize(a, 1, tp.float16) diff --git a/tripy/tests/frontend/ops/test_where.py b/tripy/tests/frontend/ops/test_where.py index 85d68f5f5..6921d9175 100644 --- a/tripy/tests/frontend/ops/test_where.py +++ b/tripy/tests/frontend/ops/test_where.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -45,7 +45,7 @@ def test_mismatched_input_dtypes(self): a = tp.ones((2,), dtype=tp.float32) b = tp.ones((2,), dtype=tp.float16) - with helper.raises(tp.TripyException, match="Mismatched data types in 'where'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: 'where'."): c = tp.where(cond, a, b) def test_condition_is_not_bool(self): @@ -53,7 +53,7 @@ def test_condition_is_not_bool(self): a = tp.ones((2,), dtype=tp.float32) b = tp.ones((2,), dtype=tp.float32) - with helper.raises(tp.TripyException, match="Unsupported data type in 'where'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: 'where'."): c = tp.where(cond, a, b) @@ -62,5 +62,5 @@ def test_condition_is_not_bool(self): a = tp.Tensor([0, 1, 0, 1]) mask = tp.Tensor([1.0, 2.0, 3.0, 4.0]) - with helper.raises(tp.TripyException, match="Unsupported data type in 'masked_fill'."): + with helper.raises(tp.TripyException, match="Invalid inputs for function: 'masked_fill'."): b = tp.masked_fill(a, mask, -1) diff --git a/tripy/tests/utils/wrappers/test_interface.py b/tripy/tests/frontend/wrappers/test_wrappers.py old mode 100755 new mode 100644 similarity index 60% rename from tripy/tests/utils/wrappers/test_interface.py rename to tripy/tests/frontend/wrappers/test_wrappers.py index 7a9466afe..fdc59c9a1 --- a/tripy/tests/utils/wrappers/test_interface.py +++ b/tripy/tests/frontend/wrappers/test_wrappers.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,8 +22,10 @@ import nvtripy as tp import pytest from nvtripy.export import PUBLIC_APIS -from nvtripy.utils import wrappers -from nvtripy.utils.wrappers import DATA_TYPE_CONSTRAINTS +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn +from nvtripy.frontend.constraints.logic import And, Equal, NotEqual, NotOneOf, OneOf, Or +from nvtripy.frontend.wrappers import _find_known_datatypes, OPERATOR_CONSTRAINTS from tests import helper # Get all functions/methods which have tensors in the type signature @@ -46,25 +48,105 @@ api.qualname + f".{func.__name__}" if func.__name__ not in api.qualname else "" ) -DATA_TYPE_CONSTRAINTS_FUNC_NAMES = {dtc.func.__qualname__ for dtc in DATA_TYPE_CONSTRAINTS} +OPERATOR_CONSTRAINTS_FUNC_NAMES = {oc.func.__qualname__ for oc in OPERATOR_CONSTRAINTS} @pytest.mark.parametrize("api", PUBLIC_API_TENSOR_FUNCTIONS, ids=PUBLIC_API_TENSOR_FUNCTION_NAMES) -def test_all_public_apis_verified(api): - assert api.__qualname__ in DATA_TYPE_CONSTRAINTS_FUNC_NAMES, f"Missing datatype constraints for: {api.__qualname__}" +def test_all_public_apis_have_operator_constraints(api): + assert api.__qualname__ in OPERATOR_CONSTRAINTS_FUNC_NAMES, f"Missing operator constraints for: {api.__qualname__}" -@wrappers.interface(dtype_constraints={"tensors": "T1"}, dtype_variables={"T1": ["float32"]}) +@wrappers.interface(input_requirements=OneOf(GetInput("tensors").dtype, [tp.float32])) def sequence_func(tensors: List[tp.Tensor]): return +class TestFindKnownDatatypes: + def test_equal_dtypes_propagation(self): + tensor_a = tp.Tensor([1.0, 2.0]) + merged_args = [("a", tensor_a), ("b", 1.0)] + + input_requirements = And(Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b")))) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] == tp.float32 + assert result["b"] == tp.float32 + + def test_multiple_equal_dtypes_chain(self): + tensor_c = tp.Tensor([1.0, 2.0]) + merged_args = [("a", 0.0), ("b", 1.0), ("c", tensor_c)] + + input_requirements = And( + Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b"))), + Equal(GetDataType(GetInput("b")), GetDataType(GetInput("c"))), + ) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] == tp.float32 + assert result["b"] == tp.float32 + assert result["c"] == tp.float32 + + def test_chain_with_disjoint_sets(self): + # Make sure implementation applies transitive equality correctly when the constraints are such + # that it could form disjoint sets (in an incorrect implementation). + tensor_a = tp.Tensor([1.0, 2.0]) + merged_args = [("a", tensor_a), ("b", 1.0), ("c", 1), ("d", 1)] + + input_requirements = And( + Equal(GetDataType(GetInput("a")), GetDataType(GetInput("c"))), + Equal(GetDataType(GetInput("b")), GetDataType(GetInput("d"))), + Equal(GetDataType(GetInput("b")), GetDataType(GetInput("a"))), + ) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] is tp.float32 + assert result["b"] is tp.float32 + assert result["c"] is tp.float32 + assert result["d"] is tp.float32 + + def test_equal_to_constant_dtype(self): + merged_args = [("a", 1.0)] + + input_requirements = And(Equal(GetDataType(GetInput("a")), tp.float16)) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] == tp.float16 + + def test_multiple_separate_dtype_groups(self): + tensor_a = tp.Tensor([1.0, 2.0]) + tensor_c = tp.Tensor([1, 2]) + merged_args = [("a", tensor_a), ("b", 1.0), ("c", tensor_c), ("d", 1)] + + input_requirements = And( + Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b"))), + Equal(GetDataType(GetInput("c")), GetDataType(GetInput("d"))), + ) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] == tp.float32 + assert result["b"] == tp.float32 + assert result["c"] == tp.int32 + assert result["d"] == tp.int32 + + def test_unknown_dtypes(self): + merged_args = [("a", 1.0), ("b", 2.0)] + + input_requirements = And(Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b")))) + + result = _find_known_datatypes(merged_args, input_requirements) + assert result["a"] is None + assert result["b"] is None + + class TestDtypes: def test_works_with_sequences(self): sequence_func([tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32)]) def test_raises_on_mismatched_sequence_dtypes(self): - with helper.raises(tp.TripyException, match="Mismatched data types in sequence argument for 'sequence_func'."): + with helper.raises( + tp.TripyException, + match=r"Mismatched data types in sequence argument", + ): sequence_func([tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.int32)]) @@ -168,8 +250,8 @@ def func(a: tp.types.TensorLike, b: tp.types.TensorLike): def test_cast_dtype(self): # When type constraints are included, the decorator should automatically cast when possible. @wrappers.interface( - dtype_constraints={"a": "T1", "b": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["float16"]}, + input_requirements=(GetInput("a").dtype == tp.float16) & (GetInput("b").dtype == GetInput("a").dtype), + output_guarantees=(GetReturn(0).dtype == GetInput("a").dtype) & (GetReturn(1).dtype == GetInput("a").dtype), convert_to_tensors=True, ) def func(a: tp.Tensor, b: tp.types.TensorLike): @@ -185,11 +267,15 @@ def func(a: tp.Tensor, b: tp.types.TensorLike): assert isinstance(b, tp.Tensor) assert b.dtype == tp.float16 - @pytest.mark.parametrize("arg, dtype", [(1.0, tp.int32), (1.0, tp.int64), (2, tp.bool)]) + @pytest.mark.parametrize( + "arg, dtype", + [(1.0, tp.int32), (1.0, tp.int64), (2, tp.bool)], + ) def test_refuse_unsafe_cast(self, arg, dtype): @wrappers.interface( - dtype_constraints={"a": "T1", "b": "T1", wrappers.RETURN_VALUE: "T1"}, - dtype_variables={"T1": ["int32", "int64"]}, + input_requirements=OneOf(GetInput("a").dtype, [tp.int32, tp.int64, tp.bool]) + & (GetInput("b").dtype == GetInput("a").dtype), + output_guarantees=(GetReturn(0).dtype == GetInput("a").dtype) & (GetReturn(1).dtype == GetInput("a").dtype), convert_to_tensors=True, ) def func(a: tp.Tensor, b: tp.types.TensorLike): diff --git a/tripy/tests/helper.py b/tripy/tests/helper.py index 73733021a..a789e9394 100644 --- a/tripy/tests/helper.py +++ b/tripy/tests/helper.py @@ -275,6 +275,7 @@ def from_name(name: str) -> "Marker": # Marks an entire block as being expected to fail. "test: xfail": Marker.from_name("TEST: XFAIL"), # Marks that a block contains the expected output from the immediate previous block. + # Used by example tests to verify correctness. "test: expected_stdout": Marker.from_name("TEST: EXPECTED_STDOUT"), # Marks that a block should be run under pytest. "test: use_pytest": Marker.from_name("TEST: USE_PYTEST"), @@ -282,7 +283,9 @@ def from_name(name: str) -> "Marker": "example: if_fp8": Marker.from_name("EXAMPLE: IF_FP8"), # Indicates that a block should be omitted from the rendered documentation. Such blocks may still be evaluated. "doc: omit": Marker.from_name("DOC: OMIT"), - # Indicates that a block should not be evaluated for the documentation. + # Indicates that a block of Python code should not be evaluated OR formatted for the documentation. + # If you would like to format but not evaluate, add a `# doc: no-eval` tag to the code block instead. + # For non-Python code, formatting is not performed anyway, so you should always use this marker instead of `# doc: no-eval`. "doc: no_eval_or_format": Marker.from_name("DOC: NO_EVAL_OR_FORMAT"), # Indicates that local variables should not be displayed for a code block in the documentation. # Useful when the raw code block is also publicly visible and we don't want inline markers (e.g. in the main README.md). diff --git a/tripy/tests/integration/test_operator_constraints.py b/tripy/tests/integration/test_operator_constraints.py new file mode 100644 index 000000000..e793ec378 --- /dev/null +++ b/tripy/tests/integration/test_operator_constraints.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for operator constraints. + +This test suite validates that Tripy's operator constraints correctly predict what the underlying +software stack (MLIR-TRT/TensorRT) will accept or reject. For each operator with constraints, we: + +1. Generate all possible combinations of data types for parameters that accept Tensors or dtypes +2. Use Tripy's input_requirements to predict whether each combination should be valid +3. Call the operator with Tripy's validation disabled, so errors come from the underlying stack +4. For valid combinations: verify the outputs match the output_guarantees +5. For invalid combinations: ensure the underlying stack raises an exception during evaluation + +This ensures Tripy's constraint system stays in sync with the underlying implementation. +""" + +import contextlib +import inspect +import itertools +from dataclasses import dataclass +from typing import Any, Callable, Dict, List + +import numpy as np +import pytest + +import nvtripy as tp +from nvtripy.common.datatype import DATA_TYPES +from nvtripy.frontend.wrappers import OPERATOR_CONSTRAINTS +from nvtripy.utils.types import str_from_type_annotation +from nvtripy.utils.utils import make_list +from tests import helper +from tests.conftest import skip_if_older_than_sm89 + + +@dataclass +class OperatorConstraintCase: + func: Callable + dtype_values: Dict[str, tp.dtype] # The specific dtype combination for this test + + def __str__(self): + param_str = "_".join(f"{key}-{val}" for key, val in self.dtype_values.items()) + return f"{self.func.__name__}-{param_str}" + + +def get_dtype_constrained_params(func: Callable) -> List[str]: + sig = inspect.signature(func) + return [ + param_name + for param_name, param in sig.parameters.items() + if param.annotation is not inspect.Parameter.empty + and (type_str := str_from_type_annotation(param.annotation)) + and ("Tensor" in type_str or "nvtripy.dtype" in type_str) + ] + + +def generate_test_cases() -> List[OperatorConstraintCase]: + cases = [] + + for op_constraint in OPERATOR_CONSTRAINTS: + func = op_constraint.func + dtype_params = get_dtype_constrained_params(func) + + if not dtype_params: + continue + + all_dtypes = list(DATA_TYPES.values()) + + for dtype_combination in itertools.product(all_dtypes, repeat=len(dtype_params)): + dtype_values = dict(zip(dtype_params, dtype_combination)) + + # Skip FP8 on older hardware + marks = [skip_if_older_than_sm89()] if any(dtype == tp.float8 for dtype in dtype_values.values()) else [] + + cases.append(pytest.param(OperatorConstraintCase(func, dtype_values), marks=marks)) + + # Sort for deterministic ordering + cases.sort(key=lambda case: str(case.values[0])) + return cases + + +CUSTOM_VALUES = { + "__getitem__": {"index": 0}, + "arange": {"start": 0, "stop": 10, "step": 1}, + "avgpool": {"kernel_dims": [2, 2]}, + "convolution": { + "weight": tp.Tensor(np.ones((2, 2, 3, 3), dtype=np.float32)), + "bias": tp.Tensor([1.0, 2.0]), + "padding": ((0, 0), (0, 0)), + "stride": [1, 1], + "groups": 1, + "dilation": [1, 1], + }, + "copy": { + "input": tp.ones((2, 2)), + "device": tp.device("cpu"), + }, + "deconvolution": { + "weight": tp.Tensor(np.ones((2, 2, 3, 3), dtype=np.float32)), + "bias": tp.Tensor([1.0, 2.0]), + "padding": ((0, 0), (0, 0)), + "stride": [1, 1], + "groups": 1, + "dilation": [1, 1], + }, + "dequantize": {"scale": tp.Tensor([1.0, 2.0]), "dim": 1}, + "expand": {"sizes": tp.Tensor((2, 2, 5, 5))}, + "full": {"shape": tp.Tensor([2, 2]), "value": tp.Tensor(1.0)}, + "full_like": {"value": tp.Tensor(1.0)}, + "gather": {"index": tp.Tensor([1])}, + "instancenorm": { + "num_channels": 2, + "weight": tp.Tensor(np.ones((2,), dtype=np.float32)), + "bias": tp.Tensor(np.zeros((2,), dtype=np.float32)), + }, + "iota": {"shape": tp.Tensor([2, 2])}, + "maxpool": {"kernel_dims": [2, 2]}, + "ones": {"shape": [2, 2]}, + "outer": {"vec1": tp.Tensor([1, 2, 3]), "vec2": tp.Tensor([1, 2, 3])}, + "pad": {"pad": [(0, 1), (1, 0), (1, 1), (0, 0)]}, + "permute": {"perm": [1, 0, 3, 2]}, + "quantize": {"scale": tp.Tensor([1.0, 2.0]), "dim": 1}, + "repeat": {"repeats": 2, "dim": 0}, + "reshape": {"shape": tp.Tensor([2, 25])}, + "resize": { + "mode": "nearest", + "output_shape": tp.Tensor((1, 2, 10, 10)), + "scales": [1, 1, 2, 2], + }, + "squeeze": {"dims": 0}, + "transpose": {"dim0": 0, "dim1": 1}, + "zeros": {"shape": [2, 2]}, +} + +# Arguments that must be constants on CPU +REQUIRES_CPU_CONST = { + "dequantize": {"scale"}, + "quantize": {"scale"}, +} + +# Some operations require input shapes to be known +REQUIRES_KNOWN_SHAPES = { + "convolution": {"input", "weight", "bias"}, + "deconvolution": {"input", "weight", "bias"}, +} + + +def _apply_tensor_adjustments(tensor: tp.Tensor, func_name: str, param_name: str) -> tp.Tensor: + if func_name in REQUIRES_CPU_CONST and param_name in REQUIRES_CPU_CONST[func_name]: + if tensor.device.kind != "cpu": + tensor = tp.copy(tensor, device=tp.device("cpu")) + + if func_name in REQUIRES_KNOWN_SHAPES and param_name in REQUIRES_KNOWN_SHAPES[func_name]: + if any(dim == tp.constants.DYNAMIC_DIM for dim in tensor.trace_tensor.shape): + tensor.trace_tensor.shape = tuple(map(int, tensor.shape)) + + return tensor + + +def generate_input_values(case: OperatorConstraintCase) -> Dict[str, Any]: + if tp.int4 in case.dtype_values.values(): + pytest.skip(f"#579: Cannot generate INT4 inputs") + + inputs = {} + sig = inspect.signature(case.func) + func_name = case.func.__name__ + + for param_name, param in sig.parameters.items(): + dtype = case.dtype_values.get(param_name) + param_type = str_from_type_annotation(param.annotation) + + # Handle custom values first + if func_name in CUSTOM_VALUES and param_name in CUSTOM_VALUES[func_name]: + inputs[param_name] = CUSTOM_VALUES[func_name][param_name] + if isinstance(inputs[param_name], tp.Tensor) and dtype is not None: + inputs[param_name] = tp.cast(inputs[param_name], dtype=dtype) + inputs[param_name] = _apply_tensor_adjustments(inputs[param_name], func_name, param_name) + continue + + # Skip optional parameters unless they need a specific dtype + if param.default is not inspect.Parameter.empty and dtype is None: + continue + + # Generate values based on parameter type + if "Tensor" in param_type: + assert dtype is not None, f"Tensor parameter '{param_name}' must have a dtype constraint" + base_tensor = tp.cast(tp.Tensor(np.ones((1, 2, 5, 5), dtype=np.float32)), dtype=dtype) + + if "Sequence" in param_type or "List" in param_type: + inputs[param_name] = [_apply_tensor_adjustments(base_tensor, func_name, param_name) for _ in range(2)] + else: + inputs[param_name] = _apply_tensor_adjustments(base_tensor, func_name, param_name) + elif "nvtripy.dtype" in param_type: + assert dtype is not None, f"dtype parameter '{param_name}' must have a dtype constraint" + inputs[param_name] = dtype + elif "numbers.Number" in param_type or "int" in param_type or "float" in param_type: + inputs[param_name] = 1 + + return inputs + + +OPERATOR_CONSTRAINT_CASES = generate_test_cases() + + +@pytest.mark.parametrize("case", OPERATOR_CONSTRAINT_CASES, ids=lambda case: str(case)) +def test_operator_constraints(case: OperatorConstraintCase): + op_constraint = next((oc for oc in OPERATOR_CONSTRAINTS if oc.func == case.func), None) + assert op_constraint is not None, f"Could not find constraints for {case.func.__name__}" + + # If input validation is enabled, negative tests will trivially pass (we will throw an + # error before even trying to call the underlying implementation). + with helper.config("enable_input_validation", False): + inputs = generate_input_values(case) + merged_args = list(inputs.items()) + + # Some operators may only define output guarantees. + # In that case, we cannot predict input validity via constraints. + is_valid = ( + True if op_constraint.input_requirements is None else bool(op_constraint.input_requirements(merged_args)) + ) + + with contextlib.ExitStack() as stack: + if not is_valid: + stack.enter_context(helper.raises(Exception)) + + outputs = make_list(case.func(**inputs)) + + for out in outputs: + if isinstance(out, tp.Tensor): + # Avoid evaluating CPU constants since some types (e.g. FP8) don't allow them to be used outside of + # certain operations. + out._eval_for_internal_methods() + + if is_valid and op_constraint.output_guarantees is not None: + output_result = op_constraint.output_guarantees(merged_args, tuple(outputs)) + assert output_result, f"Output guarantees not met for {case.func.__name__}: " + " ".join( + output_result.error_details + ) diff --git a/tripy/tests/utils/test_utils.py b/tripy/tests/utils/test_utils.py index 54725b4e1..298e4f166 100644 --- a/tripy/tests/utils/test_utils.py +++ b/tripy/tests/utils/test_utils.py @@ -20,6 +20,7 @@ import nvtripy as tp import pytest from nvtripy import utils +from nvtripy.frontend.wrappers import constant_fields from tests import helper @@ -46,7 +47,7 @@ def test_hash_equivalence(self, func): def make_with_constant_field(): - @utils.wrappers.constant_fields("field") + @constant_fields("field") class WithConstField: def __init__(self): self.custom_setter_called_count = defaultdict(int) @@ -108,3 +109,35 @@ def test_gen_uid(self, inputs, outputs, expected_prefix): def test_uniqueness(self): uids = [utils.utils.UniqueNameGen.gen_uid() for _ in range(100)] assert len(set(uids)) == 100 + + +class TestMergeFunctionArguments: + def test_defaults_filled_in_with_mixed_args(self): + # Tests that defaults are filled in and positioned correctly before kwargs + def func(a, b=2, c=3, d=4): + pass + + all_args, omitted_args, var_arg_info = utils.utils.merge_function_arguments(func, 1, d=99) + assert all_args == [("a", 1), ("d", 99)] + assert omitted_args == [("b", 2), ("c", 3)] + assert var_arg_info is None + + def test_variadic_positional_args(self): + # Tests variadic args tracking and defaults after variadic args + def func(a, *args, b=10): + pass + + all_args, omitted_args, var_arg_info = utils.utils.merge_function_arguments(func, 1, 2, 3) + assert all_args == [("a", 1), ("args", 2), ("args", 3)] + assert omitted_args == [("b", 10)] + assert var_arg_info == ("args", 1) + + def test_no_variadic_args_provided(self): + # Tests that var_arg_info is None when no variadic args are provided + def func(a, *args, b=10): + pass + + all_args, omitted_args, var_arg_info = utils.utils.merge_function_arguments(func, 1) + assert all_args == [("a", 1)] + assert omitted_args == [("b", 10)] + assert var_arg_info is None diff --git a/tripy/tests/utils/wrappers/test_datatype_constraints.py b/tripy/tests/utils/wrappers/test_datatype_constraints.py deleted file mode 100644 index 548023fe2..000000000 --- a/tripy/tests/utils/wrappers/test_datatype_constraints.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import contextlib -import inspect -import itertools -from dataclasses import dataclass -from typing import Callable, Dict, List - -import numpy as np -import nvtripy as tp -import pytest -from nvtripy.common.datatype import DATA_TYPES -from nvtripy.utils import wrappers -from nvtripy.utils.types import str_from_type_annotation -from nvtripy.utils.utils import make_list -from nvtripy.utils.wrappers import DATA_TYPE_CONSTRAINTS -from tests import helper -from tests.conftest import skip_if_older_than_sm89 - - -@dataclass -class DtypeConstraintCase: - func: Callable - constraints: Dict[str, str] - variables: Dict[str, tp.dtype] - negative: bool # Whether this is a negative test case - - def __str__(self): - return f"{self.func.__name__}-{'_'.join(f'{key}-{val}' for key, val in self.variables.items())}" + ( - "-invalid" if self.negative else "-valid" - ) - - -DTYPE_CONSTRAINT_CASES: List[DtypeConstraintCase] = [] - -for dtc in DATA_TYPE_CONSTRAINTS: - keys, values = list(zip(*dtc.variables.items())) - - def add_cases(combinations, negative): - for combination in combinations: - dtype_combination = dict(zip(keys, combination)) - DTYPE_CONSTRAINT_CASES.append( - pytest.param( - DtypeConstraintCase(dtc.func, dtc.constraints, dtype_combination, negative), - marks=( - skip_if_older_than_sm89() - if any(dtype == "float8" for dtype in dtype_combination.values()) - else [] - ), - ) - ) - - # Positive cases: - positive_combinations = list(itertools.product(*values)) - exceptions = [tuple(exception[key] for key in keys) for exception in dtc.exceptions] - positive_combinations = [comb for comb in positive_combinations if comb not in exceptions] - add_cases(positive_combinations, negative=False) - - # Negative cases - we do this by simply generating all possible combinations and removing the positive ones: - total_dtypes = set(map(str, DATA_TYPES.values())) - negative_combinations = list(itertools.product(*(total_dtypes for _ in values))) - negative_combinations = list(comb for comb in negative_combinations if comb not in positive_combinations) - add_cases(negative_combinations, negative=True) - - -DTYPE_CONSTRAINT_CASES.sort(key=lambda case: str(case)) - - -# In some cases, we need to use custom values so that the code is valid. -CUSTOM_VALUES = { - "__getitem__": {"index": 0}, - "arange": {"start": 0, "stop": 10, "step": 1}, - "avgpool": {"kernel_dims": [2, 2]}, - "convolution": { - "weight": tp.Tensor(np.ones((2, 2, 3, 3), dtype=np.float32)), - "bias": tp.Tensor([1.0, 2.0]), - "padding": ((0, 0), (0, 0)), - "stride": [1, 1], - "groups": 1, - "dilation": [1, 1], - }, - "copy": { - "input": tp.ones((2, 2)), - "device": tp.device("cpu"), - }, - "deconvolution": { - "weight": tp.Tensor(np.ones((2, 2, 3, 3), dtype=np.float32)), - "bias": tp.Tensor([1.0, 2.0]), - "padding": ((0, 0), (0, 0)), - "stride": [1, 1], - "groups": 1, - "dilation": [1, 1], - }, - "dequantize": {"scale": tp.Tensor([1.0, 2.0]), "dim": 1}, - "expand": {"sizes": tp.Tensor((2, 2, 5, 5))}, - "full": {"shape": tp.Tensor([2, 2]), "value": tp.Tensor(1.0)}, - "full_like": {"value": tp.Tensor(1.0)}, - "gather": {"index": tp.Tensor([1])}, - "instancenorm": { - "num_channels": 2, - "weight": tp.Tensor(np.ones((2,), dtype=np.float32)), - "bias": tp.Tensor(np.zeros((2,), dtype=np.float32)), - }, - "iota": {"shape": tp.Tensor([2, 2])}, - "maxpool": {"kernel_dims": [2, 2]}, - "ones": {"shape": [2, 2]}, - "outer": {"vec1": tp.Tensor([1, 2, 3]), "vec2": tp.Tensor([1, 2, 3])}, - "pad": {"pad": [(0, 1), (1, 0), (1, 1), (0, 0)]}, - "permute": {"perm": [1, 0, 3, 2]}, - "quantize": {"scale": tp.Tensor([1.0, 2.0]), "dim": 1}, - "repeat": {"repeats": 2, "dim": 0}, - "reshape": {"shape": tp.Tensor([2, 25])}, - "resize": { - "mode": "nearest", - "output_shape": tp.Tensor((1, 2, 10, 10)), - "scales": [1, 1, 2, 2], - }, - "squeeze": {"dims": 0}, - "transpose": {"dim0": 0, "dim1": 1}, - "zeros": {"shape": [2, 2]}, -} - -# Arguments that must be constants on CPU. -REQUIRES_CPU_CONST = { - "dequantize": {"scale"}, - "quantize": {"scale"}, -} - -# Some operations require input shapes to be known -REQUIRES_KNOWN_SHAPES = { - "convolution": {"input", "weight", "bias"}, - "deconvolution": {"input", "weight", "bias"}, -} - - -def generate_input_values(case: DtypeConstraintCase): - inputs = {} - for param_name, param in inspect.signature(case.func).parameters.items(): - requires_cpu_const = ( - case.func.__name__ in REQUIRES_CPU_CONST and param_name in REQUIRES_CPU_CONST[case.func.__name__] - ) - requires_known_shapes = ( - case.func.__name__ in REQUIRES_KNOWN_SHAPES and param_name in REQUIRES_KNOWN_SHAPES[case.func.__name__] - ) - - param_type = str_from_type_annotation(param.annotation) - - dtype = None - if param_name in case.constraints: - dtype = DATA_TYPES[case.variables[case.constraints[param_name]]] - - if dtype == tp.int4: - # TODO (#579): Enable int4 inputs - pytest.skip(f"#579: Cannot generate INT4 inputs") - - def copy_input_to_cpu_and_set_shapes(): - if requires_cpu_const: - assert isinstance(inputs[param_name], tp.Tensor) - if inputs[param_name].device.kind != "cpu": - inputs[param_name] = tp.copy(inputs[param_name], device=tp.device("cpu")) - - if requires_known_shapes: - assert isinstance(inputs[param_name], tp.Tensor) - if any(dim == tp.constants.DYNAMIC_DIM for dim in inputs[param_name].trace_tensor.shape): - inputs[param_name].trace_tensor.shape = tuple(map(int, inputs[param_name].shape)) - - if case.func.__name__ in CUSTOM_VALUES and param_name in CUSTOM_VALUES[case.func.__name__]: - inputs[param_name] = CUSTOM_VALUES[case.func.__name__][param_name] - if isinstance(inputs[param_name], tp.Tensor): - if dtype is not None: - inputs[param_name] = tp.cast(inputs[param_name], dtype=dtype) - copy_input_to_cpu_and_set_shapes() - continue - - if param.default is not inspect.Parameter.empty and dtype is None: - continue # Skip optional parameters unless we explicitly need to set a datatype for them. - - if "nvtripy.Tensor" in param_type: - assert dtype is not None, "Tensors must have type annotations" - # Need to cast here because `ones` does not support all types. - tensor = tp.Tensor(np.ones((1, 2, 5, 5), dtype=np.float32)) - if "Sequence" in param_type: - inputs[param_name] = [tp.cast(tensor, dtype=dtype) for _ in range(2)] - copy_input_to_cpu_and_set_shapes() - else: - inputs[param_name] = tp.cast(tensor, dtype=dtype) - copy_input_to_cpu_and_set_shapes() - elif "nvtripy.dtype" in param_type: - assert dtype is not None, "Data types must have type annotations" - inputs[param_name] = dtype - elif "numbers.Number" in param_type or "int" in param_type or "float" in param_type: - inputs[param_name] = 1 - else: - assert False, f"Unsupported parameter type: {param_type}" - return inputs - - -@pytest.mark.parametrize("case", DTYPE_CONSTRAINT_CASES, ids=lambda case: str(case)) -def test_datatype_constraints(case: DtypeConstraintCase): - - # If data type checking is enabled, negative tests will trivially pass (we will throw an - # error before even trying to call the function). - with helper.config("enable_dtype_checking", False): - inputs = generate_input_values(case) - - with contextlib.ExitStack() as stack: - if case.negative: - stack.enter_context(helper.raises(Exception)) - - outputs = make_list(case.func(**inputs)) - - # Some APIs do not generate Tensor outputs (e.g. `allclose`), so we don't need to evaluate those. - if not any(isinstance(out, tp.Tensor) for out in outputs): - return - - expected_return_types = [ - case.variables[cons] for cons in make_list(case.constraints[wrappers.RETURN_VALUE]) - ] - assert expected_return_types, "Return value must have a constraint" - if len(expected_return_types) < len(outputs): - expected_return_types += [expected_return_types[-1]] * (len(outputs) - len(expected_return_types)) - - for out, expected_type in zip(outputs, expected_return_types): - # Avoid evaluating CPU constants since some types (e.g. FP8) don't allow them to be used outside of - # certain operations. - out._eval_for_internal_methods() - assert out.dtype == DATA_TYPES[expected_type], f"Expected {expected_type}, got {out.dtype}"