Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ py/torch_tensorrt/bin
py/torch_tensorrt/BUILD
py/torch_tensorrt/LICENSE
py/torch_tensorrt/WORKSPACE
# Build copies these into the package dir for wheel packaging; sources of
# record live at the repo root (core/, csrc/).
py/torch_tensorrt/CMakeLists.txt
py/torch_tensorrt/README.md
py/torch_tensorrt/cmake/
py/torch_tensorrt/core/
py/torch_tensorrt/examples/
py/wheelhouse
py/.eggs
notebooks/.ipynb_checkpoints/
Expand Down
50 changes: 45 additions & 5 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ TRTEngine::TRTEngine(
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata,
const ResourceAllocationStrategy resource_allocation_strategy)
const ResourceAllocationStrategy resource_allocation_strategy,
const std::unordered_map<std::string, AliasedIOSpec>& aliased_io)
: TRTEngine(
"deserialized_trt",
serialized_engine,
Expand All @@ -80,7 +81,8 @@ TRTEngine::TRTEngine(
hardware_compatible,
requires_output_allocator,
serialized_metadata,
resource_allocation_strategy) {}
resource_allocation_strategy,
aliased_io) {}

TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
: TRTEngine(
Expand All @@ -95,7 +97,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
serialized_info[SERIALIZED_METADATA_IDX],
(static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]))
? ResourceAllocationStrategy::kDynamic
: ResourceAllocationStrategy::kStatic)) {
: ResourceAllocationStrategy::kStatic),
deserialize_aliased_io(serialized_info[ALIASED_IO_IDX])) {
this->requires_native_multidevice = std::stoi(serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]);
if (this->requires_native_multidevice) {
LOG_INFO("Loaded distributed TRT engine (contains NCCL collectives); NCCL comm will be bound on first execution");
Expand All @@ -112,7 +115,8 @@ TRTEngine::TRTEngine(
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata,
const ResourceAllocationStrategy resource_allocation_strategy) {
const ResourceAllocationStrategy resource_allocation_strategy,
const std::unordered_map<std::string, AliasedIOSpec>& aliased_io) {
TORCHTRT_CHECK(
is_supported_on_current_platform(target_platform),
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
Expand Down Expand Up @@ -269,6 +273,40 @@ TRTEngine::TRTEngine(
num_io = std::make_pair(inputs_size, outputs);
}

// Store the build-time aliased_io map as the starting point. Then reconcile
// against ICudaEngine::getAliasedInputTensor — that API is the source of
// truth for KV-cache-style aliasing (TRT-enforced via IKVCacheUpdateLayer)
// and may report aliases the build-time path didn't record (e.g. for
// engines built outside Torch-TensorRT). User-declared aliases (kind=kUser)
// are preserved as-is since TRT doesn't know about them.
this->aliased_io = aliased_io;
for (const auto& out_name : this->out_binding_names) {
// TRT returns nullptr / empty string for non-aliased outputs; any thrown
// exception is a real error in the engine state and propagates.
const char* aliased_in = cuda_engine->getAliasedInputTensor(out_name.c_str());
if (aliased_in == nullptr || aliased_in[0] == '\0') {
continue;
}
auto it = this->aliased_io.find(out_name);
if (it == this->aliased_io.end()) {
this->aliased_io[out_name] = AliasedIOSpec{std::string(aliased_in), AliasKind::kKVCacheUpdate};
LOG_DEBUG("aliased_io reconciliation: discovered " << out_name << " -> " << aliased_in << " (kv_cache_update)");
} else if (it->second.input_binding_name != std::string(aliased_in)) {
LOG_WARNING(
"aliased_io: build-time map disagrees with engine for output "
<< out_name << " (build: " << it->second.input_binding_name << ", engine: " << aliased_in
<< "); using engine value.");
it->second = AliasedIOSpec{std::string(aliased_in), AliasKind::kKVCacheUpdate};
}
// Validation: aliased outputs must not also require an output allocator,
// since aliasing requires the output shape to match the input's static
// shape, which is incompatible with the dynamic-allocation path.
TORCHTRT_CHECK(
!this->requires_output_allocator,
"Aliased output " << out_name
<< " is incompatible with dynamic output allocator. Aliasing requires fixed output shape.");
}

#ifndef NDEBUG
this->enable_profiling();
#endif
Expand Down Expand Up @@ -505,7 +543,8 @@ FlattenedState TRTEngine::__obj_flatten__() {
std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]),
std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]),
std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]));
std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]),
std::tuple("aliased_io", serialized_info[ALIASED_IO_IDX]));
}

std::vector<std::string> TRTEngine::serialize() {
Expand All @@ -531,6 +570,7 @@ std::vector<std::string> TRTEngine::serialize() {
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =
this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";
serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = this->requires_native_multidevice ? "1" : "0";
serialized_info[ALIASED_IO_IDX] = serialize_aliased_io(this->aliased_io);
// rank/world_size are runtime facts (may differ at load time); not serialized.

return serialized_info;
Expand Down
52 changes: 49 additions & 3 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <map>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <utility>

#include "ATen/core/function_schema.h"
Expand Down Expand Up @@ -33,6 +34,40 @@ namespace torch_tensorrt {
namespace core {
namespace runtime {

// Origin of an aliased input/output binding pair. KV_CACHE_UPDATE is enforced
// by TensorRT itself (via IKVCacheUpdateLayer; reported through
// ICudaEngine::getAliasedInputTensor); USER is declared by the Torch-TensorRT
// compile flow (TRT doesn't know about it; runtime validates and binds).
enum class AliasKind : int8_t {
kKVCacheUpdate = 0,
kUser = 1,
};

struct AliasedIOSpec {
std::string input_binding_name;
AliasKind kind;
};

inline std::string alias_kind_to_string(AliasKind k) {
switch (k) {
case AliasKind::kKVCacheUpdate:
return "kv_cache_update";
case AliasKind::kUser:
return "user";
}
return "unknown";
}

inline AliasKind alias_kind_from_string(const std::string& s) {
if (s == "kv_cache_update")
return AliasKind::kKVCacheUpdate;
if (s == "user")
return AliasKind::kUser;
// Unknown kinds are conservatively treated as KV-cache-update — TRT enforces
// those without extra runtime work, so worst case the runtime silently no-ops.
return AliasKind::kKVCacheUpdate;
}

using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // ABI_VERSION
std::tuple<std::string, std::string>, // name
Expand All @@ -45,7 +80,8 @@ using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>, // Platform
std::tuple<std::string, std::string>, // Resource Allocation Strategy
std::tuple<std::string, std::string>>; // requires_native_multidevice
std::tuple<std::string, std::string>, // requires_native_multidevice
std::tuple<std::string, std::string>>; // aliased_io

struct TorchTRTRuntimeStates {
// Indicates whether CUDAGraphs were enabled in the previous execute_engine
Expand Down Expand Up @@ -133,6 +169,14 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<std::string> in_binding_names = {}; // ITO: PYT IDX
std::vector<std::string> out_binding_names = {}; // ITO: PYT IDX

// For each output binding name that aliases an input binding, the alias spec.
// Populated either by build-time conversion records (forwarded from
// TRTInterpreterResult) or by reconciliation against the engine's own
// ICudaEngine::getAliasedInputTensor at construction time. The runtime
// consults this map in the output-binding loop to skip allocation and bind
// the same device pointer as the source input.
std::unordered_map<std::string, AliasedIOSpec> aliased_io = {};

bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used
// in compilation
Expand All @@ -149,7 +193,8 @@ struct TRTEngine : torch::CustomClassHolder {
bool requires_output_allocator = false,
const std::string& serialized_metadata = "",
const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy =
TRTEngine::ResourceAllocationStrategy::kStatic);
TRTEngine::ResourceAllocationStrategy::kStatic,
const std::unordered_map<std::string, AliasedIOSpec>& aliased_io = {});

TRTEngine(std::vector<std::string> serialized_info);

Expand All @@ -164,7 +209,8 @@ struct TRTEngine : torch::CustomClassHolder {
bool requires_output_allocator = false,
const std::string& serialized_metadata = "",
const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy =
TRTEngine::ResourceAllocationStrategy::kStatic);
TRTEngine::ResourceAllocationStrategy::kStatic,
const std::unordered_map<std::string, AliasedIOSpec>& aliased_io = {});

std::string to_str() const;
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
Expand Down
Loading
Loading