diff --git a/examples/dynamo/aot_plugin.py b/examples/dynamo/aot_plugin.py index 234b2b4204..810d6f4d98 100644 --- a/examples/dynamo/aot_plugin.py +++ b/examples/dynamo/aot_plugin.py @@ -34,10 +34,11 @@ import tensorrt as trt import tensorrt.plugin as trtp import torch -import torch_tensorrt import triton import triton.language as tl +import torch_tensorrt + trt_logger = trt.Logger(trt.Logger.VERBOSE) @@ -51,7 +52,9 @@ @triton.jit -def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr): +def add_one_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + # AOT path requires (inputs, outputs, extra_args) order — swapping any + # two slots feeds the wrong value into the kernel and segfaults. pid = tl.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -61,6 +64,18 @@ def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr): tl.store(y_ptr + offsets, output, mask=mask) +@triton.jit +def add_one_inplace_kernel(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + # Aliased-I/O variant: TRT routes a single pointer for the shared + # input/output buffer, so the kernel takes one pointer, not two. + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + tl.store(x_ptr + offsets, x + 1, mask=mask) + + # %% # Step 2: Register the PyTorch op # ----------------------------------------- @@ -77,7 +92,7 @@ def add_one(X: torch.Tensor) -> torch.Tensor: Y = torch.empty_like(X) BLOCK_SIZE = 256 grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),) - add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE) + add_one_kernel[grid](X, Y, X.numel(), BLOCK_SIZE=BLOCK_SIZE) return Y @@ -148,8 +163,8 @@ def add_plugin_aot_impl( fn=add_one_kernel, signature={ "x_ptr": f"*{type_str}", - "n_elements": "i32", "y_ptr": f"*{type_str}", + "n_elements": "i32", }, constexprs={ "BLOCK_SIZE": block_size, @@ -166,15 +181,13 @@ def add_plugin_aot_impl( launch_params.block_x = compiled_kernel.metadata.num_warps * 32 # threads per block launch_params.shared_mem = compiled_kernel.metadata.shared # bytes of shared mem - # ``extra_args`` are scalar arguments appended to the kernel's argument list at - # launch. Here ``n_elements`` is passed as a 32-bit symbolic integer so TRT - # evaluates it from the actual tensor size at runtime. - extra_args = trtp.SymIntExprs(1) - extra_args[0] = trtp.SymInt32(N) + extra_args = torch_tensorrt.dynamo.conversion.plugins.make_aot_extra_args( + [trtp.SymInt32(N)], compiled_kernel=compiled_kernel + ) return ( - compiled_kernel.metadata.name, # kernel function name in PTX - compiled_kernel.asm["ptx"], # PTX source — embedded in TRT engine + compiled_kernel.metadata.name, + compiled_kernel.asm["ptx"], launch_params, extra_args, ) @@ -201,6 +214,87 @@ def add_plugin_aot_impl( ) +# %% +# In-place variant: aliased plugin I/O +# ----------------------------------------- +# +# Same kernel exposed as an in-place plugin: the engine mutates the input +# buffer directly via QDP ``TensorDesc.aliased()`` instead of allocating a +# separate output. Useful for KV-cache updates and similar patterns. +# +# Two signals together declare aliasing to the framework: +# * ``mutates_args=("X",)`` on the torch op +# * the registered fake returns ``X`` by identity +# The eager impl must mutate ``X`` itself and return a clone — ``torch.library`` +# rejects returning an input by identity. + + +@torch.library.custom_op("my::add_one_inplace", mutates_args=("X",)) # type: ignore[misc] +def add_one_inplace(X: torch.Tensor) -> torch.Tensor: + assert X.is_cuda + BLOCK_SIZE = 256 + grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),) + add_one_inplace_kernel[grid](X, X.numel(), BLOCK_SIZE=BLOCK_SIZE) + return X.clone() + + +@torch.library.register_fake("my::add_one_inplace") +def _(X: torch.Tensor) -> torch.Tensor: + return X + + +@trtp.register("my::add_one_inplace") +def add_one_inplace_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]: + return X.aliased() + + +@trtp.aot_impl("my::add_one_inplace") +def add_one_inplace_aot_impl( + X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int +) -> Tuple[ + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs +]: + type_str = "fp32" if X.dtype == trt.float32 else "fp16" + + block_size = 256 + src = triton.compiler.ASTSource( + fn=add_one_inplace_kernel, + signature={ + "x_ptr": f"*{type_str}", + "n_elements": "i32", + }, + constexprs={ + "BLOCK_SIZE": block_size, + }, + ) + compiled_kernel = triton.compile(src) + + N = X.shape_expr.numel() + launch_params = trtp.KernelLaunchParams() + launch_params.grid_x = trtp.cdiv(N, block_size) + launch_params.block_x = compiled_kernel.metadata.num_warps * 32 + launch_params.shared_mem = compiled_kernel.metadata.shared + + extra_args = torch_tensorrt.dynamo.conversion.plugins.make_aot_extra_args( + [trtp.SymInt32(N)], compiled_kernel=compiled_kernel + ) + + return ( + compiled_kernel.metadata.name, + compiled_kernel.asm["ptx"], + launch_params, + extra_args, + ) + + +torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( + "my::add_one_inplace", + supports_dynamic_shapes=False, + requires_output_allocator=False, + use_aot_if_available=True, +) + + # %% # Step 6: Compile and Run # ----------------------------------------- @@ -219,6 +313,11 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: return res +class MyInplaceModel(torch.nn.Module): + def forward(self, X: torch.Tensor) -> torch.Tensor: + return torch.ops.my.add_one_inplace.default(X) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -246,3 +345,50 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: assert torch.allclose(res, my_model(m)), "Results do not match!" print("Inference successful!") + + # %% + # In-place plugin demo + # --------------------- + # + # In-place ops mutate their input, so eager and TRT must run on separate + # cloned buffers; otherwise each comparison double-applies the mutation. + + print("\nIn-place plugin demo:") + inplace_model = MyInplaceModel().to("cuda").eval() + base = torch.full((64, 64), 2, device="cuda", dtype=torch.float) + expected_post = base + 1 + + model_trt_inplace = torch_tensorrt.compile( + inplace_model, + inputs=[base.clone()], + min_block_size=1, + immutable_weights=True, + ) + + from torch_tensorrt.dynamo.runtime import ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + ) + + engine_submodules = [ + m + for _, m in model_trt_inplace.named_modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + assert engine_submodules, ( + "Expected a TRT engine submodule for the in-place plugin path; got a" + f" pure-PyTorch fallback. Graph:\n{model_trt_inplace.graph}" + ) + print(f" TRT engine submodule(s) present: {len(engine_submodules)}") + + with torch.no_grad(): + trt_input = base.clone() + trt_out = model_trt_inplace(trt_input) + assert torch.allclose(trt_out, expected_post), "TRT output mismatch" + assert torch.allclose( + trt_input, expected_post + ), "Engine did not mutate the input buffer — aliased plugin I/O is not active." + + print(" Output matches expected post-mutation value.") + print(" Input buffer was mutated in place by the TRT engine.") + print("In-place inference successful!") diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index f5ffdafda2..3de546404a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field import torch + from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.types import TRTNetwork @@ -23,6 +24,7 @@ class ConversionContext: ) requires_output_allocator: bool = False requires_native_multidevice: bool = False + requires_aliased_plugin_io: bool = False weight_refit_map: dict[str, torch.Tensor] = field(default_factory=dict) cpu_weights_reference_holder: list[torch.Tensor] = field(default_factory=list) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index e1f4d8bafb..9b7bdd1196 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -24,6 +24,7 @@ from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._python_dispatch import _disable_current_modes + from torch_tensorrt import ENABLED_FEATURES from torch_tensorrt._enums import dtype from torch_tensorrt._features import needs_refit @@ -310,6 +311,19 @@ def _populate_trt_builder_config( if self.compilation_settings.enable_weight_streaming: builder_config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING) + if self.ctx.requires_aliased_plugin_io: + aliased_io_feature = getattr( + trt.PreviewFeature, "ALIASED_PLUGIN_IO_10_03", None + ) + if aliased_io_feature is None: + raise RuntimeError( + "An in-place QDP plugin declared aliased I/O, but this TensorRT" + " version does not expose PreviewFeature.ALIASED_PLUGIN_IO_10_03." + " TensorRT 10.3+ is required for aliased plugin I/O." + ) + builder_config.set_preview_feature(aliased_io_feature, True) + _LOGGER.info("Enabling preview feature ALIASED_PLUGIN_IO_10_03") + if is_tensorrt_version_supported("10.8"): TilingOptimizationLevel = { "none": trt.TilingOptimizationLevel.NONE, diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py b/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py index fc5e973560..e71b3e0044 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py @@ -1,3 +1,4 @@ +from torch_tensorrt.dynamo.conversion.plugins._aot_utils import make_aot_extra_args from torch_tensorrt.dynamo.conversion.plugins._custom_op import custom_op from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import ( diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_aot_utils.py b/py/torch_tensorrt/dynamo/conversion/plugins/_aot_utils.py new file mode 100644 index 0000000000..03ec4fd9d1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_aot_utils.py @@ -0,0 +1,37 @@ +"""Helpers for writing AOT QDP plugins backed by Triton kernels.""" + +from typing import Any, Sequence + + +def _has_triton_scratch_params(compiled_kernel: Any) -> bool: + md = getattr(compiled_kernel, "metadata", None) + if md is None: + return False + return hasattr(md, "global_scratch_size") and hasattr(md, "profile_scratch_size") + + +def make_aot_extra_args( + user_args: Sequence[Any], + *, + compiled_kernel: Any = None, +) -> Any: + """Build a ``trtp.SymIntExprs`` for an AOT plugin's ``extra_args`` return. + + When ``compiled_kernel`` is a Triton-compiled kernel, four trailing + ``SymInt32(0)`` are appended to cover the two ``.param .u64 .ptr`` slots + (``global_scratch``, ``profile_scratch``) that Triton >= 3.x always emits + in PTX even when their sizes are zero. TRT's AOT plugin path does not + plumb those slots through, so without padding ``enqueueV3`` reads stale + register state for them and segfaults on the first call. + """ + import tensorrt.plugin as trtp + + pad = 4 if _has_triton_scratch_params(compiled_kernel) else 0 + total = len(user_args) + pad + out = trtp.SymIntExprs(total) + for i, arg in enumerate(user_args): + out[i] = arg + zero = trtp.SymInt32(0) + for i in range(len(user_args), total): + out[i] = zero + return out diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py index bf087d01cc..9475eab86e 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py @@ -71,12 +71,34 @@ def _generate_plugin(plugin_name: str) -> None: # retrieve the corresponding torch operation using the passed in string torch_op = getattr(getattr(torch.ops, namespace), name) + default_schema = torch_op.default._schema + + # Positional indices of tensor inputs marked as mutated in the schema + # (Tensor(aN!) ... -> ...) — candidates for in-place QDP outputs via + # `TensorDesc.aliased()`. The actual alias map (output_idx -> input_idx) + # is decided at fake-run time by checking output-to-input identity. + _tensor_arg_positions = [ + i + for i, a in enumerate(default_schema.arguments) + if a.type.isSubtypeOf(torch._C.TensorType.get()) + ] + _mutated_tensor_arg_positions = { + _tensor_arg_positions.index(i) + for i, a in enumerate(default_schema.arguments) + if a.type.isSubtypeOf(torch._C.TensorType.get()) + and a.alias_info is not None + and a.alias_info.is_write + } + + # Cached frozenset of aliased output indices, populated by the descriptor + # on first build-time call so the runtime impl skips the per-call probe. + _aliased_indices_cache: list[Any] = [None] # helper function that generates the required signature based on the torch operation def generate_signature( torch_op: Callable[[Any], Any], ) -> Tuple[str, str, str, dict[str, Any], dict[str, Any]]: - schema = torch_op._schemas[""] + schema = torch_op.default._schema arg_list = [] @@ -183,10 +205,20 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc, .. output = torch_op(*fake_args, *non_tensor_args, **torch_kwargs) - # Normalize to a list of fake outputs. Multi-output torch ops return - # a tuple; single-output ops return a bare Tensor. outputs_list = list(output) if isinstance(output, (tuple, list)) else [output] + # output_idx -> input_idx alias map: an output aliases a mutated input + # iff the schema marks the input as mutated AND the fake returns that + # tensor by identity. The schema gate prevents accidental aliasing on + # incidental identity returns from non-mutating ops. + alias_map: dict[int, int] = {} + for out_idx, fake_out in enumerate(outputs_list): + for in_idx in _mutated_tensor_arg_positions: + if in_idx < len(fake_args) and fake_out is fake_args[in_idx]: + alias_map[out_idx] = in_idx + break + _aliased_indices_cache[0] = frozenset(alias_map.keys()) + input_node_expr = list( itertools.chain.from_iterable( [sym.node.expr for sym in syms_arg] for syms_arg in syms_args @@ -208,6 +240,17 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc, .. out_descs = [] for out_idx, fake_out in enumerate(outputs_list): + if out_idx in alias_map: + # Aliased output shares its buffer with a mutated input; + # `aliased()` carries the input's shape/dtype. + # Limitation: TRT preview-feature ALIASED_PLUGIN_IO_10_03 + # inserts a defensive copy that breaks aliasing when a + # multi-output plugin's aliased output is consumed by + # another TRT layer in the same engine. Single-output + # plugins (the common KV-cache pattern) work end-to-end. + out_descs.append(tensor_args[alias_map[out_idx]].aliased()) + continue + shape_calc_fns: list[Any] = [None] * fake_out.ndim for i in range(fake_out.ndim): out_dim = fake_out.shape[i] @@ -274,12 +317,25 @@ def _generic_plugin_impl( dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] + # Skip copy_ for aliased outputs: storage is shared with the input + # the eager op already mutated, so copying would either no-op against + # itself or clobber the in-place result when the eager op returns a + # fresh tensor (as torch.library forces it to do). + aliased_outputs = _aliased_indices_cache[0] + if aliased_outputs is None: + aliased_outputs = frozenset( + i for i, o in enumerate(outputs) if o.get_aliased() is not None + ) + stream = torch.cuda.ExternalStream(stream) with torch.cuda.stream(stream): out_tensors = torch_op(*in_tensors, *non_tensor_args, **torch_kwargs) if isinstance(out_tensors, torch.Tensor): out_tensors = (out_tensors,) - [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] + for i, (d, o) in enumerate(zip(dest_tensors, out_tensors)): + if i in aliased_outputs: + continue + d.copy_(o) plugin_impl_func = f""" {plugin_impl_signature} diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index a2c39e6f7e..d43296a267 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np +import numpy.typing as npt import tensorrt as trt import torch from torch.fx.node import Argument, Node, Target @@ -44,6 +45,80 @@ def _coerce_plugin_attr_for_qdp(value: Any, attr_annotation: Any) -> Any: return value +_PYTHON_SCALAR_TO_NUMPY_DTYPE = { + float: np.float64, + int: np.int64, + bool: np.bool_, +} + + +def _patch_trtp_scalar_attr_roundtrip() -> None: + """Work around ``_TemplatePluginCreator.create_plugin`` calling + ``float(np.array([v]))`` on scalar plugin attrs and crashing because + ``f.data`` is always 1-d after the C++ PluginField round-trip. We + temporarily promote the annotation to ``npt.NDArray[dtype]``, run the + upstream path, then unwrap back to the declared Python scalar type. + Idempotent; no-op once upstream ships a fix. + """ + try: + from tensorrt_bindings.plugin import _lib as _trtp_lib + from tensorrt_bindings.plugin._utils import _is_numpy_array + except ImportError: + return + + creator_cls = getattr(_trtp_lib, "_TemplatePluginCreator", None) + if creator_cls is None or getattr(creator_cls, "_torch_trt_scalar_patched", False): + return + + orig_create_plugin = creator_cls.create_plugin + + def _patched_create_plugin( + self: Any, + name: str, + namespace: str, + fc: Any, + phase: Any, + qpcr: Any = None, + ) -> Any: + from tensorrt_bindings.plugin._lib import QDP_REGISTRY + + desc = QDP_REGISTRY.get(f"{namespace}::{name}") + if desc is None: + return orig_create_plugin(self, name, namespace, fc, phase, qpcr) + + scalar_attrs: dict[str, type] = {} + for f in fc: + ann = desc.input_attrs.get(f.name) + if ann is None or _is_numpy_array(ann): + continue + if not isinstance(ann, type): + continue + if ann in _PYTHON_SCALAR_TO_NUMPY_DTYPE: + scalar_attrs[f.name] = ann + + if not scalar_attrs: + return orig_create_plugin(self, name, namespace, fc, phase, qpcr) + + saved_annotations = {n: desc.input_attrs[n] for n in scalar_attrs} + for n, ann in scalar_attrs.items(): + desc.input_attrs[n] = npt.NDArray[_PYTHON_SCALAR_TO_NUMPY_DTYPE[ann]] # type: ignore[valid-type] + try: + plg = orig_create_plugin(self, name, namespace, fc, phase, qpcr) + finally: + for n, ann in saved_annotations.items(): + desc.input_attrs[n] = ann + + for n, ann in scalar_attrs.items(): + value = plg.attrs.get(n) + if isinstance(value, np.ndarray) and value.size == 1: + plg.attrs[n] = ann(value.reshape(()).item()) + + return plg + + creator_cls.create_plugin = _patched_create_plugin + creator_cls._torch_trt_scalar_patched = True + + def _is_numpy_attr_annotation(annotation: Any) -> bool: return annotation is np.ndarray or typing.get_origin(annotation) is np.ndarray @@ -91,6 +166,8 @@ def _generate_plugin_converter( ) from tensorrt.plugin._lib import QDP_REGISTRY + _patch_trtp_scalar_attr_roundtrip() + torch_target = getattr(getattr(torch.ops, namespace), op_name) overload_str = overload if overload else "" overload_name = overload_str if overload else "default" @@ -99,7 +176,14 @@ def _generate_plugin_converter( f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}," " unable to generate converter" ) - torch_schema = torch_target._schemas[overload_str] + torch_schema = torch_overload._schema + + schema_declares_mutation = any( + arg.alias_info is not None + and arg.alias_info.is_write + and arg.type.isSubtypeOf(torch._C.TensorType.get()) + for arg in torch_schema.arguments + ) use_aot_plugin = use_aot_if_available @@ -159,8 +243,19 @@ def custom_kernel_converter( f"Adding generated plugin for {namespace}::{name} to tensorrt network" ) layer.name = f"[{target}]-[{name}]" - # Single-output ops expect a bare ITensor; multi-output ops expect a - # tuple so the downstream ``getitem`` converter can unpack it. + + # JIT path: layer.plugin is the Python `_TemplateJITPlugin` whose + # `aliased_map` is populated by TRT during `add_plugin`. + # AOT path: layer.plugin is a C++ wrapper that does not expose the + # map, so fall back to the op schema's mutation declaration — the + # same signal `_generate_plugin` uses to emit `.aliased()`. + layer_plugin = getattr(layer, "plugin", None) + aliased_map = getattr(layer_plugin, "aliased_map", None) + if aliased_map and any(v != -1 for v in aliased_map.values()): + ctx.requires_aliased_plugin_io = True + elif schema_declares_mutation: + ctx.requires_aliased_plugin_io = True + num_outputs = len(torch_schema.returns) if num_outputs == 1: return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index b25219bc82..dc7a955065 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Optional, Sequence, Union import torch + from torch_tensorrt._utils import is_tegra_platform from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.lowering.passes._FakeTensorUpdater import FakeTensorUpdater @@ -23,6 +24,7 @@ from .replace_fused_rms_norm import replace_fused_rms_norm from .replace_max_pool_with_indices import replace_max_pool_with_indices from .rule_based_autocast import rule_based_autocast +from .unfunctionalize_qdp_inplace import unfunctionalize_qdp_inplace pre_lowering_pass_list = [ remove_detach, @@ -31,6 +33,10 @@ ] post_lowering_pass_list = [ + # Must run before remove_num_users_is_0_nodes and any pass that walks the + # converter registry, so the underlying mutating op is restored ahead of + # partitioning and the QDP `.aliased()` descriptor can take effect. + unfunctionalize_qdp_inplace, replace_fused_rms_norm, remove_input_alias_fixing_clones, constant_fold, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py b/py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py new file mode 100644 index 0000000000..605802687b --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/unfunctionalize_qdp_inplace.py @@ -0,0 +1,168 @@ +"""Reverse ``run_decompositions``' functionalization of mutating custom ops +that have a registered Dynamo converter (QDP in-place plugins). + +``run_decompositions`` rewrites ``my_inplace_op(x, ...)`` into:: + + %af = auto_functionalized_v2(my_inplace_op, _x_base_index=N, _all_bases=[%x], ...) + %g0 = af[0] # the op's actual return + %gk = af[k] # post-mutation base (k = 1..len(_all_bases)) + %copy_ = aten.copy_.default(%x, %gk) + +Correct in eager, but our converter is registered against the original +mutating overload — the partitioner sees the HOP wrapper as unsupported and +bails. This pass restores the direct mutating call when a converter exists +and drops the synthesized copy_ nodes. +""" + +import logging +import operator +from typing import Any, Dict, List + +import torch + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS + +logger = logging.getLogger(__name__) + + +def _auto_functionalized_targets() -> List[Any]: + targets: List[Any] = [] + higher_order = getattr(torch.ops, "higher_order", None) + if higher_order is None: + return targets + for name in ("auto_functionalized_v2", "auto_functionalized"): + op = getattr(higher_order, name, None) + if op is not None: + targets.append(op) + return targets + + +def _reconstruct_op_args(op_overload: Any, node_kwargs: Dict[str, Any]) -> List[Any]: + """Rebuild positional args for a direct call to ``op_overload``. + + The QDP-generated converter reads tensor inputs positionally + (``args[0 : len(tensor_inputs)]``), so args must be in schema order. + ``auto_functionalized_v2`` packs mutated tensors via + ``__base_index: N`` + ``_all_bases: [t0, ...]``; non-mutated args + are passed by name as-is. + """ + bases = node_kwargs.get("_all_bases", []) + out: List[Any] = [] + schema = op_overload._schema + for arg in schema.arguments: + base_key = f"_{arg.name}_base_index" + if base_key in node_kwargs: + out.append(bases[node_kwargs[base_key]]) + elif arg.name in node_kwargs: + out.append(node_kwargs[arg.name]) + elif arg.has_default_value(): + out.append(arg.default_value) + else: + raise RuntimeError( + f"auto_functionalized_v2 missing argument '{arg.name}' for" + f" {op_overload} (no value and no default)" + ) + return out + + +def unfunctionalize_qdp_inplace( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + af_targets = _auto_functionalized_targets() + if not af_targets: + return gm + + converter_check_cache: Dict[Any, bool] = {} + aten_copy_ = torch.ops.aten.copy_.default + hops_to_rewrite: List[Any] = [] + copy_candidates: List[Any] = [] + + for node in list(gm.graph.nodes): + if node.op != "call_function": + continue + if node.target in af_targets: + if not node.args or not hasattr(node.args[0], "_schema"): + continue + op_overload = node.args[0] + has_converter = converter_check_cache.get(op_overload) + if has_converter is None: + has_converter = op_overload in DYNAMO_CONVERTERS + converter_check_cache[op_overload] = has_converter + if has_converter: + hops_to_rewrite.append(node) + elif ( + node.target is aten_copy_ + and len(node.args) >= 2 + and getattr(node.args[0], "op", None) == "placeholder" + ): + copy_candidates.append(node) + + if not hops_to_rewrite: + return gm + + for node in hops_to_rewrite: + op_overload = node.args[0] + op_args = _reconstruct_op_args(op_overload, dict(node.kwargs)) + n_outputs = len(op_overload._schema.returns) + hop_val = node.meta.get("val") + bases = node.kwargs.get("_all_bases", []) + + with gm.graph.inserting_before(node): + new_call = gm.graph.call_function(op_overload, args=tuple(op_args)) + if isinstance(hop_val, tuple) and len(hop_val) >= 1: + if n_outputs == 1: + new_call.meta["val"] = hop_val[0] + else: + new_call.meta["val"] = tuple(hop_val[:n_outputs]) + if "tensor_meta" in node.meta: + new_call.meta["tensor_meta"] = node.meta["tensor_meta"] + + # For single-output ops the op's return and every post-mutation base + # are the same tensor (the in-place result), so routing all getitem + # users to ``new_call`` is correct and keeps it alive. For + # multi-output ops we materialize one getitem per return and route + # base-slot users to the corresponding base placeholder — the + # mutation has already been applied in place by ``new_call``. + getitem_users = [u for u in list(node.users) if u.target is operator.getitem] + if n_outputs == 1: + for user in getitem_users: + user.replace_all_uses_with(new_call) + gm.graph.erase_node(user) + else: + return_getitems: List[Any] = [] + with gm.graph.inserting_after(new_call): + for i in range(n_outputs): + g = gm.graph.call_function( + operator.getitem, args=(new_call, i) + ) + if isinstance(hop_val, tuple) and i < len(hop_val): + g.meta["val"] = hop_val[i] + return_getitems.append(g) + for user in getitem_users: + idx = user.args[1] + if idx < n_outputs: + user.replace_all_uses_with(return_getitems[idx]) + else: + base_idx = idx - n_outputs + user.replace_all_uses_with(bases[base_idx]) + gm.graph.erase_node(user) + + if list(node.users): + raise RuntimeError( + f"auto_functionalized_v2 node {node.name} has non-getitem users" + f" {list(node.users)}; cannot un-functionalize safely." + ) + gm.graph.erase_node(node) + + # Functionalization adds copy_(base, op_return) to write the mutation + # back through the placeholder. The direct call already mutates the + # buffer, so the copy_ is redundant and would block partitioning. + for node in copy_candidates: + node.replace_all_uses_with(node.args[1]) + gm.graph.erase_node(node) + + gm.graph.lint() + gm.recompile() + logger.debug(f"Un-functionalized QDP in-place ops:\n{gm.graph}") + return gm diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 3c454933bb..3ddb812b9f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -7,8 +7,9 @@ import torch import torch.distributed as dist -import torch_tensorrt from torch.nn import Module + +import torch_tensorrt from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -436,6 +437,20 @@ def setup_engine(self) -> None: for input_name in self.input_names } + # For QDP plugins with aliased I/O, the output binding must share the + # input binding's buffer at runtime; otherwise the in-place mutation is + # lost. Resolve the alias mapping to (output_idx -> input_idx) once so + # the forward path can rebind without per-call name lookups. + self.aliased_output_idx_to_input_idx: Dict[int, int] = {} + if hasattr(self.engine, "get_aliased_input_tensor"): + input_name_to_idx = {n: i for i, n in enumerate(self.input_names)} + for out_idx, output_name in enumerate(self.output_names): + aliased_input_name = self.engine.get_aliased_input_tensor(output_name) + if aliased_input_name: + self.aliased_output_idx_to_input_idx[out_idx] = input_name_to_idx[ + aliased_input_name + ] + def _setup_runtime_config(self) -> None: """Create a RuntimeConfig with runtime cache for TensorRT-RTX. @@ -727,6 +742,14 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ) outputs = self.create_output_tensors() + # Rebind aliased outputs to their paired input buffer so + # the in-place mutation lands in the caller's tensor. + for ( + out_idx, + in_idx, + ) in self.aliased_output_idx_to_input_idx.items(): + outputs[out_idx] = contiguous_inputs[in_idx] + for o, output_name in enumerate(self.output_names): if need_cudagraphs_record: self._output_buffers[o] = outputs[o].clone() diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py new file mode 100644 index 0000000000..609ba1c075 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace.py @@ -0,0 +1,99 @@ +import platform +import unittest + +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +import torch_tensorrt + + +@torch.library.custom_op("torchtrt_ex::add_one_inplace", mutates_args=("X",)) # type: ignore[misc] +def add_one_inplace(X: torch.Tensor) -> torch.Tensor: + assert X.is_cuda + X.add_(1) + return X.clone() + + +@torch.library.register_fake("torchtrt_ex::add_one_inplace") +def _(X: torch.Tensor) -> torch.Tensor: + return X + + +if torch_tensorrt.ENABLED_FEATURES.qdp_plugin: + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::add_one_inplace", supports_dynamic_shapes=False + ) + + +@unittest.skipIf( + platform.system() == "Windows", + "QDP in-place test requires Linux", +) +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.qdp_plugin, + "QDP Plugin is not available", +) +class TestInplacePlugin(unittest.TestCase): + """In-place ops mutate their input, so DispatchTestCase.run_test (which + feeds the same tensor to eager and TRT) double-applies the mutation. We + use cloned inputs and check both the return value and the in-place write. + """ + + @parameterized.expand( + [ + ((64, 64), torch.float), + ((128, 32), torch.float), + ] + ) + def test_add_one_inplace(self, input_shape, dtype): + class Model(nn.Module): + def forward(self, x): + return torch.ops.torchtrt_ex.add_one_inplace.default(x) + + base = torch.randn(input_shape, device="cuda", dtype=dtype) + + eager_input = base.clone() + eager_out = Model()(eager_input) + expected_post = base + 1 + torch.testing.assert_close(eager_input, expected_post, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(eager_out, expected_post, rtol=1e-5, atol=1e-5) + + trt_input = base.clone() + compiled = torch_tensorrt.compile( + Model(), + inputs=[trt_input.clone()], + ir="dynamo", + min_block_size=1, + immutable_weights=True, + ) + + # Guard against a silent fallback to pure-PyTorch: the eager op + # already mutates the input, so output-only checks pass even when no + # TRT engine was built. + from torch_tensorrt.dynamo.runtime import ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + ) + + engine_submodules = [ + m + for _, m in compiled.named_modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + self.assertGreaterEqual( + len(engine_submodules), + 1, + f"Expected at least one TRT engine submodule, got graph:\n{compiled.graph}", + ) + + trt_out = compiled(trt_input) + torch.testing.assert_close(trt_out, expected_post, rtol=1e-5, atol=1e-5) + # Aliased plugin I/O is only active if the engine mutated trt_input; + # a fresh-output engine would leave it at its pre-call values. + torch.testing.assert_close(trt_input, expected_post, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_consumed.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_consumed.py new file mode 100644 index 0000000000..d441bdb467 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_consumed.py @@ -0,0 +1,87 @@ +"""Single-output aliased plugin whose output is consumed by another TRT layer. + +This is the realistic production pattern (e.g. a KV-cache update whose +post-update tensor is read by an attention layer in the same engine). +""" + +import platform +import unittest + +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests + +import torch_tensorrt + + +@torch.library.custom_op( + "torchtrt_ex::add_one_inplace_consumed", mutates_args=("X",) +) # type: ignore[misc] +def add_one_inplace_consumed(X: torch.Tensor) -> torch.Tensor: + assert X.is_cuda + X.add_(1) + return X.clone() + + +@torch.library.register_fake("torchtrt_ex::add_one_inplace_consumed") +def _(X: torch.Tensor) -> torch.Tensor: + return X + + +if torch_tensorrt.ENABLED_FEATURES.qdp_plugin: + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::add_one_inplace_consumed", supports_dynamic_shapes=False + ) + + +@unittest.skipIf( + platform.system() == "Windows", + "QDP in-place test requires Linux", +) +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.qdp_plugin, + "QDP Plugin is not available", +) +class TestInplacePluginConsumed(unittest.TestCase): + def test_aliased_output_consumed_downstream(self): + class Model(nn.Module): + def forward(self, x): + a = torch.ops.torchtrt_ex.add_one_inplace_consumed.default(x) + return a * 2 + + x_base = torch.randn(64, 64, device="cuda", dtype=torch.float) + expected_post = x_base + 1 + expected = expected_post * 2 + + x_trt = x_base.clone() + compiled = torch_tensorrt.compile( + Model(), + inputs=[x_trt.clone()], + ir="dynamo", + min_block_size=1, + immutable_weights=True, + ) + + from torch_tensorrt.dynamo.runtime import ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + ) + + engine_submodules = [ + m + for _, m in compiled.named_modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + self.assertGreaterEqual( + len(engine_submodules), + 1, + f"Expected at least one TRT engine submodule, got graph:\n{compiled.graph}", + ) + + result = compiled(x_trt) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(x_trt, expected_post, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_dynamic.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_dynamic.py new file mode 100644 index 0000000000..2b3b04886e --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_dynamic.py @@ -0,0 +1,90 @@ +"""Aliased plugin I/O combined with dynamic shapes — the production case for +KV-cache-style ops where the cache tensor's batch dim varies.""" + +import platform +import unittest + +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests + +import torch_tensorrt + + +@torch.library.custom_op( + "torchtrt_ex::add_one_inplace_dyn", mutates_args=("X",) +) # type: ignore[misc] +def add_one_inplace_dyn(X: torch.Tensor) -> torch.Tensor: + assert X.is_cuda + X.add_(1) + return X.clone() + + +@torch.library.register_fake("torchtrt_ex::add_one_inplace_dyn") +def _(X: torch.Tensor) -> torch.Tensor: + return X + + +if torch_tensorrt.ENABLED_FEATURES.qdp_plugin: + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::add_one_inplace_dyn", supports_dynamic_shapes=True + ) + + +@unittest.skipIf( + platform.system() == "Windows", + "QDP in-place test requires Linux", +) +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.qdp_plugin, + "QDP Plugin is not available", +) +class TestInplacePluginDynamicShapes(unittest.TestCase): + def test_dynamic_batch(self): + class Model(nn.Module): + def forward(self, x): + return torch.ops.torchtrt_ex.add_one_inplace_dyn.default(x) + + compile_input = torch.randn(8, 32, device="cuda", dtype=torch.float) + compiled = torch_tensorrt.compile( + Model(), + inputs=[ + torch_tensorrt.Input( + min_shape=(1, 32), + opt_shape=(8, 32), + max_shape=(16, 32), + dtype=torch.float, + ) + ], + ir="dynamo", + min_block_size=1, + immutable_weights=True, + ) + + from torch_tensorrt.dynamo.runtime import ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + ) + + engine_submodules = [ + m + for _, m in compiled.named_modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + self.assertGreaterEqual( + len(engine_submodules), + 1, + f"Expected at least one TRT engine submodule, got graph:\n{compiled.graph}", + ) + + for batch in (1, 4, 16): + base = torch.randn(batch, 32, device="cuda", dtype=torch.float) + expected = base + 1 + trt_input = base.clone() + trt_out = compiled(trt_input) + torch.testing.assert_close(trt_out, expected, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(trt_input, expected, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_multi.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_multi.py new file mode 100644 index 0000000000..b3e1359186 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin_inplace_multi.py @@ -0,0 +1,106 @@ +"""Multi-input partial-mutation + multi-output coverage. + +Exercises the un-functionalize pass's multi-output branch, the alias-map +build in ``_generate_plugin._generic_plugin_desc`` for ops where only one +input is mutated, and the JIT impl's aliased-output ``copy_`` filter. + +Only the *fresh* output is returned by the model. The aliased output is +unused. This is deliberate: TRT's preview-feature ``ALIASED_PLUGIN_IO_10_03`` +inserts a defensive copy that breaks aliasing when a multi-output plugin's +aliased output is consumed by another TRT layer in the same engine. The +correctness-critical path the test covers is the multi-output plumbing +itself; coverage for "aliased output consumed downstream" is provided by +the single-output test (which TRT handles correctly). +""" + +import platform +import unittest +from typing import Tuple + +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests + +import torch_tensorrt + + +@torch.library.custom_op( + "torchtrt_ex::add_inplace_two_out", mutates_args=("X",) +) # type: ignore[misc] +def add_inplace_two_out( + X: torch.Tensor, Y: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + assert X.is_cuda and Y.is_cuda + X.add_(Y) + return X.clone(), X * 2 + + +@torch.library.register_fake("torchtrt_ex::add_inplace_two_out") +def _(X: torch.Tensor, Y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return X, torch.empty_like(X) + + +if torch_tensorrt.ENABLED_FEATURES.qdp_plugin: + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::add_inplace_two_out", supports_dynamic_shapes=False + ) + + +@unittest.skipIf( + platform.system() == "Windows", + "QDP in-place test requires Linux", +) +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.qdp_plugin, + "QDP Plugin is not available", +) +class TestMultiOutputInplacePlugin(unittest.TestCase): + def test_partial_mutation_fresh_output(self): + class Model(nn.Module): + def forward(self, x, y): + _, b = torch.ops.torchtrt_ex.add_inplace_two_out.default(x, y) + return b + + x_base = torch.randn(64, 64, device="cuda", dtype=torch.float) + y_base = torch.randn(64, 64, device="cuda", dtype=torch.float) + + x_eager = x_base.clone() + _, eager_b = add_inplace_two_out(x_eager, y_base.clone()) + expected_x = x_base + y_base + expected_b = expected_x * 2 + torch.testing.assert_close(x_eager, expected_x, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(eager_b, expected_b, rtol=1e-5, atol=1e-5) + + x_trt = x_base.clone() + compiled = torch_tensorrt.compile( + Model(), + inputs=[x_trt.clone(), y_base.clone()], + ir="dynamo", + min_block_size=1, + immutable_weights=True, + ) + + from torch_tensorrt.dynamo.runtime import ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + ) + + engine_submodules = [ + m + for _, m in compiled.named_modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + self.assertGreaterEqual( + len(engine_submodules), + 1, + f"Expected at least one TRT engine submodule, got graph:\n{compiled.graph}", + ) + + result = compiled(x_trt, y_base.clone()) + torch.testing.assert_close(result, expected_b, rtol=1e-5, atol=1e-5) + # X was mutated in place; Y was not. + torch.testing.assert_close(x_trt, expected_x, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + run_tests()