Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo_rl/models/generation/vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class VllmSpecificArgs(TypedDict):
precision: NotRequired[str]
kv_cache_dtype: Literal["auto", "fp8", "fp8_e4m3"]
enforce_eager: NotRequired[bool]
# Whether to show a tqdm progress bar during generation. Defaults to vLLM's own default (True) when absent.
use_tqdm: NotRequired[bool]
Comment thread
yuki-97 marked this conversation as resolved.
# By default, NeMo RL only has a Python handle to the vllm.LLM generation engine. The expose_http_server flag here will expose that generation engine as an HTTP server.
# Exposing vLLM as a server is useful in instances where the multi-turn rollout is performed with utilities outside of NeMo RL, but the user still wants to take advantage of the refit logic in NeMo RL that keeps the policy and generation up to date.
# Currently it will expose the /tokenize and /v1/chat/completions endpoints. Later on we may expose /v1/completions or /v1/responses.
Expand Down
6 changes: 4 additions & 2 deletions nemo_rl/models/generation/vllm/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,8 @@ def generate(
assert self.llm is not None, (
"Attempting to generate with either an uninitialized vLLM or non-model-owner"
)
outputs = self.llm.generate(prompts, sampling_params)
use_tqdm = self.cfg["vllm_cfg"].get("use_tqdm", True)
outputs = self.llm.generate(prompts, sampling_params, use_tqdm=use_tqdm)

# Process the outputs - but preserve the original input padding structure
output_ids_list = []
Expand Down Expand Up @@ -886,7 +887,8 @@ def generate_text(
assert self.llm is not None, (
"Attempting to generate with either an uninitialized vLLM or non-model-owner"
)
outputs = self.llm.generate(data["prompts"], sampling_params)
use_tqdm = self.cfg["vllm_cfg"].get("use_tqdm", True)
outputs = self.llm.generate(data["prompts"], sampling_params, use_tqdm=use_tqdm)
texts = [output.outputs[0].text for output in outputs]

# Convert to BatchedDataDict
Expand Down
Loading