import asyncio
import atexit
import threading
import weakref
from concurrent.futures import Future
from typing import Any
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio
from ..logger import get_logger
from ..utilities.utils import _make_cache_key, generate_seeds
from .dynamic_pydantic import build_pydantic_model_from_json_object
from .reasoning_parser import parse_reasoning
from .response_generation import (
ChoiceResponseGenerationMethod,
JSONResponseGenerationMethod,
LogprobResponseGenerationMethod,
ResponseGenerationMethod,
)
from .utils import normalize_system_messages
logger = get_logger(__name__)
class _ClientLoopRunner:
"""Persistent event loop running in a background thread."""
def __init__(self):
self._loop: asyncio.AbstractEventLoop | None = None
self._loop_ready = threading.Event()
self._thread = threading.Thread(target=self._thread_main, daemon=True)
self._thread.start()
self._loop_ready.wait()
def _thread_main(self) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._loop = loop
self._loop_ready.set()
loop.run_forever()
pending = asyncio.all_tasks(loop=loop)
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
def submit(self, coro: Any) -> Future:
if self._loop is None:
raise RuntimeError("Loop runner is not initialized.")
return asyncio.run_coroutine_threadsafe(coro, self._loop)
def shutdown(self) -> None:
if self._loop is None:
return
if self._loop.is_closed():
return
self._loop.call_soon_threadsafe(self._loop.stop)
if threading.current_thread() is not self._thread:
self._thread.join(timeout=5)
_CLIENT_LOOP_RUNNERS: weakref.WeakKeyDictionary[AsyncOpenAI, _ClientLoopRunner] = (
weakref.WeakKeyDictionary()
)
_CLIENT_LOOP_RUNNERS_LOCK = threading.Lock()
def _get_or_create_runner(client: AsyncOpenAI) -> _ClientLoopRunner:
with _CLIENT_LOOP_RUNNERS_LOCK:
runner = _CLIENT_LOOP_RUNNERS.get(client)
if runner is None:
runner = _ClientLoopRunner()
_CLIENT_LOOP_RUNNERS[client] = runner
return runner
def _shutdown_all_client_loop_runners() -> None:
with _CLIENT_LOOP_RUNNERS_LOCK:
runners = list(_CLIENT_LOOP_RUNNERS.values())
_CLIENT_LOOP_RUNNERS.clear()
for runner in runners:
runner.shutdown()
atexit.register(_shutdown_all_client_loop_runners)
[docs]
def run_openai_batch(
model: AsyncOpenAI,
system_messages: list[str | None] | None = ("You are a helpful assistant.",),
prompts: list[str] = ("Hi there! What is your name?",),
response_generation_method: (
ResponseGenerationMethod | list[ResponseGenerationMethod] | None
) = None,
seed: int = 42,
client_model_name: str | None = None,
api_concurrency: int = 10,
reasoning_start_token: str = "<think>",
reasoning_end_token: str = "</think>",
print_progress: bool = True,
**generation_kwargs: Any,
) -> tuple[list[str], list[str], list[str]]:
normalized_system_messages = normalize_system_messages(
system_messages=system_messages,
batch_size=len(prompts),
)
# Prepare batch of messages
batch_messages: list[list[dict[str, str]]] = []
for system_message, prompt in zip(normalized_system_messages, prompts):
messages: list[dict[str, str]] = [{"role": "user", "content": prompt}]
if system_message is not None:
messages.insert(0, {"role": "system", "content": system_message})
batch_messages.append(messages)
batch_size: int = len(batch_messages)
seeds = generate_seeds(seed, batch_size=batch_size)
plain_results, logprob_result, reasoning_output = _run_async_in_thread(
client=model,
client_model_name=client_model_name,
batch_messages=batch_messages,
seeds=seeds,
concurrency_limit=api_concurrency,
response_generation_method=response_generation_method,
print_progress=print_progress,
reasoning_start_token=reasoning_start_token,
reasoning_end_token=reasoning_end_token,
**generation_kwargs,
)
return (plain_results, logprob_result, reasoning_output)
[docs]
def run_openai_batch_conversation(
model: AsyncOpenAI,
system_messages: list[str | None] | None = ("You are a helpful assistant.",),
prompts: list[list[str]] = (("Hi there! What is your name?",),),
assistant_messages: list[list[str]] | None = None,
response_generation_method: (
ResponseGenerationMethod | list[ResponseGenerationMethod] | None
) = None,
seed: int = 42,
client_model_name: str | None = None,
api_concurrency: int = 10,
reasoning_start_token: str = "<think>",
reasoning_end_token: str = "</think>",
print_progress: bool = True,
**generation_kwargs: Any,
) -> tuple[list[str], list[str], list[str]]:
normalized_system_messages = normalize_system_messages(
system_messages=system_messages,
batch_size=len(prompts),
)
batch_messages = []
batch_size = len(normalized_system_messages)
if assistant_messages is None:
assistant_messages = [[] for _ in range(batch_size)]
for i in range(batch_size):
messages = []
# Add system message
if normalized_system_messages[i] is not None:
messages.append({"role": "system", "content": normalized_system_messages[i]})
num_user_msgs = len(prompts[i])
num_assistant_msgs = len(assistant_messages[i])
for j in range(num_user_msgs):
messages.append({"role": "user", "content": prompts[i][j]})
if j < num_assistant_msgs:
messages.append({"role": "assistant", "content": assistant_messages[i][j]})
batch_messages.append(messages)
seeds = generate_seeds(seed, batch_size=batch_size)
plain_results, logprob_result, reasoning_output = _run_async_in_thread(
client=model,
client_model_name=client_model_name,
batch_messages=batch_messages,
seeds=seeds,
concurrency_limit=api_concurrency,
response_generation_method=response_generation_method,
print_progress=print_progress,
reasoning_start_token=reasoning_start_token,
reasoning_end_token=reasoning_end_token,
**generation_kwargs,
)
return (plain_results, logprob_result, reasoning_output)
def _run_async_in_thread(
client: AsyncOpenAI,
client_model_name: str | None,
batch_messages: list[list[dict[str, str]]],
seeds: list[int],
concurrency_limit: int = 10,
print_progress: bool = True,
response_generation_method: (
ResponseGenerationMethod | list[ResponseGenerationMethod] | None
) = None,
reasoning_start_token: str = "<think>",
reasoning_end_token: str = "</think>",
**generation_kwargs,
):
logprob_config = _update_logprob_kwargs(response_generation_method, generation_kwargs)
sampling_params = _create_structured_output(
batch_size=len(batch_messages),
response_generation_method=response_generation_method,
)
coro = _run_api_batch_async(
client=client,
client_model_name=client_model_name,
batch_messages=batch_messages,
seeds=seeds,
concurrency_limit=concurrency_limit,
print_progress=print_progress,
response_generation_method=response_generation_method,
sampling_params=sampling_params,
logprob_config=logprob_config,
reasoning_start_token=reasoning_start_token,
reasoning_end_token=reasoning_end_token,
**generation_kwargs,
)
runner = _get_or_create_runner(client)
try:
future = runner.submit(coro)
except Exception:
coro.close()
raise
return future.result()
async def _run_api_batch_async(
client: AsyncOpenAI,
client_model_name: str,
batch_messages: list[list[dict[str, str]]],
seeds: list[int],
concurrency_limit: int = 10,
print_progress: bool = True,
sampling_params: list[dict[str, Any]] = (),
response_generation_method: (
ResponseGenerationMethod | list[ResponseGenerationMethod] | None
) = None,
logprob_config: LogprobResponseGenerationMethod | None = None,
reasoning_start_token: str = "<think>",
reasoning_end_token: str = "</think>",
**generation_kwargs,
) -> list[str]:
semaphore = asyncio.Semaphore(concurrency_limit)
async def get_completion(
messages: list,
seed: int,
sampling_params: dict[str, Any] | list[str] | None = None,
response_generation_method: ResponseGenerationMethod | None = None,
**generation_kwargs,
):
async with semaphore:
request_kwargs = {
"model": client_model_name,
"messages": messages,
"seed": seed,
**generation_kwargs,
}
if response_generation_method:
if isinstance(response_generation_method, JSONResponseGenerationMethod):
request_kwargs["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "json_schema",
"schema": sampling_params,
},
}
elif isinstance(response_generation_method, ChoiceResponseGenerationMethod):
import warnings
warnings.warn(
"Strict Choice Response Generation is only supported for " "vllm APIs.",
stacklevel=2,
)
if True: # We could use this if we can ensure that the api is vllm.
request_kwargs["extra_body"] = {
"structured_outputs": {"choice": sampling_params}
}
else:
request_kwargs["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "Choice",
"strict": True,
"schema": {
"type": "object",
"properties": {
"selection": {
"type": "string",
"enum": sampling_params,
}
},
"required": ["selection"],
"additionalProperties": False,
},
},
}
return await client.chat.completions.create(**request_kwargs)
# pbar = tqdm.tqdm if print_progress else lambda x: x
if sampling_params:
tasks = [
get_completion(messages, seed, struct_output, rgm, **generation_kwargs)
for messages, seed, struct_output, rgm in zip(
batch_messages,
seeds,
sampling_params,
response_generation_method,
)
]
else:
tasks = [
get_completion(messages, seed, **generation_kwargs)
for messages, seed in zip(batch_messages, seeds)
]
if print_progress:
responses = await tqdm_asyncio.gather(*tasks, total=len(tasks), desc="Generating responses")
else:
responses = await asyncio.gather(*tasks, return_exceptions=True)
final_results = []
reasoning_output = []
logprob_result = []
patterns = [
(reasoning_start_token, reasoning_end_token),
]
for response in responses:
if isinstance(response, Exception):
logger.warning("A request failed permanently after all retries: %s", response)
final_results.append(f"Error: {response}")
else:
msg = response.choices[0].message
# Automatic reasoning parsing
reasoning = getattr(msg, "reasoning", None) or getattr(msg, "reasoning_content", None)
if reasoning is None:
# Fallback to parsing manually
final_answer, extracted_reasoning = parse_reasoning(msg.content, patterns=patterns)
final_results.append(final_answer)
if extracted_reasoning:
reasoning_output.append(extracted_reasoning.strip())
else:
reasoning_output.append(extracted_reasoning)
else:
final_results.append(msg.content)
reasoning_output.append(reasoning)
if logprob_config and response.choices[0].logprobs:
logprob_result.append(
[
[{"token": top.token, "logprob": top.logprob} for top in lp.top_logprobs]
for lp in response.choices[0].logprobs.content
]
)
else:
logprob_result.append(None)
return final_results, logprob_result, reasoning_output
def _update_logprob_kwargs(response_generation_method, generation_kwargs):
logprob_config = None
if isinstance(response_generation_method, LogprobResponseGenerationMethod):
logprob_config = response_generation_method
elif isinstance(response_generation_method, list):
logprob_config = next(
(
item
for item in response_generation_method
if isinstance(item, LogprobResponseGenerationMethod)
),
None,
)
if logprob_config:
generation_kwargs["logprobs"] = True
generation_kwargs["top_logprobs"] = logprob_config.top_logprobs
if logprob_config.token_limit is not None:
generation_kwargs["max_tokens"] = logprob_config.token_limit
return logprob_config
def _create_structured_output(
batch_size: int,
response_generation_method: ResponseGenerationMethod | list[ResponseGenerationMethod] | None,
) -> dict[str, Any]:
"""
Create sampling parameters for generation.
Args:
batch_size: Number of prompts in batch
seeds: Random seeds for generation
answer_production_method: Output structure configuration
use_vllm: If True, creates vLLM parameters
**generation_kwargs: Additional sampling parameters
Returns:
Sampling parameters for vLLM or API configuration
"""
use_structured: bool = response_generation_method and isinstance(
response_generation_method, (list, ResponseGenerationMethod)
)
if use_structured:
return _create_structured_params(
batch_size=batch_size,
response_generation_method=response_generation_method,
)
return None
def _create_structured_params(
batch_size: int,
response_generation_method: ResponseGenerationMethod | list[ResponseGenerationMethod],
) -> list[dict[str, Any]]:
structured_output = []
# Same for all calls
if isinstance(response_generation_method, ResponseGenerationMethod):
if isinstance(response_generation_method, JSONResponseGenerationMethod):
pydantic_model = build_pydantic_model_from_json_object(
json_object=response_generation_method.json_object,
)
json_schema = pydantic_model.model_json_schema()
structured_output = [json_schema] * batch_size
elif (
isinstance(
response_generation_method,
(ChoiceResponseGenerationMethod, LogprobResponseGenerationMethod),
)
and response_generation_method.allowed_choices is not None
):
_allowed_choices = [str(c) for c in response_generation_method.allowed_choices]
structured_output = [_allowed_choices] * batch_size
# Different response generation methods for each question
else:
structured_output = []
cache: dict[str, Any] = {}
for i in range(batch_size):
current_method = response_generation_method[i]
if isinstance(current_method, JSONResponseGenerationMethod):
key = _make_cache_key(current_method.get_json_prompt(), None)
if key not in cache:
pydantic_model = build_pydantic_model_from_json_object(
json_object=current_method.json_object,
)
json_schema = pydantic_model.model_json_schema()
cache[key] = json_schema
structured_output.append(cache[key])
elif (
isinstance(
current_method,
(ChoiceResponseGenerationMethod, LogprobResponseGenerationMethod),
)
and current_method.allowed_choices is not None
):
_allowed_choices = [str(c) for c in current_method.allowed_choices]
key = _make_cache_key(_allowed_choices, None)
if key not in cache:
cache[key] = _allowed_choices
structured_output.append(cache[key])
else:
structured_output.append(None)
return structured_output