Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 157 additions & 11 deletions examples/dynamo/aot_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@
import tensorrt as trt
import tensorrt.plugin as trtp
import torch
import torch_tensorrt
import triton
import triton.language as tl

import torch_tensorrt

trt_logger = trt.Logger(trt.Logger.VERBOSE)


Expand All @@ -51,7 +52,9 @@


@triton.jit
def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
def add_one_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# AOT path requires (inputs, outputs, extra_args) order — swapping any
# two slots feeds the wrong value into the kernel and segfaults.
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
Expand All @@ -61,6 +64,18 @@ def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
tl.store(y_ptr + offsets, output, mask=mask)


@triton.jit
def add_one_inplace_kernel(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# Aliased-I/O variant: TRT routes a single pointer for the shared
# input/output buffer, so the kernel takes one pointer, not two.
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(x_ptr + offsets, x + 1, mask=mask)


# %%
# Step 2: Register the PyTorch op
# -----------------------------------------
Expand All @@ -77,7 +92,7 @@ def add_one(X: torch.Tensor) -> torch.Tensor:
Y = torch.empty_like(X)
BLOCK_SIZE = 256
grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)
add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE)
add_one_kernel[grid](X, Y, X.numel(), BLOCK_SIZE=BLOCK_SIZE)
return Y


Expand Down Expand Up @@ -148,8 +163,8 @@ def add_plugin_aot_impl(
fn=add_one_kernel,
signature={
"x_ptr": f"*{type_str}",
"n_elements": "i32",
"y_ptr": f"*{type_str}",
"n_elements": "i32",
},
constexprs={
"BLOCK_SIZE": block_size,
Expand All @@ -166,15 +181,13 @@ def add_plugin_aot_impl(
launch_params.block_x = compiled_kernel.metadata.num_warps * 32 # threads per block
launch_params.shared_mem = compiled_kernel.metadata.shared # bytes of shared mem

# ``extra_args`` are scalar arguments appended to the kernel's argument list at
# launch. Here ``n_elements`` is passed as a 32-bit symbolic integer so TRT
# evaluates it from the actual tensor size at runtime.
extra_args = trtp.SymIntExprs(1)
extra_args[0] = trtp.SymInt32(N)
extra_args = torch_tensorrt.dynamo.conversion.plugins.make_aot_extra_args(
[trtp.SymInt32(N)], compiled_kernel=compiled_kernel
)

return (
compiled_kernel.metadata.name, # kernel function name in PTX
compiled_kernel.asm["ptx"], # PTX source — embedded in TRT engine
compiled_kernel.metadata.name,
compiled_kernel.asm["ptx"],
launch_params,
extra_args,
)
Expand All @@ -201,6 +214,87 @@ def add_plugin_aot_impl(
)


# %%
# In-place variant: aliased plugin I/O
# -----------------------------------------
#
# Same kernel exposed as an in-place plugin: the engine mutates the input
# buffer directly via QDP ``TensorDesc.aliased()`` instead of allocating a
# separate output. Useful for KV-cache updates and similar patterns.
#
# Two signals together declare aliasing to the framework:
# * ``mutates_args=("X",)`` on the torch op
# * the registered fake returns ``X`` by identity
# The eager impl must mutate ``X`` itself and return a clone — ``torch.library``
# rejects returning an input by identity.


@torch.library.custom_op("my::add_one_inplace", mutates_args=("X",)) # type: ignore[misc]
def add_one_inplace(X: torch.Tensor) -> torch.Tensor:
assert X.is_cuda
BLOCK_SIZE = 256
grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)
add_one_inplace_kernel[grid](X, X.numel(), BLOCK_SIZE=BLOCK_SIZE)
return X.clone()


@torch.library.register_fake("my::add_one_inplace")
def _(X: torch.Tensor) -> torch.Tensor:
return X


@trtp.register("my::add_one_inplace")
def add_one_inplace_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
return X.aliased()


@trtp.aot_impl("my::add_one_inplace")
def add_one_inplace_aot_impl(
X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
) -> Tuple[
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
]:
type_str = "fp32" if X.dtype == trt.float32 else "fp16"

block_size = 256
src = triton.compiler.ASTSource(
fn=add_one_inplace_kernel,
signature={
"x_ptr": f"*{type_str}",
"n_elements": "i32",
},
constexprs={
"BLOCK_SIZE": block_size,
},
)
compiled_kernel = triton.compile(src)

N = X.shape_expr.numel()
launch_params = trtp.KernelLaunchParams()
launch_params.grid_x = trtp.cdiv(N, block_size)
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
launch_params.shared_mem = compiled_kernel.metadata.shared

extra_args = torch_tensorrt.dynamo.conversion.plugins.make_aot_extra_args(
[trtp.SymInt32(N)], compiled_kernel=compiled_kernel
)

return (
compiled_kernel.metadata.name,
compiled_kernel.asm["ptx"],
launch_params,
extra_args,
)


torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
"my::add_one_inplace",
supports_dynamic_shapes=False,
requires_output_allocator=False,
use_aot_if_available=True,
)


# %%
# Step 6: Compile and Run
# -----------------------------------------
Expand All @@ -219,6 +313,11 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
return res


class MyInplaceModel(torch.nn.Module):
def forward(self, X: torch.Tensor) -> torch.Tensor:
return torch.ops.my.add_one_inplace.default(X)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -246,3 +345,50 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
assert torch.allclose(res, my_model(m)), "Results do not match!"

print("Inference successful!")

# %%
# In-place plugin demo
# ---------------------
#
# In-place ops mutate their input, so eager and TRT must run on separate
# cloned buffers; otherwise each comparison double-applies the mutation.

print("\nIn-place plugin demo:")
inplace_model = MyInplaceModel().to("cuda").eval()
base = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
expected_post = base + 1

model_trt_inplace = torch_tensorrt.compile(
inplace_model,
inputs=[base.clone()],
min_block_size=1,
immutable_weights=True,
)

from torch_tensorrt.dynamo.runtime import (
PythonTorchTensorRTModule,
TorchTensorRTModule,
)

engine_submodules = [
m
for _, m in model_trt_inplace.named_modules()
if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule))
]
assert engine_submodules, (
"Expected a TRT engine submodule for the in-place plugin path; got a"
f" pure-PyTorch fallback. Graph:\n{model_trt_inplace.graph}"
)
print(f" TRT engine submodule(s) present: {len(engine_submodules)}")

with torch.no_grad():
trt_input = base.clone()
trt_out = model_trt_inplace(trt_input)
assert torch.allclose(trt_out, expected_post), "TRT output mismatch"
assert torch.allclose(
trt_input, expected_post
), "Engine did not mutate the input buffer — aliased plugin I/O is not active."

print(" Output matches expected post-mutation value.")
print(" Input buffer was mutated in place by the TRT engine.")
print("In-place inference successful!")
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_ConversionContext.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field

import torch

from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.types import TRTNetwork

Expand All @@ -23,6 +24,7 @@ class ConversionContext:
)
requires_output_allocator: bool = False
requires_native_multidevice: bool = False
requires_aliased_plugin_io: bool = False
weight_refit_map: dict[str, torch.Tensor] = field(default_factory=dict)
cpu_weights_reference_holder: list[torch.Tensor] = field(default_factory=list)

Expand Down
14 changes: 14 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._python_dispatch import _disable_current_modes

from torch_tensorrt import ENABLED_FEATURES
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import needs_refit
Expand Down Expand Up @@ -310,6 +311,19 @@ def _populate_trt_builder_config(
if self.compilation_settings.enable_weight_streaming:
builder_config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)

if self.ctx.requires_aliased_plugin_io:
aliased_io_feature = getattr(
trt.PreviewFeature, "ALIASED_PLUGIN_IO_10_03", None
)
if aliased_io_feature is None:
raise RuntimeError(
"An in-place QDP plugin declared aliased I/O, but this TensorRT"
" version does not expose PreviewFeature.ALIASED_PLUGIN_IO_10_03."
" TensorRT 10.3+ is required for aliased plugin I/O."
)
builder_config.set_preview_feature(aliased_io_feature, True)
_LOGGER.info("Enabling preview feature ALIASED_PLUGIN_IO_10_03")

if is_tensorrt_version_supported("10.8"):
TilingOptimizationLevel = {
"none": trt.TilingOptimizationLevel.NONE,
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from torch_tensorrt.dynamo.conversion.plugins._aot_utils import make_aot_extra_args
from torch_tensorrt.dynamo.conversion.plugins._custom_op import custom_op
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import (
Expand Down
37 changes: 37 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/plugins/_aot_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Helpers for writing AOT QDP plugins backed by Triton kernels."""

from typing import Any, Sequence


def _has_triton_scratch_params(compiled_kernel: Any) -> bool:
md = getattr(compiled_kernel, "metadata", None)
if md is None:
return False
return hasattr(md, "global_scratch_size") and hasattr(md, "profile_scratch_size")


def make_aot_extra_args(
user_args: Sequence[Any],
*,
compiled_kernel: Any = None,
) -> Any:
"""Build a ``trtp.SymIntExprs`` for an AOT plugin's ``extra_args`` return.

When ``compiled_kernel`` is a Triton-compiled kernel, four trailing
``SymInt32(0)`` are appended to cover the two ``.param .u64 .ptr`` slots
(``global_scratch``, ``profile_scratch``) that Triton >= 3.x always emits
in PTX even when their sizes are zero. TRT's AOT plugin path does not
plumb those slots through, so without padding ``enqueueV3`` reads stale
register state for them and segfaults on the first call.
"""
import tensorrt.plugin as trtp

pad = 4 if _has_triton_scratch_params(compiled_kernel) else 0
total = len(user_args) + pad
out = trtp.SymIntExprs(total)
for i, arg in enumerate(user_args):
out[i] = arg
zero = trtp.SymInt32(0)
for i in range(len(user_args), total):
out[i] = zero
return out
Loading
Loading