diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 7f594ecea7..40acc5840d 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -95,12 +95,12 @@ cc_library( hdrs = [ "Platform.h", "RTDevice.h", - "TensorRTBindingNames.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TensorRTBindingNames.h", "runtime.h", ], - copts = if_torch_nccl(["-DUSE_C10D_NCCL"]), + defines = if_torch_nccl(["USE_C10D_NCCL"]), linkopts = [ "-lstdc++fs", ], @@ -135,9 +135,9 @@ cc_library( hdrs = [ "Platform.h", "RTDevice.h", - "TensorRTBindingNames.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TensorRTBindingNames.h", "runtime.h", ], deps = [ @@ -151,9 +151,9 @@ filegroup( srcs = [ "Platform.h", "RTDevice.h", - "TensorRTBindingNames.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TensorRTBindingNames.h", "runtime.h", ], visibility = ["//visibility:public"], diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index e1f4d8bafb..157d327131 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -199,6 +199,18 @@ def _populate_trt_builder_config( ) -> trt.IBuilderConfig: builder_config = self.builder.create_builder_config() + # Enable TRT's native multi-device runtime preview feature when the + # Torch-TRT runtime was built with NCCL collectives support. Without + # this, IBuilder::buildEngineWithConfig() rejects networks that contain + # IDistCollectiveLayer with "PreviewFeature::kMULTIDEVICE_RUNTIME_10_16 + # is not enabled in the builder config". + if ENABLED_FEATURES.native_trt_collectives and hasattr( + trt.PreviewFeature, "MULTIDEVICE_RUNTIME_10_16" + ): + builder_config.set_preview_feature( + trt.PreviewFeature.MULTIDEVICE_RUNTIME_10_16, True + ) + if self._debugger_config and self._debugger_config.engine_builder_monitor: builder_config.progress_monitor = TRTBulderMonitor() diff --git a/tools/llm/tensor_parallel_llama_export.py b/tools/llm/tensor_parallel_llama_export.py index 19b401cd9a..f7a5b00a4f 100644 --- a/tools/llm/tensor_parallel_llama_export.py +++ b/tools/llm/tensor_parallel_llama_export.py @@ -31,12 +31,20 @@ import logging import os import sys +import timeit +from contextlib import ExitStack from pathlib import Path import torch import torch.distributed as dist import torch.distributed.tensor._dtensor_spec import torch.utils._pytree +from utils import ( + generate, + generate_with_static_cache, + record_stats_split, + time_generate_split, +) # DTensorSpec must be a pytree constant before torch.export traces a TP model. torch.utils._pytree.register_constant( @@ -59,6 +67,7 @@ setup_nccl_for_torch_tensorrt() +from torchtrt_ext import register_sdpa from transformers import AutoModelForCausalLM, AutoTokenizer logging.basicConfig( @@ -66,6 +75,9 @@ format=f"[Rank {rank}] %(levelname)s: %(message)s", ) logger = logging.getLogger(__name__) +# Quiet torch_tensorrt's default verbose output. `--debug` still re-enables +# Debug-level logging during compile via torch_tensorrt.logging.debug(). +torch_tensorrt.logging.set_level(logging.ERROR) logger.info(f"dist init OK rank={rank}/{world_size} device={DEVICE}") @@ -87,18 +99,39 @@ def _extract_logits(outputs): return outputs -def generate_greedy(model, input_ids, max_len, eos_token_id): - """Greedy decode that works with any model output format.""" - seq = input_ids.clone() - for _ in range(max_len - input_ids.shape[1]): - position_ids = torch.arange(seq.shape[1]).unsqueeze(0).to(seq.device) - outputs = model(seq, position_ids=position_ids) - logits = _extract_logits(outputs) - next_token = logits[:, -1, :].argmax(dim=-1) - seq = torch.cat([seq, next_token[:, None]], dim=-1) - if (next_token == eos_token_id).all(): - break - return seq +def time_generate(generate_fn, model, input_ids, max_len, eos_token_id, iterations=5): + """Measure end-to-end generation latency over multiple iterations.""" + timings = [] + for _ in range(iterations): + start = timeit.default_timer() + generate_fn(model, input_ids.clone(), max_len, eos_token_id) + torch.cuda.synchronize() + timings.append(timeit.default_timer() - start) + return timings + + +def _pick_generate_fn(args): + """Pick the right greedy-decode helper based on the --cache flag.""" + if args.cache in ("static_v1", "static_v2"): + return generate_with_static_cache + return generate + + +def record_stats(backend, timings, precision, batch_size=1): + import numpy as np + + times = np.array(timings) + speeds = batch_size / times + return { + "Backend": backend, + "Model Precision": precision, + "Batch size": batch_size, + "Median(FPS)": float(np.median(speeds)), + "Mean(FPS)": float(np.mean(speeds)), + "Median-Latency(ms)": float(np.median(times)) * 1000, + "Mean-Latency(ms)": float(np.mean(times)) * 1000, + "Latency-StdDev(ms)": float(np.std(times)) * 1000, + } # --------------------------------------------------------------------------- @@ -143,6 +176,9 @@ def get_exportable_model(args, rank, world_size): .eval() .to(DEVICE) ) + # Keep SDPA as a single op so the TRT custom converter (and optional + # static KV cache lowering pass) can pattern-match it. + register_sdpa.enable_sdpa_converter(args.model, model.config) # Get the default process group name for NCCL all-reduce. default_pg = dist.distributed_c10d._get_default_group() @@ -234,12 +270,20 @@ def export_and_save(input_ids, args): ) logger.info("Export succeeded.") + # Importing these modules registers the static KV cache lowering passes + # via @_aten_lowering_pass. Must happen before torch_tensorrt.dynamo.compile. + if args.cache == "static_v1": + import static_cache_v1 # noqa: F401 + elif args.cache == "static_v2": + import static_cache_v2 # noqa: F401 + logger.info("Compiling exported program with TRT (AOT) ...") - with ( - torch_tensorrt.logging.debug() - if args.debug - else torch.autocast("cuda", dtype=torch.float16) - ): + # Always run compile under FP16 autocast (matches the verify path below); + # additionally enable verbose Torch-TRT logging when --debug is set. + with ExitStack() as _compile_stack: + _compile_stack.enter_context(torch.autocast("cuda", dtype=torch.float16)) + if args.debug: + _compile_stack.enter_context(torch_tensorrt.logging.debug()) trt_model = torch_tensorrt.dynamo.compile( ep, inputs=[ @@ -280,11 +324,19 @@ def export_and_save(input_ids, args): logger.info("NCCL communicator eagerly initialized for export verification") # Verify - logger.info("Verifying compiled model ...") - with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16): - ref = _extract_logits(model(input_ids, position_ids=position_ids)) - trt = _extract_logits(trt_model(input_ids, position_ids=position_ids)) - logger.info(f"Max logit diff: {(ref.float() - trt.float()).abs().max().item():.6f}") + if args.cache: + # With KV cache, the engine takes (input_ids, position_ids, *kv_cache, + # start_idx, end_idx) and the reference model doesn't — logit + # comparison requires a separate KV-aware driver. Skip. + logger.info("Skipping logit-diff verify (KV cache enabled).") + else: + logger.info("Verifying compiled model ...") + with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16): + ref = _extract_logits(model(input_ids, position_ids=position_ids)) + trt = _extract_logits(trt_model(input_ids, position_ids=position_ids)) + logger.info( + f"Max logit diff: {(ref.float() - trt.float()).abs().max().item():.6f}" + ) # Save outside autocast — serialization doesn't need it and retrace=True # would fail (execute_engine has no AutocastCUDA kernel for torch.export). @@ -326,7 +378,8 @@ def load_and_run(input_ids, tokenizer, args): logger.info("Engine loaded.") max_len = input_ids.shape[1] + args.num_tokens - loaded_tokens = generate_greedy( + gen_fn = _pick_generate_fn(args) + loaded_tokens = gen_fn( trt_model, input_ids.clone(), max_len, @@ -351,8 +404,7 @@ def load_and_run(input_ids, tokenizer, args): ) parser.add_argument("--model", default="meta-llama/Llama-3.2-1B-Instruct") parser.add_argument("--prompt", default="What is tensor parallelism?") - parser.add_argument("--num_tokens", type=int, default=64) - parser.add_argument("--max_seq_len", type=int, default=128) + parser.add_argument("--num_tokens", type=int, default=128) parser.add_argument( "--mode", required=True, @@ -360,35 +412,105 @@ def load_and_run(input_ids, tokenizer, args): help="export: AOT compile + save engines | load: load engines + infer", ) parser.add_argument("--save_dir", default="/tmp/llama_tp_engines") + parser.add_argument( + "--cache", + choices=["", "static_v1", "static_v2"], + default="", + help="KV cache lowering pass. '' disables cache (full-seq recompute).", + ) parser.add_argument("--debug", action="store_true") + parser.add_argument( + "--benchmark", action="store_true", help="Measure generation latency" + ) + parser.add_argument( + "--iterations", type=int, default=5, help="Benchmark iterations" + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for benchmarking" + ) + parser.add_argument( + "--isl", type=int, default=2048, help="Input sequence length for benchmarking" + ) args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(DEVICE) + if args.benchmark: + input_ids = torch.randint( + 1, 10000, (args.batch_size, args.isl), dtype=torch.int64 + ).to(DEVICE) + else: + input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(DEVICE) max_len = input_ids.shape[1] + args.num_tokens + args.max_seq_len = max_len trt_model = None with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): + gen_fn = _pick_generate_fn(args) + if args.mode == "export": trt_model = export_and_save(input_ids.clone(), args) logger.info("Running freshly compiled model ...") - trt_tokens = generate_greedy( + trt_tokens = gen_fn( trt_model, input_ids.clone(), max_len, tokenizer.eos_token_id, ) - if rank == 0: - print("\n===== TensorRT-TP (freshly compiled) =====") - print(tokenizer.decode(trt_tokens[0], skip_special_tokens=True)) - sys.stdout.flush() + if not args.benchmark: + if rank == 0: + print("\n===== TensorRT-TP (freshly compiled) =====") + print(tokenizer.decode(trt_tokens[0], skip_special_tokens=True)) + sys.stdout.flush() + else: + # All ranks must participate in the benchmark loop. + use_cache = args.cache in ("static_v1", "static_v2") + trt_results = time_generate_split( + trt_model, + input_ids.clone(), + max_len, + tokenizer.eos_token_id, + iterations=args.iterations, + use_cache=use_cache, + ) + if rank == 0: + stats = record_stats_split( + "TensorRT-TP (export)", + trt_results, + "FP16", + batch_size=args.batch_size, + ) + print("\n=========TensorRT-TP (export) PERFORMANCE============") + print(stats) + sys.stdout.flush() elif args.mode == "load": trt_model, _ = load_and_run(input_ids, tokenizer, args) + if args.benchmark: + # All ranks must participate — the decode loop contains NCCL + # all-reduce ops that require every rank to call in lockstep. + use_cache = args.cache in ("static_v1", "static_v2") + trt_results = time_generate_split( + trt_model, + input_ids.clone(), + max_len, + tokenizer.eos_token_id, + iterations=args.iterations, + use_cache=use_cache, + ) + if rank == 0: + stats = record_stats_split( + "TensorRT-TP (load)", + trt_results, + "FP16", + batch_size=args.batch_size, + ) + print("\n=========TensorRT-TP (load) PERFORMANCE============") + print(stats) + sys.stdout.flush() # Delete the TRT engine before destroying the process group — the engine # holds a reference to the NCCL communicator and will segfault if NCCL is diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 5c3197356d..c7079f434e 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -1,5 +1,6 @@ import copy import timeit +from typing import Any, Callable, Optional, TypedDict import numpy as np import torch @@ -10,6 +11,25 @@ ) +class IterTiming(TypedDict): + """Per-iteration prefill/decode timing record. + + Fields: + ttft_s: Time-to-first-token (prefill latency) in seconds. + decode_s: Total decode-phase latency (all decode steps) in seconds. + total_s: ttft_s + decode_s; convenience for downstream aggregation. + prefill_tokens: Number of input tokens fed to the prefill forward pass. + decode_tokens: Number of tokens produced by the decode loop (does not + include the first token produced by the prefill step). + """ + + ttft_s: float + decode_s: float + total_s: float + prefill_tokens: int + decode_tokens: int + + def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): """ Exports the LLM model into an ExportedProgram with dynamic shapes. @@ -289,6 +309,228 @@ def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None) return stats +def _timed_generate_static_cache( + model: torch.nn.Module, + input_seq: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, +) -> IterTiming: + """Single-iteration timed greedy decode with a static KV cache. + + Splits timing into a prefill phase (TTFT - time to first token) and a + decode phase (all subsequent single-token steps). Both phases are + bracketed by ``torch.cuda.synchronize()`` so the measured wall-clock + time reflects GPU completion, not just kernel launches. + """ + start_idx = 0 + end_idx = input_seq.shape[1] + prefill_tokens = end_idx + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + kv_cache = get_zeroed_static_cache_inputs(model) + + torch.cuda.synchronize() + prefill_start = timeit.default_timer() + input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx) + logits_keys_values = model(*input_signature) + torch.cuda.synchronize() + ttft_s = timeit.default_timer() - prefill_start + + logits = logits_keys_values[0] + kv_cache = logits_keys_values[1:] + next_tokens = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) + input_seq = next_tokens + start_idx = end_idx + end_idx = start_idx + 1 + + decode_tokens = 0 + decode_start = timeit.default_timer() + while end_idx < max_output_seq_length: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).cuda() + input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx) + logits_keys_values = model(*input_signature) + logits = logits_keys_values[0] + kv_cache = logits_keys_values[1:] + next_tokens = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) + input_seq = next_tokens + start_idx = end_idx + end_idx += 1 + decode_tokens += 1 + torch.cuda.synchronize() + decode_s = timeit.default_timer() - decode_start + + return { + "ttft_s": ttft_s, + "decode_s": decode_s, + "total_s": ttft_s + decode_s, + "prefill_tokens": prefill_tokens, + "decode_tokens": decode_tokens, + } + + +def _timed_generate_no_cache( + model: torch.nn.Module, + input_seq: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, +) -> IterTiming: + """Single-iteration timed greedy decode without a KV cache. + + Each decode step re-runs the full forward pass over the growing input + sequence. Prefill (first forward) and decode (remaining forwards) are + timed separately under explicit CUDA synchronization. + """ + prefill_tokens = input_seq.shape[1] + target_decode_tokens = max_output_seq_length - prefill_tokens + + position_ids = torch.arange(input_seq.shape[1], device=input_seq.device).unsqueeze( + 0 + ) + + torch.cuda.synchronize() + prefill_start = timeit.default_timer() + outputs = model(input_seq, position_ids=position_ids) + torch.cuda.synchronize() + ttft_s = timeit.default_timer() - prefill_start + + logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] + next_tokens = torch.argmax(logits[:, -1, :], dim=-1) + input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) + + decode_tokens = 0 + decode_start = timeit.default_timer() + while decode_tokens + 1 < target_decode_tokens: + position_ids = torch.arange( + input_seq.shape[1], device=input_seq.device + ).unsqueeze(0) + outputs = model(input_seq, position_ids=position_ids) + logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] + next_tokens = torch.argmax(logits[:, -1, :], dim=-1) + input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) + decode_tokens += 1 + torch.cuda.synchronize() + decode_s = timeit.default_timer() - decode_start + + return { + "ttft_s": ttft_s, + "decode_s": decode_s, + "total_s": ttft_s + decode_s, + "prefill_tokens": prefill_tokens, + "decode_tokens": decode_tokens, + } + + +def time_generate_split( + *, + model: torch.nn.Module, + inputs: torch.Tensor, + output_seq_length: int, + eos_token_id: int, + use_cache: bool = False, + iterations: int = 5, +) -> list[IterTiming]: + """Measure per-iteration prefill / decode timings over ``iterations`` runs. + + A single warmup pass (using the un-timed helper) is run before the + timed iterations. The warmup-fn / timed-fn pair is built up-front so + the warmup loop is no longer duplicated inside each ``use_cache`` + branch. + + Args: + model: The model (or compiled TRT module) to benchmark. + inputs: Input token tensor. ``.clone()`` is taken before each call + so the original tensor is not mutated. + output_seq_length: Total target sequence length (prefill + decode). + eos_token_id: EOS id; forwarded to the underlying generate helpers + (they ignore it when ``benchmark=True``). + use_cache: If True, use static-cache prefill/decode; else recompute + the full sequence each step. + iterations: Number of timed iterations. + + Returns: + List of per-iteration ``IterTiming`` dicts. + """ + if use_cache: + warmup_fn: Callable[[], Any] = lambda: generate_with_static_cache( + model, inputs.clone(), output_seq_length, eos_token_id + ) + timed_fn: Callable[[], IterTiming] = lambda: _timed_generate_static_cache( + model, inputs.clone(), output_seq_length, eos_token_id + ) + else: + warmup_fn = lambda: generate( + model, inputs.clone(), output_seq_length, eos_token_id + ) + timed_fn = lambda: _timed_generate_no_cache( + model, inputs.clone(), output_seq_length, eos_token_id + ) + + _ = warmup_fn() + torch.cuda.synchronize() + + timings: list[IterTiming] = [] + for _ in range(iterations): + timings.append(timed_fn()) + return timings + + +def record_stats_split( + backend: str, + timings: list[IterTiming], + precision: str, + *, + batch_size: int = 1, + compile_time_s: Optional[float] = None, +) -> dict[str, Any]: + """Aggregate per-iteration prefill/decode timings into a summary dict. + + ``output_tokens`` is ``decode_tokens + 1`` — the prefill step itself + produces the first output token, then the decode loop produces + ``decode_tokens`` more. + + Args: + backend: Free-form backend label (e.g. ``"TensorRT"``, ``"PyTorch"``). + timings: Per-iteration timings produced by ``time_generate_split``. + precision: Free-form precision label (e.g. ``"FP16"``). + batch_size: Batch size used during the run. + compile_time_s: Optional compile time in seconds; passthrough. + + Returns: + Dict of summary stats (latencies in ms, throughputs in tok/s). + """ + ttfts = np.array([t["ttft_s"] for t in timings]) + decodes = np.array([t["decode_s"] for t in timings]) + totals = np.array([t["total_s"] for t in timings]) + + prefill_tokens = timings[0]["prefill_tokens"] + decode_tokens = timings[0]["decode_tokens"] + output_tokens = decode_tokens + 1 + + decode_tps_mean = ( + float((decode_tokens / decodes).mean()) if decode_tokens > 0 else 0.0 + ) + output_tps_mean = ( + float((output_tokens / totals).mean()) if output_tokens > 0 else 0.0 + ) + + return { + "Backend": backend, + "Model Precision": precision, + "Batch size": batch_size, + "Prefill tokens": prefill_tokens, + "Decode tokens": decode_tokens, + "Output tokens": output_tokens, + "Median TTFT(ms)": float(np.median(ttfts)) * 1000, + "Mean TTFT(ms)": float(ttfts.mean()) * 1000, + "Median Decode(ms)": float(np.median(decodes)) * 1000, + "Mean Decode(ms)": float(decodes.mean()) * 1000, + "Median Total(ms)": float(np.median(totals)) * 1000, + "Mean Total(ms)": float(totals.mean()) * 1000, + "Decode tokens/s (mean)": decode_tps_mean, + "Output tokens/s (mean)": output_tps_mean, + "Compile Time(s)": compile_time_s, + } + + def _prepare_mm_inputs( model, pixel_values: torch.Tensor | None,