Source code for qstn.inference.remote_inference

import asyncio
import atexit
import threading
import weakref
from concurrent.futures import Future
from typing import Any

from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio

from ..logger import get_logger
from ..utilities.utils import _make_cache_key, generate_seeds
from .dynamic_pydantic import build_pydantic_model_from_json_object
from .reasoning_parser import parse_reasoning
from .response_generation import (
    ChoiceResponseGenerationMethod,
    JSONResponseGenerationMethod,
    LogprobResponseGenerationMethod,
    ResponseGenerationMethod,
)
from .utils import normalize_system_messages

logger = get_logger(__name__)


class _ClientLoopRunner:
    """Persistent event loop running in a background thread."""

    def __init__(self):
        self._loop: asyncio.AbstractEventLoop | None = None
        self._loop_ready = threading.Event()

        self._thread = threading.Thread(target=self._thread_main, daemon=True)
        self._thread.start()

        self._loop_ready.wait()

    def _thread_main(self) -> None:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

        self._loop = loop
        self._loop_ready.set()

        loop.run_forever()

        pending = asyncio.all_tasks(loop=loop)

        for task in pending:
            task.cancel()

        if pending:
            loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))

        loop.close()

    def submit(self, coro: Any) -> Future:
        if self._loop is None:
            raise RuntimeError("Loop runner is not initialized.")

        return asyncio.run_coroutine_threadsafe(coro, self._loop)

    def shutdown(self) -> None:
        if self._loop is None:
            return

        if self._loop.is_closed():
            return

        self._loop.call_soon_threadsafe(self._loop.stop)

        if threading.current_thread() is not self._thread:
            self._thread.join(timeout=5)


_CLIENT_LOOP_RUNNERS: weakref.WeakKeyDictionary[AsyncOpenAI, _ClientLoopRunner] = (
    weakref.WeakKeyDictionary()
)
_CLIENT_LOOP_RUNNERS_LOCK = threading.Lock()


def _get_or_create_runner(client: AsyncOpenAI) -> _ClientLoopRunner:
    with _CLIENT_LOOP_RUNNERS_LOCK:
        runner = _CLIENT_LOOP_RUNNERS.get(client)

        if runner is None:
            runner = _ClientLoopRunner()
            _CLIENT_LOOP_RUNNERS[client] = runner

        return runner


def _shutdown_all_client_loop_runners() -> None:
    with _CLIENT_LOOP_RUNNERS_LOCK:
        runners = list(_CLIENT_LOOP_RUNNERS.values())
        _CLIENT_LOOP_RUNNERS.clear()

    for runner in runners:
        runner.shutdown()


atexit.register(_shutdown_all_client_loop_runners)


[docs] def run_openai_batch( model: AsyncOpenAI, system_messages: list[str | None] | None = ("You are a helpful assistant.",), prompts: list[str] = ("Hi there! What is your name?",), response_generation_method: ( ResponseGenerationMethod | list[ResponseGenerationMethod] | None ) = None, seed: int = 42, client_model_name: str | None = None, api_concurrency: int = 10, reasoning_start_token: str = "<think>", reasoning_end_token: str = "</think>", print_progress: bool = True, **generation_kwargs: Any, ) -> tuple[list[str], list[str], list[str]]: normalized_system_messages = normalize_system_messages( system_messages=system_messages, batch_size=len(prompts), ) # Prepare batch of messages batch_messages: list[list[dict[str, str]]] = [] for system_message, prompt in zip(normalized_system_messages, prompts): messages: list[dict[str, str]] = [{"role": "user", "content": prompt}] if system_message is not None: messages.insert(0, {"role": "system", "content": system_message}) batch_messages.append(messages) batch_size: int = len(batch_messages) seeds = generate_seeds(seed, batch_size=batch_size) plain_results, logprob_result, reasoning_output = _run_async_in_thread( client=model, client_model_name=client_model_name, batch_messages=batch_messages, seeds=seeds, concurrency_limit=api_concurrency, response_generation_method=response_generation_method, print_progress=print_progress, reasoning_start_token=reasoning_start_token, reasoning_end_token=reasoning_end_token, **generation_kwargs, ) return (plain_results, logprob_result, reasoning_output)
[docs] def run_openai_batch_conversation( model: AsyncOpenAI, system_messages: list[str | None] | None = ("You are a helpful assistant.",), prompts: list[list[str]] = (("Hi there! What is your name?",),), assistant_messages: list[list[str]] | None = None, response_generation_method: ( ResponseGenerationMethod | list[ResponseGenerationMethod] | None ) = None, seed: int = 42, client_model_name: str | None = None, api_concurrency: int = 10, reasoning_start_token: str = "<think>", reasoning_end_token: str = "</think>", print_progress: bool = True, **generation_kwargs: Any, ) -> tuple[list[str], list[str], list[str]]: normalized_system_messages = normalize_system_messages( system_messages=system_messages, batch_size=len(prompts), ) batch_messages = [] batch_size = len(normalized_system_messages) if assistant_messages is None: assistant_messages = [[] for _ in range(batch_size)] for i in range(batch_size): messages = [] # Add system message if normalized_system_messages[i] is not None: messages.append({"role": "system", "content": normalized_system_messages[i]}) num_user_msgs = len(prompts[i]) num_assistant_msgs = len(assistant_messages[i]) for j in range(num_user_msgs): messages.append({"role": "user", "content": prompts[i][j]}) if j < num_assistant_msgs: messages.append({"role": "assistant", "content": assistant_messages[i][j]}) batch_messages.append(messages) seeds = generate_seeds(seed, batch_size=batch_size) plain_results, logprob_result, reasoning_output = _run_async_in_thread( client=model, client_model_name=client_model_name, batch_messages=batch_messages, seeds=seeds, concurrency_limit=api_concurrency, response_generation_method=response_generation_method, print_progress=print_progress, reasoning_start_token=reasoning_start_token, reasoning_end_token=reasoning_end_token, **generation_kwargs, ) return (plain_results, logprob_result, reasoning_output)
def _run_async_in_thread( client: AsyncOpenAI, client_model_name: str | None, batch_messages: list[list[dict[str, str]]], seeds: list[int], concurrency_limit: int = 10, print_progress: bool = True, response_generation_method: ( ResponseGenerationMethod | list[ResponseGenerationMethod] | None ) = None, reasoning_start_token: str = "<think>", reasoning_end_token: str = "</think>", **generation_kwargs, ): logprob_config = _update_logprob_kwargs(response_generation_method, generation_kwargs) sampling_params = _create_structured_output( batch_size=len(batch_messages), response_generation_method=response_generation_method, ) coro = _run_api_batch_async( client=client, client_model_name=client_model_name, batch_messages=batch_messages, seeds=seeds, concurrency_limit=concurrency_limit, print_progress=print_progress, response_generation_method=response_generation_method, sampling_params=sampling_params, logprob_config=logprob_config, reasoning_start_token=reasoning_start_token, reasoning_end_token=reasoning_end_token, **generation_kwargs, ) runner = _get_or_create_runner(client) try: future = runner.submit(coro) except Exception: coro.close() raise return future.result() async def _run_api_batch_async( client: AsyncOpenAI, client_model_name: str, batch_messages: list[list[dict[str, str]]], seeds: list[int], concurrency_limit: int = 10, print_progress: bool = True, sampling_params: list[dict[str, Any]] = (), response_generation_method: ( ResponseGenerationMethod | list[ResponseGenerationMethod] | None ) = None, logprob_config: LogprobResponseGenerationMethod | None = None, reasoning_start_token: str = "<think>", reasoning_end_token: str = "</think>", **generation_kwargs, ) -> list[str]: semaphore = asyncio.Semaphore(concurrency_limit) async def get_completion( messages: list, seed: int, sampling_params: dict[str, Any] | list[str] | None = None, response_generation_method: ResponseGenerationMethod | None = None, **generation_kwargs, ): async with semaphore: request_kwargs = { "model": client_model_name, "messages": messages, "seed": seed, **generation_kwargs, } if response_generation_method: if isinstance(response_generation_method, JSONResponseGenerationMethod): request_kwargs["response_format"] = { "type": "json_schema", "json_schema": { "name": "json_schema", "schema": sampling_params, }, } elif isinstance(response_generation_method, ChoiceResponseGenerationMethod): import warnings warnings.warn( "Strict Choice Response Generation is only supported for " "vllm APIs.", stacklevel=2, ) if True: # We could use this if we can ensure that the api is vllm. request_kwargs["extra_body"] = { "structured_outputs": {"choice": sampling_params} } else: request_kwargs["response_format"] = { "type": "json_schema", "json_schema": { "name": "Choice", "strict": True, "schema": { "type": "object", "properties": { "selection": { "type": "string", "enum": sampling_params, } }, "required": ["selection"], "additionalProperties": False, }, }, } return await client.chat.completions.create(**request_kwargs) # pbar = tqdm.tqdm if print_progress else lambda x: x if sampling_params: tasks = [ get_completion(messages, seed, struct_output, rgm, **generation_kwargs) for messages, seed, struct_output, rgm in zip( batch_messages, seeds, sampling_params, response_generation_method, ) ] else: tasks = [ get_completion(messages, seed, **generation_kwargs) for messages, seed in zip(batch_messages, seeds) ] if print_progress: responses = await tqdm_asyncio.gather(*tasks, total=len(tasks), desc="Generating responses") else: responses = await asyncio.gather(*tasks, return_exceptions=True) final_results = [] reasoning_output = [] logprob_result = [] patterns = [ (reasoning_start_token, reasoning_end_token), ] for response in responses: if isinstance(response, Exception): logger.warning("A request failed permanently after all retries: %s", response) final_results.append(f"Error: {response}") else: msg = response.choices[0].message # Automatic reasoning parsing reasoning = getattr(msg, "reasoning", None) or getattr(msg, "reasoning_content", None) if reasoning is None: # Fallback to parsing manually final_answer, extracted_reasoning = parse_reasoning(msg.content, patterns=patterns) final_results.append(final_answer) if extracted_reasoning: reasoning_output.append(extracted_reasoning.strip()) else: reasoning_output.append(extracted_reasoning) else: final_results.append(msg.content) reasoning_output.append(reasoning) if logprob_config and response.choices[0].logprobs: logprob_result.append( [ [{"token": top.token, "logprob": top.logprob} for top in lp.top_logprobs] for lp in response.choices[0].logprobs.content ] ) else: logprob_result.append(None) return final_results, logprob_result, reasoning_output def _update_logprob_kwargs(response_generation_method, generation_kwargs): logprob_config = None if isinstance(response_generation_method, LogprobResponseGenerationMethod): logprob_config = response_generation_method elif isinstance(response_generation_method, list): logprob_config = next( ( item for item in response_generation_method if isinstance(item, LogprobResponseGenerationMethod) ), None, ) if logprob_config: generation_kwargs["logprobs"] = True generation_kwargs["top_logprobs"] = logprob_config.top_logprobs if logprob_config.token_limit is not None: generation_kwargs["max_tokens"] = logprob_config.token_limit return logprob_config def _create_structured_output( batch_size: int, response_generation_method: ResponseGenerationMethod | list[ResponseGenerationMethod] | None, ) -> dict[str, Any]: """ Create sampling parameters for generation. Args: batch_size: Number of prompts in batch seeds: Random seeds for generation answer_production_method: Output structure configuration use_vllm: If True, creates vLLM parameters **generation_kwargs: Additional sampling parameters Returns: Sampling parameters for vLLM or API configuration """ use_structured: bool = response_generation_method and isinstance( response_generation_method, (list, ResponseGenerationMethod) ) if use_structured: return _create_structured_params( batch_size=batch_size, response_generation_method=response_generation_method, ) return None def _create_structured_params( batch_size: int, response_generation_method: ResponseGenerationMethod | list[ResponseGenerationMethod], ) -> list[dict[str, Any]]: structured_output = [] # Same for all calls if isinstance(response_generation_method, ResponseGenerationMethod): if isinstance(response_generation_method, JSONResponseGenerationMethod): pydantic_model = build_pydantic_model_from_json_object( json_object=response_generation_method.json_object, ) json_schema = pydantic_model.model_json_schema() structured_output = [json_schema] * batch_size elif ( isinstance( response_generation_method, (ChoiceResponseGenerationMethod, LogprobResponseGenerationMethod), ) and response_generation_method.allowed_choices is not None ): _allowed_choices = [str(c) for c in response_generation_method.allowed_choices] structured_output = [_allowed_choices] * batch_size # Different response generation methods for each question else: structured_output = [] cache: dict[str, Any] = {} for i in range(batch_size): current_method = response_generation_method[i] if isinstance(current_method, JSONResponseGenerationMethod): key = _make_cache_key(current_method.get_json_prompt(), None) if key not in cache: pydantic_model = build_pydantic_model_from_json_object( json_object=current_method.json_object, ) json_schema = pydantic_model.model_json_schema() cache[key] = json_schema structured_output.append(cache[key]) elif ( isinstance( current_method, (ChoiceResponseGenerationMethod, LogprobResponseGenerationMethod), ) and current_method.allowed_choices is not None ): _allowed_choices = [str(c) for c in current_method.allowed_choices] key = _make_cache_key(_allowed_choices, None) if key not in cache: cache[key] = _allowed_choices structured_output.append(cache[key]) else: structured_output.append(None) return structured_output