Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ cc_library(
hdrs = [
"Platform.h",
"RTDevice.h",
"TensorRTBindingNames.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TensorRTBindingNames.h",
"runtime.h",
],
copts = if_torch_nccl(["-DUSE_C10D_NCCL"]),
defines = if_torch_nccl(["USE_C10D_NCCL"]),
linkopts = [
"-lstdc++fs",
],
Expand Down Expand Up @@ -135,9 +135,9 @@ cc_library(
hdrs = [
"Platform.h",
"RTDevice.h",
"TensorRTBindingNames.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TensorRTBindingNames.h",
"runtime.h",
],
deps = [
Expand All @@ -151,9 +151,9 @@ filegroup(
srcs = [
"Platform.h",
"RTDevice.h",
"TensorRTBindingNames.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TensorRTBindingNames.h",
"runtime.h",
],
visibility = ["//visibility:public"],
Expand Down
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ def _populate_trt_builder_config(
) -> trt.IBuilderConfig:
builder_config = self.builder.create_builder_config()

# Enable TRT's native multi-device runtime preview feature when the
# Torch-TRT runtime was built with NCCL collectives support. Without
# this, IBuilder::buildEngineWithConfig() rejects networks that contain
# IDistCollectiveLayer with "PreviewFeature::kMULTIDEVICE_RUNTIME_10_16
# is not enabled in the builder config".
if ENABLED_FEATURES.native_trt_collectives and hasattr(
trt.PreviewFeature, "MULTIDEVICE_RUNTIME_10_16"
):
builder_config.set_preview_feature(
trt.PreviewFeature.MULTIDEVICE_RUNTIME_10_16, True
)

if self._debugger_config and self._debugger_config.engine_builder_monitor:
builder_config.progress_monitor = TRTBulderMonitor()

Expand Down
184 changes: 153 additions & 31 deletions tools/llm/tensor_parallel_llama_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,20 @@
import logging
import os
import sys
import timeit
from contextlib import ExitStack
from pathlib import Path

import torch
import torch.distributed as dist
import torch.distributed.tensor._dtensor_spec
import torch.utils._pytree
from utils import (
generate,
generate_with_static_cache,
record_stats_split,
time_generate_split,
)

# DTensorSpec must be a pytree constant before torch.export traces a TP model.
torch.utils._pytree.register_constant(
Expand All @@ -59,13 +67,17 @@

setup_nccl_for_torch_tensorrt()

from torchtrt_ext import register_sdpa
from transformers import AutoModelForCausalLM, AutoTokenizer

logging.basicConfig(
level=logging.INFO,
format=f"[Rank {rank}] %(levelname)s: %(message)s",
)
logger = logging.getLogger(__name__)
# Quiet torch_tensorrt's default verbose output. `--debug` still re-enables
# Debug-level logging during compile via torch_tensorrt.logging.debug().
torch_tensorrt.logging.set_level(logging.ERROR)
logger.info(f"dist init OK rank={rank}/{world_size} device={DEVICE}")


Expand All @@ -87,18 +99,39 @@ def _extract_logits(outputs):
return outputs


def generate_greedy(model, input_ids, max_len, eos_token_id):
"""Greedy decode that works with any model output format."""
seq = input_ids.clone()
for _ in range(max_len - input_ids.shape[1]):
position_ids = torch.arange(seq.shape[1]).unsqueeze(0).to(seq.device)
outputs = model(seq, position_ids=position_ids)
logits = _extract_logits(outputs)
next_token = logits[:, -1, :].argmax(dim=-1)
seq = torch.cat([seq, next_token[:, None]], dim=-1)
if (next_token == eos_token_id).all():
break
return seq
def time_generate(generate_fn, model, input_ids, max_len, eos_token_id, iterations=5):
"""Measure end-to-end generation latency over multiple iterations."""
timings = []
for _ in range(iterations):
start = timeit.default_timer()
generate_fn(model, input_ids.clone(), max_len, eos_token_id)
torch.cuda.synchronize()
timings.append(timeit.default_timer() - start)
return timings


def _pick_generate_fn(args):
"""Pick the right greedy-decode helper based on the --cache flag."""
if args.cache in ("static_v1", "static_v2"):
return generate_with_static_cache
return generate


def record_stats(backend, timings, precision, batch_size=1):
import numpy as np

times = np.array(timings)
speeds = batch_size / times
return {
"Backend": backend,
"Model Precision": precision,
"Batch size": batch_size,
"Median(FPS)": float(np.median(speeds)),
"Mean(FPS)": float(np.mean(speeds)),
"Median-Latency(ms)": float(np.median(times)) * 1000,
"Mean-Latency(ms)": float(np.mean(times)) * 1000,
"Latency-StdDev(ms)": float(np.std(times)) * 1000,
}


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -143,6 +176,9 @@ def get_exportable_model(args, rank, world_size):
.eval()
.to(DEVICE)
)
# Keep SDPA as a single op so the TRT custom converter (and optional
# static KV cache lowering pass) can pattern-match it.
register_sdpa.enable_sdpa_converter(args.model, model.config)

# Get the default process group name for NCCL all-reduce.
default_pg = dist.distributed_c10d._get_default_group()
Expand Down Expand Up @@ -234,12 +270,20 @@ def export_and_save(input_ids, args):
)
logger.info("Export succeeded.")

# Importing these modules registers the static KV cache lowering passes
# via @_aten_lowering_pass. Must happen before torch_tensorrt.dynamo.compile.
if args.cache == "static_v1":
import static_cache_v1 # noqa: F401
elif args.cache == "static_v2":
import static_cache_v2 # noqa: F401

logger.info("Compiling exported program with TRT (AOT) ...")
with (
torch_tensorrt.logging.debug()
if args.debug
else torch.autocast("cuda", dtype=torch.float16)
):
# Always run compile under FP16 autocast (matches the verify path below);
# additionally enable verbose Torch-TRT logging when --debug is set.
with ExitStack() as _compile_stack:
_compile_stack.enter_context(torch.autocast("cuda", dtype=torch.float16))
if args.debug:
_compile_stack.enter_context(torch_tensorrt.logging.debug())
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[
Expand Down Expand Up @@ -280,11 +324,19 @@ def export_and_save(input_ids, args):
logger.info("NCCL communicator eagerly initialized for export verification")

# Verify
logger.info("Verifying compiled model ...")
with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
ref = _extract_logits(model(input_ids, position_ids=position_ids))
trt = _extract_logits(trt_model(input_ids, position_ids=position_ids))
logger.info(f"Max logit diff: {(ref.float() - trt.float()).abs().max().item():.6f}")
if args.cache:
# With KV cache, the engine takes (input_ids, position_ids, *kv_cache,
# start_idx, end_idx) and the reference model doesn't — logit
# comparison requires a separate KV-aware driver. Skip.
logger.info("Skipping logit-diff verify (KV cache enabled).")
else:
logger.info("Verifying compiled model ...")
with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
ref = _extract_logits(model(input_ids, position_ids=position_ids))
trt = _extract_logits(trt_model(input_ids, position_ids=position_ids))
logger.info(
f"Max logit diff: {(ref.float() - trt.float()).abs().max().item():.6f}"
)

# Save outside autocast — serialization doesn't need it and retrace=True
# would fail (execute_engine has no AutocastCUDA kernel for torch.export).
Expand Down Expand Up @@ -326,7 +378,8 @@ def load_and_run(input_ids, tokenizer, args):
logger.info("Engine loaded.")

max_len = input_ids.shape[1] + args.num_tokens
loaded_tokens = generate_greedy(
gen_fn = _pick_generate_fn(args)
loaded_tokens = gen_fn(
trt_model,
input_ids.clone(),
max_len,
Expand All @@ -351,44 +404,113 @@ def load_and_run(input_ids, tokenizer, args):
)
parser.add_argument("--model", default="meta-llama/Llama-3.2-1B-Instruct")
parser.add_argument("--prompt", default="What is tensor parallelism?")
parser.add_argument("--num_tokens", type=int, default=64)
parser.add_argument("--max_seq_len", type=int, default=128)
parser.add_argument("--num_tokens", type=int, default=128)
parser.add_argument(
"--mode",
required=True,
choices=["export", "load"],
help="export: AOT compile + save engines | load: load engines + infer",
)
parser.add_argument("--save_dir", default="/tmp/llama_tp_engines")
parser.add_argument(
"--cache",
choices=["", "static_v1", "static_v2"],
default="",
help="KV cache lowering pass. '' disables cache (full-seq recompute).",
)
parser.add_argument("--debug", action="store_true")
parser.add_argument(
"--benchmark", action="store_true", help="Measure generation latency"
)
parser.add_argument(
"--iterations", type=int, default=5, help="Benchmark iterations"
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for benchmarking"
)
parser.add_argument(
"--isl", type=int, default=2048, help="Input sequence length for benchmarking"
)
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(DEVICE)
if args.benchmark:
input_ids = torch.randint(
1, 10000, (args.batch_size, args.isl), dtype=torch.int64
).to(DEVICE)
else:
input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(DEVICE)
max_len = input_ids.shape[1] + args.num_tokens
args.max_seq_len = max_len

trt_model = None
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
gen_fn = _pick_generate_fn(args)

if args.mode == "export":
trt_model = export_and_save(input_ids.clone(), args)

logger.info("Running freshly compiled model ...")
trt_tokens = generate_greedy(
trt_tokens = gen_fn(
trt_model,
input_ids.clone(),
max_len,
tokenizer.eos_token_id,
)
if rank == 0:
print("\n===== TensorRT-TP (freshly compiled) =====")
print(tokenizer.decode(trt_tokens[0], skip_special_tokens=True))
sys.stdout.flush()
if not args.benchmark:
if rank == 0:
print("\n===== TensorRT-TP (freshly compiled) =====")
print(tokenizer.decode(trt_tokens[0], skip_special_tokens=True))
sys.stdout.flush()
else:
# All ranks must participate in the benchmark loop.
use_cache = args.cache in ("static_v1", "static_v2")
trt_results = time_generate_split(
trt_model,
input_ids.clone(),
max_len,
tokenizer.eos_token_id,
iterations=args.iterations,
use_cache=use_cache,
)
if rank == 0:
stats = record_stats_split(
"TensorRT-TP (export)",
trt_results,
"FP16",
batch_size=args.batch_size,
)
print("\n=========TensorRT-TP (export) PERFORMANCE============")
print(stats)
sys.stdout.flush()

elif args.mode == "load":
trt_model, _ = load_and_run(input_ids, tokenizer, args)
if args.benchmark:
# All ranks must participate — the decode loop contains NCCL
# all-reduce ops that require every rank to call in lockstep.
use_cache = args.cache in ("static_v1", "static_v2")
trt_results = time_generate_split(
trt_model,
input_ids.clone(),
max_len,
tokenizer.eos_token_id,
iterations=args.iterations,
use_cache=use_cache,
)
if rank == 0:
stats = record_stats_split(
"TensorRT-TP (load)",
trt_results,
"FP16",
batch_size=args.batch_size,
)
print("\n=========TensorRT-TP (load) PERFORMANCE============")
print(stats)
sys.stdout.flush()

# Delete the TRT engine before destroying the process group — the engine
# holds a reference to the NCCL communicator and will segfault if NCCL is
Expand Down
Loading
Loading