diff --git a/.gitignore b/.gitignore index f13b9871d4..6cf9e1aa4c 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,13 @@ py/torch_tensorrt/bin py/torch_tensorrt/BUILD py/torch_tensorrt/LICENSE py/torch_tensorrt/WORKSPACE +# Build copies these into the package dir for wheel packaging; sources of +# record live at the repo root (core/, csrc/). +py/torch_tensorrt/CMakeLists.txt +py/torch_tensorrt/README.md +py/torch_tensorrt/cmake/ +py/torch_tensorrt/core/ +py/torch_tensorrt/examples/ py/wheelhouse py/.eggs notebooks/.ipynb_checkpoints/ diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index bc95b93ab1..b4b8a1a4c6 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -69,7 +69,8 @@ TRTEngine::TRTEngine( bool hardware_compatible, bool requires_output_allocator, const std::string& serialized_metadata, - const ResourceAllocationStrategy resource_allocation_strategy) + const ResourceAllocationStrategy resource_allocation_strategy, + const std::unordered_map& aliased_io) : TRTEngine( "deserialized_trt", serialized_engine, @@ -80,7 +81,8 @@ TRTEngine::TRTEngine( hardware_compatible, requires_output_allocator, serialized_metadata, - resource_allocation_strategy) {} + resource_allocation_strategy, + aliased_io) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -95,7 +97,8 @@ TRTEngine::TRTEngine(std::vector serialized_info) serialized_info[SERIALIZED_METADATA_IDX], (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic - : ResourceAllocationStrategy::kStatic)) { + : ResourceAllocationStrategy::kStatic), + deserialize_aliased_io(serialized_info[ALIASED_IO_IDX])) { this->requires_native_multidevice = std::stoi(serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]); if (this->requires_native_multidevice) { LOG_INFO("Loaded distributed TRT engine (contains NCCL collectives); NCCL comm will be bound on first execution"); @@ -112,7 +115,8 @@ TRTEngine::TRTEngine( bool hardware_compatible, bool requires_output_allocator, const std::string& serialized_metadata, - const ResourceAllocationStrategy resource_allocation_strategy) { + const ResourceAllocationStrategy resource_allocation_strategy, + const std::unordered_map& aliased_io) { TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -269,6 +273,40 @@ TRTEngine::TRTEngine( num_io = std::make_pair(inputs_size, outputs); } + // Store the build-time aliased_io map as the starting point. Then reconcile + // against ICudaEngine::getAliasedInputTensor — that API is the source of + // truth for KV-cache-style aliasing (TRT-enforced via IKVCacheUpdateLayer) + // and may report aliases the build-time path didn't record (e.g. for + // engines built outside Torch-TensorRT). User-declared aliases (kind=kUser) + // are preserved as-is since TRT doesn't know about them. + this->aliased_io = aliased_io; + for (const auto& out_name : this->out_binding_names) { + // TRT returns nullptr / empty string for non-aliased outputs; any thrown + // exception is a real error in the engine state and propagates. + const char* aliased_in = cuda_engine->getAliasedInputTensor(out_name.c_str()); + if (aliased_in == nullptr || aliased_in[0] == '\0') { + continue; + } + auto it = this->aliased_io.find(out_name); + if (it == this->aliased_io.end()) { + this->aliased_io[out_name] = AliasedIOSpec{std::string(aliased_in), AliasKind::kKVCacheUpdate}; + LOG_DEBUG("aliased_io reconciliation: discovered " << out_name << " -> " << aliased_in << " (kv_cache_update)"); + } else if (it->second.input_binding_name != std::string(aliased_in)) { + LOG_WARNING( + "aliased_io: build-time map disagrees with engine for output " + << out_name << " (build: " << it->second.input_binding_name << ", engine: " << aliased_in + << "); using engine value."); + it->second = AliasedIOSpec{std::string(aliased_in), AliasKind::kKVCacheUpdate}; + } + // Validation: aliased outputs must not also require an output allocator, + // since aliasing requires the output shape to match the input's static + // shape, which is incompatible with the dynamic-allocation path. + TORCHTRT_CHECK( + !this->requires_output_allocator, + "Aliased output " << out_name + << " is incompatible with dynamic output allocator. Aliasing requires fixed output shape."); + } + #ifndef NDEBUG this->enable_profiling(); #endif @@ -505,7 +543,8 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]), - std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX])); + std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]), + std::tuple("aliased_io", serialized_info[ALIASED_IO_IDX])); } std::vector TRTEngine::serialize() { @@ -531,6 +570,7 @@ std::vector TRTEngine::serialize() { serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = this->requires_native_multidevice ? "1" : "0"; + serialized_info[ALIASED_IO_IDX] = serialize_aliased_io(this->aliased_io); // rank/world_size are runtime facts (may differ at load time); not serialized. return serialized_info; diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index d851cda07e..ebc4ca09f6 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "ATen/core/function_schema.h" @@ -33,6 +34,40 @@ namespace torch_tensorrt { namespace core { namespace runtime { +// Origin of an aliased input/output binding pair. KV_CACHE_UPDATE is enforced +// by TensorRT itself (via IKVCacheUpdateLayer; reported through +// ICudaEngine::getAliasedInputTensor); USER is declared by the Torch-TensorRT +// compile flow (TRT doesn't know about it; runtime validates and binds). +enum class AliasKind : int8_t { + kKVCacheUpdate = 0, + kUser = 1, +}; + +struct AliasedIOSpec { + std::string input_binding_name; + AliasKind kind; +}; + +inline std::string alias_kind_to_string(AliasKind k) { + switch (k) { + case AliasKind::kKVCacheUpdate: + return "kv_cache_update"; + case AliasKind::kUser: + return "user"; + } + return "unknown"; +} + +inline AliasKind alias_kind_from_string(const std::string& s) { + if (s == "kv_cache_update") + return AliasKind::kKVCacheUpdate; + if (s == "user") + return AliasKind::kUser; + // Unknown kinds are conservatively treated as KV-cache-update — TRT enforces + // those without extra runtime work, so worst case the runtime silently no-ops. + return AliasKind::kKVCacheUpdate; +} + using FlattenedState = std::tuple< std::tuple, // ABI_VERSION std::tuple, // name @@ -45,7 +80,8 @@ using FlattenedState = std::tuple< std::tuple, // serialized metadata std::tuple, // Platform std::tuple, // Resource Allocation Strategy - std::tuple>; // requires_native_multidevice + std::tuple, // requires_native_multidevice + std::tuple>; // aliased_io struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -133,6 +169,14 @@ struct TRTEngine : torch::CustomClassHolder { std::vector in_binding_names = {}; // ITO: PYT IDX std::vector out_binding_names = {}; // ITO: PYT IDX + // For each output binding name that aliases an input binding, the alias spec. + // Populated either by build-time conversion records (forwarded from + // TRTInterpreterResult) or by reconciliation against the engine's own + // ICudaEngine::getAliasedInputTensor at construction time. The runtime + // consults this map in the output-binding loop to skip allocation and bind + // the same device pointer as the source input. + std::unordered_map aliased_io = {}; + bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used // in compilation @@ -149,7 +193,8 @@ struct TRTEngine : torch::CustomClassHolder { bool requires_output_allocator = false, const std::string& serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = - TRTEngine::ResourceAllocationStrategy::kStatic); + TRTEngine::ResourceAllocationStrategy::kStatic, + const std::unordered_map& aliased_io = {}); TRTEngine(std::vector serialized_info); @@ -164,7 +209,8 @@ struct TRTEngine : torch::CustomClassHolder { bool requires_output_allocator = false, const std::string& serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = - TRTEngine::ResourceAllocationStrategy::kStatic); + TRTEngine::ResourceAllocationStrategy::kStatic, + const std::unordered_map& aliased_io = {}); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index ffefa2c742..177ec0f3c6 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -97,7 +97,8 @@ void setup_input_tensors( c10::intrusive_ptr compiled_engine, bool cudagraphs_enabled, bool need_cudagraphs_record, - std::list>& inputShapeTensorValues) { + std::list>& inputShapeTensorValues, + std::unordered_map& bound_inputs_by_name) { std::list formatted_inputs(compiled_engine->num_io.first); for (size_t i = 0; i < inputs.size(); i++) { @@ -141,7 +142,22 @@ void setup_input_tensors( at::Tensor contig_input = inputs[i].view(shape).contiguous(); formatted_inputs.emplace_back(std::move(contig_input)); - if (need_cudagraphs_record) { + // An aliased input is one whose data_ptr will also be the address of an + // aliased output binding. Cudagraphs normally clone inputs into a + // persistent buffer so addresses are stable across replays; for an + // aliased input we deliberately bind to the user's tensor instead, so + // the engine writes through to the user's storage. The user is already + // required to pass stable input addresses under cudagraphs, so the + // aliasing contract is compatible. + bool is_aliased_input = false; + for (const auto& kv : compiled_engine->aliased_io) { + if (kv.second.input_binding_name == name) { + is_aliased_input = true; + break; + } + } + + if (need_cudagraphs_record && !is_aliased_input) { // Create a new persistent input buffer compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone()); } @@ -150,12 +166,12 @@ void setup_input_tensors( compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); at::Tensor final_input; - if (cudagraphs_enabled) { + if (cudagraphs_enabled && !is_aliased_input) { // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true); final_input = compiled_engine->input_buffers[i]; } else { - // Otherwise use the formatted buffer directly + // Aliased inputs OR non-cudagraphs path: use the user's tensor directly. final_input = formatted_inputs.back(); } @@ -167,11 +183,17 @@ void setup_input_tensors( TORCHTRT_CHECK( compiled_engine->exec_ctx->setTensorAddress(name.c_str(), input_addr), "Failed to bind tensor address for " << name); + + // Record the bound tensor by binding name so the output-binding loop + // can resolve aliased outputs to their source input's storage. + bound_inputs_by_name[name] = final_input; } } } -std::vector create_output_tensors(c10::intrusive_ptr compiled_engine) { +std::vector create_output_tensors( + c10::intrusive_ptr compiled_engine, + const std::unordered_map& bound_inputs_by_name) { std::vector outputs(compiled_engine->num_io.second); for (auto output_indices : compiled_engine->out_binding_map) { // out_binding_map stores TRT_IDX: PYT_IDX @@ -183,6 +205,30 @@ std::vector create_output_tensors(c10::intrusive_ptr comp auto dims = core::util::toVec(out_shape); auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); + + // Aliased outputs share storage with a source input binding. Don't + // allocate; reuse the input tensor by identity. The wrapping Python + // module is responsible for excluding aliased outputs from the + // user-facing return tuple. + auto alias_it = compiled_engine->aliased_io.find(name); + if (alias_it != compiled_engine->aliased_io.end()) { + auto in_it = bound_inputs_by_name.find(alias_it->second.input_binding_name); + TORCHTRT_CHECK( + in_it != bound_inputs_by_name.end(), + "Aliased output " << name << " references input binding " << alias_it->second.input_binding_name + << " but that input was not bound during this call."); + const auto& aliased_input = in_it->second; + TORCHTRT_CHECK( + aliased_input.sizes() == c10::IntArrayRef(dims), + "Aliased output " << name << " shape (" << dims << ") does not match source input " + << alias_it->second.input_binding_name << " shape (" << aliased_input.sizes() << ")"); + outputs[pyt_idx] = aliased_input; + LOG_DEBUG( + "Aliased output " << name << " (kind=" << alias_kind_to_string(alias_it->second.kind) << ") bound to input " + << alias_it->second.input_binding_name << " — skipping fresh allocation"); + continue; + } + outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous()); } @@ -255,6 +301,13 @@ std::vector execute_engine(std::vector inputs, c10::intr // Shape tensor CPU buffers must outlive inferShapes() and enqueueV3() std::list> inputShapeTensorValues; + // Bound input tensors keyed by binding name. Populated by setup_input_tensors + // and consumed by create_output_tensors / the output binding loop to alias + // outputs to their source-input device pointers (no fresh allocation, no + // post-engine copy). The map's tensor refs keep the storage alive for the + // duration of the engine call. + std::unordered_map bound_inputs_by_name; + // Intialize inputs and outputs to be available throughout the succeeding scopes { // Input Setup std::unique_ptr input_profiler_guard; @@ -263,7 +316,13 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->input_profile_path); } - setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, inputShapeTensorValues); + setup_input_tensors( + inputs, + compiled_engine, + cudagraphs_enabled, + need_cudagraphs_record, + inputShapeTensorValues, + bound_inputs_by_name); // Check if input shapes can be inferred. int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; std::vector names(io_size); @@ -284,12 +343,30 @@ std::vector execute_engine(std::vector inputs, c10::intr if (can_use_pre_allocated_outputs) { outputs = compiled_engine->pre_allocated_outputs; } else { - outputs = create_output_tensors(compiled_engine); + outputs = create_output_tensors(compiled_engine, bound_inputs_by_name); } for (auto output_indices : compiled_engine->out_binding_map) { auto pyt_idx = output_indices.second; std::string name = compiled_engine->out_binding_names[pyt_idx]; + + // Aliased outputs share storage with a source input. We bind directly + // to the input's data_ptr and intentionally bypass the cudagraphs + // persistent-output-buffer path: there is no separate buffer to keep + // in sync, and copying into a persistent buffer would defeat the + // aliasing. + auto alias_it = compiled_engine->aliased_io.find(name); + if (alias_it != compiled_engine->aliased_io.end()) { + auto in_it = bound_inputs_by_name.find(alias_it->second.input_binding_name); + TORCHTRT_CHECK( + in_it != bound_inputs_by_name.end(), + "Aliased output " << name << " references unbound input " << alias_it->second.input_binding_name); + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), in_it->second.data_ptr()), + "Failed to bind aliased output " << name << " to input " << alias_it->second.input_binding_name); + continue; + } + if (need_cudagraphs_record) { // If we are recording the cuda graph then we need to update the persistent output buffer compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); @@ -357,11 +434,19 @@ std::vector execute_engine(std::vector inputs, c10::intr } // End engine exeuction (resets to caller stream) // When the pre-allocated output mode is turned on, for intermediate modules, we only create the output in the first - // execution or when shape is changed. - if (compiled_engine->use_pre_allocated_outputs && + // execution or when shape is changed. If the engine has aliased outputs we + // disable pre-allocation entirely: aliased outputs share storage with + // user-supplied inputs that may change between calls, so caching the + // tensor reference would lead to writes against stale storage. + if (compiled_engine->use_pre_allocated_outputs && !compiled_engine->aliased_io.empty()) { + LOG_DEBUG( + "Skipping pre_allocated_outputs cache because engine has aliased I/O; " + "aliased outputs reuse the user's input storage on every call."); + } else if ( + compiled_engine->use_pre_allocated_outputs && (compiled_engine->pre_allocated_outputs.size() == 0 || compiled_engine->output_tensors_are_unowned || shape_changed)) { - compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); + compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine, bound_inputs_by_name); } // Block caller stream until engine execution is complete @@ -370,8 +455,17 @@ std::vector execute_engine(std::vector inputs, c10::intr trt_exec_complete.block(compiled_engine->caller_stream); if (cudagraphs_enabled) { - // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) + // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream). + // Aliased outputs are skipped: the engine wrote directly into the user's + // input storage (we bound the aliased output binding to the user's + // tensor data_ptr in create_output_tensors / the output-binding loop), + // so no copy-back is needed AND output_buffers[o] is uninitialized for + // aliased indices. for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { + const auto& name = compiled_engine->out_binding_names[o]; + if (compiled_engine->aliased_io.find(name) != compiled_engine->aliased_io.end()) { + continue; + } outputs[o].copy_(compiled_engine->output_buffers[o], false); } } @@ -389,6 +483,10 @@ std::vector execute_engine(std::vector inputs, c10::intr // Shape tensor CPU buffers must outlive inferShapes() and enqueueV3() std::list> inputShapeTensorValues; + // Discard map: the output-allocator path is incompatible with aliased I/O + // (validated at engine construction). The bound-inputs map is unused here. + std::unordered_map bound_inputs_by_name; + { // Input Setup std::unique_ptr input_profiler_guard; if (compiled_engine->profile_execution) { @@ -396,7 +494,7 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->input_profile_path); } - setup_input_tensors(inputs, compiled_engine, false, false, inputShapeTensorValues); + setup_input_tensors(inputs, compiled_engine, false, false, inputShapeTensorValues, bound_inputs_by_name); // Check if input shapes can be inferred. int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; std::vector names(io_size); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 268609a03f..af8cb3998d 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -139,6 +139,7 @@ TORCH_LIBRARY(tensorrt, m) { m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; }); m.def("REQUIRES_NATIVE_MULTIDEVICE_IDX", []() -> int64_t { return REQUIRES_NATIVE_MULTIDEVICE_IDX; }); + m.def("ALIASED_IO_IDX", []() -> int64_t { return ALIASED_IO_IDX; }); m.def("NATIVE_TRT_COLLECTIVES_AVAIL", []() -> bool { #ifdef ENABLE_TRT_NCCL_COLLECTIVES return true; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 59e861f79b..6488632c1d 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "ATen/core/function_schema.h" #include "NvInfer.h" @@ -17,8 +18,10 @@ namespace core { namespace runtime { using EngineID = int64_t; -const std::string ABI_VERSION = "9"; +const std::string ABI_VERSION = "10"; extern bool MULTI_DEVICE_SAFE_MODE; +// AliasKind, AliasedIOSpec, and the alias_kind_(to|from)_string helpers are +// declared in core/runtime/TRTEngine.h since runtime.h includes that header. typedef enum { STANDARD = 0, @@ -41,6 +44,7 @@ typedef enum { REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, REQUIRES_NATIVE_MULTIDEVICE_IDX, + ALIASED_IO_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; @@ -63,6 +67,12 @@ std::string base64_encode(const std::string& in); std::string base64_decode(const std::string& in); std::string serialize_bindings(const std::vector& bindings); +// Encode/decode the aliased_io map. Records are separated by BINDING_DELIM +// ('%') and each record is "output_name@input_name@kind" (the '@' avoids +// collision with TRT binding names which are alphanumeric + underscore). +std::string serialize_aliased_io(const std::unordered_map& aliased_io); +std::unordered_map deserialize_aliased_io(const std::string& s); + c10::optional get_most_compatible_device( const RTDevice& target_device, const RTDevice& curr_device = RTDevice(), diff --git a/core/runtime/runtime_utils.cpp b/core/runtime/runtime_utils.cpp index 9f67690873..db9937d569 100644 --- a/core/runtime/runtime_utils.cpp +++ b/core/runtime/runtime_utils.cpp @@ -37,6 +37,63 @@ std::string serialize_bindings(const std::vector& bindings) { return serialized_binding_info; } +// Aliased I/O wire format: +// record: "@@" +// joined: records separated by TRTEngine::BINDING_DELIM ('%') +// '@' is the intra-record field separator — TRT binding names are +// alphanumeric + underscore so '@' cannot collide with a binding name. +static const char ALIASED_IO_FIELD_DELIM = '@'; + +std::string serialize_aliased_io(const std::unordered_map& aliased_io) { + if (aliased_io.empty()) { + return ""; + } + std::stringstream ss; + bool first = true; + for (const auto& kv : aliased_io) { + if (!first) { + ss << TRTEngine::BINDING_DELIM; + } + first = false; + ss << kv.first << ALIASED_IO_FIELD_DELIM << kv.second.input_binding_name << ALIASED_IO_FIELD_DELIM + << alias_kind_to_string(kv.second.kind); + } + std::string out = ss.str(); + LOG_DEBUG("Serialized aliased_io: " << out); + return out; +} + +std::unordered_map deserialize_aliased_io(const std::string& s) { + std::unordered_map out; + if (s.empty()) { + return out; + } + size_t pos = 0; + while (pos < s.size()) { + size_t rec_end = s.find(TRTEngine::BINDING_DELIM, pos); + std::string rec = (rec_end == std::string::npos) ? s.substr(pos) : s.substr(pos, rec_end - pos); + + size_t f1 = rec.find(ALIASED_IO_FIELD_DELIM); + if (f1 == std::string::npos) { + LOG_WARNING("Skipping malformed aliased_io record (missing first field delim): " << rec); + } else { + size_t f2 = rec.find(ALIASED_IO_FIELD_DELIM, f1 + 1); + if (f2 == std::string::npos) { + LOG_WARNING("Skipping malformed aliased_io record (missing second field delim): " << rec); + } else { + std::string out_name = rec.substr(0, f1); + std::string in_name = rec.substr(f1 + 1, f2 - f1 - 1); + std::string kind_str = rec.substr(f2 + 1); + out[out_name] = AliasedIOSpec{in_name, alias_kind_from_string(kind_str)}; + } + } + if (rec_end == std::string::npos) + break; + pos = rec_end + 1; + } + return out; +} + // Base64 alphabet (RFC 4648 §4) static const std::string sym_table = // NOLINT(cert-err58-cpp) "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; //= diff --git a/docsrc/contributing.rst b/docsrc/contributing.rst index 9d7a317670..f46656020e 100644 --- a/docsrc/contributing.rst +++ b/docsrc/contributing.rst @@ -27,3 +27,4 @@ understanding the system architecture, and managing resources. contributors/dynamic_memory_allocation contributors/autocast contributors/export_serialization + contributors/inplace_operations diff --git a/docsrc/contributors/inplace_operations.rst b/docsrc/contributors/inplace_operations.rst new file mode 100644 index 0000000000..80377e447c --- /dev/null +++ b/docsrc/contributors/inplace_operations.rst @@ -0,0 +1,505 @@ +.. _inplace_operations: + +In-Place Operation Support +============================ + +This document describes the design for native (non-plugin) in-place operator +support in Torch-TensorRT. The motivating workload is streaming inference with a +key/value cache (e.g. ZoomASR, autoregressive LLM decoding) where the cache is +updated each step and the round-trip copy between PyTorch input buffer and +TensorRT output buffer dominates per-step cost. + +Plugin-side aliasing (custom ops via +``PreviewFeature.ALIASED_PLUGIN_IO_10_03``) is intentionally out of scope here. +This design covers only built-in TensorRT operators and the runtime path for +declaring that an engine output shares its buffer with one of its inputs. + +Implementation Status +^^^^^^^^^^^^^^^^^^^^^^ + +* **Implemented and verified end-to-end (C++ runtime path)** + + * The ``aten.slice_scatter.default`` decomposition is disabled; a converter + handles it directly, emitting ``IKVCacheUpdateLayer`` when the cache is a + direct network input with a 4-D static shape and write on dim 2, + otherwise falling back to a scatter sequence. + * ``aliased_io`` (mapping ``output_binding -> (input_binding, kind)``) is + plumbed from the converter through ``TRTInterpreterResult``, + ``SerializedInterpreterResult``, the Python wrapper, and the serialized + engine blob (new ``ALIASED_IO_IDX`` at ABI v10). + * The C++ runtime (``execute_engine``) honors the map: aliased outputs + skip ``at::empty`` allocation, bind to the source input's ``data_ptr``, + and are filtered from the user-facing return tuple. Pre-allocated + outputs are disabled when aliased I/O is present. + * ``TRTEngine`` constructor reconciles its build-time map against + ``ICudaEngine::getAliasedInputTensor`` so the TRT API is the source of + truth even for engines built outside Torch-TensorRT. + * Streaming use case (``user passes the same cache each step``) works + end-to-end: identity and ``data_ptr`` of the user's cache tensor are + preserved across repeated calls; the engine writes in place. + +* **Also implemented**: the ``BUFFER`` / ``BUFFER_MUTATION`` flow. A new + ``lift_mutated_buffers`` pre-compile pass detects the trailing + ``aten.copy_(get_attr, value)`` that ``ExportedProgram.module()`` + generates for a BUFFER_MUTATION, converts the ``get_attr`` to a + ``placeholder``, and removes the trailing ``copy_``. The buffer becomes + an engine input binding so the KV-cache fast path can fire. The compiled + module is wrapped in ``BufferThreadingModule``, which owns the buffers + as module state and threads them into each forward call. With aliased + I/O the engine writes through the binding into the buffer's storage and + the buffer state persists across calls — the user just calls + ``module(x)`` without managing the cache. + +Motivation +----------- + +PyTorch's ``torch.export`` runs functionalization during decomposition. By the +time the FX graph reaches ``TRTInterpreter`` every in-place operation has been +rewritten to a functional equivalent followed by a ``copy_``:: + + x.add_(y) → x_new = aten.add(x, y); x.copy_(x_new) + cache.scatter_(...) → cache_new = aten.scatter(cache, ...); + cache.copy_(cache_new) + +The compiled engine therefore produces a fresh output tensor and the wrapping +module copies it back into the user's input buffer. For workloads where the +mutated tensor is the dominant tensor (KV cache, ring buffers in streaming ASR) +this is a meaningful loss of bandwidth and an unnecessary allocation. + +TensorRT exposes two relevant primitives: + +* ``IKVCacheUpdateLayer`` — a built-in layer that performs a scatter into a + static-shape cache and is automatically aliased to its cache input. The + output binding shares device memory with the cache input. +* ``ICudaEngine.getAliasedInputTensor(output_name)`` — a runtime API that + returns the input binding name an output is aliased with, or ``nullptr`` if + the output is not aliased. + +This design wires those primitives through the Torch-TensorRT pipeline so the +user can declare "this input is mutated" and the compiled engine will update it +in place, with no post-engine copy. + +Background: TensorRT Primitives +-------------------------------- + +KVCacheUpdate Operator +^^^^^^^^^^^^^^^^^^^^^^^ + +The ``IKVCacheUpdateLayer`` performs ``output[i, :, writeIndices[i] + s, :] = +update[i, :, s, :]`` and aliases ``output`` to ``cache``. Inputs: + +* ``cache`` — shape ``[b, d, s_max, h]``, network input, static ``s_max``. +* ``update`` — shape ``[b, d, s, h]`` with ``s ≤ s_max``. +* ``writeIndices`` — shape ``[b]``, ``int32`` or ``int64``, satisfying + ``writeIndices[i] + s <= s_max``. + +The output is the updated cache, which must be a network output and shares +memory with the cache input. K and V are independent layers. DLA is not +supported. The maximum sequence length must be static; dynamic ``s_max`` is not +permitted. + +Aliased I/O Query API +^^^^^^^^^^^^^^^^^^^^^^ + +At runtime, ``engine->getAliasedInputTensor(out_name)`` returns the name of the +input binding aliased with the given output, or ``nullptr``. This is the source +of truth for the runtime: regardless of how aliasing was established (via +``IKVCacheUpdateLayer`` or any future API) the engine reports the +post-build wiring through this single call. + +Design Overview +---------------- + +The work is layered into two tiers. + +Tier A — KVCacheUpdate Fast Path +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A pattern-match lowering pass detects scatter-into-static-cache patterns and +rewrites them to a marker op ``torch_trt.kv_cache_update``. The converter for +that marker emits ``IKVCacheUpdateLayer``. Aliasing is automatic on the +TensorRT side; the runtime sees it through ``getAliasedInputTensor``. + +Tier B — General Input/Output Aliasing +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For mutated inputs that do not match the KV-cache pattern, the user supplies a +``mutated_inputs`` argument to :func:`torch_tensorrt.dynamo.compile`. The +converter records the pairing and the runtime treats the corresponding output +binding as aliased to the input binding — same device pointer, no fresh +allocation, and the input tensor is returned by identity. + +Whether Tier B can also use a public TensorRT network-build API to declare +non-KV aliasing (without involving plugins) is an open question (see +:ref:`open_questions`). If TensorRT 10.x does not expose such an API, Tier B +collapses to Tier A and ships only KV-cache support; the rest of this design +still applies in that case. + +User-Facing API +---------------- + +The ``compile()`` entry point gains one new keyword argument: + +.. code-block:: python + + compiled = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[cache_k, cache_v, x], + mutated_inputs={"cache_k": "cache_k_out", + "cache_v": "cache_v_out"}, + ) + + # cache_k and cache_v are mutated in place; out[2] is the real output. + out = compiled(cache_k, cache_v, x) + assert out[0] is cache_k + assert out[1] is cache_v + +``mutated_inputs`` maps an input binding name (or index) to the output binding +name (or index) that should alias it. When the compiled module is called, the +aliased output positions in the returned tuple contain the *same* ``at::Tensor`` +that was passed in — same storage, same ``data_ptr()``, observably mutated. +This mirrors PyTorch's in-place op convention in which ``x.add_(y)`` returns +``x``. + +Implementation Phases +---------------------- + +Lowering +^^^^^^^^^ + +A new post-lowering pass ``mark_aliased_outputs`` runs after +``remove_input_alias_fixing_clones``. For each entry in ``mutated_inputs``: + +1. Resolve the input ``placeholder`` node and the corresponding output node. +2. Tag the output node with ``node.meta["aliased_input"] = ""``. + +The metadata travels through to ``TRTInterpreter`` and is the carrier for +aliasing intent. + +A second lowering sub-pass detects KV-cache-shaped patterns among the +aliased-output nodes: + +* Scatter or ``index_put`` into a tensor of shape ``[b, d, s_max, h]``. +* ``writeIndices`` is a ``[b]``-shaped integer tensor. +* ``s_max`` is static. + +Matching nodes are replaced with a ``torch_trt.kv_cache_update`` marker so the +converter can emit ``IKVCacheUpdateLayer`` directly. Non-matching aliased +outputs fall through to Tier B. + +Conversion +^^^^^^^^^^^ + +Two touch points in ``py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py``: + +1. **New converter** for ``torch_trt.kv_cache_update`` which calls + ``net.add_kv_cache_update(cache, update, write_indices)``. The cache input + must already be added via ``net.add_input``; converters look up its + ``ITensor`` handle via the conversion context. + +2. **``output()`` (around line 837)** — when marking each output, check + ``node.meta["aliased_input"]``. If present and the output came from + ``IKVCacheUpdateLayer``, no extra call is needed (TRT aliases automatically). + For Tier B (if the public API exists) call the corresponding TRT method + before ``mark_output``. + +The interpreter records the alias map on its result type: + +.. code-block:: cpp + + struct TRTInterpreterResult { + // ...existing fields + std::unordered_map aliased_io; // out_name → in_name + }; + +This map is the bridge from build-time intent to the runtime engine object. + +C++ Runtime Changes +^^^^^^^^^^^^^^^^^^^^ + +All runtime work lives under ``core/runtime/``. The Python runtime +(``PythonTorchTensorRTModule``) is not modified by this design. + +``TRTEngine`` (``core/runtime/TRTEngine.h``) + Add one field:: + + std::unordered_map aliased_io; // out → in + + Populated from two sources: + + 1. The serialized engine metadata (see :ref:`serialization_format`). + 2. **Source-of-truth reconciliation at deserialize time:** for every output + binding, query ``cuda_engine->getAliasedInputTensor(out_name)`` and merge + the result into ``aliased_io``. The TRT API is authoritative; the + serialized map is a cache that allows the runtime to avoid a per-call + query and lets engines built from external TensorRT plans (e.g. via + ``IKVCacheUpdateLayer`` from a non-Torch-TRT build flow) be loaded + transparently. + +``execute_engine`` (``core/runtime/execute_engine.cpp``) + Three narrow changes: + + 1. **Input binding** — record each contiguous input's ``data_ptr()`` + keyed by binding name into a local ``input_addrs`` map. The existing + ``setTensorAddress`` call is unchanged. + + 2. **Output binding (~line 188)** — branch: + + .. code-block:: cpp + + if (auto it = engine.aliased_io.find(out_name); + it != engine.aliased_io.end()) { + // Aliased output: bind the same device ptr as its input, + // do NOT allocate, and return the input tensor by identity. + void* aliased_ptr = input_addrs.at(it->second); + ctx->setTensorAddress(out_name.c_str(), aliased_ptr); + output_tensors.push_back(input_tensors_by_name.at(it->second)); + } else { + auto out = at::empty(dims, options).contiguous(); + ctx->setTensorAddress(out_name.c_str(), out.data_ptr()); + output_tensors.push_back(out); + } + + 3. **Shape consistency check** — before binding, assert + ``dims == input_tensor.sizes()`` for aliased pairs. A mismatch + indicates compilation produced an output shape that differs from the + input it claims to alias; abort with a clear error rather than silently + corrupting memory. + +Output allocator interaction + The existing ``OutputAllocator`` path (used when + ``requires_output_allocator=true``) is incompatible with aliasing by + construction: aliasing requires the output's storage to match the input's, + while the output allocator exists precisely because the output shape is not + known ahead of time. ``TRTEngine`` construction validates that no binding + appears in both ``aliased_io`` (as an output) and the + ``requires_output_allocator`` set, throwing on construction if it does. + +Stream and synchronization + Unchanged. Aliased I/O is a pointer-identity trick, not a synchronization + trick. The pre-/post-enqueue stream handling remains valid. + +CUDA Graph capture + Aliased I/O makes capture *more* deterministic — fewer allocations, stable + addresses. The user-supplied input tensor's ``data_ptr()`` must be stable + across replays; if the caller passes a different tensor each call, capture + is invalidated, the same as today for non-aliased inputs. + +.. _serialization_format: + +Serialization Format Update +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The C++ engine serialization format in ``core/runtime/runtime.h`` is a +positional ``std::vector`` indexed by an enum, gated by a string +``ABI_VERSION``. The change is additive: one new index, one new field, +one ABI bump. + +In ``core/runtime/runtime.h``: + +.. code-block:: cpp + + const std::string ABI_VERSION = "10"; // bumped from "9" + + typedef enum { + ABI_TARGET_IDX = 0, + // ...existing entries unchanged... + REQUIRES_NATIVE_MULTIDEVICE_IDX, + ALIASED_IO_IDX, // NEW + SERIALIZATION_LEN, + } SerializedInfoIndex; + + std::string serialize_aliased_io( + const std::unordered_map& aliased_io); + std::unordered_map deserialize_aliased_io( + const std::string& s); + +Encoding follows the same convention as ``serialize_bindings``: pairs joined +with a key/value delimiter, records joined with a record delimiter. No JSON, +no protobuf, no new dependencies. An empty map serializes to the empty +string. + +In ``core/runtime/TRTEngine.cpp``: + +* **Constructor (line 85)** delegates one extra deserialized argument to the + primary constructor:: + + deserialize_aliased_io(serialized_info[ALIASED_IO_IDX]) + +* **``serialize()`` (line 508)** writes one extra entry:: + + serialized_info[ALIASED_IO_IDX] = + serialize_aliased_io(this->aliased_io); + +* **``__obj_flatten__()`` (line 484)** gains a corresponding tuple so + Python introspection sees the new field:: + + std::tuple("aliased_io", serialized_info[ALIASED_IO_IDX]), + +* **``verify_serialization_fmt`` (line 471)** is unchanged. Its existing length + check (``size() == SERIALIZATION_LEN``) plus the ABI version check + collectively reject any pre-bump engine cleanly. + +Compatibility + Pre-bump engines (ABI ``"9"``) are rejected by ``verify_serialization_fmt`` + with the existing ABI mismatch error. Users recompile. This matches the + behavior of all prior ABI bumps; the format is intentionally not + forward/backward compatible across version changes. + + Engines with no aliased outputs serialize ``aliased_io`` to the empty + string and take the existing allocate-fresh-output path at runtime. Zero + behavioral change for unaffected users. + +.. _buffer-style: + +Buffer-Backed KV Cache +----------------------- + +PyTorch's canonical pattern for streaming inference is to hold the cache as a +module buffer: + +.. code-block:: python + + class StreamingKV(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache_k", torch.zeros(B, H, S_MAX, D)) + + def forward(self, x_k): + self.cache_k[:, :, t:t+1, :] = x_k + return attention(self.cache_k, x_k) + +``torch.export`` captures this cleanly: ``cache_k`` becomes a ``BUFFER`` +input in ``graph_signature.input_specs`` and the slice-write becomes a +``BUFFER_MUTATION`` output spec. + +``ExportedProgram.module()`` rewrites this into a ``GraphModule`` where the +buffer is read via ``get_attr`` and the mutation is emitted as a trailing +``aten.copy_(get_attr_buffer, slice_scatter_result)`` node. Left alone, +Torch-TensorRT's constant-folding pass would bake the buffer into the +engine as a weight; the slice-scatter would scatter against that constant +and produce a result the engine discards. + +To preserve buffer state across calls and let the KV-cache aliasing fast +path fire, a pre-compile pass ``lift_mutated_buffers`` runs after +``ep.module()`` and before ``post_lowering``: + +1. Scans for ``aten.copy_.default(get_attr, _)`` patterns — the marker for + a BUFFER_MUTATION. +2. For each match, converts the ``get_attr`` to a ``placeholder``, + redirects all uses, and erases the trailing ``copy_``. +3. Rebuilds the ``GraphModule`` with the default ``CodeGen`` so the + ``forward`` signature reflects the new placeholder set (the original + ``_PyTreeCodeGen`` would re-impose the original arity through a stored + pytree spec). + +The compiled result is wrapped in ``BufferThreadingModule``, which owns +the buffers as ``register_buffer`` state and threads them into the +underlying compiled module on each forward call. Combined with the +engine-level KV-cache aliasing, the engine writes directly into the +buffer's storage; the buffer is observably mutated and the next call +reads the updated state. The user-facing API is just ``module(x)``. + +Both the user-passed-cache pattern (caller owns the cache) and the +buffer-backed pattern (module owns the cache) work and produce identical +results. + +Constraints and Known Limitations +---------------------------------- + +Static cache shape (Tier A) + ``IKVCacheUpdateLayer`` requires static ``s_max``. Streaming ASR with a + fixed context window satisfies this; truly dynamic-length caches do not + and fall through to Tier B. + +Single-input aliasing + Each output binding aliases at most one input binding. There is no design + here for an output that aliases multiple inputs; that has no clear + semantics anyway. + +Tensor identity through PyTorch + When ``execute_engine`` returns the aliased input tensor as one of its + outputs, downstream code observes ``out is input`` as ``True``. Wrappers in + ``_TorchTensorRTModule.forward()`` that touch outputs (e.g. an unsuspecting + ``output.contiguous()``) must remain a no-op for already-contiguous + tensors; this is the existing PyTorch contract and the design relies on + it. + +DLA + ``IKVCacheUpdateLayer`` is not supported on DLA. ``mutated_inputs`` on a + DLA target raises at compile time. + +User responsibility for buffer reuse + Buffers passed in via ``mutated_inputs`` are mutated by the engine. Users + that need a pre-mutation snapshot must clone before calling. This is + consistent with PyTorch in-place op semantics. + +.. _open_questions: + +Open Questions +--------------- + +Tier B network-build API + ``IKVCacheUpdateLayer`` aliases its output automatically, which fully + covers Tier A. Whether TensorRT 10.x exposes a public network-build API + for declaring output↔input aliasing on arbitrary layers (without involving + the plugin path) is an open question. If such an API does not exist, Tier + B is deferred and the design ships Tier A only — the runtime, lowering, + and serialization changes described above remain valid; only the second + lowering sub-pass and its corresponding converter route are unbuilt. + +ZoomASR fit + The KVCacheUpdate constraints (static ``s_max``, K/V split, ``[b, d, s_max, + h]`` layout) need confirmation against the ZoomASR cache layout. If the + model uses a different memory layout (e.g. ``[b, s_max, d, h]``) a + ``permute`` may be required, and the cost of that permute may exceed the + savings from aliasing. A small benchmark on the actual model is the + gating criterion before investing in pattern-matching. + +Phased Rollout +--------------- + +1. **Phase 1 — Tier A only.** Lowering pass, converter for + ``IKVCacheUpdateLayer``, C++ runtime aliased-output path, + serialization-format bump. This is the smallest end-to-end surface that + solves the streaming-ASR / KV-cache use case. + +2. **Phase 2 — Tier B (general aliasing).** Only if the TensorRT public API + supports non-KV aliasing. Reuses the same runtime code path; new converter + route. No further runtime or serialization changes. + +3. **Phase 3 — Auto-detection.** Pattern-match common in-place residue + post-functionalization (e.g. ``copy_`` of a scattered tensor back into + the original input) so users do not need to specify ``mutated_inputs`` + for the common case. Ergonomics-only; behavior identical to Phase 1/2. + +Summary of Code Touch Points +----------------------------- + ++-------------------------------------------+--------------------------------------------------------------------+ +| File | Change | ++===========================================+====================================================================+ +| ``core/runtime/runtime.h`` | Bump ``ABI_VERSION`` to ``"10"``; add ``ALIASED_IO_IDX``; | +| | declare aliased-IO serialization helpers. | ++-------------------------------------------+--------------------------------------------------------------------+ +| ``core/runtime/runtime.cpp`` | Implement ``serialize_aliased_io`` / | +| | ``deserialize_aliased_io``. | ++-------------------------------------------+--------------------------------------------------------------------+ +| ``core/runtime/TRTEngine.h`` | Add ``aliased_io`` field; update constructor signatures. | ++-------------------------------------------+--------------------------------------------------------------------+ +| ``core/runtime/TRTEngine.cpp`` | Read/write ``ALIASED_IO_IDX``; populate from | +| | ``getAliasedInputTensor`` post-deserialize; add to | +| | ``__obj_flatten__``. | ++-------------------------------------------+--------------------------------------------------------------------+ +| ``core/runtime/execute_engine.cpp`` | Branch in output binding loop: skip ``at::empty`` for aliased | +| | outputs, bind input ``data_ptr``, return input tensor by identity. | ++-------------------------------------------+--------------------------------------------------------------------+ +| ``py/.../lowering/passes/`` | New ``mark_aliased_outputs`` pass; KV-cache pattern matcher. | ++-------------------------------------------+--------------------------------------------------------------------+ +| ``py/.../conversion/_TRTInterpreter.py`` | New converter for ``torch_trt.kv_cache_update``; aliased-output | +| | handling in ``output()``; ``aliased_io`` on | +| | ``TRTInterpreterResult``. | ++-------------------------------------------+--------------------------------------------------------------------+ +| ``py/.../dynamo/_compiler.py`` | New ``mutated_inputs`` argument; plumb to lowering pass and | +| | engine builder. | ++-------------------------------------------+--------------------------------------------------------------------+ diff --git a/examples/dynamo/aliased_io_buffers.py b/examples/dynamo/aliased_io_buffers.py new file mode 100644 index 0000000000..fb52ee769b --- /dev/null +++ b/examples/dynamo/aliased_io_buffers.py @@ -0,0 +1,134 @@ +""" +.. _aliased_io_buffers_example: + +In-place aliased I/O: module-owned buffers +================================================== + +This is the PyTorch-canonical pattern for streaming inference: the cache +lives inside the model via ``register_buffer``. The user simply calls +``model(x)`` — no need to thread the cache through manually. + +How it flows through the compile pipeline: + +* ``torch.export`` captures ``cache`` as a ``BUFFER`` input in the + ``graph_signature``; mutations to it become ``BUFFER_MUTATION`` + output specs. +* ``ExportedProgram.module()`` rewrites the buffer into a ``get_attr`` + node plus a trailing ``aten.copy_(get_attr_buffer, slice_scatter_result)`` + that represents the mutation. +* Torch-TensorRT's ``lift_mutated_buffers`` pre-compile pass detects that + trailing ``copy_``, lifts the ``get_attr`` to a ``placeholder``, and + rebuilds the GraphModule so the engine treats the buffer as a regular + input binding. +* The slice-scatter converter sees the cache as a network input and emits + ``IKVCacheUpdateLayer`` with aliased I/O. +* The compiled result's lifted-buffer placeholders are rewritten in place + to ``get_attr`` reads from registered ``nn.Module`` buffers (via + ``inline_lifted_buffers_into_gm``). The buffer state lives on the + compiled module itself; ``forward`` takes only user inputs. Because the + result is a plain ``fx.GraphModule`` with buffers, it serializes through + ``torch_tensorrt.save`` / ``torch.export`` without any external wrapper. + +The net effect: the engine writes through the buffer's storage in place, +and the next call sees the updated state. No copy-back, no allocation +per call. +""" + +# %% +# Imports +# ------- +import torch +import torch_tensorrt +from torch.export import export + +# %% +# Model with a buffer-backed KV cache +# ----------------------------------- +# Two caches (K and V), both held as buffers. Each forward call writes one +# timestep into position 3 of each cache. + + +class StreamingKV(torch.nn.Module): + def __init__(self, b=1, h=4, s_max=16, d=8): + super().__init__() + self.register_buffer("cache_k", torch.zeros(b, h, s_max, d)) + self.register_buffer("cache_v", torch.zeros(b, h, s_max, d)) + + def forward(self, x_k, x_v): + self.cache_k[:, :, 3:4, :] = x_k + self.cache_v[:, :, 3:4, :] = x_v + return self.cache_k.sum() + self.cache_v.sum() + + +# %% +# Compile +# ------- +model = StreamingKV().cuda() +x_k = torch.ones(1, 4, 1, 8, device="cuda") * 3.0 +x_v = torch.ones(1, 4, 1, 8, device="cuda") * 5.0 + +ep = export(model, (x_k.clone(), x_v.clone())) + +# Show how torch.export sees the model: +print("graph_signature.input_specs:") +for s in ep.graph_signature.input_specs: + print(f" {s}") +print("graph_signature.output_specs:") +for s in ep.graph_signature.output_specs: + print(f" {s}") + +compiled = torch_tensorrt.compile( + ep, + ir="dynamo", + inputs=[x_k.clone(), x_v.clone()], + enabled_precisions={torch.float32}, + min_block_size=1, + use_python_runtime=False, +) + + +# %% +# Verify aliasing was established for both caches +# ------------------------------------------------ +for _, mod in compiled.named_modules(): + if hasattr(mod, "aliased_io") and mod.aliased_io: + print(f"\nEngine input bindings: {list(mod.input_binding_names)}") + for out, (inp, kind) in mod.aliased_io.items(): + print(f" {out} <-aliased-> {inp} (kind={kind})") + + +# %% +# Run: the compiled module owns the cache buffers +# ------------------------------------------------ +# The user calls ``compiled(x_k, x_v)`` — same signature as the original +# model. The buffers are owned by the wrapping module and threaded into +# the engine automatically. After the call, ``compiled.cache_k`` and +# ``compiled.cache_v`` reflect the mutation. + +returned = compiled(x_k, x_v) +returned_val = returned[0] if isinstance(returned, tuple) else returned + +# Eager reference for comparison. +eager = StreamingKV().cuda() +eager_returned = eager(x_k.clone(), x_v.clone()) + +print(f"\nreturn matches eager: {torch.allclose(returned_val, eager_returned)}") +print(f"cache_k matches eager: {torch.allclose(compiled.cache_k, eager.cache_k)}") +print(f"cache_v matches eager: {torch.allclose(compiled.cache_v, eager.cache_v)}") + + +# %% +# Streaming: module-held state persists across calls +# --------------------------------------------------- +# Reset the cache and step through three updates. Each call mutates the +# module's buffer state in place; the next call sees the updated value. + +compiled.cache_k.zero_() +compiled.cache_v.zero_() +for step, val in enumerate([1.0, 5.0, 0.0]): + x = torch.ones(1, 4, 1, 8, device="cuda") * val + compiled(x, x) + print( + f"step {step}: cache_k.sum()={compiled.cache_k.sum().item():.1f}, " + f"cache_v.sum()={compiled.cache_v.sum().item():.1f}" + ) diff --git a/examples/dynamo/aliased_io_kv_attention.py b/examples/dynamo/aliased_io_kv_attention.py new file mode 100644 index 0000000000..5d2f8cd2d3 --- /dev/null +++ b/examples/dynamo/aliased_io_kv_attention.py @@ -0,0 +1,153 @@ +""" +.. _aliased_io_kv_attention_example: + +Streaming attention with a static KV cache +============================================ + +A realistic single-layer transformer attention block with a static-shape +KV cache held as module buffers. This is the canonical PyTorch pattern +for streaming decoder inference: at each step the module takes a single +token's hidden state, projects K and V, writes them into a fixed-size +cache, and attends over the cache. + +Compared to the simpler aliased-I/O examples this one exercises: + +* ``LayerNorm``, ``Linear`` projections, ``scaled_dot_product_attention`` +* multi-head reshapes / transposes around the cache writes +* the ``register_buffer`` + slice-write pattern for both K and V + +The compiled engine emits two ``IKVCacheUpdateLayer`` ops (one each for +K and V) with aliased outputs. The C++ runtime writes the new K/V +directly into the buffer storage; the next step's attention reads the +updated cache without any copy. +""" + +# %% +# Imports +# ------- +import torch +import torch.nn.functional as F +import torch_tensorrt +from torch.export import export + +# %% +# Single-layer attention block +# ---------------------------- +# The model takes one timestep at a time and uses a compile-time +# constant ``write_pos`` for the cache slot. In a real generation loop +# you'd vary ``write_pos`` per step; a few practical recipes: +# +# * Recompile once per ``write_pos`` (cheap with engine caching). +# * Bake ``write_pos`` into the model's state via a buffer and increment +# it inside ``forward`` (requires an extra in-place op pattern we +# support separately). +# * Use the lower-level ``user_inputs`` flow and pass ``write_pos`` as +# an integer argument to a wrapper that selects between pre-compiled +# engines. + + +class StaticKVAttention(torch.nn.Module): + def __init__(self, batch=1, max_seq=64, n_heads=4, head_dim=16, write_pos=3): + super().__init__() + self.batch = batch + self.n_heads = n_heads + self.head_dim = head_dim + self.hidden = n_heads * head_dim + self.write_pos = write_pos + + # Static KV cache — fixed shape across the whole generation. + self.register_buffer("cache_k", torch.zeros(batch, n_heads, max_seq, head_dim)) + self.register_buffer("cache_v", torch.zeros(batch, n_heads, max_seq, head_dim)) + + self.norm = torch.nn.LayerNorm(self.hidden) + self.q_proj = torch.nn.Linear(self.hidden, self.hidden, bias=False) + self.k_proj = torch.nn.Linear(self.hidden, self.hidden, bias=False) + self.v_proj = torch.nn.Linear(self.hidden, self.hidden, bias=False) + self.o_proj = torch.nn.Linear(self.hidden, self.hidden, bias=False) + + def forward(self, hidden_states): + # hidden_states: [B, 1, H] + B = hidden_states.shape[0] + h = self.norm(hidden_states) + + q = self.q_proj(h).view(B, 1, self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(h).view(B, 1, self.n_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(h).view(B, 1, self.n_heads, self.head_dim).transpose(1, 2) + + # In-place writes into the static KV cache. These two lines are + # what trigger the KV-cache aliasing fast path. + self.cache_k[:, :, self.write_pos : self.write_pos + 1, :] = k + self.cache_v[:, :, self.write_pos : self.write_pos + 1, :] = v + + attn_out = F.scaled_dot_product_attention(q, self.cache_k, self.cache_v) + attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, self.hidden) + return self.o_proj(attn_out) + + +# %% +# Compile +# ------- +torch.manual_seed(0) +model = StaticKVAttention().cuda() +hidden = torch.randn(1, 1, model.hidden, device="cuda") + +ep = export(model, (hidden.clone(),)) +compiled = torch_tensorrt.compile( + ep, + ir="dynamo", + inputs=[hidden.clone()], + enabled_precisions={torch.float32}, + min_block_size=1, + use_python_runtime=False, +) + + +# %% +# Inspect aliasing +# ---------------- +for _, mod in compiled.named_modules(): + if hasattr(mod, "aliased_io") and mod.aliased_io: + print(f"Engine input bindings: {list(mod.input_binding_names)}") + print(f"Engine output bindings: {list(mod.output_binding_names)}") + for out, (inp, kind) in mod.aliased_io.items(): + print(f" {out} <-aliased-> {inp} (kind={kind})") + + +# %% +# Numerical check against eager +# ------------------------------ +eager_model = StaticKVAttention().cuda() +eager_model.load_state_dict(model.state_dict()) +eager_out = eager_model(hidden.clone()) + +# Reset the compiled cache to match eager's fresh state +compiled.cache_k.zero_() +compiled.cache_v.zero_() +compiled_out = compiled(hidden.clone()) +compiled_val = compiled_out[0] if isinstance(compiled_out, tuple) else compiled_out + +print(f"\nmax output diff: {(compiled_val - eager_out).abs().max().item():.6f}") +print( + f"cache_k matches eager: {torch.allclose(compiled.cache_k, eager_model.cache_k, atol=1e-4)}" +) +print( + f"cache_v matches eager: {torch.allclose(compiled.cache_v, eager_model.cache_v, atol=1e-4)}" +) + + +# %% +# Streaming inference loop +# ------------------------ +# Each call writes a new K/V at the compiled ``write_pos`` slot. In a +# real decoder you'd rotate the write position per step; this example +# just demonstrates that the cache state persists. + +compiled.cache_k.zero_() +compiled.cache_v.zero_() +for step in range(3): + h = torch.randn(1, 1, model.hidden, device="cuda") + out = compiled(h) + print( + f"step {step}: cache_k.norm()={compiled.cache_k.norm().item():.4f}, " + f"out.norm()={(out[0] if isinstance(out, tuple) else out).norm().item():.4f}" + ) diff --git a/examples/dynamo/aliased_io_user_inputs.py b/examples/dynamo/aliased_io_user_inputs.py new file mode 100644 index 0000000000..4c2367465d --- /dev/null +++ b/examples/dynamo/aliased_io_user_inputs.py @@ -0,0 +1,117 @@ +""" +.. _aliased_io_user_inputs_example: + +In-place aliased I/O: caller-owned tensors +============================================ + +This example shows the simplest in-place pattern: the caller owns a buffer +(e.g. a KV cache), passes it into the compiled module on every call, and +the TensorRT engine mutates it in place. No fresh allocation per call, no +post-engine copy. + +Mechanically: + +* The model writes into a slice of one of its inputs: + ``cache[:, :, t:t+1, :] = update``. +* The converter recognizes the ``slice_scatter`` pattern and emits a + ``IKVCacheUpdateLayer`` whose output is aliased to the cache input. +* The C++ runtime binds the aliased output to the input's ``data_ptr``, + skipping ``at::empty``. The engine writes through that pointer directly + into the caller's tensor storage. +* The user-facing return tuple contains only the model's explicit outputs + — the aliased "mutation output" is invisible to the caller. + +Constraints (from the TensorRT operator): + +* Cache shape must be 4-D ``[batch, heads, max_seq, head_dim]`` and fully + static. +* Write dimension must be ``2`` (the sequence axis). +* ``write_start + update_len <= max_seq``. + +If the constraints aren't met the converter silently falls back to a +regular scatter and aliasing isn't established (correctness is still +preserved, but no in-place benefit). +""" + +# %% +# Imports +# ------- +import torch +import torch_tensorrt +from torch.export import export + +# %% +# Model +# ----- +# A trivial "step" that writes one timestep into the cache and returns a +# summary statistic. The interesting line is ``cache[:, :, 3:4, :] = update`` +# — that single slice-write is what triggers the KV-cache fast path. + + +class KVStep(torch.nn.Module): + def forward(self, cache, update): + cache[:, :, 3:4, :] = update + return cache.sum() + + +# %% +# Compile +# ------- +B, H, S_MAX, D = 1, 4, 16, 8 +cache_proto = torch.zeros(B, H, S_MAX, D, device="cuda") +update_proto = torch.ones(B, H, 1, D, device="cuda") + +ep = export(KVStep().cuda(), (cache_proto.clone(), update_proto.clone())) +compiled = torch_tensorrt.compile( + ep, + ir="dynamo", + inputs=[cache_proto.clone(), update_proto.clone()], + enabled_precisions={torch.float32}, + min_block_size=1, + use_python_runtime=False, # aliased I/O requires the C++ runtime +) + +# %% +# Verify aliasing was established +# -------------------------------- +# Each compiled engine carries an ``aliased_io`` map of output binding +# name -> ``(input_binding_name, kind)``. ``kind`` is "kv_cache_update" +# when TensorRT itself enforces the alias (via ``IKVCacheUpdateLayer``). + +for _, mod in compiled.named_modules(): + if hasattr(mod, "aliased_io") and mod.aliased_io: + for out, (inp, kind) in mod.aliased_io.items(): + print(f" {out} <-aliased-> {inp} (kind={kind})") + + +# %% +# Run: cache mutates in place +# --------------------------- +# The caller owns ``cache``. After ``compiled(cache, update)`` returns, +# ``cache`` has been mutated; ``id(cache)`` and ``cache.data_ptr()`` are +# the same as before the call. The return value is just the model's +# explicit output (the sum). + +cache = torch.zeros(B, H, S_MAX, D, device="cuda") +update = torch.ones(B, H, 1, D, device="cuda") * 7.0 +cache_id_before = id(cache) +cache_ptr_before = cache.data_ptr() + +returned = compiled(cache, update) + +print(f"\nreturned: {returned.item()} (expected {7.0 * H * D})") +print(f"cache.sum(): {cache.sum().item()} (expected {7.0 * H * D})") +print(f"id preserved: {id(cache) == cache_id_before}") +print(f"data_ptr preserved: {cache.data_ptr() == cache_ptr_before}") + + +# %% +# Streaming: repeated calls accumulate state +# ------------------------------------------- +# Because the caller's cache tensor identity is preserved and the engine +# writes in place, each call sees the result of the previous one. + +cache = torch.zeros(B, H, S_MAX, D, device="cuda") +for step, scale in enumerate([1.0, 5.0, 0.0]): + compiled(cache, torch.ones(B, H, 1, D, device="cuda") * scale) + print(f"step {step}: cache.sum() = {cache.sum().item()}") diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index e7ff1cf812..1135400eec 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,7 +5,7 @@ import os import platform import warnings -from typing import Any, Collection, List, Optional, Sequence, Union +from typing import Any, Collection, List, Optional, Sequence, Tuple, Union import torch from torch.export import ExportedProgram @@ -40,6 +40,10 @@ post_lowering, pre_export_lowering, ) +from torch_tensorrt.dynamo.lowering._buffer_lifting import ( + inline_lifted_buffers_into_gm, + lift_mutated_buffers, +) from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( resource_partition, ) @@ -750,6 +754,25 @@ def compile( # Move the weights in the state_dict to CPU logger.debug("Input graph: " + str(gm.graph)) + # Lift mutated buffers from get_attr to placeholders BEFORE post_lowering's + # constant_fold runs, so the engine sees them as input bindings (a + # prerequisite for IKVCacheUpdateLayer / aliased I/O to fire on a + # module-held cache). Returns a fresh GraphModule whose forward signature + # reflects the new placeholders. + gm, lifted_buffers = lift_mutated_buffers(gm) + if lifted_buffers: + # Append each lifted buffer as an engine input AFTER the user inputs. + # Buffer tensors live on the gm's state; prepare an Input spec for + # each so engine building knows their shape/dtype/device. + buffer_tensors = [t for _, _, t in lifted_buffers] + buffer_inputs = prepare_inputs(buffer_tensors) + trt_arg_inputs = list(trt_arg_inputs) + list(buffer_inputs) + logger.info( + "Lifted %d mutable buffer(s) into engine inputs: %s", + len(lifted_buffers), + [b for _, b, _ in lifted_buffers], + ) + # Apply lowering on the graph module. Note: constant_fold runs inside post_lowering and requires # module parameters to still be on GPU, so we must not deallocate before this call. gm = post_lowering(gm, settings) @@ -772,6 +795,14 @@ def compile( trt_gm = compile_module( gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache ) + if lifted_buffers: + # Inline buffers into the compiled gm as get_attr nodes + registered + # buffers. The resulting gm's forward takes only user inputs; buffers + # are read from module state on each call and threaded into the + # engine via get_attr nodes in the fx graph. This shape is naturally + # serializable by torch_tensorrt.save / torch.export (no external + # Python wrapper that would be lost on a round-trip). + trt_gm = inline_lifted_buffers_into_gm(trt_gm, lifted_buffers) return trt_gm @@ -1244,12 +1275,27 @@ def convert_exported_program_to_serialized_trt_engine( use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, + lift_mutable_buffers: bool = False, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine Converts an ExportedProgram to a serialized TensorRT engine given a dictionary of conversion settings + When ``lift_mutable_buffers=True``, any module buffer that the model mutates + (a ``BUFFER_MUTATION`` in the EP's graph signature) is lifted from a baked-in + constant to an engine *input binding*. The resulting engine has additional + input bindings appended after the user-supplied inputs, in the order the + buffers appear in the EP. The caller is responsible for threading those + bindings at runtime — pass the current buffer values in on each call; the + engine writes through the binding via aliased I/O so the buffer's storage + is mutated in place. Use ``trt.ICudaEngine.get_aliased_input_tensor`` (or + the metadata exposed by ``TRTEngine`` via ``aliased_io``) to discover + which output binding aliases which input. The higher-level + :func:`torch_tensorrt.dynamo.compile` does this lifting and threading + automatically; this lower-level entry point exposes the same machinery + for callers that want to manage the bindings themselves. + Arguments: exported_program (torch.export.ExportedProgram): Source module, running torch.export on a ``torch.nn.Module`` inputs (Optional[Sequence[Sequence[Any]]]): List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using @@ -1497,6 +1543,25 @@ def convert_exported_program_to_serialized_trt_engine( # Move the weights in the state_dict to CPU logger.debug("Input graph: " + str(gm.graph)) + # Optional: lift mutated module buffers from get_attr to placeholder so the + # engine treats them as input bindings (enabling KV-cache aliasing for + # module-held caches). The caller is responsible for threading the + # resulting bindings at runtime — they are appended after the user inputs + # in the order returned here. + lifted_buffers: List[Tuple[str, str, torch.Tensor]] = [] + if lift_mutable_buffers: + gm, lifted_buffers = lift_mutated_buffers(gm) + if lifted_buffers: + buffer_tensors = [t for _, _, t in lifted_buffers] + buffer_inputs = prepare_inputs(buffer_tensors) + trt_arg_inputs = list(trt_arg_inputs) + list(buffer_inputs) + logger.info( + "lift_mutable_buffers=True: lifted %d buffer(s) into engine " + "inputs (appended after user inputs): %s", + len(lifted_buffers), + [b for _, b, _ in lifted_buffers], + ) + # Apply lowering on the graph module gm = post_lowering(gm, settings) logger.debug("Lowered Input graph: " + str(gm.graph)) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index f5ffdafda2..bf077be087 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -1,10 +1,39 @@ from dataclasses import dataclass, field +from enum import Enum +from typing import NamedTuple import torch from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.types import TRTNetwork +class AliasKind(str, Enum): + """Origin of an aliased input/output binding pair. + + KV_CACHE_UPDATE: Aliasing is enforced by TensorRT itself via + ``IKVCacheUpdateLayer``; the engine reports it through + ``ICudaEngine.get_aliased_input_tensor``. Shape contract is enforced + by the layer. + + USER: Aliasing is declared by the Torch-TensorRT compile flow. TRT does + not enforce it; the runtime must validate shape compatibility and + bind both input and output to the same device pointer. + """ + + KV_CACHE_UPDATE = "kv_cache_update" + USER = "user" + + +class AliasedOutput(NamedTuple): + """One aliased output recorded during conversion.""" + + # The TRT ITensor that should be aliased to an input binding. + output_tensor: object # tensorrt.ITensor (avoid hard import here) + # The TRT input binding name the output should share device memory with. + input_binding_name: str + kind: AliasKind + + @dataclass class ConversionContext: """Class representing the context for conversion of a particular network @@ -25,6 +54,11 @@ class ConversionContext: requires_native_multidevice: bool = False weight_refit_map: dict[str, torch.Tensor] = field(default_factory=dict) cpu_weights_reference_holder: list[torch.Tensor] = field(default_factory=list) + # Aliased outputs registered by converters during conversion. + # ``TRTInterpreter`` is responsible for ensuring each output_tensor that + # isn't already a user output is added to the network outputs, and for + # carrying the alias mapping forward into ``TRTInterpreterResult``. + aliased_outputs: list[AliasedOutput] = field(default_factory=list) def record_weight(self, name: str, weight: torch.Tensor) -> None: """ diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index e1f4d8bafb..bc0f0fa26a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -17,7 +17,6 @@ ) import numpy as np -import tensorrt as trt import torch import torch.fx from torch.fx.experimental.proxy_tensor import unset_fake_temporarily @@ -56,6 +55,8 @@ ) from torch_tensorrt.logging import TRT_LOGGER +import tensorrt as trt + _LOGGER: logging.Logger = logging.getLogger(__name__) TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = ( @@ -74,6 +75,19 @@ class TRTInterpreterResult(NamedTuple): weight_name_map: Optional[dict[Any, Any]] requires_output_allocator: bool requires_native_multidevice: bool + # Per-output aliasing info. Map of engine output binding name -> tuple of + # (input binding name, alias-kind string). The kind is "kv_cache_update" + # for TRT-enforced aliasing (IKVCacheUpdateLayer) or "user" for + # Torch-TensorRT-declared aliasing. The runtime uses this to bind aliased + # outputs to their input device pointers and skip fresh allocation. + # + # By convention the interpreter appends side-effect aliased outputs + # (added to satisfy layers like IKVCacheUpdateLayer that require their + # output to be a network output) to the END of ``output_names``. The + # runtime derives the user-output count by walking that list backwards + # — see ``user_output_count`` in the runtime module — and hides the + # side-effect outputs from the caller's return tuple. + aliased_io: dict[str, tuple[str, str]] = {} @cls_supports_debugger @@ -135,6 +149,10 @@ def __init__( self._cur_node: Optional[torch.fx.Node] = None self._input_names: List[str] = [] self._output_names: List[str] = [] + # Per-output binding aliasing: output_name -> (input_binding_name, kind_str). + # `kind_str` is "kv_cache_update" or "user". Populated by aliased + # converters and reconciled with engine.get_aliased_input_tensor. + self._aliased_io: Dict[str, Tuple[str, str]] = {} self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = ( dict() ) @@ -665,6 +683,21 @@ def run( builder_config, self.compilation_settings.timing_cache_path ) + # Reconcile against the engine: TRT exposes aliasing via + # get_aliased_input_tensor on ICudaEngine. The engine API is the + # source of truth for KV-cache-style aliasing; our build-time records + # are a fast cache. User-declared aliasing (kind="user") is preserved + # as-is since TRT doesn't know about it. TRT returns None / empty + # string for non-aliased outputs; any raised exception is a real + # error in the engine and propagates. + engine_aliased_io: Dict[str, Tuple[str, str]] = dict(self._aliased_io) + for out_name in self._output_names: + aliased_in = cuda_engine.get_aliased_input_tensor(out_name) + if aliased_in: + # Engine-reported aliasing is always KV-cache-update origin + # (the only TRT-enforced aliasing API in 10.x). + engine_aliased_io[out_name] = (aliased_in, "kv_cache_update") + return TRTInterpreterResult( cuda_engine, self._input_names, @@ -672,6 +705,7 @@ def run( self.weight_name_map, self.ctx.requires_output_allocator, self.ctx.requires_native_multidevice, + engine_aliased_io, ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: @@ -843,6 +877,28 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: else: outputs = (args[0],) + # Aliased outputs (e.g. IKVCacheUpdateLayer) must be marked as network + # outputs even if the user's source did not return them. We APPEND + # them after the user outputs — the runtime derives the user/ + # side-effect boundary by walking output_names backward and treating + # the contiguous in-aliased_io suffix as side-effects. + user_output_ids = {id(o) for o in outputs if isinstance(o, trt.ITensor)} + for entry in self.ctx.aliased_outputs: + aliased_tensor = entry.output_tensor + if id(aliased_tensor) not in user_output_ids: + outputs = outputs + (aliased_tensor,) + # Extend output_dtypes so the dtype-mismatch check passes and + # no cast is inserted (a cast would break engine-level aliasing). + if self.output_dtypes is not None: + aliased_dtype = dtype._from(aliased_tensor.dtype) + self.output_dtypes = list(self.output_dtypes) + [aliased_dtype] + # Map ITensor identity -> (input_binding_name, kind_str), used after + # rename below to populate self._aliased_io keyed by final binding name. + aliased_info_by_id = { + id(e.output_tensor): (e.input_binding_name, e.kind.value) + for e in self.ctx.aliased_outputs + } + for output_idx in range(len(outputs)): output = outputs[output_idx] @@ -892,6 +948,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: else: output_dtype = dtype.unknown + # Capture identity before any cast — we use it to find aliasing. + original_id = id(output) + if output_dtype is not dtype.unknown: output = self._cast_output_dtype( output, @@ -904,6 +963,30 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: self.ctx.net.mark_output(output) self._output_names.append(name) + + # If this output was emitted by an aliased layer (e.g. IKVCacheUpdateLayer), + # carry the alias to the final binding name. A cast layer inserted + # above would break engine-level aliasing — warn instead of recording. + alias_info = aliased_info_by_id.get(original_id) + if alias_info is not None: + aliased_input, kind_str = alias_info + if output_dtype is not dtype.unknown and id(output) != original_id: + _LOGGER.warning( + "Output %s was aliased to input %s (kind=%s) but a dtype cast " + "was inserted; engine-level aliasing is broken for this output.", + name, + aliased_input, + kind_str, + ) + else: + self._aliased_io[name] = (aliased_input, kind_str) + _LOGGER.debug( + "Output %s aliased to input %s (kind=%s)", + name, + aliased_input, + kind_str, + ) + _LOGGER.debug( f"Marking output {name} [shape={output.shape}, dtype={output.dtype}]" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 77e48fe92e..97aac6cd17 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -2,7 +2,7 @@ import io import logging -from typing import Any, Dict, List, NamedTuple, Optional, Sequence +from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple import torch from torch_tensorrt._enums import dtype @@ -38,6 +38,12 @@ class SerializedInterpreterResult(NamedTuple): requires_output_allocator: bool symbolic_shape_expressions: Dict[str, List[Dict[str, Any]]] requires_native_multidevice: bool + # Map of engine output binding name -> (input binding name, kind_str). The + # kind_str distinguishes "kv_cache_update" (TRT-enforced via + # IKVCacheUpdateLayer; reported by ICudaEngine.get_aliased_input_tensor) + # from "user" (Torch-TensorRT-declared; runtime must enforce shape match + # and bind the same device pointer). + aliased_io: Dict[str, Tuple[str, str]] = {} def infer_module_output_dtypes( @@ -323,6 +329,7 @@ def interpret_module_to_result( requires_output_allocator=interpreter_result.requires_output_allocator, requires_native_multidevice=interpreter_result.requires_native_multidevice, symbolic_shape_expressions=symbolic_shape_expressions, + aliased_io=interpreter_result.aliased_io, ) return serialized_interpreter_result @@ -389,4 +396,5 @@ def convert_module( requires_output_allocator=serialized_interpreter_result.requires_output_allocator, requires_native_multidevice=serialized_interpreter_result.requires_native_multidevice, symbolic_shape_expressions=serialized_interpreter_result.symbolic_shape_expressions, + aliased_io=serialized_interpreter_result.aliased_io, ) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 426c3e0708..47d8608210 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -6,7 +6,6 @@ import numpy as np import torch -from tensorrt import ITensor as TRTTensor from torch.fx.node import Argument, Node, Target from torch_tensorrt import ENABLED_FEATURES from torch_tensorrt._features import needs_not_tensorrt_rtx @@ -28,6 +27,8 @@ ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +from tensorrt import ITensor as TRTTensor + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -1058,6 +1059,133 @@ def aten_ops_gather( ) +def _index_copy_kv_eligible( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: + """Validator for the KV-cache fast path of ``aten.index_copy.default``. + + Returns True only for the narrow case our ``IKVCacheUpdateLayer`` + emitter can handle without a graph break: + + * Input is an FX placeholder (i.e. a network input — required for + aliasing). + * Input has rank 4 and fully static shape ``[b, d, s_max, h]``. + * ``dim`` argument is exactly ``2``. + * Source tensor has rank 4 with ``shape[2] == 1`` (single-position + write; matches HF's per-step ``StaticCache.update`` call). + * Batch is 1 (avoids writeIndices broadcasting; trivially extensible + to larger batches when needed). + + Cases that fail this validator fall through to + ``aten_ops_index_copy_fallback``. + """ + if len(node.args) < 4: + return False + input_node, dim, _index_node, src_node = node.args[:4] + + if not isinstance(input_node, Node) or input_node.op != "placeholder": + return False + input_val = input_node.meta.get("val") + if input_val is None: + return False + input_shape = tuple(input_val.shape) + if len(input_shape) != 4: + return False + if any(not isinstance(s, int) or s < 0 for s in input_shape): + return False + if input_shape[0] != 1: + return False # batch > 1 deferred; see index_copy.index_copy_kv + + if dim != 2: + return False + + if not isinstance(src_node, Node): + return False + src_val = src_node.meta.get("val") + if src_val is None: + return False + src_shape = tuple(src_val.shape) + if len(src_shape) != 4 or not isinstance(src_shape[2], int) or src_shape[2] != 1: + return False + + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.index_copy.default, + capability_validator=_index_copy_kv_eligible, + priority=ConverterPriority.HIGH, + supports_dynamic_shapes=False, +) +@enforce_tensor_types({0: (TRTTensor,), 2: (TRTTensor,), 3: (TRTTensor,)}) +def aten_ops_index_copy_kv( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.index_copy.index_copy_kv( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + dim=args[1], + index=args[2], + src=args[3], + ) + + +@dynamo_tensorrt_converter( + torch.ops.aten.index_copy.default, + supports_dynamic_shapes=False, +) +@enforce_tensor_types({0: (TRTTensor,), 2: (TRTTensor,), 3: (TRTTensor,)}) +def aten_ops_index_copy_fallback( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.index_copy.index_copy_fallback( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + dim=args[1], + index=args[2], + src=args[3], + ) + + +@dynamo_tensorrt_converter( + torch.ops.aten.slice_scatter.default, supports_dynamic_shapes=True +) +@enforce_tensor_types({0: (TRTTensor,), 1: (TRTTensor,)}) +def aten_ops_slice_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice_scatter.slice_scatter( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + src=args[1], + dim=args[2], + start=args_bounds_check(args, 3), + end=args_bounds_check(args, 4), + step=args_bounds_check(args, 5), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.scatter.src, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.scatter.value, supports_dynamic_shapes=True) @enforce_tensor_types( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index ffe3071cc8..0cc92c40d4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -13,6 +13,7 @@ embedding, full, grid, + index_copy, linear, matmul, nccl_ops, @@ -27,6 +28,7 @@ shape, shuffle, slice, + slice_scatter, split, squeeze, topk, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/index_copy.py b/py/torch_tensorrt/dynamo/conversion/impl/index_copy.py new file mode 100644 index 0000000000..7fb9ee97b2 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/index_copy.py @@ -0,0 +1,161 @@ +"""TensorRT conversion for ``aten.index_copy.default``. + +Mirrors the structure of ``slice_scatter`` but the eligibility check is +gated by a validator (declared in ``aten_ops_converters.py``) and there +are two registered converters for the same op: + +* ``aten_ops_index_copy_kv`` — HIGH priority, validator-gated. Fires only + when the input is a 4-D static cache, ``dim=2``, and the source has + ``shape[2] == 1`` (single-position write — the common + streaming-decoder pattern). Emits ``IKVCacheUpdateLayer`` whose output + is aliased to the cache input. + +* ``aten_ops_index_copy_fallback`` — STANDARD priority, always fires. + Implements the general semantics via scatter (equivalent to what the + torch decomposition would produce). + +Since both converters live in TRT, no graph break is introduced for +non-KV cases; they just take a less-efficient TRT path. + +The two functions here implement the bodies. Registration with the +validator lives next to the other aten converters for discoverability. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.dynamo.conversion.impl import select +from torch_tensorrt.dynamo.conversion.impl.slice_scatter import ( + emit_kv_cache_update_layer, +) + +import tensorrt as trt +from tensorrt import ITensor as TRTTensor + +logger = logging.getLogger(__name__) + + +def index_copy_kv( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: TRTTensor, + src: TRTTensor, +) -> TRTTensor: + """KV-cache fast path. Caller (the validator) has already verified + that this case is KV-eligible — 4-D static cache, dim=2, source + rank=4 with seq dim of size 1. + + ``index`` is the position tensor: shape ``[s_update]``. For the + eligible case (s_update == 1) this is a single value that we use + directly as the write start. KVCacheUpdate's ``writeIndices`` arg + expects shape ``[batch]``, so we broadcast/repeat the single index + value across the batch dimension. + """ + cache_shape = tuple(input.shape) + batch = cache_shape[0] + + # KVCacheUpdate accepts int32 / int64 writeIndices; TRT auto-promotes + # but be explicit to avoid surprises across version drift. + if index.dtype != trt.int32: + index = cast_trt_tensor(ctx, index, trt.int32, name + "_index_to_int32") + + # writeIndices shape must be [batch]. For batch=1 with index shape [1] + # this is already correct. For batch > 1, broadcast: emit a constant + # of shape [batch] filled with the single index value at runtime via + # a gather/expand pattern. The validator's job is to allow only cases + # where this is well-defined; for now we restrict to batch == 1 (see + # the validator). + write_indices = index + + # When batch > 1, we'd need to broadcast `index` (shape [1]) to + # [batch]. The validator currently keeps us in the batch==1 case. + if isinstance(batch, int) and batch > 1: + # Defensive: shouldn't happen if the validator is correct, but + # fall back rather than emit a wrong layer. + logger.debug( + "index_copy_kv: batch > 1 not yet supported for runtime indices; " + "falling back to scatter" + ) + return index_copy_fallback(ctx, target, source_ir, name, input, dim, index, src) + + out = emit_kv_cache_update_layer(ctx, name, input, src, write_indices) + if out is None: + # KV emission failed (e.g. input not a direct network input); + # fall through to scatter so correctness is preserved. + return index_copy_fallback(ctx, target, source_ir, name, input, dim, index, src) + return out + + +def index_copy_fallback( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: TRTTensor, + src: TRTTensor, +) -> TRTTensor: + """General-purpose ``index_copy``: scatter ``src`` into ``input`` at + positions in ``index`` along ``dim``. Equivalent to the standard + torch decomposition: build a broadcast index tensor of the same + shape as ``src`` with ``index`` placed along ``dim`` and call + ``scatter``. + """ + rank = len(input.shape) + src_shape = tuple(src.shape) + + # Reshape `index` (1-D, length matches src.shape[dim]) so it broadcasts + # over the remaining dims of `src`. + reshape_to = [1] * rank + if isinstance(src_shape[dim], int): + reshape_to[dim] = src_shape[dim] + else: + # Dynamic seq dim — defer to a shape-aware reshape. Build it via + # a shape op so the reshape is dynamic-shape-safe. + # For now require static; raise instead of silently producing wrong results. + raise NotImplementedError( + "index_copy fallback with dynamic source shape on dim %d is not " + "yet supported." % dim + ) + + shuffle = ctx.net.add_shuffle(index) + shuffle.reshape_dims = trt.Dims(reshape_to) + reshaped_index = shuffle.get_output(0) + + # Broadcast/expand to src's shape. ``broadcast_to`` semantics: numpy + # array first then to TRT. + if all(isinstance(s, int) for s in src_shape): + # Static case: just expand via a slice with broadcast strides. + expand_layer = ctx.net.add_slice( + reshaped_index, + start=tuple(0 for _ in range(rank)), + shape=src_shape, + stride=tuple(0 if i != dim else 1 for i in range(rank)), + ) + index_broadcast = expand_layer.get_output(0) + else: + raise NotImplementedError( + "index_copy fallback with dynamic shapes is not yet supported." + ) + + return select.scatter( + ctx, + target, + source_ir, + name + "_fallback_scatter", + input, + dim, + index_broadcast, + src, + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice_scatter.py b/py/torch_tensorrt/dynamo/conversion/impl/slice_scatter.py new file mode 100644 index 0000000000..fe6621f62b --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice_scatter.py @@ -0,0 +1,230 @@ +"""TensorRT conversion for ``aten.slice_scatter.default``. + +Two paths: + +* **KV-cache fast path** (``IKVCacheUpdateLayer``) — fires when the input is + a direct network input, the layer's invariants hold (4-D static shape, write + on dim 2, ``start + update_len <= s_max``), and the batch dim is static. The + resulting output is recorded in ``ctx.aliased_io`` so the runtime can bind + it to the input's device pointer. + +* **Fallback** — equivalent to the previous Torch-TRT decomposition: build a + broadcast index tensor and emit a regular scatter. Used whenever the KV + constraints fail. + +Slice_scatter is intentionally NOT decomposed in the Torch-TRT decomposition +table; this converter is the single place that handles it. +""" + +from __future__ import annotations + +import logging +from typing import Optional, Tuple + +import numpy as np +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ( + AliasedOutput, + AliasKind, + ConversionContext, +) +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_trt_tensor, + set_layer_name, +) +from torch_tensorrt.dynamo.conversion.impl import select + +import tensorrt as trt +from tensorrt import ITensor as TRTTensor + +logger = logging.getLogger(__name__) + + +def _kv_eligible( + cache_shape: Tuple[int, ...], dim: int, start: int, update_len: int +) -> Tuple[bool, str]: + """Apply IKVCacheUpdateLayer's invariants. + + Returns (eligible, reason). The reason is non-empty in both cases for logs. + """ + if any(not isinstance(s, int) or s < 0 for s in cache_shape): + return False, f"cache shape is dynamic ({cache_shape}); s_max must be static" + if len(cache_shape) != 4: + return ( + False, + f"cache rank is {len(cache_shape)}; KVCacheUpdate requires 4-D [b,d,s_max,h]", + ) + if dim != 2: + return False, f"write dim is {dim}; KVCacheUpdate requires dim=2" + s_max = cache_shape[2] + if start + update_len > s_max: + return ( + False, + f"write_start({start})+update_len({update_len}) > s_max({s_max})", + ) + return True, f"eligible (s_max={s_max}, write_start={start}, len={update_len})" + + +def input_binding_name(ctx: ConversionContext, tensor: TRTTensor) -> Optional[str]: + """If ``tensor`` is a direct network input, return its binding name, else None.""" + for i in range(ctx.net.num_inputs): + net_input = ctx.net.get_input(i) + if net_input is tensor or net_input.name == tensor.name: + return str(net_input.name) + return None + + +def emit_kv_cache_update_layer( + ctx: ConversionContext, + name: str, + cache: TRTTensor, + src: TRTTensor, + write_indices: TRTTensor, +) -> Optional[TRTTensor]: + """Lower-level KVCacheUpdate emission given a write_indices ITensor. + + Performs the binding-name lookup, calls ``add_kv_cache_update``, and + records the aliased output. Returns the layer output ITensor (which is + a network output, aliased to ``cache``) or None if the cache isn't a + direct network input or the layer can't be created. + + Validators upstream are expected to have already verified shape / + dtype / dim invariants; this function trusts its inputs. + """ + cache_input_name = input_binding_name(ctx, cache) + if cache_input_name is None: + logger.debug("KV cache update: skipped — input is not a direct network input") + return None + + layer = ctx.net.add_kv_cache_update( + cache, src, write_indices, trt.KVCacheMode.LINEAR + ) + if layer is None: + logger.debug("KV cache update: add_kv_cache_update returned None") + return None + set_layer_name(layer, "kv_cache_update", name + "_kv_cache_update", SourceIR.ATEN) + out = layer.get_output(0) + + ctx.aliased_outputs.append( + AliasedOutput( + output_tensor=out, + input_binding_name=cache_input_name, + kind=AliasKind.KV_CACHE_UPDATE, + ) + ) + logger.debug( + "KV cache update: emitted; output %s aliased to input %s", + out.name, + cache_input_name, + ) + return out + + +def try_emit_kv_cache_update( + ctx: ConversionContext, + name: str, + cache: TRTTensor, + src: TRTTensor, + dim: int, + start: int, + update_len: int, +) -> Optional[TRTTensor]: + """Emit IKVCacheUpdateLayer if all constraints are met. None otherwise. + + Shared by the slice_scatter and index_copy converters. ``start`` is the + constant write position (e.g. ``slice_scatter``'s ``start`` arg or the + single value from an ``index_copy`` index tensor). The resulting layer + writes ``update_len`` slots starting at ``start`` per batch element and + its output is recorded as aliased to the cache input. + """ + cache_shape = tuple(cache.shape) + eligible, reason = _kv_eligible(cache_shape, dim, start, update_len) + if not eligible: + logger.debug("slice_scatter: KV fast path skipped — %s", reason) + return None + + batch = cache_shape[0] + if not isinstance(batch, int) or batch < 0: + logger.debug( + "slice_scatter: KV fast path skipped — dynamic batch dim (%s); writeIndices " + "must be statically sized for now", + batch, + ) + return None + + write_indices_np: np.ndarray = np.full((batch,), start, dtype=np.int32) + write_indices = get_trt_tensor(ctx, write_indices_np, name + "_write_indices") + + return emit_kv_cache_update_layer(ctx, name, cache, src, write_indices) + + +def slice_scatter( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + src: TRTTensor, + dim: int, + start: Optional[int] = None, + end: Optional[int] = None, + step: Optional[int] = None, +) -> TRTTensor: + """Emit either an IKVCacheUpdateLayer (with aliased I/O) or a scatter sequence.""" + rank = len(input.shape) + dim_size = input.shape[dim] + + if start is None: + start = 0 + if start < 0 and isinstance(dim_size, int): + start = dim_size + start + if end is None: + end = dim_size + if isinstance(end, int) and end < 0 and isinstance(dim_size, int): + end = dim_size + end + if step is None: + step = 1 + + # Trivial: full overwrite. + if ( + isinstance(start, int) + and isinstance(end, int) + and isinstance(dim_size, int) + and start == 0 + and end == dim_size + and step == 1 + ): + return src + + if not (isinstance(start, int) and isinstance(end, int) and isinstance(step, int)): + raise NotImplementedError( + "slice_scatter with dynamic start/end/step is not yet supported" + ) + + update_len = end - start + + # KV fast path. + kv_out = try_emit_kv_cache_update(ctx, name, input, src, dim, start, update_len) + if kv_out is not None: + return kv_out + + # Fallback: build broadcast indices and scatter. + indices_np: np.ndarray = np.arange(start, end, step, dtype=np.int64) + target_shape = [1] * rank + target_shape[dim] = len(indices_np) + indices_np = indices_np.reshape(target_shape) + src_shape = tuple(src.shape) + indices_np = np.broadcast_to(indices_np, src_shape).astype(np.int64) + indices_tensor = get_trt_tensor(ctx, indices_np, name + "_fallback_indices") + + return select.scatter( + ctx, + target, + source_ir, + name + "_fallback_scatter", + input, + dim, + indices_tensor, + src, + ) diff --git a/py/torch_tensorrt/dynamo/lowering/_buffer_lifting.py b/py/torch_tensorrt/dynamo/lowering/_buffer_lifting.py new file mode 100644 index 0000000000..2f57e1bf23 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/_buffer_lifting.py @@ -0,0 +1,252 @@ +"""Lift mutated module buffers back to engine input bindings. + +PyTorch's ``ExportedProgram.module()`` converts BUFFER placeholders into +``get_attr`` nodes that read the buffer from the GraphModule's state, plus a +trailing ``aten.copy_(get_attr_buffer, new_value)`` per BUFFER_MUTATION +output. From Torch-TensorRT's point of view those ``get_attr`` nodes are +parameters that get constant-folded; the buffer becomes baked into the engine +and the trailing ``copy_`` is dropped. Per-call buffer state is lost and the +KV-cache aliasing path cannot fire (the cache isn't a network input). + +This module provides: + +* :func:`lift_mutated_buffers` — pre-compile rewrite that turns each mutated + buffer's ``get_attr`` into a ``placeholder`` and removes the trailing + ``copy_``. The buffer becomes an engine input binding; downstream the + slice_scatter converter's KV-cache fast path can recognize the cache as a + network input and emit ``IKVCacheUpdateLayer`` with aliased I/O. + +* :func:`inline_lifted_buffers_into_gm` — post-compile transform that + registers each lifted buffer as state on the compiled GraphModule and + rewrites the corresponding placeholder nodes to ``get_attr`` reads. The + resulting module's ``forward`` takes only user inputs (buffers are + threaded internally via the fx graph). Because everything is fx + + module state, the result serializes naturally through + ``torch_tensorrt.save`` / ``torch.export``. +""" + +from __future__ import annotations + +import logging +from typing import Dict, List, Tuple + +import torch + +logger = logging.getLogger(__name__) + + +def lift_mutated_buffers( + gm: torch.fx.GraphModule, +) -> Tuple[torch.fx.GraphModule, List[Tuple[str, str, torch.Tensor]]]: + """Lift each mutated buffer from a ``get_attr`` to a ``placeholder``. + + A mutated buffer is identified by a trailing + ``aten.copy_(get_attr_buffer, new_value)`` pattern, which is how + ``ExportedProgram.module()`` represents a BUFFER_MUTATION. + + Returns ``(new_gm, lifted)`` where: + + * ``new_gm`` is a plain ``torch.fx.GraphModule`` whose ``forward`` + signature reflects the updated placeholder set. Necessary because + ``ExportedProgram.module()`` produces a module whose forward is + fixed by a pytree spec — recompiling alone doesn't pick up new + placeholders. + * ``lifted`` is a list of ``(placeholder_name, buffer_name, buffer_tensor)`` + tuples, in the order placeholders were appended (which matches the + order they appear in the new gm's forward signature, after the + original user inputs). + """ + # Find all aten.copy_(get_attr_X, _) calls. The first arg's target is + # the buffer name. Some EPs emit copy_.default, others copy_. + mutation_pairs: List[Tuple[torch.fx.Node, torch.fx.Node]] = ( + [] + ) # (copy_node, get_attr_node) + for node in gm.graph.nodes: + if node.op != "call_function": + continue + if node.target not in (torch.ops.aten.copy_.default, torch.ops.aten.copy_): + continue + if not node.args: + continue + target = node.args[0] + if isinstance(target, torch.fx.Node) and target.op == "get_attr": + mutation_pairs.append((node, target)) + + if not mutation_pairs: + return gm, [] + + # Find the position to insert new placeholders (after the last existing placeholder). + placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + insert_after = placeholders[-1] if placeholders else None + + lifted: List[Tuple[str, str, torch.Tensor]] = [] + seen_buffers: Dict[str, torch.fx.Node] = {} # buffer name -> new placeholder node + + for copy_node, get_attr_node in mutation_pairs: + buffer_name = get_attr_node.target + if not hasattr(gm, buffer_name): + logger.warning( + "lift_mutated_buffers: get_attr target %s not found on gm; skipping", + buffer_name, + ) + continue + buffer_tensor = getattr(gm, buffer_name) + if not isinstance(buffer_tensor, torch.Tensor): + logger.debug( + "lift_mutated_buffers: attribute %s is not a Tensor; skipping", + buffer_name, + ) + continue + + if buffer_name in seen_buffers: + # Same buffer mutated more than once — already lifted; just remove + # this copy_ node and rely on the existing placeholder. + replacement = seen_buffers[buffer_name] + else: + # Build a unique placeholder name from the buffer name. + placeholder_name = "buf_" + buffer_name.replace(".", "_") + base = placeholder_name + suffix = 0 + existing = {n.name for n in gm.graph.nodes} + while placeholder_name in existing: + suffix += 1 + placeholder_name = f"{base}_{suffix}" + + if insert_after is not None: + with gm.graph.inserting_after(insert_after): + new_ph = gm.graph.placeholder(placeholder_name) + else: + # No existing placeholders — insert at graph start. + with gm.graph.inserting_before(next(iter(gm.graph.nodes))): + new_ph = gm.graph.placeholder(placeholder_name) + new_ph.meta["val"] = get_attr_node.meta.get( + "val", + torch.empty_like(buffer_tensor, device="meta"), + ) + new_ph.meta["_lifted_buffer"] = buffer_name + insert_after = new_ph + seen_buffers[buffer_name] = new_ph + replacement = new_ph + lifted.append((placeholder_name, buffer_name, buffer_tensor)) + + # Re-route every use of the original get_attr (other than the copy_ itself) + # to the new placeholder. + get_attr_node.replace_all_uses_with(replacement) + + # Drop the trailing copy_ (it's now redundant — the mutation lands on the + # placeholder's storage via engine-level aliasing). + gm.graph.erase_node(copy_node) + + # Erase the now-unused get_attr. + if not get_attr_node.users: + gm.graph.erase_node(get_attr_node) + + gm.graph.lint() + + if not lifted: + return gm, [] + + # ExportedProgram.module() produces a GraphModule whose forward is + # generated by a ``_PyTreeCodeGen`` baked into the graph: the body + # unpacks args through a stored pytree spec, ignoring any added + # placeholders. Rebuild the gm with the default ``CodeGen`` so the + # forward signature reflects the placeholder set as written. + # First remove the call to ``_guards_fn`` (generated for the original + # arity; would fail after lifting). + for node in list(gm.graph.nodes): + if ( + node.op == "call_module" + and isinstance(node.target, str) + and node.target == "_guards_fn" + ): + gm.graph.erase_node(node) + break + + # Reset codegen to the plain CodeGen so the forward args = placeholders. + gm.graph.set_codegen(torch.fx.graph.CodeGen()) + gm.graph.lint() + + new_gm = torch.fx.GraphModule(gm, gm.graph) + for attr in ("_in_spec", "_out_spec"): + if hasattr(new_gm, attr): + try: + delattr(new_gm, attr) + except AttributeError: + pass + new_gm.recompile() + + logger.debug( + "Lifted %d mutated buffer(s) to placeholders: %s", + len(lifted), + [(p, b) for p, b, _ in lifted], + ) + + return new_gm, lifted + + +def inline_lifted_buffers_into_gm( + gm: torch.fx.GraphModule, + lifted_buffers: List[Tuple[str, str, torch.Tensor]], +) -> torch.fx.GraphModule: + """Inline lifted buffers as ``get_attr`` reads on the compiled GraphModule. + + After ``lift_mutated_buffers`` + ``compile_module``, ``gm`` is a + ``torch.fx.GraphModule`` whose top-level ``forward`` takes the user's + inputs *plus* the lifted buffers as placeholders. To make the result + look like a normal ``nn.Module`` (and to make it serializable via + ``torch_tensorrt.save`` / ``torch.export``) we: + + 1. Register each lifted buffer as a ``register_buffer`` on ``gm``. + 2. Replace each buffer-placeholder node with a ``get_attr`` node that + reads from ``gm.``. + 3. Recompile. + + The resulting GraphModule's ``forward`` takes only the user's inputs; + the buffers are threaded through internally via the get_attr nodes. + The engine still sees the buffers as input bindings (and writes through + them via aliased I/O); the buffer storage lives on ``gm`` so subsequent + calls reuse the mutated state. + + This transform is a no-op if ``lifted_buffers`` is empty (returns + ``gm`` unchanged). + """ + if not lifted_buffers: + return gm + + placeholder_to_buf: Dict[str, str] = { + ph_name: buf_name for ph_name, buf_name, _ in lifted_buffers + } + # Register buffers as module state. Clone so the gm owns its own storage. + for _ph_name, buf_name, tensor in lifted_buffers: + if not hasattr(gm, buf_name): + gm.register_buffer(buf_name, tensor.clone()) + + # Find placeholders we need to replace. Insert get_attr nodes BEFORE + # removing the placeholders so the graph remains valid throughout. + placeholders_to_remove = [] + for node in list(gm.graph.nodes): + if node.op != "placeholder": + continue + if node.name not in placeholder_to_buf: + continue + buf_name = placeholder_to_buf[node.name] + with gm.graph.inserting_after(node): + get_attr_node = gm.graph.get_attr(buf_name) + # Carry over fake-tensor metadata so downstream passes see the right + # shape/dtype. + if "val" in node.meta: + get_attr_node.meta["val"] = node.meta["val"] + node.replace_all_uses_with(get_attr_node) + placeholders_to_remove.append(node) + + for node in placeholders_to_remove: + gm.graph.erase_node(node) + + gm.graph.lint() + gm.recompile() + logger.debug( + "Inlined %d lifted buffer(s) into gm as get_attr reads: %s", + len(lifted_buffers), + [b for _, b, _ in lifted_buffers], + ) + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index a46e0c9d01..9b1ace344e 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -59,8 +59,11 @@ aten.im2col, aten.index_add, aten.index_add_, - aten.index_copy, - aten.index_copy_, + # aten.index_copy / aten.index_copy_ are intentionally NOT decomposed — + # they are handled by two converters: a KV-cache fast path that emits + # IKVCacheUpdateLayer with aliased I/O when the input is a 4-D static + # cache and a fallback that emits the standard scatter sequence. + # See py/torch_tensorrt/dynamo/conversion/impl/index_copy.py. aten.index_fill, aten.index_fill_, aten.isneginf, diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index a930755a50..38b599feb8 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -183,9 +183,11 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: # type: igno return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm) -@register_torch_trt_decomposition( - torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS -) +# NOTE: slice_scatter is intentionally NOT registered as a Torch-TRT decomposition. +# It is handled by the slice_scatter converter, which can emit either +# IKVCacheUpdateLayer (with aliased I/O) for KV-cache update patterns, or fall +# back to a scatter sequence equivalent to this function for the general case. +# The converter calls slice_scatter_decomposition directly for the fallback path. def slice_scatter_decomposition( input_tensor: torch.Tensor, src_tensor: torch.Tensor, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 3c454933bb..e2cfb298cf 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -147,6 +147,7 @@ def __init__( requires_output_allocator: bool = False, requires_native_multidevice: bool = False, symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, + aliased_io: Optional[Dict[str, Tuple[str, str]]] = None, _debugger_config: Optional[DebuggerConfig] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs @@ -206,6 +207,19 @@ def __init__( self.output_names = ( output_binding_names if output_binding_names is not None else [] ) + # aliased_io is accepted for API compatibility with TorchTensorRTModule + # but the Python runtime does NOT honor aliasing — outputs are always + # allocated fresh and not bound to input device pointers. Use the C++ + # runtime (use_python_runtime=False) for true in-place aliasing. + self.aliased_io: Dict[str, Tuple[str, str]] = dict(aliased_io or {}) + if self.aliased_io: + logger.warning( + "Aliased I/O is set on the engine (%d outputs) but the Python " + "runtime does not implement aliasing. The engine will fail at " + "execution time with a memory-address mismatch. Use the C++ " + "runtime (use_python_runtime=False) for aliased I/O support.", + len(self.aliased_io), + ) self.initialized = False self.target_device_id = ( settings.device.gpu_id diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 833fdee639..bb5325f53a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -37,6 +37,7 @@ REQUIRES_OUTPUT_ALLOCATOR_IDX = -1 # Not implemented SERIALIZATION_LEN = -1 # Not implemented REQUIRES_NATIVE_MULTIDEVICE_IDX = -1 # Not implemented +ALIASED_IO_IDX = -1 # Not implemented if ENABLED_FEATURES.torch_tensorrt_runtime: ABI_TARGET_IDX = torch.ops.tensorrt.ABI_TARGET_IDX() # 0 @@ -57,7 +58,32 @@ REQUIRES_NATIVE_MULTIDEVICE_IDX = ( torch.ops.tensorrt.REQUIRES_NATIVE_MULTIDEVICE_IDX() ) # 11 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 12 + ALIASED_IO_IDX = torch.ops.tensorrt.ALIASED_IO_IDX() # 12 + SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 13 + + +def user_output_count( + output_binding_names: List[str], aliased_io: Dict[str, Tuple[str, str]] +) -> int: + """Derive the boundary between user-visible outputs and side-effect + aliased outputs. + + By convention the TRTInterpreter appends side-effect aliased outputs + (those added on behalf of layers like ``IKVCacheUpdateLayer`` that + require their output to be a network output) to the END of + ``output_binding_names``. So the user-output count is one past the + index of the last name that is *not* in ``aliased_io``. + + Returns ``len(output_binding_names)`` if no outputs are aliased OR + if every output is aliased (in which case we conservatively treat + them all as user-returned — a pure-side-effect engine is degenerate). + """ + if not aliased_io: + return len(output_binding_names) + for i in range(len(output_binding_names) - 1, -1, -1): + if output_binding_names[i] not in aliased_io: + return i + 1 + return len(output_binding_names) @for_all_methods(needs_torch_tensorrt_runtime) @@ -93,6 +119,7 @@ def __init__( requires_output_allocator: bool = False, requires_native_multidevice: bool = False, symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, + aliased_io: Optional[Dict[str, Tuple[str, str]]] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -153,6 +180,8 @@ def __init__( self.dynamically_allocate_resources = settings.dynamically_allocate_resources self.symbolic_shape_expressions = symbolic_shape_expressions self.requires_native_multidevice = requires_native_multidevice + # Map of output binding name -> (input binding name, kind_str) + self.aliased_io: Dict[str, Tuple[str, str]] = dict(aliased_io or {}) if ( serialized_engine @@ -228,6 +257,9 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = str( int(self.requires_native_multidevice) ) + engine_info[ALIASED_IO_IDX] = TorchTensorRTModule._pack_aliased_io( + self.aliased_io + ) # rank/world_size are runtime facts; queried from ProcessGroup at execution time return engine_info @@ -424,9 +456,24 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: list(input_tensors), self.engine ) + # The interpreter may have appended extra "side-effect" outputs to + # satisfy the network-output requirement of aliased layers (e.g. + # IKVCacheUpdateLayer). Truncate to the user-facing count; mutation + # effects are visible on the corresponding input tensors. The + # boundary is derived from output_binding_names + aliased_io — no + # extra state needed. + if self.aliased_io: + n = user_output_count(self.output_binding_names, self.aliased_io) + if n < len(outputs): + outputs = outputs[:n] + if len(outputs) == 1: return outputs[0] + if len(outputs) == 0: + # Pure side-effect engine (model has no return value). + return () + return tuple(outputs) def enable_profiling( @@ -488,3 +535,21 @@ def _pack_binding_names(binding_names: List[str]) -> str: delim = torch.ops.tensorrt.SERIALIZED_ENGINE_BINDING_DELIM()[0] packed_bindings: str = delim.join(binding_names) return packed_bindings + + @staticmethod + def _pack_aliased_io(aliased_io: Dict[str, Tuple[str, str]]) -> str: + """Encode the aliased_io map into the wire format consumed by the + C++ runtime. One record per (output_binding -> input_binding, kind) + entry. Records are joined by ``SERIALIZED_ENGINE_BINDING_DELIM`` + ('%') — same convention as ``_pack_binding_names``. Within a + record the field separator is '@' (TRT binding names are + alphanumeric + underscore so '@' cannot collide).""" + if not aliased_io: + return "" + delim = torch.ops.tensorrt.SERIALIZED_ENGINE_BINDING_DELIM()[0] + records = [ + f"{out_name}@{in_name}@{kind_str}" + for out_name, (in_name, kind_str) in aliased_io.items() + ] + packed: str = delim.join(records) + return packed diff --git a/tests/py/dynamo/conversion/test_slice_scatter_aten.py b/tests/py/dynamo/conversion/test_slice_scatter_aten.py new file mode 100644 index 0000000000..b27a756f79 --- /dev/null +++ b/tests/py/dynamo/conversion/test_slice_scatter_aten.py @@ -0,0 +1,97 @@ +# type: ignore +"""Converter tests for the fallback (non-KV) path of aten.slice_scatter.default. + +The converter has two paths: + +1. **KV-cache fast path** — emits ``IKVCacheUpdateLayer`` with aliased I/O. + Aliasing requires the C++ runtime, so these cases are tested end-to-end + in ``tests/py/dynamo/runtime/test_aliased_io.py`` (the Python-runtime + converter harness can't bind aliased addresses). + +2. **Scatter fallback** — equivalent to the historical Torch-TensorRT + decomposition (``arange + scatter``). Used for any shape that doesn't + meet KVCacheUpdate's invariants. + +This file covers the fallback path. To force the fallback regardless of +shape we add a small no-op (``+ 0``) to the cache so it isn't a direct +network input — the converter's "input is a placeholder" check fails and +falls through to scatter. +""" +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class _SliceScatterNotInputModule(torch.nn.Module): + """Helper: forces the fallback path by making `cache` not a direct + network input (the converter's KV fast path requires placeholder input). + """ + + def __init__(self, dim, start, end, step=1): + super().__init__() + self.dim = dim + self.start = start + self.end = end + self.step = step + + def forward(self, cache_in, update): + # `cache_in + 0` produces a non-placeholder ITensor, forcing the + # converter to take the scatter fallback rather than KVCacheUpdate. + cache = cache_in + 0 + return torch.ops.aten.slice_scatter.default( + cache, update, self.dim, self.start, self.end, self.step + ) + + +class TestSliceScatterFallback(DispatchTestCase): + @parameterized.expand( + [ + # (name, cache_shape, update_shape, dim, start, end) + # 3-D + ("rank3_dim1", (4, 8, 16), (4, 2, 16), 1, 3, 5), + # 4-D writing on dim != 2 (not eligible for KVCacheUpdate) + ("rank4_dim1", (2, 8, 4, 16), (2, 2, 4, 16), 1, 2, 4), + ("rank4_dim3", (2, 4, 16, 8), (2, 4, 16, 2), 3, 1, 3), + # 2-D + ("rank2_dim0", (8, 16), (3, 16), 0, 2, 5), + # 5-D + ("rank5_dim2", (2, 3, 8, 4, 16), (2, 3, 2, 4, 16), 2, 1, 3), + # 4-D dim=2 — the eligible shape, but forced via non-placeholder + # input. Tests that the fallback handles the same shape correctly. + ("rank4_dim2_forced", (2, 4, 16, 8), (2, 4, 1, 8), 2, 3, 4), + ] + ) + def test_fallback(self, _, cache_shape, update_shape, dim, start, end): + module = _SliceScatterNotInputModule(dim, start, end) + cache = torch.randn(cache_shape) + update = torch.randn(update_shape) + self.run_test(module, [cache, update]) + + def test_fallback_step_two(self): + module = _SliceScatterNotInputModule(2, 0, 16, step=2) + cache = torch.randn(2, 4, 16, 8) + update = torch.randn(2, 4, 8, 8) + self.run_test(module, [cache, update]) + + def test_full_overwrite_is_identity(self): + """When start=0, end=dim_size, step=1, the converter short-circuits + and returns ``src`` directly. Wrap the returned tensor in a small op + so it isn't simultaneously a network input and a network output — + which TRT rejects (handled by ``repair_input_as_output`` in + production but bypassed in this lower-level harness).""" + + class M(torch.nn.Module): + def forward(self, cache_in, update): + cache = cache_in + 0 # force non-placeholder + out = torch.ops.aten.slice_scatter.default(cache, update, 2, 0, 16) + return out + 0 # avoid placeholder-as-output + + cache = torch.randn(2, 4, 16, 8) + update = torch.randn(2, 4, 16, 8) + self.run_test(M(), [cache, update]) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/lowering/test_buffer_lifting.py b/tests/py/dynamo/lowering/test_buffer_lifting.py new file mode 100644 index 0000000000..2d8a00114f --- /dev/null +++ b/tests/py/dynamo/lowering/test_buffer_lifting.py @@ -0,0 +1,248 @@ +# type: ignore +"""Unit tests for ``lift_mutated_buffers`` and ``inline_lifted_buffers_into_gm``. + +``lift_mutated_buffers`` is a pre-compile rewrite that detects mutated +buffers (the trailing ``aten.copy_(get_attr, _)`` pattern that +``ExportedProgram.module()`` generates for each ``BUFFER_MUTATION``) and +lifts each one from a ``get_attr`` to a ``placeholder``. The rebuilt +GraphModule's ``forward`` signature reflects the new placeholder set — +which requires resetting the graph's ``_codegen`` from the +``_PyTreeCodeGen`` baked in by ``ep.module()`` to the default ``CodeGen``. + +These tests verify: + +* Buffers ARE lifted when mutated. +* Buffers are NOT lifted when only read. +* The rebuilt GraphModule's ``forward`` accepts the new placeholders. +* The rebuilt GraphModule produces the same outputs as the original + pre-lift gm when both are given the same inputs (buffers + user inputs). +* The original buffer tensors are returned alongside the placeholder + names for downstream wiring. +* ``inline_lifted_buffers_into_gm`` rewrites the lifted-buffer + placeholders into ``get_attr`` reads and registers the buffers as + module state. The result is a plain ``fx.GraphModule`` that + serializes via ``torch_tensorrt.save`` without an external wrapper. +""" +import inspect + +import torch +from torch.export import export +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo.lowering._buffer_lifting import ( + inline_lifted_buffers_into_gm, + lift_mutated_buffers, +) + + +def _ep_module_decomposed(model, args): + """Run the prefix of the compile pipeline up through ``ep.module()``.""" + ep = export(model, tuple(args)) + ep = ep.run_decompositions({}) + return ep.module() + + +class TestLiftMutatedBuffers(TestCase): + def test_no_mutation_no_lift(self): + """A module that reads buffers but doesn't mutate them returns + ``(gm, [])`` — no rewrite happens.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("table", torch.arange(16, dtype=torch.float32)) + + def forward(self, x): + return x + self.table.sum() + + gm = _ep_module_decomposed(M(), (torch.zeros(4),)) + new_gm, lifted = lift_mutated_buffers(gm) + self.assertEqual(lifted, []) + # The same gm is returned when nothing is lifted. + self.assertIs(new_gm, gm) + + def test_single_buffer_lifted(self): + """A buffer that's mutated should be lifted to a placeholder, the + trailing copy_ removed, and the rebuilt forward should accept it + as an argument.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(2, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() + + gm = _ep_module_decomposed(M(), (torch.ones(2, 4, 1, 8),)) + new_gm, lifted = lift_mutated_buffers(gm) + + # Exactly one buffer was lifted. + self.assertEqual(len(lifted), 1) + ph_name, buf_name, tensor = lifted[0] + self.assertEqual(buf_name, "cache") + self.assertEqual(tuple(tensor.shape), (2, 4, 16, 8)) + self.assertEqual(ph_name, "buf_cache") + + # The rebuilt forward should now accept (x, buf_cache). + sig = inspect.signature(new_gm.forward) + param_names = list(sig.parameters.keys()) + self.assertEqual(param_names, ["x", "buf_cache"]) + + # No get_attr nodes for `cache` remain in the graph. + for node in new_gm.graph.nodes: + if node.op == "get_attr": + self.assertNotEqual(node.target, "cache") + # No trailing aten.copy_ to the (now removed) cache get_attr. + for node in new_gm.graph.nodes: + self.assertNotEqual(node.target, torch.ops.aten.copy_.default) + + def test_paired_buffers_lifted(self): + """Two mutated buffers are both lifted; placeholders appear in a + stable order so callers can match them positionally.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache_k", torch.zeros(2, 4, 16, 8)) + self.register_buffer("cache_v", torch.zeros(2, 4, 16, 8)) + + def forward(self, x_k, x_v): + self.cache_k[:, :, 3:4, :] = x_k + self.cache_v[:, :, 3:4, :] = x_v + return self.cache_k.sum() + self.cache_v.sum() + + gm = _ep_module_decomposed( + M(), (torch.ones(2, 4, 1, 8), torch.ones(2, 4, 1, 8)) + ) + new_gm, lifted = lift_mutated_buffers(gm) + self.assertEqual(len(lifted), 2) + buf_names = {b for _, b, _ in lifted} + self.assertEqual(buf_names, {"cache_k", "cache_v"}) + + # forward signature should have all 4 params (2 user + 2 buffer). + sig = inspect.signature(new_gm.forward) + self.assertEqual(len(sig.parameters), 4) + + def test_rebuilt_forward_matches_original(self): + """The rebuilt GraphModule, when given (user_args..., buffers...), + should produce the same outputs as the original ep.module() when + given the same user_args (with buffers used from internal state).""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(2, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() * 2.0 + + x = torch.randn(2, 4, 1, 8) + gm_original = _ep_module_decomposed(M(), (x.clone(),)) + original_out = gm_original(x.clone()) + # ep.module() returns its outputs as a tuple. Take the first element + # to compare against the rebuilt gm (whose default CodeGen returns + # tuples too, but possibly with a different surrounding shape). + if isinstance(original_out, tuple): + original_out = original_out[0] + + # Re-create gm for the lift (in-place mutation of the first gm's + # graph would change its forward behavior). + gm_for_lift = _ep_module_decomposed(M(), (x.clone(),)) + new_gm, lifted = lift_mutated_buffers(gm_for_lift) + _, _, buf_tensor = lifted[0] + # Call rebuilt gm with the original buffer state. + new_out = new_gm(x.clone(), buf_tensor.clone()) + if isinstance(new_out, tuple): + new_out = new_out[0] + self.assertTrue(torch.allclose(new_out, original_out)) + + +class TestInlineLiftedBuffers(TestCase): + """``inline_lifted_buffers_into_gm`` should register each lifted + buffer as module state on the gm and rewrite the corresponding + placeholder node into a ``get_attr`` read. After inlining, the gm's + forward should accept only the user inputs.""" + + def _build_simple_gm(self): + """Construct an fx GraphModule with two placeholders (x, buf) and + a body that sums them, matching what ``lift_mutated_buffers`` + would produce.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + buf = graph.placeholder("buf_cache") + out = graph.call_function(torch.add, args=(x, buf)) + graph.output(out) + gm = torch.fx.GraphModule({}, graph) + gm.recompile() + return gm + + def test_inline_registers_buffer_and_rewrites_placeholder(self): + gm = self._build_simple_gm() + buf_tensor = torch.tensor([1.0, 2.0, 3.0]) + + new_gm = inline_lifted_buffers_into_gm( + gm, lifted_buffers=[("buf_cache", "cache", buf_tensor)] + ) + + # Buffer registered as module state. + self.assertTrue(hasattr(new_gm, "cache")) + self.assertTrue(torch.allclose(new_gm.cache, buf_tensor)) + + # Placeholder count is now 1 (only `x`); buffer is a get_attr. + placeholders = [n for n in new_gm.graph.nodes if n.op == "placeholder"] + self.assertEqual(len(placeholders), 1) + self.assertEqual(placeholders[0].name, "x") + get_attrs = [n for n in new_gm.graph.nodes if n.op == "get_attr"] + self.assertEqual(len(get_attrs), 1) + self.assertEqual(get_attrs[0].target, "cache") + + # forward(x) computes x + cache via the inlined get_attr. + x = torch.tensor([10.0, 20.0, 30.0]) + out = new_gm(x) + if isinstance(out, tuple): + out = out[0] + self.assertTrue(torch.allclose(out, x + buf_tensor)) + + def test_inline_is_noop_for_empty_lifted(self): + gm = self._build_simple_gm() + ph_before = [n.name for n in gm.graph.nodes if n.op == "placeholder"] + result = inline_lifted_buffers_into_gm(gm, lifted_buffers=[]) + self.assertIs(result, gm) + ph_after = [n.name for n in result.graph.nodes if n.op == "placeholder"] + self.assertEqual(ph_before, ph_after) + + def test_inline_preserves_user_input_order(self): + """When multiple buffers are inlined, the user inputs come first + and are unchanged; the buffers become get_attr reads.""" + graph = torch.fx.Graph() + u1 = graph.placeholder("u1") + u2 = graph.placeholder("u2") + b1 = graph.placeholder("buf_a") + b2 = graph.placeholder("buf_b") + s1 = graph.call_function(torch.add, args=(u1, b1)) + s2 = graph.call_function(torch.add, args=(u2, b2)) + out = graph.call_function(torch.add, args=(s1, s2)) + graph.output(out) + gm = torch.fx.GraphModule({}, graph) + gm.recompile() + + new_gm = inline_lifted_buffers_into_gm( + gm, + lifted_buffers=[ + ("buf_a", "a", torch.tensor(1.0)), + ("buf_b", "b", torch.tensor(2.0)), + ], + ) + placeholders = [n.name for n in new_gm.graph.nodes if n.op == "placeholder"] + self.assertEqual(placeholders, ["u1", "u2"]) + # Numerical: (10 + 1) + (20 + 2) = 33 + out = new_gm(torch.tensor(10.0), torch.tensor(20.0)) + if isinstance(out, tuple): + out = out[0] + self.assertEqual(out.item(), 33.0) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 3e488fe678..00efbdb8dd 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -626,10 +626,10 @@ def forward(self, x, src, dim, start=None, end=None, step=1): y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step) return y - # Operations expected to be removed in the traced graph after decompositions - expected_ops = { - torch.ops.aten.scatter.src, - } + # slice_scatter is no longer decomposed — the converter handles it + # directly (emits IKVCacheUpdateLayer when KV-eligible, scatter + # otherwise). select_scatter is still decomposed via slice_scatter. + expected_ops = {torch.ops.aten.slice_scatter.default} unexpected_ops = {torch.ops.aten.select_scatter} inputs = [torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda(), 1, 6, None, 1] @@ -688,11 +688,9 @@ def forward(self, x, src, dim, start, end, step): y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step) return y - # Operations expected to be removed in the traced graph after decompositions - expected_ops = { - torch.ops.aten.scatter.src, - } - unexpected_ops = {torch.ops.aten.slice_scatter} + # slice_scatter is no longer decomposed — survives to the converter. + expected_ops = {torch.ops.aten.slice_scatter.default} + unexpected_ops: set = set() inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 2, 6, 2] @@ -751,11 +749,9 @@ def forward(self, x, src, dim, start, end, step): y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step) return y - # Operations expected to be removed in the traced graph after decompositions - expected_ops = { - torch.ops.aten.scatter.src, - } - unexpected_ops = {torch.ops.aten.slice_scatter} + # slice_scatter is no longer decomposed — survives to the converter. + expected_ops = {torch.ops.aten.slice_scatter.default} + unexpected_ops: set = set() inputs = [ torch.zeros(8, 8, 8).cuda(), @@ -853,12 +849,14 @@ def forward(self, x, src, dim, index): y = torch.ops.aten.select_scatter.default(x, src, dim, index) return y - # Operations expected to be removed in the traced graph after decompositions - expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default} - unexpected_ops = { - torch.ops.aten.select_scatter.default, + # select_scatter is still decomposed (to slice_scatter via the + # Torch-TRT decomposition); slice_scatter is no longer decomposed + # further and survives to the converter. + expected_ops = { torch.ops.aten.slice_scatter.default, + torch.ops.aten.unsqueeze.default, } + unexpected_ops = {torch.ops.aten.select_scatter.default} inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 0, 0] @@ -916,12 +914,14 @@ def forward(self, x, src, dim, index): y = torch.ops.aten.select_scatter.default(x, src, dim, index) return y - # Operations expected to be removed in the traced graph after decompositions - expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default} - unexpected_ops = { - torch.ops.aten.select_scatter.default, + # select_scatter is still decomposed (to slice_scatter via the + # Torch-TRT decomposition); slice_scatter is no longer decomposed + # further and survives to the converter. + expected_ops = { torch.ops.aten.slice_scatter.default, + torch.ops.aten.unsqueeze.default, } + unexpected_ops = {torch.ops.aten.select_scatter.default} inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0] @@ -979,12 +979,14 @@ def forward(self, x, src, dim, index): y = torch.ops.aten.select_scatter.default(x, src, dim, index) return y - # Operations expected to be removed in the traced graph after decompositions - expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default} - unexpected_ops = { - torch.ops.aten.select_scatter.default, + # select_scatter is still decomposed (to slice_scatter via the + # Torch-TRT decomposition); slice_scatter is no longer decomposed + # further and survives to the converter. + expected_ops = { torch.ops.aten.slice_scatter.default, + torch.ops.aten.unsqueeze.default, } + unexpected_ops = {torch.ops.aten.select_scatter.default} inputs = [torch.zeros(2, 3, 4).cuda(), torch.ones(2, 4).cuda(), 1, 0] diff --git a/tests/py/dynamo/runtime/test_aliased_io.py b/tests/py/dynamo/runtime/test_aliased_io.py new file mode 100644 index 0000000000..20cff216dd --- /dev/null +++ b/tests/py/dynamo/runtime/test_aliased_io.py @@ -0,0 +1,272 @@ +# type: ignore +"""End-to-end tests for aliased I/O (KV-cache fast path + buffer lifting). + +The KV-cache fast path emits ``IKVCacheUpdateLayer`` and binds the layer's +output to the cache input via aliased I/O. The C++ runtime honors the +aliasing by binding both bindings to the same device pointer; the engine +writes through the binding into the user's input storage. The Python +runtime does NOT support aliasing, so all tests here force the C++ runtime +(``use_python_runtime=False``). + +These tests exercise the full pipeline: + +* Converter emits the layer with aliased output. +* ``TRTInterpreter`` carries the aliased_io map (with kind tag) through. +* ``SerializedInterpreterResult`` plumbs it to ``TorchTensorRTModule``. +* C++ ``TRTEngine`` reconciles against ``getAliasedInputTensor`` at load. +* ``execute_engine`` binds the aliased output to the input ``data_ptr``, + skipping fresh allocation. +* ``TorchTensorRTModule.forward`` filters aliased outputs from the user + return tuple. +* For buffer-style models, ``lift_mutated_buffers`` rewrites the EP and + ``BufferThreadingModule`` threads buffers through each call. +""" +import torch +import torch_tensorrt +from torch.export import export +from torch.testing._internal.common_utils import TestCase, run_tests + + +def _compile_cpp(model, args): + """Convenience: torch.export + torch_tensorrt.compile with C++ runtime.""" + ep = export(model, tuple(args)) + return torch_tensorrt.compile( + ep, + ir="dynamo", + inputs=list(args), + enabled_precisions={torch.float32}, + min_block_size=1, + use_python_runtime=False, + ) + + +def _find_aliased_io(compiled): + """Return the aliased_io map from the first inner TRT module, or {}.""" + for _name, mod in compiled.named_modules(): + if hasattr(mod, "aliased_io") and mod.aliased_io: + return dict(mod.aliased_io) + return {} + + +class TestUserInputKVCache(TestCase): + """User passes the cache tensor on every call; engine mutates in place.""" + + def test_single_slice_write_in_place(self): + class M(torch.nn.Module): + def forward(self, cache, update): + cache[:, :, 3:4, :] = update + return cache.sum() + + cache = torch.zeros(2, 4, 16, 8, device="cuda") + update = torch.ones(2, 4, 1, 8, device="cuda") * 7.0 + compiled = _compile_cpp(M().cuda(), (cache.clone(), update.clone())) + + # Aliasing is recorded with kv_cache_update kind. + aliased = _find_aliased_io(compiled) + self.assertEqual(len(aliased), 1) + _, kind = next(iter(aliased.values())) + self.assertEqual(kind, "kv_cache_update") + + cache_run = cache.clone() + cache_id, cache_ptr = id(cache_run), cache_run.data_ptr() + ret = compiled(cache_run, update) + ret_val = ret[0] if isinstance(ret, tuple) else ret + + # Numerical match against eager. + eager = cache.clone() + eager[:, :, 3:4, :] = update + self.assertTrue(torch.allclose(cache_run, eager)) + self.assertTrue(torch.allclose(ret_val, eager.sum())) + + # Identity preserved. + self.assertEqual(id(cache_run), cache_id) + self.assertEqual(cache_run.data_ptr(), cache_ptr) + + def test_paired_kv_caches(self): + class M(torch.nn.Module): + def forward(self, ck, cv, k, v): + ck[:, :, 3:4, :] = k + cv[:, :, 5:6, :] = v + return ck.sum() + cv.sum() + + ck = torch.zeros(2, 4, 16, 8, device="cuda") + cv = torch.zeros(2, 4, 16, 8, device="cuda") + k = torch.ones(2, 4, 1, 8, device="cuda") * 3.0 + v = torch.ones(2, 4, 1, 8, device="cuda") * 5.0 + compiled = _compile_cpp( + M().cuda(), (ck.clone(), cv.clone(), k.clone(), v.clone()) + ) + + # Both K and V should be aliased. + aliased = _find_aliased_io(compiled) + self.assertEqual(len(aliased), 2) + + ck_run, cv_run = ck.clone(), cv.clone() + ret = compiled(ck_run, cv_run, k, v) + ret_val = ret[0] if isinstance(ret, tuple) else ret + + ck_eager, cv_eager = ck.clone(), cv.clone() + ck_eager[:, :, 3:4, :] = k + cv_eager[:, :, 5:6, :] = v + self.assertTrue(torch.allclose(ck_run, ck_eager)) + self.assertTrue(torch.allclose(cv_run, cv_eager)) + self.assertTrue(torch.allclose(ret_val, ck_eager.sum() + cv_eager.sum())) + + def test_streaming_state_accumulates(self): + """Repeated calls on the same cache tensor should observe the + previous call's mutation.""" + + class M(torch.nn.Module): + def forward(self, cache, update): + cache[:, :, 3:4, :] = update + return cache.sum() + + proto = torch.zeros(2, 4, 16, 8, device="cuda") + upd_proto = torch.ones(2, 4, 1, 8, device="cuda") + compiled = _compile_cpp(M().cuda(), (proto.clone(), upd_proto.clone())) + + cache = torch.zeros(2, 4, 16, 8, device="cuda") + # Step 1: write 1s -> 64 elements of 1 at position 3 + compiled(cache, torch.ones(2, 4, 1, 8, device="cuda") * 1.0) + self.assertAlmostEqual(cache.sum().item(), 64.0, places=3) + # Step 2: overwrite with 5s -> 64 * 5 = 320 + compiled(cache, torch.ones(2, 4, 1, 8, device="cuda") * 5.0) + self.assertAlmostEqual(cache.sum().item(), 320.0, places=3) + # Step 3: write 0s + compiled(cache, torch.zeros(2, 4, 1, 8, device="cuda")) + self.assertAlmostEqual(cache.sum().item(), 0.0, places=3) + + +class TestBufferBackedKVCache(TestCase): + """Buffers held by the module via ``register_buffer``. The compile flow + lifts the buffers to engine inputs and wraps the compiled module to + thread them in automatically.""" + + def test_buffer_mutation_in_place(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(2, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() + + m = M().cuda() + x = torch.ones(2, 4, 1, 8, device="cuda") * 7.0 + compiled = _compile_cpp(m, (x.clone(),)) + + # The lifted buffer should be aliased. + aliased = _find_aliased_io(compiled) + self.assertEqual(len(aliased), 1) + _, kind = next(iter(aliased.values())) + self.assertEqual(kind, "kv_cache_update") + + # The compiled module owns the buffer (BufferThreadingModule). + self.assertTrue(hasattr(compiled, "cache")) + self.assertAlmostEqual(compiled.cache.sum().item(), 0.0) + + # Call the compiled module the same way the user wrote it: model(x). + ret = compiled(x) + ret_val = ret[0] if isinstance(ret, tuple) else ret + + # Buffer should be mutated; sum matches eager. + eager_m = M().cuda() + eager_ret = eager_m(x) + self.assertTrue(torch.allclose(compiled.cache, eager_m.cache)) + self.assertTrue(torch.allclose(ret_val, eager_ret)) + + def test_paired_buffer_caches(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache_k", torch.zeros(2, 4, 16, 8)) + self.register_buffer("cache_v", torch.zeros(2, 4, 16, 8)) + + def forward(self, x_k, x_v): + self.cache_k[:, :, 3:4, :] = x_k + self.cache_v[:, :, 3:4, :] = x_v + return self.cache_k.sum() + self.cache_v.sum() + + m = M().cuda() + x_k = torch.ones(2, 4, 1, 8, device="cuda") * 3.0 + x_v = torch.ones(2, 4, 1, 8, device="cuda") * 5.0 + compiled = _compile_cpp(m, (x_k.clone(), x_v.clone())) + + aliased = _find_aliased_io(compiled) + self.assertEqual(len(aliased), 2) + + ret = compiled(x_k, x_v) + ret_val = ret[0] if isinstance(ret, tuple) else ret + + eager_m = M().cuda() + eager_ret = eager_m(x_k, x_v) + self.assertTrue(torch.allclose(compiled.cache_k, eager_m.cache_k)) + self.assertTrue(torch.allclose(compiled.cache_v, eager_m.cache_v)) + self.assertTrue(torch.allclose(ret_val, eager_ret)) + + def test_buffer_streaming_persists_across_calls(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(2, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() + + m = M().cuda() + x_proto = torch.ones(2, 4, 1, 8, device="cuda") + compiled = _compile_cpp(m, (x_proto.clone(),)) + + compiled(torch.ones(2, 4, 1, 8, device="cuda") * 1.0) + self.assertAlmostEqual(compiled.cache.sum().item(), 64.0, places=3) + compiled(torch.ones(2, 4, 1, 8, device="cuda") * 5.0) + self.assertAlmostEqual(compiled.cache.sum().item(), 320.0, places=3) + compiled(torch.zeros(2, 4, 1, 8, device="cuda")) + self.assertAlmostEqual(compiled.cache.sum().item(), 0.0, places=3) + + +class TestAliasedIORegressions(TestCase): + """Models without aliased I/O should be unaffected by these changes.""" + + def test_no_aliasing_path_untouched(self): + class M(torch.nn.Module): + def forward(self, x, y): + return (x + y) * 2.0 + + x = torch.randn(4, 8, device="cuda") + y = torch.randn(4, 8, device="cuda") + compiled = _compile_cpp(M().cuda(), (x, y)) + + aliased = _find_aliased_io(compiled) + self.assertEqual(aliased, {}) + + got = compiled(x, y) + expected = M().cuda()(x, y) + self.assertTrue(torch.allclose(got, expected, atol=1e-4)) + + def test_slice_scatter_fallback_path(self): + """A slice_scatter that doesn't qualify for KVCacheUpdate should + still produce correct results via the scatter fallback.""" + + class M(torch.nn.Module): + def forward(self, x, y): + z = x.clone() + z[:, 2:4, :] = y # 3-D — wrong rank + return z.sum() + + x = torch.ones(2, 8, 4, device="cuda") + y = torch.zeros(2, 2, 4, device="cuda") + compiled = _compile_cpp(M().cuda(), (x, y)) + + # No aliasing since the cache isn't 4-D. + self.assertEqual(_find_aliased_io(compiled), {}) + + got = compiled(x, y).item() + expected = M().cuda()(x, y).item() + self.assertAlmostEqual(got, expected, places=3) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py b/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py new file mode 100644 index 0000000000..36f4145e9b --- /dev/null +++ b/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py @@ -0,0 +1,145 @@ +# type: ignore +"""CUDA Graphs + aliased I/O. + +CUDA Graphs normally clones each input into a persistent buffer so binding +addresses stay stable across replays. That mechanism is incompatible with +aliased I/O — the engine would write to the persistent clone and the +user's input tensor wouldn't observe the mutation. + +The runtime handles this by: + +* For each input binding that appears as the target of an aliased output, + *bypass* the persistent-buffer copy and bind directly to the user's + tensor. The caller is already required to pass stable input addresses + under cudagraphs; aliased I/O just makes that contract observable. +* Skip aliased outputs in the post-execution copy-back loop (their + ``output_buffers`` slot is intentionally never populated; the mutation + is already visible on the user's input). + +These tests cover capture + replay correctness for both KV-cache patterns +(user-input and buffer-style). +""" +import unittest + +import torch +import torch_tensorrt +from torch.export import export +from torch.testing._internal.common_utils import TestCase, run_tests + + +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", +) +class TestCudagraphsAliasedIO(TestCase): + def setUp(self): + # Ensure clean state regardless of prior test ordering. + torch_tensorrt.runtime.set_cudagraphs_mode(False) + + def tearDown(self): + torch_tensorrt.runtime.set_cudagraphs_mode(False) + + def _compile(self, model, args): + ep = export(model, tuple(args)) + return torch_tensorrt.compile( + ep, + ir="dynamo", + inputs=list(args), + enabled_precisions={torch.float32}, + min_block_size=1, + use_python_runtime=False, + ) + + def test_user_input_kv_capture_and_replay(self): + """User passes the same cache tensor across multiple cudagraph + replays; mutation should land on that tensor each time.""" + + class M(torch.nn.Module): + def forward(self, cache, update): + cache[:, :, 3:4, :] = update + return cache.sum() + + cache_sample = torch.zeros(1, 4, 16, 8, device="cuda") + update_sample = torch.ones(1, 4, 1, 8, device="cuda") + compiled = self._compile( + M().cuda(), (cache_sample.clone(), update_sample.clone()) + ) + + with torch_tensorrt.runtime.enable_cudagraphs(compiled) as cg: + # Use the SAME cache tensor across calls (the cudagraphs + # contract). Each call overwrites position 3 with the new + # update value. + cache = torch.zeros(1, 4, 16, 8, device="cuda") + cache_id, cache_ptr = id(cache), cache.data_ptr() + + # Step 1: capture. cache[3, :] becomes all 1s — sum = 32. + cg(cache, torch.ones(1, 4, 1, 8, device="cuda") * 1.0) + self.assertAlmostEqual(cache.sum().item(), 32.0, places=3) + self.assertEqual(id(cache), cache_id) + self.assertEqual(cache.data_ptr(), cache_ptr) + + # Step 2: replay. cache[3, :] becomes all 5s — sum = 160. + cg(cache, torch.ones(1, 4, 1, 8, device="cuda") * 5.0) + self.assertAlmostEqual(cache.sum().item(), 160.0, places=3) + + # Step 3: replay again. cache[3, :] becomes all 0s — sum = 0. + cg(cache, torch.zeros(1, 4, 1, 8, device="cuda")) + self.assertAlmostEqual(cache.sum().item(), 0.0, places=3) + + def test_buffer_kv_capture_and_replay(self): + """Buffer-backed KV cache: the buffer lives on the compiled + module. Cudagraphs should still capture+replay and mutate the + buffer in place.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(1, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() + + compiled = self._compile(M().cuda(), (torch.ones(1, 4, 1, 8, device="cuda"),)) + + with torch_tensorrt.runtime.enable_cudagraphs(compiled) as cg: + cg(torch.ones(1, 4, 1, 8, device="cuda") * 1.0) + self.assertAlmostEqual(cg.cache.sum().item(), 32.0, places=3) + cg(torch.ones(1, 4, 1, 8, device="cuda") * 5.0) + self.assertAlmostEqual(cg.cache.sum().item(), 160.0, places=3) + cg(torch.zeros(1, 4, 1, 8, device="cuda")) + self.assertAlmostEqual(cg.cache.sum().item(), 0.0, places=3) + + def test_matches_non_cudagraphs(self): + """Same inputs, same model — cudagraphs vs no cudagraphs should + produce identical cache state and return values.""" + + class M(torch.nn.Module): + def forward(self, cache, update): + cache[:, :, 3:4, :] = update + return cache.sum() + cache.mean() + + cache_sample = torch.zeros(1, 4, 16, 8, device="cuda") + update_sample = torch.ones(1, 4, 1, 8, device="cuda") * 3.0 + compiled = self._compile( + M().cuda(), (cache_sample.clone(), update_sample.clone()) + ) + + # No cudagraphs. + cache_plain = torch.zeros(1, 4, 16, 8, device="cuda") + update = torch.ones(1, 4, 1, 8, device="cuda") * 7.0 + ret_plain = compiled(cache_plain, update) + ret_plain_val = ret_plain[0] if isinstance(ret_plain, tuple) else ret_plain + + # With cudagraphs. + with torch_tensorrt.runtime.enable_cudagraphs(compiled) as cg: + cache_cg = torch.zeros(1, 4, 16, 8, device="cuda") + ret_cg = cg(cache_cg, update) + ret_cg_val = ret_cg[0] if isinstance(ret_cg, tuple) else ret_cg + + self.assertTrue(torch.allclose(cache_plain, cache_cg)) + self.assertTrue(torch.allclose(ret_plain_val, ret_cg_val)) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_aliased_io_serialization.py b/tests/py/dynamo/runtime/test_aliased_io_serialization.py new file mode 100644 index 0000000000..d13dd637be --- /dev/null +++ b/tests/py/dynamo/runtime/test_aliased_io_serialization.py @@ -0,0 +1,146 @@ +# type: ignore +"""Serialization tests for aliased I/O. + +Verifies that compiled modules with aliased I/O survive a round-trip +through ``torch_tensorrt.save`` / ``torch_tensorrt.load``: + +* **User-input style** — the engine's ``aliased_io`` map is part of the + C++ engine's serialized form (``ALIASED_IO_IDX`` in the wire format). + After load, ``execute_engine`` reconstructs aliasing from those bytes + and the runtime binds outputs to input ``data_ptr`` as before. + +* **Buffer-backed style** — additionally requires that the lifted + buffers (registered as ``nn.Module`` state on the compiled GraphModule + and read via ``get_attr`` nodes in the fx graph) survive + ``torch.export``. The ``inline_lifted_buffers_into_gm`` post-compile + transform replaces what used to be an external ``BufferThreadingModule`` + wrapper — making the result a plain ``fx.GraphModule`` that exports + naturally without a custom wrapper class. +""" +import tempfile + +import torch +import torch_tensorrt +from torch.export import export +from torch.testing._internal.common_utils import TestCase, run_tests + + +def _compile_and_roundtrip(model, args): + """Compile, save, load, return (compiled, loaded_gm).""" + ep = export(model, tuple(args)) + compiled = torch_tensorrt.compile( + ep, + ir="dynamo", + inputs=list(args), + enabled_precisions={torch.float32}, + min_block_size=1, + use_python_runtime=False, + ) + with tempfile.NamedTemporaryFile(suffix=".ep", delete=False) as f: + path = f.name + torch_tensorrt.save(compiled, path, arg_inputs=list(args)) + loaded_ep = torch_tensorrt.load(path) + loaded = loaded_ep.module() if hasattr(loaded_ep, "module") else loaded_ep + return compiled, loaded + + +class TestUserInputAliasingSurvivesSaveLoad(TestCase): + """User passes the cache each call. The engine's aliased_io map is + serialized in the engine bytes; after load, runtime aliasing still + works.""" + + def test_kv_cache_user_input_save_load(self): + class M(torch.nn.Module): + def forward(self, cache, update): + cache[:, :, 3:4, :] = update + return cache.sum() + + cache_sample = torch.zeros(1, 4, 16, 8, device="cuda") + update_sample = torch.ones(1, 4, 1, 8, device="cuda") * 7.0 + compiled, loaded = _compile_and_roundtrip( + M().cuda(), (cache_sample.clone(), update_sample.clone()) + ) + + # Run loaded module; cache should be mutated in place via aliasing. + cache_run = torch.zeros(1, 4, 16, 8, device="cuda") + cache_id, cache_ptr = id(cache_run), cache_run.data_ptr() + ret = loaded(cache_run, update_sample) + ret_val = ret[0] if isinstance(ret, tuple) else ret + + eager = torch.zeros(1, 4, 16, 8, device="cuda") + eager[:, :, 3:4, :] = update_sample + self.assertTrue(torch.allclose(cache_run, eager)) + self.assertTrue(torch.allclose(ret_val, eager.sum())) + # Aliased pointer identity is preserved post-load too. + self.assertEqual(id(cache_run), cache_id) + self.assertEqual(cache_run.data_ptr(), cache_ptr) + + +class TestBufferAliasingSurvivesSaveLoad(TestCase): + """Module-held buffer (BUFFER + BUFFER_MUTATION). The post-compile + transform registers the buffer on the compiled GraphModule and + rewrites the lifted-buffer placeholder to a ``get_attr`` read. That + structure exports cleanly; after load the buffer is still a + ``nn.Module`` buffer on the loaded gm and the engine still aliases + it in place.""" + + def test_buffer_kv_save_load(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(1, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() + + m = M().cuda() + x = torch.ones(1, 4, 1, 8, device="cuda") * 3.0 + compiled, loaded = _compile_and_roundtrip(m, (x.clone(),)) + + # The compiled module already had the buffer; the loaded one + # should still have it (registered as nn.Module state, saved + # natively through torch.export). + self.assertTrue(hasattr(compiled, "cache")) + self.assertTrue(hasattr(loaded, "cache")) + self.assertEqual(tuple(loaded.cache.shape), tuple(compiled.cache.shape)) + + # Reset to zero so the comparison is clean. + loaded.cache.zero_() + ret = loaded(x) + ret_val = ret[0] if isinstance(ret, tuple) else ret + + eager = M().cuda() + eager_ret = eager(x.clone()) + self.assertTrue(torch.allclose(ret_val, eager_ret)) + self.assertTrue(torch.allclose(loaded.cache, eager.cache)) + + def test_buffer_kv_save_load_streaming(self): + """Repeated calls on the LOADED module accumulate state on the + loaded module's buffer.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(1, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() + + m = M().cuda() + x = torch.ones(1, 4, 1, 8, device="cuda") + _, loaded = _compile_and_roundtrip(m, (x.clone(),)) + + loaded.cache.zero_() + # Each step overwrites position 3 (32 elements). + loaded(torch.ones(1, 4, 1, 8, device="cuda") * 1.0) + self.assertAlmostEqual(loaded.cache.sum().item(), 32.0, places=3) + loaded(torch.ones(1, 4, 1, 8, device="cuda") * 5.0) + self.assertAlmostEqual(loaded.cache.sum().item(), 160.0, places=3) + loaded(torch.zeros(1, 4, 1, 8, device="cuda")) + self.assertAlmostEqual(loaded.cache.sum().item(), 0.0, places=3) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py b/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py new file mode 100644 index 0000000000..21e27540d3 --- /dev/null +++ b/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py @@ -0,0 +1,137 @@ +# type: ignore +"""HuggingFace decoder with ``StaticCache`` — currently expected to fail. + +This file documents the current gap between Torch-TensorRT's aliased-I/O +support and stock HuggingFace decoder-only LMs. The test compiles +``TorchExportableModuleWithStaticCache(GPT2)`` (HF's recommended export +path for static-cache models) and asserts that it fails with the specific +known error so we notice if/when it starts working. + +What works today (covered by ``test_aliased_io.py``): + +* ``register_buffer`` + slice-write KV cache (the pattern we built for). + +What fails with stock HF + ``StaticCache``: + +1. ``ExportedProgram.run_decompositions`` fails inside torch with + ``AssertionError: expected compiled_fn to be GraphModule, got + `` (``_functorch/_aot_autograd/graph_compile.py``). + This is a torch.export internal that surfaces when running + decompositions on an EP whose body contains certain ATen patterns + produced by the HF wrapper. Not Torch-TensorRT's fault. + +2. Even bypassing the decomp issue, HF's ``StaticCache.update`` writes + via ``aten.index_copy_(cache, dim=2, idx, k_or_v)``, NOT + ``aten.slice_scatter``. Our converter today only matches + ``slice_scatter``. The cache tensors also show up as ``c_*_lifted`` + constants rather than BUFFER inputs in the graph_signature. + +To make this end-to-end we'd need: + +* An ``index_copy`` converter with the same KV-eligibility check + alias + recording. +* Extension of ``lift_mutated_buffers`` to recognize mutated lifted + constants (in addition to mutated buffers). +* Either an upstream fix for the ``aot_stage2_export`` issue or a + workaround that skips ``run_decompositions`` for already-decomposed EPs. + +When the upstream issues are resolved or those features land, this +xfail test should start passing — flip it to a real test then. +""" +import unittest + +import torch +import torch_tensorrt + + +def _can_import_hf(): + try: + from transformers import GPT2Config # noqa: F401 + from transformers.integrations.executorch import ( # noqa: F401 + TorchExportableModuleWithStaticCache, + ) + + return True + except Exception: + return False + + +@unittest.skipUnless(_can_import_hf(), "transformers not installed") +class TestHFStaticCacheCurrentLimitations(unittest.TestCase): + def _make_wrapped(self): + from transformers import GPT2Config, GPT2LMHeadModel + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + ) + + config = GPT2Config( + vocab_size=128, n_positions=64, n_embd=64, n_layer=2, n_head=4 + ) + model = GPT2LMHeadModel(config).eval().cuda() + model.config.use_cache = True + model.generation_config.use_cache = True + model.generation_config.cache_implementation = "static" + model.generation_config.cache_config = { + "batch_size": 1, + "max_cache_len": 32, + "device": "cuda", + } + return TorchExportableModuleWithStaticCache( + model, batch_size=1, max_cache_len=32, device="cuda" + ).cuda() + + def test_compile_fails_with_known_error(self): + """Compiling GPT2 + StaticCache currently raises during + ``run_decompositions``. We assert the error so a future + upstream/internal fix that unblocks compilation shows up as a + test failure (signaling we should remove this xfail and write a + real test).""" + wrapped = self._make_wrapped() + input_ids = torch.tensor([[1]], dtype=torch.long, device="cuda") + cache_position = torch.tensor([0], dtype=torch.long, device="cuda") + ep = torch.export.export( + wrapped, + (), + {"input_ids": input_ids, "cache_position": cache_position}, + strict=False, + ) + + # Confirm the EP looks like we expect (HF emits index_copy_ for + # the cache write, not slice_scatter, and treats caches as + # lifted constants rather than BUFFER inputs). + has_index_copy = any( + n.op == "call_function" and "index_copy" in str(n.target) + for n in ep.graph.nodes + ) + self.assertTrue( + has_index_copy, + "Expected aten.index_copy_ in HF GPT2 + StaticCache EP — if this " + "test starts failing here, HF may have switched to slice_scatter " + "and our existing converter might now handle the model directly.", + ) + + with self.assertRaises(Exception) as ctx: + torch_tensorrt.compile( + ep, + ir="dynamo", + inputs=[input_ids, cache_position], + enabled_precisions={torch.float32}, + min_block_size=1, + use_python_runtime=False, + ) + # The torch internal raises an AssertionError with a specific + # message. Match loosely so the test isn't brittle to phrasing + # changes — we only want to detect that the failure is the + # known one rather than something new. + msg = str(ctx.exception) + self.assertTrue( + "compiled_fn" in msg + or "GraphModule" in msg + or "index_copy" in msg.lower() + or "aot_stage2_export" in msg, + f"Compile failed but not with the known error pattern: {msg}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/py/dynamo/runtime/test_index_copy_kv.py b/tests/py/dynamo/runtime/test_index_copy_kv.py new file mode 100644 index 0000000000..7fb67eaeb9 --- /dev/null +++ b/tests/py/dynamo/runtime/test_index_copy_kv.py @@ -0,0 +1,135 @@ +# type: ignore +"""End-to-end tests for ``aten.index_copy`` KV-cache aliasing. + +Two converters are registered for ``aten.index_copy.default``: + +* ``aten_ops_index_copy_kv`` — HIGH priority, validator-gated. Fires for + the narrow KV-eligible case (4-D static cache, dim=2, source seq dim + size 1, batch=1) and emits ``IKVCacheUpdateLayer`` with aliased I/O. + +* ``aten_ops_index_copy_fallback`` — STANDARD priority. Fires for + everything else; produces correct results via the scatter path. No + graph break. + +These tests verify both paths end-to-end via the C++ runtime: the +fast path mutates in place, the fallback produces correct numerical +results without aliasing. +""" +import torch +import torch_tensorrt +from torch.export import export +from torch.testing._internal.common_utils import TestCase, run_tests + + +def _compile(model, args): + ep = export(model, tuple(args)) + return torch_tensorrt.compile( + ep, + ir="dynamo", + inputs=list(args), + enabled_precisions={torch.float32}, + min_block_size=1, + use_python_runtime=False, + ) + + +def _aliased_io(compiled): + for _name, mod in compiled.named_modules(): + if hasattr(mod, "aliased_io") and mod.aliased_io: + return dict(mod.aliased_io) + return {} + + +class TestIndexCopyKVFastPath(TestCase): + """KV-eligible: 4-D static cache, dim=2, batch=1, single-position + write. The validator passes and the fast path emits + ``IKVCacheUpdateLayer`` with aliased output.""" + + def test_single_position_write_aliased(self): + class M(torch.nn.Module): + def forward(self, cache, index, update): + return torch.ops.aten.index_copy.default(cache, 2, index, update) + + cache = torch.zeros(1, 4, 16, 8, device="cuda") + index = torch.tensor([3], dtype=torch.int64, device="cuda") + update = torch.ones(1, 4, 1, 8, device="cuda") * 7.0 + + compiled = _compile(M().cuda(), (cache.clone(), index, update.clone())) + + # Fast path fired — aliasing recorded. + aliased = _aliased_io(compiled) + self.assertEqual(len(aliased), 1) + _, kind = next(iter(aliased.values())) + self.assertEqual(kind, "kv_cache_update") + + # Numerical match against eager. + cache_run = cache.clone() + out = compiled(cache_run, index, update) + out_val = out[0] if isinstance(out, tuple) else out + eager = cache.clone() + eager_out = torch.ops.aten.index_copy.default(eager, 2, index, update) + self.assertTrue(torch.allclose(out_val, eager_out)) + + +class TestIndexCopyFallback(TestCase): + """Cases where the validator denies the KV fast path. The fallback + converter must produce correct results without aliasing.""" + + def test_rank_3_input_uses_fallback(self): + class M(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_copy.default(x, 1, index, update) + + x = torch.zeros(2, 8, 16, device="cuda") + index = torch.tensor([1, 3, 5], dtype=torch.int64, device="cuda") + update = torch.randn(2, 3, 16, device="cuda") + + compiled = _compile(M().cuda(), (x.clone(), index, update.clone())) + + # No aliasing (validator rejected the KV path). + self.assertEqual(_aliased_io(compiled), {}) + + out = compiled(x.clone(), index, update) + eager = torch.ops.aten.index_copy.default(x.clone(), 1, index, update) + self.assertTrue(torch.allclose(out, eager)) + + def test_dim_other_than_two_uses_fallback(self): + class M(torch.nn.Module): + def forward(self, cache, index, update): + return torch.ops.aten.index_copy.default(cache, 1, index, update) + + cache = torch.zeros(1, 16, 4, 8, device="cuda") + index = torch.tensor([3], dtype=torch.int64, device="cuda") + update = torch.ones(1, 1, 4, 8, device="cuda") * 5.0 + + compiled = _compile(M().cuda(), (cache.clone(), index, update.clone())) + self.assertEqual(_aliased_io(compiled), {}) + + cache_run = cache.clone() + out = compiled(cache_run, index, update) + eager = torch.ops.aten.index_copy.default(cache.clone(), 1, index, update) + self.assertTrue(torch.allclose(out, eager)) + + def test_batch_gt_one_uses_fallback(self): + """Batch > 1 currently routes to fallback (broadcasting writeIndices + is a Phase-2 extension).""" + + class M(torch.nn.Module): + def forward(self, cache, index, update): + return torch.ops.aten.index_copy.default(cache, 2, index, update) + + cache = torch.zeros(4, 4, 16, 8, device="cuda") + index = torch.tensor([3], dtype=torch.int64, device="cuda") + update = torch.ones(4, 4, 1, 8, device="cuda") * 7.0 + + compiled = _compile(M().cuda(), (cache.clone(), index, update.clone())) + self.assertEqual(_aliased_io(compiled), {}) + + cache_run = cache.clone() + out = compiled(cache_run, index, update) + eager = torch.ops.aten.index_copy.default(cache.clone(), 2, index, update) + self.assertTrue(torch.allclose(out, eager)) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py b/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py new file mode 100644 index 0000000000..3608fd14fa --- /dev/null +++ b/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py @@ -0,0 +1,217 @@ +# type: ignore +"""Integration tests for the ``lift_mutable_buffers`` flag on +``convert_exported_program_to_serialized_trt_engine``. + +The low-level entry point returns a serialized engine. The high-level +``torch_tensorrt.compile`` automatically lifts buffers and wraps the +result in ``BufferThreadingModule``; this lower-level surface exposes +the same lifting machinery but leaves runtime binding management to the +caller. These tests exercise that caller-managed workflow end-to-end: + +1. Compile via the low-level API with ``lift_mutable_buffers=True``. +2. Introspect the resulting engine — confirm it has additional input + bindings for each mutated buffer and an aliased output per binding. +3. Construct a ``TorchTensorRTModule`` (C++ runtime — required for + aliased I/O) with the discovered bindings. +4. Thread the buffer values in on each call and verify in-place + mutation works (cache state persists across calls). +""" +import torch +import torch_tensorrt +from torch.export import export +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine +from torch_tensorrt.dynamo.runtime import TorchTensorRTModule + +import tensorrt as trt + + +def _introspect_engine(engine_bytes): + """Deserialize and return (input_names, output_names, aliased_io).""" + rt = trt.Runtime(trt.Logger(trt.Logger.WARNING)) + eng = rt.deserialize_cuda_engine(engine_bytes) + input_names = [] + output_names = [] + aliased_io = {} + for i in range(eng.num_io_tensors): + name = eng.get_tensor_name(i) + if eng.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + input_names.append(name) + else: + output_names.append(name) + try: + aliased_in = eng.get_aliased_input_tensor(name) + except Exception: + aliased_in = None + if aliased_in: + aliased_io[name] = (aliased_in, "kv_cache_update") + return input_names, output_names, aliased_io + + +def _build_module(engine_bytes, input_names, output_names, aliased_io): + """Wrap engine bytes in a TorchTensorRTModule (C++ runtime path). + + The user-output boundary is derived inside the module from + ``output_binding_names`` + ``aliased_io`` (side-effect aliased + outputs always live at the end of the binding list). + """ + return TorchTensorRTModule( + serialized_engine=engine_bytes, + input_binding_names=input_names, + output_binding_names=output_names, + aliased_io=aliased_io, + ) + + +class TestLiftMutableBuffersAPI(TestCase): + def test_flag_off_no_buffer_bindings(self): + """Default ``lift_mutable_buffers=False`` keeps the buffer baked + into the engine; bindings only cover user inputs/outputs.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(1, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() + + m = M().cuda() + x = torch.ones(1, 4, 1, 8, device="cuda") + ep = export(m, (x.clone(),)) + engine_bytes = convert_exported_program_to_serialized_trt_engine( + ep, + inputs=[x.clone()], + enabled_precisions={torch.float32}, + min_block_size=1, + ) + inputs, outputs, aliased = _introspect_engine(engine_bytes) + # Only the user input survives; the buffer is folded into the engine. + self.assertEqual(inputs, ["x"]) + self.assertEqual(len(outputs), 1) + self.assertEqual(aliased, {}) + + def test_flag_on_adds_buffer_binding_and_alias(self): + """``lift_mutable_buffers=True`` adds the buffer as an input + binding and the engine reports an aliased output for it.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(1, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() + + m = M().cuda() + x = torch.ones(1, 4, 1, 8, device="cuda") + ep = export(m, (x.clone(),)) + engine_bytes = convert_exported_program_to_serialized_trt_engine( + ep, + inputs=[x.clone()], + enabled_precisions={torch.float32}, + min_block_size=1, + lift_mutable_buffers=True, + ) + inputs, outputs, aliased = _introspect_engine(engine_bytes) + + # User input first, lifted buffer appended after. + self.assertEqual(inputs, ["x", "buf_cache"]) + # One user output + one aliased mutation output. + self.assertEqual(len(outputs), 2) + # The aliased output should point at the buffer binding. + self.assertEqual(len(aliased), 1) + out_name, (in_name, kind) = next(iter(aliased.items())) + self.assertEqual(in_name, "buf_cache") + self.assertEqual(kind, "kv_cache_update") + + def test_caller_threads_buffer_for_in_place_mutation(self): + """End-to-end: caller takes the engine, threads the buffer in, + observes in-place mutation.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(1, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() + + m = M().cuda() + x_sample = torch.ones(1, 4, 1, 8, device="cuda") + ep = export(m, (x_sample.clone(),)) + + engine_bytes = convert_exported_program_to_serialized_trt_engine( + ep, + inputs=[x_sample.clone()], + enabled_precisions={torch.float32}, + min_block_size=1, + lift_mutable_buffers=True, + ) + inputs, outputs, aliased = _introspect_engine(engine_bytes) + module = _build_module(engine_bytes, inputs, outputs, aliased) + + # The caller owns the buffer. Pass it explicitly each call; the + # engine writes through the aliased binding into its storage. + cache = torch.zeros(1, 4, 16, 8, device="cuda") + x = torch.ones(1, 4, 1, 8, device="cuda") * 7.0 + + # Bindings order: ["x", "buf_cache"], so call as (x, cache). + cache_id, cache_ptr = id(cache), cache.data_ptr() + ret = module(x, cache) + + # Numerical: cache should now have 7s at position [3:4, :]. + eager_cache = torch.zeros(1, 4, 16, 8, device="cuda") + eager_cache[:, :, 3:4, :] = x + eager_ret = eager_cache.sum() + + self.assertTrue(torch.allclose(cache, eager_cache)) + ret_val = ret[0] if isinstance(ret, tuple) else ret + self.assertTrue(torch.allclose(ret_val, eager_ret)) + # Identity preserved — same tensor, same storage. + self.assertEqual(id(cache), cache_id) + self.assertEqual(cache.data_ptr(), cache_ptr) + + def test_streaming_state_persists_across_calls(self): + """Caller-managed buffer state should persist across repeated + calls when the same tensor is threaded in each time.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(1, 4, 16, 8)) + + def forward(self, x): + self.cache[:, :, 3:4, :] = x + return self.cache.sum() + + m = M().cuda() + x_sample = torch.ones(1, 4, 1, 8, device="cuda") + ep = export(m, (x_sample.clone(),)) + engine_bytes = convert_exported_program_to_serialized_trt_engine( + ep, + inputs=[x_sample.clone()], + enabled_precisions={torch.float32}, + min_block_size=1, + lift_mutable_buffers=True, + ) + inputs, outputs, aliased = _introspect_engine(engine_bytes) + module = _build_module(engine_bytes, inputs, outputs, aliased) + + # Caller's cache lives across iterations; engine writes through + # the aliased binding each call. Cache slice at position 3 has + # shape (1,4,1,8) = 32 elements, so sum = 32 * scalar. + cache = torch.zeros(1, 4, 16, 8, device="cuda") + module(torch.ones(1, 4, 1, 8, device="cuda") * 1.0, cache) + self.assertAlmostEqual(cache.sum().item(), 32.0, places=3) + module(torch.ones(1, 4, 1, 8, device="cuda") * 5.0, cache) + self.assertAlmostEqual(cache.sum().item(), 160.0, places=3) + module(torch.zeros(1, 4, 1, 8, device="cuda"), cache) + self.assertAlmostEqual(cache.sum().item(), 0.0, places=3) + + +if __name__ == "__main__": + run_tests()