Source code for qstn.inference.survey_inference

import random
from typing import TYPE_CHECKING, Any

from ..logger import get_logger, tqdm_write
from .response_generation import (
    LogprobResponseGenerationMethod,
    ResponseGenerationMethod,
)
from .utils import InferenceMode, normalize_system_messages, validate_inference_mode

logger = get_logger(__name__)

if TYPE_CHECKING:
    from openai import AsyncOpenAI
    from vllm import LLM  # pyright: ignore[reportMissingImports]

HAS_VLLM = False
HAS_OPENAI = False

try:
    from vllm import LLM  # pyright: ignore[reportMissingImports]

    from .local_inference import run_vllm_batch, run_vllm_batch_conversation

    HAS_VLLM = True
except ImportError:
    LLM = None

try:
    from openai import AsyncOpenAI

    from .remote_inference import run_openai_batch, run_openai_batch_conversation

    HAS_OPENAI = True
except ImportError:
    AsyncOpenAI = None


def _print_conversation(
    system_messages: list[str | None],
    prompts: list[str],
    assistant_messages: list[str] | None,
    plain_results: list[str],
    reasoning_output: list[str],
    logprob_result: list[str],
    response_generation_method: list[ResponseGenerationMethod] | None,
    number_of_printed_conversations: int = 2,
):
    methods = response_generation_method or []

    if reasoning_output is None:
        reasonings = [None] * len(system_messages)
    else:
        reasonings = reasoning_output

    if logprob_result is None:
        logprobs = [None] * len(system_messages)
    else:
        logprobs = logprob_result

    if assistant_messages:

        conversation_print = "--- Conversation ---"
        for i, (
            system_message,
            prompt_list,
            answer,
            reasoning,
            logprob_answer,
            assistant_list,
        ) in enumerate(
            zip(
                system_messages,
                prompts,
                plain_results,
                reasonings,
                logprobs,
                assistant_messages,
            )
        ):
            if i >= number_of_printed_conversations:
                break

            round_print = conversation_print
            if system_message is not None:
                round_print = f"{round_print}\n-- System Message --\n{system_message}"
            for j, _ in enumerate(prompt_list):
                round_print = f"{round_print}\n-- User Message --\n{prompt_list[j]}"
                if j < len(assistant_list):
                    prefill = assistant_list[j]
                    if prefill:
                        round_print = f"{round_print}\n-- Assistant Message --\n{assistant_list[j]}"
            round_print = f"{round_print}\n-- Generated Answer --\n{answer}"
            if reasoning:
                round_print += "\n-- Reasoning --\n" + str(reasoning)

            if i < len(methods):
                current_method = methods[i]
                if isinstance(current_method, LogprobResponseGenerationMethod):
                    round_print += "\n-- Logprobs --\n" + str(logprob_answer)
            logger.debug(round_print)
            tqdm_write(round_print)
    else:
        conversation_print = "--- Conversation ---"
        for i, (system_message, prompt, answer, reasoning, logprob_answer) in enumerate(
            zip(system_messages, prompts, plain_results, reasonings, logprobs)
        ):
            if i >= number_of_printed_conversations:
                break
            round_print = conversation_print
            if system_message is not None:
                round_print = f"{round_print}\n-- System Message --\n{system_message}"
            round_print = (
                f"{round_print}\n-- User Message ---\n{prompt}"
                f"\n-- Generated Message --\n{answer}"
            )
            if reasoning:
                round_print += "\n-- Reasoning --\n" + str(reasoning)

            if i < len(methods):
                current_method = methods[i]
                if isinstance(current_method, LogprobResponseGenerationMethod):
                    round_print += "\n-- Logprobs --\n" + str(logprob_answer)
            logger.debug(round_print)
            tqdm_write(round_print)


[docs] def batch_generation( model: LLM | AsyncOpenAI, # pyright: ignore[reportInvalidTypeForm] system_messages: list[str | None] | None = ("You are a helpful assistant.",), prompts: list[str] = ("Hi there! What is your name?",), response_generation_method: ( ResponseGenerationMethod | list[ResponseGenerationMethod] | None ) = None, seed: int = 42, client_model_name: str | None = None, api_concurrency: int = 10, print_conversation: bool = False, number_of_printed_conversations: int = 2, print_progress: bool = True, reasoning_start_token: str = "<think>", reasoning_end_token: str = "</think>", space_char: str = "Ġ", inference_mode: InferenceMode = "chat", **generation_kwargs: Any, ) -> tuple: """ Generate responses for a batch of prompts. Handles both vLLM and OpenAI API generation with support for: - Structured output (JSON or choice format) - Conversation printing - Progress tracking - Concurrent API requests Args: model (LLM or AsyncOpenAI): vLLM model or AsyncOpenAI client. system_messages (List(str)): System prompts for each conversation. prompts (List(str)): User prompts to generate responses for. response_generation_method ( ResponseGenerationMethod or List(ResponseGenerationMethod), optional ): Configuration for structured output. seed (int): Random seed for reproducibility, defaults to 42. client_model_name (str, optional): Model name when using OpenAI API. api_concurrency (int): Max concurrent API requests when using OpenAI API. number_if_printed_conversations (int): How many conversations should be printed. Defaults to 2. print_conversation (bool): If True, prints conversations. Defaults to False. print_progress (bool): If True, shows progress bar. Defaults to True. reasoning_start_token (str): Special token at the beginning of reasoning models' output. Used for manual parsing if automatic parsing fails. reasoning_end_token (str): Special token to separate reasoning from regular model output. Used for manual parsing if automatic parsing fails. space_token (str): Special char to encode spaces in tokens ("Ġ" for most byte-pair tokenizers). inference_mode (str): Use "chat" for message-based models or "completion" for base-model text generation. Defaults to "chat". generation_kwargs: Additional generation parameters Returns: Tuple[List[str], List[str], List[str]]: Generated Response, Logprobs, Reasoning """ model_type = type(model).__name__ if model_type == "LLM" and not HAS_VLLM: raise ImportError("You are trying to use a vLLM model, but 'vllm' is not installed.") if model_type == "AsyncOpenAI" and not HAS_OPENAI: raise ImportError("You are trying to use OpenAI, but 'openai' is not installed.") if model_type != "LLM" and model_type != "AsyncOpenAI": raise ValueError(f"Unsupported model type: {type(model)}") inference_mode = validate_inference_mode(inference_mode) random.seed(seed) normalized_system_messages = normalize_system_messages( system_messages=system_messages, batch_size=len(prompts), ) logger.debug("Generating %s responses with %s backend.", len(prompts), model_type) # Inference if HAS_VLLM and isinstance(model, LLM): plain_results, logprob_result, reasoning_outputs = run_vllm_batch( model, system_messages=normalized_system_messages, prompts=prompts, response_generation_method=response_generation_method, seed=seed, print_progress=print_progress, reasoning_start_token=reasoning_start_token, reasoning_end_token=reasoning_end_token, space_char=space_char, inference_mode=inference_mode, **generation_kwargs, ) elif HAS_OPENAI and isinstance(model, AsyncOpenAI): plain_results, logprob_result, reasoning_outputs = run_openai_batch( model, system_messages=normalized_system_messages, prompts=prompts, response_generation_method=response_generation_method, seed=seed, print_progress=print_progress, client_model_name=client_model_name, api_concurrency=api_concurrency, reasoning_start_token=reasoning_start_token, reasoning_end_token=reasoning_end_token, inference_mode=inference_mode, **generation_kwargs, ) else: raise ValueError("Inference cannot be run without OpenAI or vllm installed.") if print_conversation: _print_conversation( system_messages=normalized_system_messages, prompts=prompts, assistant_messages=None, plain_results=plain_results, reasoning_output=reasoning_outputs, logprob_result=logprob_result, response_generation_method=response_generation_method, number_of_printed_conversations=number_of_printed_conversations, ) return (plain_results, logprob_result, reasoning_outputs)
[docs] def batch_turn_by_turn_generation( model: LLM | AsyncOpenAI, # type: ignore system_messages: list[str | None] | None = ("You are a helpful assistant.",), prompts: list[list[str]] = ( ( "Hi there! What is your name?", "Interesting", ), ), assistant_messages: list[list[str]] | None = None, response_generation_method: ( ResponseGenerationMethod | list[ResponseGenerationMethod] | None ) = None, seed: int = 42, client_model_name: str | None = None, api_concurrency: int = 10, print_conversation: bool = False, number_of_printed_conversations: int = 2, print_progress: bool = True, reasoning_start_token: str = "<think>", reasoning_end_token: str = "</think>", space_char: str = "Ġ", inference_mode: InferenceMode = "chat", **generation_kwargs, ) -> list[str]: """ Generate responses for multi-turn conversations. Handles conversations with multiple back-and-forth exchanges between user and assistant. Supports: - Structured output formats - Pre-filled assistant messages - Conversation printing - Progress tracking Args: model (LLM or AsyncOpenAI): vLLM model or AsyncOpenAI client. system_messages (List(str)): System prompts for each conversation. prompts (List(List(str))): User prompts to generate responses for. Can include multiple requests per system prompt. assistant_messages (List(List(str)), optional): Prefilled assistant responses. For example, if the first list contains one entry, the first assistant turn is prefilled and not inferred. response_generation_method ( ResponseGenerationMethod or List(ResponseGenerationMethod), optional ): Configuration for structured output. seed (int): Random seed for reproducibility, defaults to 42. client_model_name (str, optional): Model name when using OpenAI API. api_concurrency (int): Max concurrent API requests when using OpenAI API. print_conversation (bool): If True, prints conversations. Defaults to False. number_of_printed_conversations (int): How many conversations should be printed. Defaults to 2. print_progress (bool): If True, shows progress bar. Defaults to True. reasoning_start_token (str): Special token at the beginning of reasoning models' output. Used for manual parsing if automatic parsing fails. reasoning_end_token (str): Special token to separate reasoning from regular model output. Used for manual parsing if automatic parsing fails. space_token (str): Special char to encode spaces in tokens ("Ġ" for most byte-pair tokenizers). inference_mode (str): Use "chat" for message-based models or "completion" for base-model text generation. Defaults to "chat". generation_kwargs: Additional generation parameters. Returns: Tuple[List[str], List[str], List[str]]: Generated Response, Logprobs, Reasoning """ model_type = type(model).__name__ if model_type == "LLM" and not HAS_VLLM: raise ImportError("You are trying to use a vLLM model, but 'vllm' is not installed.") elif model_type == "AsyncOpenAI" and not HAS_OPENAI: raise ImportError("You are trying to use OpenAI, but 'openai' is not installed.") elif model_type != "LLM" and model_type != "AsyncOpenAI": raise ValueError(f"Unsupported model type: {type(model)}") inference_mode = validate_inference_mode(inference_mode) random.seed(seed) normalized_system_messages = normalize_system_messages( system_messages=system_messages, batch_size=len(prompts), ) logger.debug("Generating %s conversations with %s backend.", len(prompts), model_type) # Inference if HAS_VLLM and isinstance(model, LLM): plain_results, logprob_result, reasoning_outputs = run_vllm_batch_conversation( model, system_messages=normalized_system_messages, prompts=prompts, assistant_messages=assistant_messages, response_generation_method=response_generation_method, seed=seed, print_progress=print_progress, reasoning_start_token=reasoning_start_token, reasoning_end_token=reasoning_end_token, space_char=space_char, inference_mode=inference_mode, **generation_kwargs, ) elif HAS_OPENAI and isinstance(model, AsyncOpenAI): plain_results, logprob_result, reasoning_outputs = run_openai_batch_conversation( model, system_messages=normalized_system_messages, prompts=prompts, assistant_messages=assistant_messages, response_generation_method=response_generation_method, client_model_name=client_model_name, seed=seed, print_progress=print_progress, api_concurrency=api_concurrency, reasoning_start_token=reasoning_start_token, reasoning_end_token=reasoning_end_token, inference_mode=inference_mode, **generation_kwargs, ) else: raise ValueError("Inference cannot be run without OpenAI or vllm installed.") if print_conversation: _print_conversation( system_messages=normalized_system_messages, prompts=prompts, assistant_messages=assistant_messages, plain_results=plain_results, reasoning_output=reasoning_outputs, logprob_result=logprob_result, response_generation_method=response_generation_method, number_of_printed_conversations=number_of_printed_conversations, ) return (plain_results, logprob_result, reasoning_outputs)
# def batch_decoding( # model: Union[LLM, AsyncOpenAI], # prompts: List[str] = ["Hi there! What is your name?"], # stop_tokens: List[str] = ["\nA:"], # structured_output_options: Optional[ # Union[ResponseGenerationMethod, List[ResponseGenerationMethod]] # ] = None, # seed: int = 42, # client_model_name: Optional[str] = None, # api_concurrency: int = 10, # print_conversation: bool = False, # print_progress: bool = True, # **generation_kwargs: Any, # ): # """ # Generate responses for a batch of prompts. # Handles both vLLM and OpenAI API generation with support for: # - Structured output (JSON or choice format) # - Conversation printing # - Progress tracking # - Concurrent API requests # Args: # model: vLLM model or AsyncOpenAI client # system_messages: System prompts for each conversation # prompts: User prompts to generate responses for # structured_output_options: Configuration for structured output # seed: Random seed for reproducibility # client_model_name: Model name when using OpenAI API # api_concurrency: Max concurrent API requests # print_conversation: If True, prints conversations # print_progress: If True, shows progress bar # **generation_kwargs: Additional generation parameters # Returns: # List[str]: Generated responses # """ # random.seed(seed) # batch_size: int = len(prompts) # seeds = _generate_seeds(seed, batch_size=batch_size) # if isinstance(model, LLM): # sampling_params_list = _create_sampling_params( # batch_size=batch_size, # seeds=seeds, # structured_output_options=structured_output_options, # stop_tokens=stop_tokens, # **generation_kwargs, # ) # outputs: List[RequestOutput] = model.generate( # prompts, # sampling_params=sampling_params_list, # use_tqdm=print_progress, # ) # result = [output.outputs[0].text for output in outputs] # else: # result = _run_async_in_thread( # client=model, # client_model_name=client_model_name, # batch_messages=prompts, # seeds=seeds, # concurrency_limit=api_concurrency, # structured_output_options=structured_output_options, # **generation_kwargs, # ) # # TODO add argument to specify how many conversations should be printed. # # The base argument should be reasonable. # if print_conversation: # conversation_print = "Conversation:" # for prompt, answer in zip(prompts, result): # round_print = ( # f"{conversation_print}\nUser Message:\n{prompt}" # f"\nGenerated Message\n{answer}" # ) # print(round_print, flush=True) # break # return result