import random
from collections.abc import Sequence
from typing import Any
import torch # pyright: ignore[reportMissingImports]
from vllm import LLM, SamplingParams # pyright: ignore[reportMissingImports]
from vllm.outputs import RequestOutput # pyright: ignore[reportMissingImports]
from vllm.sampling_params import StructuredOutputsParams # pyright: ignore[reportMissingImports]
from ..logger import get_logger
from ..utilities.utils import _make_cache_key, generate_seeds
from .dynamic_pydantic import build_pydantic_model_from_json_object
from .reasoning_parser import parse_reasoning
from .response_generation import (
ChoiceResponseGenerationMethod,
JSONResponseGenerationMethod,
LogprobResponseGenerationMethod,
ResponseGenerationMethod,
)
from .utils import InferenceMode, normalize_system_messages, validate_inference_mode
logger = get_logger(__name__)
def _prepare_vllm_generation(
batch_size: int,
response_generation_method: ResponseGenerationMethod | list[ResponseGenerationMethod] | None,
seed: int,
print_progress: bool,
generation_kwargs: dict[str, Any],
) -> tuple[list[SamplingParams], bool, dict[str, Any], LogprobResponseGenerationMethod | None]:
"""Prepare shared vLLM sampling and execution options."""
seeds = generate_seeds(seed, batch_size=batch_size)
logprob_config = _update_logprob_kwargs(response_generation_method, generation_kwargs)
# If users specify use_tqdm themselves, we use that flag instead.
print_progress = generation_kwargs.pop("use_tqdm", print_progress)
if "sampling_params" in generation_kwargs.keys():
import warnings
warnings.warn(
"Do not specify sampling_params for vllm inference. "
"If you want to use hyperparameters, add them directly to the "
"generation kwargs. Given argument sampling_params will be ignored.",
stacklevel=2,
)
generation_kwargs.pop("sampling_params")
gen_kwargs, call_kwargs = _split_kwargs(generation_kwargs)
sampling_params_list = _create_sampling_params(
batch_size=batch_size,
seeds=seeds,
response_generation_method=response_generation_method,
**gen_kwargs,
)
return sampling_params_list, print_progress, call_kwargs, logprob_config
def _finalize_vllm_outputs(
model: LLM,
outputs: list[RequestOutput],
response_generation_method: ResponseGenerationMethod | list[ResponseGenerationMethod] | None,
logprob_config: LogprobResponseGenerationMethod | None,
reasoning_start_token: str,
reasoning_end_token: str,
space_char: str,
) -> tuple[list[str], list[str], list[str]]:
"""Parse shared vLLM outputs into answer, logprob, and reasoning lists."""
raw_reasonings, reasoning_outputs, plain_results = _extract_reasoning_and_answer(
reasoning_start_token, reasoning_end_token, outputs
)
if logprob_config:
logprob_result = _get_logprobs(
model,
response_generation_method,
reasoning_start_token,
reasoning_end_token,
space_char,
outputs,
raw_reasonings,
)
else:
logprob_result = [None] * len(plain_results)
return (plain_results, logprob_result, reasoning_outputs)
def _run_vllm_chat_pipeline(
model: LLM,
batch_messages: list[list[dict[str, str]]],
response_generation_method: ResponseGenerationMethod | list[ResponseGenerationMethod] | None,
seed: int,
print_progress: bool,
reasoning_start_token: str,
reasoning_end_token: str,
space_char: str,
**generation_kwargs: Any,
) -> tuple[list[str], list[str], list[str]]:
"""Run the shared vLLM chat pipeline for single and conversation batching."""
sampling_params_list, print_progress, chat_kwargs, logprob_config = _prepare_vllm_generation(
batch_size=len(batch_messages),
response_generation_method=response_generation_method,
seed=seed,
print_progress=print_progress,
generation_kwargs=generation_kwargs,
)
outputs: list[RequestOutput] = model.chat(
batch_messages,
sampling_params=sampling_params_list,
use_tqdm=print_progress,
**chat_kwargs,
)
return _finalize_vllm_outputs(
model=model,
outputs=outputs,
response_generation_method=response_generation_method,
logprob_config=logprob_config,
reasoning_start_token=reasoning_start_token,
reasoning_end_token=reasoning_end_token,
space_char=space_char,
)
def _run_vllm_completion_pipeline(
model: LLM,
batch_messages: list[list[dict[str, str]]],
response_generation_method: ResponseGenerationMethod | list[ResponseGenerationMethod] | None,
seed: int,
print_progress: bool,
reasoning_start_token: str,
reasoning_end_token: str,
space_char: str,
**generation_kwargs: Any,
) -> tuple[list[str], list[str], list[str]]:
"""Run vLLM completion generation for base models."""
sampling_params_list, print_progress, generate_kwargs, logprob_config = (
_prepare_vllm_generation(
batch_size=len(batch_messages),
response_generation_method=response_generation_method,
seed=seed,
print_progress=print_progress,
generation_kwargs=generation_kwargs,
)
)
rendered_prompts = [messages[-1]["content"] for messages in batch_messages]
outputs: list[RequestOutput] = model.generate(
rendered_prompts,
sampling_params=sampling_params_list,
use_tqdm=print_progress,
**generate_kwargs,
)
return _finalize_vllm_outputs(
model=model,
outputs=outputs,
response_generation_method=response_generation_method,
logprob_config=logprob_config,
reasoning_start_token=reasoning_start_token,
reasoning_end_token=reasoning_end_token,
space_char=space_char,
)
[docs]
def run_vllm_batch(
model: LLM,
system_messages: Sequence[str | None] | None = ("You are a helpful assistant.",),
prompts: Sequence[str] = ("Hi there! What is your name?",),
response_generation_method: (
ResponseGenerationMethod | list[ResponseGenerationMethod] | None
) = None,
seed: int = 42,
# number_of_printed_conversation: int = 2,
print_progress: bool = True,
# <think>...</think> tokens are used by Qwen3 to separate reasoning
reasoning_start_token: str = "<think>",
reasoning_end_token: str = "</think>",
space_char: str = "Ġ",
inference_mode: InferenceMode = "chat",
**generation_kwargs: Any,
) -> tuple[list[str], list[str], list[str]]:
inference_mode = validate_inference_mode(inference_mode)
normalized_system_messages = normalize_system_messages(
system_messages=system_messages,
batch_size=len(prompts),
)
# Prepare batch of messages
batch_messages: list[list[dict[str, str]]] = []
for system_message, prompt in zip(normalized_system_messages, prompts):
messages: list[dict[str, str]] = [{"role": "user", "content": prompt}]
if system_message is not None:
messages.insert(0, {"role": "system", "content": system_message})
batch_messages.append(messages)
if inference_mode == "completion":
return _run_vllm_completion_pipeline(
model=model,
batch_messages=batch_messages,
response_generation_method=response_generation_method,
seed=seed,
print_progress=print_progress,
reasoning_start_token=reasoning_start_token,
reasoning_end_token=reasoning_end_token,
space_char=space_char,
**generation_kwargs,
)
return _run_vllm_chat_pipeline(
model=model,
batch_messages=batch_messages,
response_generation_method=response_generation_method,
seed=seed,
print_progress=print_progress,
reasoning_start_token=reasoning_start_token,
reasoning_end_token=reasoning_end_token,
space_char=space_char,
**generation_kwargs,
)
[docs]
def run_vllm_batch_conversation(
model: LLM,
system_messages: Sequence[str | None] | None = ("You are a helpful assistant.",),
prompts: Sequence[Sequence[str]] = (("Hi there! What is your name?",),),
assistant_messages: Sequence[Sequence[str]] = (),
response_generation_method: (
ResponseGenerationMethod | list[ResponseGenerationMethod] | None
) = None,
seed: int = 42,
# number_of_printed_conversation: int = 2,
print_progress: bool = True,
# <think>...</think> tokens are used by Qwen3 to separate reasoning
reasoning_start_token: str = "<think>",
reasoning_end_token: str = "</think>",
space_char: str = "Ġ",
inference_mode: InferenceMode = "chat",
**generation_kwargs: Any,
) -> tuple[list[str], list[str], list[str]]:
inference_mode = validate_inference_mode(inference_mode)
normalized_system_messages = normalize_system_messages(
system_messages=system_messages,
batch_size=len(prompts),
)
batch_messages = []
batch_size = len(normalized_system_messages)
if not assistant_messages:
assistant_messages = tuple(() for _ in range(batch_size))
for i in range(batch_size):
messages = []
# Add system message
if normalized_system_messages[i] is not None:
messages.append({"role": "system", "content": normalized_system_messages[i]})
num_user_msgs = len(prompts[i])
num_assistant_msgs = len(assistant_messages[i])
for j in range(num_user_msgs):
messages.append({"role": "user", "content": prompts[i][j]})
if j < num_assistant_msgs:
messages.append({"role": "assistant", "content": assistant_messages[i][j]})
batch_messages.append(messages)
if inference_mode == "completion":
return _run_vllm_completion_pipeline(
model=model,
batch_messages=batch_messages,
response_generation_method=response_generation_method,
seed=seed,
print_progress=print_progress,
reasoning_start_token=reasoning_start_token,
reasoning_end_token=reasoning_end_token,
space_char=space_char,
**generation_kwargs,
)
return _run_vllm_chat_pipeline(
model=model,
batch_messages=batch_messages,
response_generation_method=response_generation_method,
seed=seed,
print_progress=print_progress,
reasoning_start_token=reasoning_start_token,
reasoning_end_token=reasoning_end_token,
space_char=space_char,
**generation_kwargs,
)
[docs]
def default_model_init(model_id: str, seed: int = 42, **model_keywords) -> LLM:
"""
Initialize a vLLM model with default settings.
Args:
model_id: HuggingFace model identifier
seed: Random seed for reproducibility
**model_keywords: Additional keywords passed to LLM constructor
Returns:
LLM: Initialized vLLM model instance
"""
random.seed(seed)
torch.manual_seed(seed)
logger.info("Initializing vLLM model with %s CUDA devices.", torch.cuda.device_count())
logger.debug("vLLM model initialization kwargs: %s", model_keywords)
return LLM(
model=model_id,
tensor_parallel_size=torch.cuda.device_count(),
seed=seed,
**model_keywords,
)
def _get_sampling_field_names() -> set[str]:
"""
Dynamically fetch valid arguments for SamplingParams.
"""
import inspect
# inspect.signature is the most robust way to get constructor arguments
sig = inspect.signature(SamplingParams)
return set(sig.parameters.keys())
def _split_kwargs(kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Splits kwargs into (generation_args, chat_args).
"""
sampling_keys = _get_sampling_field_names()
generation_args = {}
chat_args = {}
for key, value in kwargs.items():
if key in sampling_keys:
generation_args[key] = value
else:
chat_args[key] = value
return generation_args, chat_args
def _structured_sampling_params(
batch_size: int,
seeds: list[int],
response_generation_method: ResponseGenerationMethod | list[ResponseGenerationMethod],
**generation_kwargs: Any,
) -> list[SamplingParams]:
structured_output = []
# Same for all calls
if isinstance(response_generation_method, ResponseGenerationMethod):
if isinstance(response_generation_method, JSONResponseGenerationMethod):
pydantic_model = build_pydantic_model_from_json_object(
json_object=response_generation_method.json_object,
)
json_schema = pydantic_model.model_json_schema()
global_structured_output = StructuredOutputsParams(json=json_schema)
structured_output = [global_structured_output] * batch_size
# remote inference
# else:
# structured_output = [json_schema] * batch_size
elif (
isinstance(
response_generation_method,
(ChoiceResponseGenerationMethod, LogprobResponseGenerationMethod),
)
and response_generation_method.allowed_choices is not None
):
_allowed_choices = [str(c) for c in response_generation_method.allowed_choices]
global_structured_output = StructuredOutputsParams(choice=_allowed_choices)
structured_output = [global_structured_output] * batch_size
# Remote Inference
# else:
# structured_output = [_allowed_choices] * batch_size
# Different response generation methods for each question
else:
structured_output = []
cache: dict[str, StructuredOutputsParams] = {}
for i in range(batch_size):
current_method = response_generation_method[i]
if isinstance(current_method, JSONResponseGenerationMethod):
key = _make_cache_key(current_method.get_json_prompt(), None)
if key not in cache:
pydantic_model = build_pydantic_model_from_json_object(
json_object=current_method.json_object,
)
json_schema = pydantic_model.model_json_schema()
cache[key] = StructuredOutputsParams(json=json_schema)
# Remote Inference
# else:
# cache[key] = json_schema
structured_output.append(cache[key])
elif (
isinstance(
current_method,
(ChoiceResponseGenerationMethod, LogprobResponseGenerationMethod),
)
and current_method.allowed_choices is not None
):
_allowed_choices = [str(c) for c in current_method.allowed_choices]
key = _make_cache_key(_allowed_choices, None)
if key not in cache:
cache[key] = StructuredOutputsParams(choice=_allowed_choices)
# Remote Inference
# else:
# cache[key] = _allowed_choices
structured_output.append(cache[key])
else:
structured_output.append(None)
if len(structured_output) == batch_size:
sampling_params_list = [
SamplingParams(
seed=seeds[i],
structured_outputs=structured_output[i],
**generation_kwargs,
)
for i in range(batch_size)
]
else:
sampling_params_list = [
SamplingParams(seed=seeds[i], **generation_kwargs) for i in range(batch_size)
]
# Remote Inference
# else:
# return structured_output
return sampling_params_list
def _create_sampling_params(
batch_size: int,
seeds: list[int],
response_generation_method: ResponseGenerationMethod | list[ResponseGenerationMethod] | None,
**generation_kwargs: Any,
) -> list[SamplingParams]:
"""
Create sampling parameters for generation.
Args:
batch_size: Number of prompts in batch
seeds: Random seeds for generation
answer_production_method: Output structure configuration
use_vllm: If True, creates vLLM parameters
**generation_kwargs: Additional sampling parameters
Returns:
Sampling parameters for vLLM or API configuration
"""
use_structured: bool = response_generation_method and isinstance(
response_generation_method, (list, ResponseGenerationMethod)
)
if use_structured:
return _structured_sampling_params(
batch_size=batch_size,
seeds=seeds,
response_generation_method=response_generation_method,
**generation_kwargs,
)
return [SamplingParams(seed=seeds[i], **generation_kwargs) for i in range(batch_size)]
def _get_logprobs(
model,
response_generation_method,
reasoning_start_token,
reasoning_end_token,
space_char,
outputs,
raw_reasonings,
):
logprob_result = []
# ignore the first k tokens that belong to the reasoning
rgms: list[LogprobResponseGenerationMethod] = []
if isinstance(response_generation_method, LogprobResponseGenerationMethod):
rgms.append(response_generation_method)
elif isinstance(response_generation_method, list):
rgms = [
rgm
for rgm in response_generation_method
if isinstance(rgm, LogprobResponseGenerationMethod)
]
for rgm in rgms:
if rgm.ignore_reasoning:
tokenizer = model.get_tokenizer()
logprob_positions = [
(
len(
tokenizer.tokenize(
f"{reasoning_start_token}{_reasoning}{reasoning_end_token}"
)
)
+ 1
+ rgm.token_position
if _reasoning is not None
else rgm.token_position
)
for _reasoning in raw_reasonings
]
else:
logprob_positions = [rgm.token_position] * len(outputs)
for req_output, logprob_position in zip(outputs, logprob_positions):
try:
# Strip space token and any leading whitespace from tokenization.
answer_dict = {
x.decoded_token.lstrip(space_char).lstrip(): x.logprob
for x in req_output.outputs[0].logprobs[logprob_position].values()
}
except IndexError: # less than [logprob_position] tokens in the output!
answer_dict = {}
logprob_result.append(answer_dict)
return logprob_result
def _update_logprob_kwargs(response_generation_method, generation_kwargs):
logprob_config = None
if isinstance(response_generation_method, LogprobResponseGenerationMethod):
logprob_config = response_generation_method
elif isinstance(response_generation_method, list):
logprob_config = next(
(
item
for item in response_generation_method
if isinstance(item, LogprobResponseGenerationMethod)
),
None,
)
if logprob_config:
generation_kwargs["logprobs"] = logprob_config.top_logprobs
if logprob_config.token_limit is not None:
generation_kwargs["max_tokens"] = logprob_config.token_limit
return logprob_config
def _extract_reasoning_and_answer(
reasoning_start_token: str, reasoning_end_token: str, outputs: list[RequestOutput]
):
plain_results = []
reasoning_output = []
raw_reasonings = [] # keep the whitespace for length calculations
patterns = [
(reasoning_start_token, reasoning_end_token),
]
for request_output in outputs:
completion_output = request_output.outputs[0]
full_text = getattr(completion_output, "text", "") or ""
reasoning = getattr(completion_output, "reasoning", None) or getattr(
completion_output, "reasoning_content", None
)
content = getattr(completion_output, "content", None)
# If we have no reasoning, directly output everything
extracted_reasoning = None
final_answer = full_text
if reasoning is None:
final_answer, extracted_reasoning = parse_reasoning(full_text, patterns=patterns)
else:
final_answer = content if content is not None else full_text
extracted_reasoning = reasoning
raw_reasonings.append(extracted_reasoning)
if extracted_reasoning is not None:
reasoning_output.append(extracted_reasoning.strip())
else:
reasoning_output.append(extracted_reasoning)
plain_results.append(final_answer)
return raw_reasonings, reasoning_output, plain_results