"""
Module for managing and conducting surveys using LLM models.
This module provides functions to conduct surveys in different ways:
- Single-item
- battery
- sequential
Usage example:
--------------
.. code-block:: python
from qstn import survey_manager
from qstn.prompt_builder import LLMPrompt
from qstn.utilities import placeholder
from vllm import LLM
import pandas as pd
questionnaire = [
{"questionnaire_item_id": 1, "question_content": "The Democratic Party?"},
{"questionnaire_item_id": 2, "question_content": "The Republican Party?"},
]
party_questionnaire = pd.DataFrame(questionnaire)
system_prompt = (
"Act as if you were a black middle aged man from New York! "
"Answer in a single short sentence!"
)
prompt = (
"Please tell us how you feel about the following parties: "
+ placeholder.PROMPT_QUESTIONS
)
questionnaire = LLMPrompt(
questionnaire_name="political_parties",
questionnaire_source=party_questionnaire,
system_prompt=system_prompt,
prompt=prompt,
)
model_id = "meta-llama/Llama-3.2-3B-Instruct"
chat_generator = LLM(model_id, max_model_len=5000, seed=42)
results = survey_manager.conduct_survey_single_item(
chat_generator,
questionnaire,
client_model_name=model_id,
print_conversation=True,
# We can use the same inference arguments for inference, as we would for vllm or OpenAI
temperature=0.8,
max_tokens=5000,
)
"""
import os
from collections.abc import Iterable
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Union,
)
import pandas as pd
from openai import AsyncOpenAI
from tqdm.auto import tqdm
from .inference.response_generation import (
ResponseGenerationMethod,
resolve_battery_response_generation_method,
)
from .inference.survey_inference import batch_generation, batch_turn_by_turn_generation
from .inference.utils import InferenceMode
from .logger import get_logger
from .parser.llm_answer_parser import raw_responses
from .prompt_builder import LLMPrompt, QuestionnairePresentation
from .utilities import constants, utils
from .utilities.survey_objects import (
InferenceResult,
QuestionLLMResponseTuple,
)
if TYPE_CHECKING:
from vllm import LLM # pyright: ignore[reportMissingImports]
logger = get_logger(__name__)
# @dataclass
# class GenerationFailure:
# batch_index: int
# error_type: str
# error_message: str
def _normalize_llm_prompts(
llm_prompts: LLMPrompt | list[LLMPrompt],
) -> list[LLMPrompt]:
"""Normalize a single prompt or list of prompts to a list."""
if isinstance(llm_prompts, LLMPrompt):
return [llm_prompts]
return llm_prompts
def _initialize_question_response_pairs(
llm_prompts: list[LLMPrompt],
) -> list[dict[int, QuestionLLMResponseTuple]]:
"""Initialize mutable per-questionnaire response maps."""
return [{} for _ in llm_prompts]
def _iter_survey_steps(max_survey_length: int, print_progress: bool):
"""Yield step indices, optionally wrapped in tqdm."""
if print_progress:
return tqdm(range(max_survey_length), desc="Running survey steps")
return range(max_survey_length)
def _get_current_batch(llm_prompts: list[LLMPrompt], i: int) -> dict[int, LLMPrompt]:
"""Return questionnaire batch that still has a question at step index i."""
return {
pos: questionnaire
for pos, questionnaire in enumerate(llm_prompts)
if len(questionnaire) > i
}
def _normalize_generation_outputs(
output: list[str],
logprobs: list[Any] | None,
reasoning_output: list[Any] | None,
expected_size: int,
) -> tuple[list[str], list[Any], list[Any]]:
"""Normalize optional generation outputs to zip-safe lists."""
if logprobs is None:
logprobs = [None] * expected_size
if reasoning_output is None:
reasoning_output = [None] * expected_size
return output, logprobs, reasoning_output
def _validate_lengths_against_expected(
expected_size: int,
error_prefix: str,
**values: Any,
) -> None:
"""Raise a ValueError when one or more values don't match expected length."""
mismatches = {name: len(value) for name, value in values.items() if len(value) != expected_size}
if mismatches:
mismatch_details = ", ".join(f"{name}={size}" for name, size in mismatches.items())
raise ValueError(f"{error_prefix}: expected={expected_size}, {mismatch_details}")
def _run_batch_generation(
model: Union["LLM", AsyncOpenAI],
system_messages: list[str | None],
prompts: list[str],
response_generation_methods: list[ResponseGenerationMethod] | None,
client_model_name: str | None,
api_concurrency: int,
print_conversation: bool,
print_progress: bool,
seed: int,
inference_mode: InferenceMode,
**generation_kwargs: Any,
) -> tuple[list[str], list[Any], list[Any]]:
"""Thin wrapper around `batch_generation` with normalized optional outputs."""
output, logprobs, reasoning_output = batch_generation(
model=model,
system_messages=system_messages,
prompts=prompts,
response_generation_method=response_generation_methods,
client_model_name=client_model_name,
api_concurrency=api_concurrency,
print_conversation=print_conversation,
print_progress=print_progress,
seed=seed,
inference_mode=inference_mode,
**generation_kwargs,
)
return _normalize_generation_outputs(
output=output,
logprobs=logprobs,
reasoning_output=reasoning_output,
expected_size=len(system_messages),
)
def _finalize_survey_results(
llm_prompts: list[LLMPrompt],
question_llm_response_pairs: list[dict[int, QuestionLLMResponseTuple]],
) -> list[InferenceResult]:
"""Convert internal response maps to public `InferenceResult` objects."""
return [
InferenceResult(survey, question_llm_response_pairs[i])
for i, survey in enumerate(llm_prompts)
]
def _store_question_responses(
question_llm_response_pairs: list[dict[int, QuestionLLMResponseTuple]],
survey_ids: Iterable[int],
questions: list[str],
answers: list[str],
logprobs: list[Any],
reasoning_output: list[Any],
item_ids: Iterable[Any],
) -> None:
"""Store question-level answers in the mutable response map."""
for survey_id, question, answer, logprob_answer, reasoning, item_id in zip(
survey_ids,
questions,
answers,
logprobs,
reasoning_output,
item_ids,
):
question_llm_response_pairs[survey_id].update(
{item_id: QuestionLLMResponseTuple(question, answer, logprob_answer, reasoning)}
)
def _prepare_single_item_batch(
current_batch: dict[int, LLMPrompt], i: int, inference_mode: InferenceMode = "chat"
) -> tuple[list[str | None], list[str], list[str], list[ResponseGenerationMethod | None]]:
"""Prepare messages and metadata for a single-item survey step."""
system_messages, prompts = zip(
*[
questionnaire.get_prompt_for_questionnaire_type(
QuestionnairePresentation.SINGLE_ITEM,
questionnaire.get_question_item_id(i),
inference_type=("generation" if inference_mode == "completion" else "chat"),
)
for questionnaire in current_batch.values()
]
)
questions = [
questionnaire.generate_question_prompt(questionnaire_items=questionnaire.get_question(i))
for questionnaire in current_batch.values()
]
response_generation_methods = [
(
questionnaire.get_question(i).answer_options.response_generation_method
if questionnaire.get_question(i).answer_options
else None
)
for questionnaire in current_batch.values()
]
return list(system_messages), list(prompts), questions, response_generation_methods
def _prepare_battery_batch(
current_batch: dict[int, LLMPrompt],
i: int,
item_separator: str,
inference_mode: InferenceMode = "chat",
) -> tuple[list[str | None], list[str], list[ResponseGenerationMethod | None]]:
"""Prepare messages and response-generation methods for battery mode."""
system_messages, prompts = zip(
*[
questionnaire.get_prompt_for_questionnaire_type(
QuestionnairePresentation.BATTERY,
questionnaire.get_question_item_id(i),
item_separator=item_separator,
inference_type=("generation" if inference_mode == "completion" else "chat"),
)
for questionnaire in current_batch.values()
]
)
response_generation_methods: list[ResponseGenerationMethod | None] = []
for questionnaire in current_batch.values():
response_generation_method = resolve_battery_response_generation_method(
questions=list(questionnaire.questions),
item_position=i,
)
response_generation_methods.append(response_generation_method)
return list(system_messages), list(prompts), response_generation_methods
def _prepare_sequential_step(
current_batch: dict[int, LLMPrompt], i: int
) -> tuple[list[str | None], list[str], list[str], list[ResponseGenerationMethod | None]]:
"""Prepare per-step prompts/questions/methods for sequential mode."""
first_question: bool = i == 0
if first_question:
system_messages, prompts = zip(
*[
questionnaire.get_prompt_for_questionnaire_type(
QuestionnairePresentation.SEQUENTIAL,
questionnaire.get_question_item_id(i),
)
for questionnaire in current_batch.values()
]
)
questions = [
questionnaire.generate_question_prompt(
questionnaire_items=questionnaire.get_question(i)
)
for questionnaire in current_batch.values()
]
else:
system_messages, _ = zip(
*[
questionnaire.get_prompt_for_questionnaire_type(
QuestionnairePresentation.SEQUENTIAL,
questionnaire.get_question_item_id(i),
)
for questionnaire in current_batch.values()
]
)
prompts = [
questionnaire.generate_question_prompt(
questionnaire_items=questionnaire.get_question(i)
)
for questionnaire in current_batch.values()
]
questions = prompts
response_generation_methods = [
(
questionnaire.get_question(i).answer_options.response_generation_method
if questionnaire.get_question(i).answer_options
else None
)
for questionnaire in current_batch.values()
]
return list(system_messages), list(prompts), questions, response_generation_methods
[docs]
def conduct_survey_single_item(
model: Union["LLM", AsyncOpenAI],
llm_prompts: LLMPrompt | list[LLMPrompt],
client_model_name: str | None = None,
api_concurrency: int = 10,
print_conversation: bool = False,
print_progress: bool = True,
n_save_step: int | None = None,
intermediate_save_file: str | None = None,
seed: int = 42,
inference_mode: InferenceMode = "chat",
**generation_kwargs: Any,
) -> list[InferenceResult]:
"""
Conducts a survey by asking each question in a new context (single item presentation).
System Prompt -> User Prompt with one question -> LLM Answer for one question
-> Reset Context -> New instance with System Prompt
Args:
model (LLM or AsyncOpenAI): vllm.LLM instance or AsyncOpenAI client.
llm_prompts (LLMPrompt or List(LLMPrompt)): Single LLMPrompt or list of LLMPrompt objects
to conduct as a survey.
client_model_name (str, optional): Name of model when using OpenAI client.
api_concurrency (int): Number of concurrent API requests. Defaults to 10.
print_conversation (bool): If True, prints all conversations to stdout. Default False.
print_progress (bool): If True, shows a tqdm progress bar. Default True.
n_save_step (int, optional): Save intermediate results every n steps.
intermediate_save_file (str, optional): Path to save intermediate results.
Has to be provided if n_save_step.
seed (int): Random seed for reproducibility. Defaults to 42.
inference_mode (str): Use "chat" for message-based models or
"completion" for base-model text generation. Defaults to "chat".
generation_kwargs: Additional generation parameters that will be given to vllm,
vllm.SamplingParams, or OpenAI-compatible generation calls.
Returns:
List(InferenceResult): A list of results containing the survey data and
LLM responses for each provided prompt.
"""
_intermediate_save_path_check(n_save_step, intermediate_save_file)
llm_prompts = _normalize_llm_prompts(llm_prompts)
max_survey_length: int = max(len(questionnaire) for questionnaire in llm_prompts)
question_llm_response_pairs = _initialize_question_response_pairs(llm_prompts)
logger.info(
"Starting single-item survey for %s questionnaires and %s steps.",
len(llm_prompts),
max_survey_length,
)
for i in _iter_survey_steps(max_survey_length, print_progress):
current_batch = _get_current_batch(llm_prompts, i)
(
system_messages,
prompts,
questions,
response_generation_methods,
) = _prepare_single_item_batch(current_batch, i, inference_mode)
# TODO Implement Retrying for errors.
# try:
output, logprobs, reasoning_output = _run_batch_generation(
model=model,
system_messages=system_messages,
prompts=prompts,
response_generation_methods=response_generation_methods,
client_model_name=client_model_name,
api_concurrency=api_concurrency,
print_conversation=print_conversation,
print_progress=print_progress,
seed=seed,
inference_mode=inference_mode,
**generation_kwargs,
)
# except Exception as e:
# warnings.warn(
# "Questions at position {i} could not be processed, because "
# "an error occured: {e}. Output is set to None"
# )
# output = [None] * len(current_batch)
# logprobs = None
# reasoning_output = [None] * len(current_batch)
_store_question_responses(
question_llm_response_pairs=question_llm_response_pairs,
survey_ids=current_batch.keys(),
questions=questions,
answers=output,
logprobs=logprobs,
reasoning_output=reasoning_output,
item_ids=[item.get_question_item_id(i) for item in current_batch.values()],
)
_intermediate_saves(
llm_prompts,
n_save_step,
intermediate_save_file,
question_llm_response_pairs,
i,
)
results = _finalize_survey_results(llm_prompts, question_llm_response_pairs)
logger.info("Completed single-item survey.")
return results
def _intermediate_saves(
questionnaires: list[LLMPrompt],
n_save_step: int,
intermediate_save_file: str,
question_llm_response_pairs: QuestionLLMResponseTuple,
i: int,
):
"""
Internal helper to save intermediate survey results.
Args:
questionnaires: List of questionnaires being conducted.
n_save_step: Save frequency in steps.
intermediate_save_file: Path to save file.
question_llm_response_pairs: Current responses.
i: Current step number.
"""
if n_save_step:
if i % n_save_step == 0:
intermediate_survey_results: list[InferenceResult] = []
for j, questionnaire in enumerate(questionnaires):
intermediate_survey_results.append(
InferenceResult(questionnaire, question_llm_response_pairs[j])
)
parsed_results = raw_responses(intermediate_survey_results)
utils.create_one_dataframe(parsed_results).to_csv(intermediate_save_file)
def _intermediate_save_path_check(n_save_step: int, intermediate_save_path: str):
"""
Internal helper to validate intermediate save path.
Args:
n_save_step: Save frequency in steps.
intermediate_save_path: Path to check.
"""
if n_save_step:
if not isinstance(n_save_step, int) or n_save_step <= 0:
raise ValueError("`n_save_step` must be a positive integer.")
if not intermediate_save_path:
raise ValueError("`intermediate_save_file` must be provided if saving is enabled.")
if not intermediate_save_path.endswith(".csv"):
raise ValueError("`intermediate_save_file` should be a .csv file.")
# Ensure it's a directory that exists or can be created
parent_dir = Path(intermediate_save_path).parent
if not parent_dir.exists():
try:
parent_dir.mkdir(parents=True, exist_ok=True)
except Exception as e:
raise ValueError(
f"Invalid intermediate save path: {intermediate_save_path}. Error: {e}"
) from e
# Optional: Check it's writable
if not os.access(parent_dir, os.W_OK):
raise ValueError(f"Save path '{intermediate_save_path}' is not writable.")
[docs]
def conduct_survey_battery(
model: Union["LLM", AsyncOpenAI],
llm_prompts: LLMPrompt | list[LLMPrompt],
client_model_name: str | None = None,
api_concurrency: int = 10,
n_save_step: int | None = None,
intermediate_save_file: str | None = None,
print_conversation: bool = False,
print_progress: bool = True,
seed: int = 42,
item_separator: str = "\n",
inference_mode: InferenceMode = "chat",
**generation_kwargs: Any,
) -> list[InferenceResult]:
"""
Conducts the entire survey in one single LLM prompt (battery presentation).
System Prompt -> User Prompt with all questions -> LLM Answers all questions
Args:
model (LLM or AsyncOpenAI): vllm.LLM instance or AsyncOpenAI client.
llm_prompts (LLMPrompt or List(LLMPrompt)): Single LLMPrompt or list
of LLMPrompt objects to conduct as a survey.
client_model_name (str, optional): Name of model when using OpenAI client.
api_concurrency (int): Number of concurrent API requests. Defaults to 10.
print_conversation (bool): If True, prints all conversations to stdout. Default False.
print_progress (bool): If True, shows a tqdm progress bar. Default True.
n_save_step (int, optional): Save intermediate results every n steps.
intermediate_save_file (str, optional): Path to save intermediate
results. Has to be provided if n_save_step.
seed (int): Random seed for reproducibility. Defaults to 42.
item_separator (str): The str that separates each question. Defaults to a newline.
inference_mode (str): Use "chat" for message-based models or
"completion" for base-model text generation. Defaults to "chat".
generation_kwargs: Additional generation parameters that will be given
to vllm, vllm.SamplingParams, or OpenAI-compatible generation calls.
Returns:
List(InferenceResult): A list of results containing the survey data and
LLM responses for each provided prompt.
"""
_intermediate_save_path_check(n_save_step, intermediate_save_file)
llm_prompts = _normalize_llm_prompts(llm_prompts)
# inference_options: List[InferenceOptions] = []
# We always conduct the survey in one prompt
max_survey_length: int = 1
question_llm_response_pairs = _initialize_question_response_pairs(llm_prompts)
logger.info("Starting battery survey for %s questionnaires.", len(llm_prompts))
for i in _iter_survey_steps(max_survey_length, print_progress):
current_batch = _get_current_batch(llm_prompts, i)
(
system_messages,
prompts,
response_generation_methods,
) = _prepare_battery_batch(current_batch, i, item_separator, inference_mode)
# TODO Implement Retrying for errors.
# try:
output, logprobs, reasoning_output = _run_batch_generation(
model=model,
system_messages=system_messages,
prompts=prompts,
response_generation_methods=response_generation_methods,
client_model_name=client_model_name,
api_concurrency=api_concurrency,
print_conversation=print_conversation,
print_progress=print_progress,
seed=seed,
inference_mode=inference_mode,
**generation_kwargs,
)
# except Exception as e:
# warnings.warn(
# "Questions at position {i} could not be processed, because "
# "an error occured: {e}. Output is set to None"
# )
# output = [None] * len(current_batch)
# logprobs = None
# reasoning_output = [None] * len(current_batch)
_store_question_responses(
question_llm_response_pairs=question_llm_response_pairs,
survey_ids=current_batch.keys(),
questions=prompts,
answers=output,
logprobs=logprobs,
reasoning_output=reasoning_output,
item_ids=[-1] * len(output),
)
_intermediate_saves(
llm_prompts,
n_save_step,
intermediate_save_file,
question_llm_response_pairs,
i,
)
results = _finalize_survey_results(llm_prompts, question_llm_response_pairs)
logger.info("Completed battery survey.")
return results
[docs]
def conduct_survey_sequential(
model: Union["LLM", AsyncOpenAI],
llm_prompts: LLMPrompt | list[LLMPrompt],
client_model_name: str | None = None,
api_concurrency: int = 10,
print_conversation: bool = False,
print_progress: bool = True,
n_save_step: int | None = None,
intermediate_save_file: str | None = None,
seed: int = 42,
inference_mode: InferenceMode = "chat",
**generation_kwargs: Any,
) -> list[InferenceResult]:
"""
Conducts the survey in multiple chat calls, where all questions and answers
are kept in context (sequential presentation).
System Prompt -> User Prompt with first question -> LLM Answer to first
question -> User Prompt with second question -> ....
Args:
model (LLM or AsyncOpenAI): vllm.LLM instance or AsyncOpenAI client.
llm_prompts (LLMPrompt or List(LLMPrompt)): Single LLMPrompt or list
of LLMPrompt objects to conduct as a survey.
client_model_name (str, optional): Name of model when using OpenAI client.
api_concurrency (int): Number of concurrent API requests. Defaults to 10.
print_conversation (bool): If True, prints all conversations to stdout. Default False.
print_progress (bool): If True, shows a tqdm progress bar. Default True.
n_save_step (int, optional): Save intermediate results every n steps.
intermediate_save_file (str, optional): Path to save intermediate
results. Has to be provided if n_save_step.
seed (int): Random seed for reproducibility. Defaults to 42.
inference_mode (str): Use "chat" for message-based models or
"completion" for base-model text generation. Defaults to "chat".
generation_kwargs: Additional generation parameters that will be given
to vllm, vllm.SamplingParams, or OpenAI-compatible generation calls.
Returns:
List(InferenceResult): A list of results containing the survey data and
LLM responses for each provided prompt.
"""
_intermediate_save_path_check(n_save_step, intermediate_save_file)
llm_prompts = _normalize_llm_prompts(llm_prompts)
max_survey_length: int = max(len(questionnaire) for questionnaire in llm_prompts)
question_llm_response = _initialize_question_response_pairs(llm_prompts)
prompt_history: dict[int, list[str]] = {idx: [] for idx, _ in enumerate(llm_prompts)}
assistant_history: dict[int, list[str]] = {idx: [] for idx, _ in enumerate(llm_prompts)}
logger.info(
"Starting sequential survey for %s questionnaires and %s steps.",
len(llm_prompts),
max_survey_length,
)
for i in _iter_survey_steps(max_survey_length, print_progress):
current_batch = _get_current_batch(llm_prompts, i)
current_survey_ids = list(current_batch.keys())
(
system_messages,
prompts,
questions,
response_generation_methods,
) = _prepare_sequential_step(current_batch, i)
current_batch_size = len(current_batch)
_validate_lengths_against_expected(
current_batch_size,
"Sequential step inputs are misaligned",
system_messages=system_messages,
prompts=prompts,
questions=questions,
response_generation_methods=response_generation_methods,
)
for survey_id, prompt in zip(current_survey_ids, prompts):
prompt_history[survey_id].append(prompt)
prefilled_by_position: dict[int, str] = {}
needed_survey_ids: list[int] = []
needed_system_messages: list[str | None] = []
needed_response_generation_methods: list[ResponseGenerationMethod | None] = []
for local_index, survey_id in enumerate(current_survey_ids):
surv = current_batch[survey_id]
prefilled_answer = surv.get_question(i).prefilled_response
if prefilled_answer is not None:
prefilled_by_position[local_index] = prefilled_answer
else:
needed_survey_ids.append(survey_id)
needed_system_messages.append(system_messages[local_index])
needed_response_generation_methods.append(response_generation_methods[local_index])
needed_batch_size = len(needed_survey_ids)
_validate_lengths_against_expected(
needed_batch_size,
"Filtered sequential inference inputs are misaligned",
needed_system_messages=needed_system_messages,
needed_response_generation_methods=needed_response_generation_methods,
)
# All prompts in the current batch are prefilled -> we continue with the next loop
if needed_batch_size == 0:
output = [prefilled_by_position[idx] for idx in range(current_batch_size)]
logprobs = [None] * current_batch_size
reasoning_output = [None] * current_batch_size
_store_question_responses(
question_llm_response_pairs=question_llm_response,
survey_ids=current_batch.keys(),
questions=questions,
answers=output,
logprobs=logprobs,
reasoning_output=reasoning_output,
item_ids=[item.get_question_item_id(i) for item in current_batch.values()],
)
for survey_id, answer in zip(current_survey_ids, output):
assistant_history[survey_id].append(answer)
continue
needed_prompt_history = [prompt_history[survey_id] for survey_id in needed_survey_ids]
needed_assistant_history = [assistant_history[survey_id] for survey_id in needed_survey_ids]
if inference_mode == "completion":
rendered_prompts = []
for (
survey_id,
system_message,
prompt_history_for_survey,
assistant_history_for_survey,
) in zip(
needed_survey_ids,
needed_system_messages,
needed_prompt_history,
needed_assistant_history,
):
rendered_prompts.append(
llm_prompts[survey_id].render_base_model_prompt(
system_message=system_message,
prompts=prompt_history_for_survey,
assistant_messages=assistant_history_for_survey,
)
)
needed_output, logprobs, reasoning_output = _run_batch_generation(
model=model,
system_messages=[None] * needed_batch_size,
prompts=rendered_prompts,
response_generation_methods=needed_response_generation_methods,
client_model_name=client_model_name,
api_concurrency=api_concurrency,
print_conversation=print_conversation,
print_progress=print_progress,
seed=seed,
inference_mode=inference_mode,
**generation_kwargs,
)
else:
needed_output, logprobs, reasoning_output = batch_turn_by_turn_generation(
model=model,
system_messages=needed_system_messages,
prompts=needed_prompt_history,
assistant_messages=needed_assistant_history,
response_generation_method=needed_response_generation_methods,
client_model_name=client_model_name,
api_concurrency=api_concurrency,
print_conversation=print_conversation,
print_progress=print_progress,
seed=seed,
inference_mode=inference_mode,
**generation_kwargs,
)
# Make sure that outputs are available
if reasoning_output is None:
reasoning_output = [None] * needed_batch_size
if logprobs is None or len(logprobs) == 0:
logprobs = [None] * needed_batch_size
_validate_lengths_against_expected(
needed_batch_size,
"Sequential inference outputs are misaligned",
needed_output=needed_output,
logprobs=logprobs,
reasoning_output=reasoning_output,
)
needed_answers_by_survey = dict(zip(needed_survey_ids, needed_output))
needed_logprobs_by_survey = dict(zip(needed_survey_ids, logprobs))
needed_reasoning_by_survey = dict(zip(needed_survey_ids, reasoning_output))
output: list[str] = []
full_logprobs: list[Any] = []
full_reasoning_output: list[Any] = []
# Append to the correct survey
for local_index, survey_id in enumerate(current_survey_ids):
if local_index in prefilled_by_position:
answer = prefilled_by_position[local_index]
output.append(answer)
full_logprobs.append(None)
full_reasoning_output.append(None)
else:
output.append(needed_answers_by_survey[survey_id])
full_logprobs.append(needed_logprobs_by_survey[survey_id])
full_reasoning_output.append(needed_reasoning_by_survey[survey_id])
assistant_history[survey_id].append(output[-1])
_store_question_responses(
question_llm_response_pairs=question_llm_response,
survey_ids=current_batch.keys(),
questions=questions,
answers=output,
logprobs=full_logprobs,
reasoning_output=full_reasoning_output,
item_ids=[item.get_question_item_id(i) for item in current_batch.values()],
)
_intermediate_saves(
llm_prompts, n_save_step, intermediate_save_file, question_llm_response, i
)
results = _finalize_survey_results(llm_prompts, question_llm_response)
logger.info("Completed sequential survey.")
return results
[docs]
class SurveyCreator:
"""Helper class to create LLM prompts from a population CSV/DataFrame and questionnaire."""
[docs]
@classmethod
def from_path(cls, survey_path: str, questionnaire_path: str) -> list[LLMPrompt]:
"""
Generates LLMPrompt objects from two CSV files (population/survey and questionnaire).
Args:
survey_path (str): The path to the survey CSV file.
Returns:
A list of LLMQuestionnaire objects.
"""
df = pd.read_csv(survey_path)
df_questionnaire = pd.read_csv(questionnaire_path)
return cls._from_dataframe(df, df_questionnaire)
[docs]
@classmethod
def from_dataframe(
cls, survey_dataframe: pd.DataFrame, questionnaire_dataframe: pd.DataFrame
) -> list[LLMPrompt]:
"""
Generates LLMPrompt objects from two pandas DataFrames.
Args:
survey_dataframe (pandas.DataFrame): A DataFrame containing survey
data (questionnaire_name, system_prompt, and
questionnaire_instruction).
questionnaire_dataframe (pandas.DataFrame): A DataFrame containing the questions.
Returns:
A list of LLMQuestionnaire objects.
"""
return cls._from_dataframe(survey_dataframe, questionnaire_dataframe)
@classmethod
def _create_questionnaire(cls, row: pd.Series, df_questionnaire) -> LLMPrompt:
"""
Internal helper method to create the LLM Prompts.
"""
system_prompt = row.get(constants.SYSTEM_PROMPT_FIELD, None)
if pd.isna(system_prompt):
system_prompt = None
return LLMPrompt(
questionnaire_source=df_questionnaire,
questionnaire_name=row[constants.QUESTIONNAIRE_NAME],
system_prompt=system_prompt,
prompt=row[constants.QUESTIONNAIRE_INSTRUCTION_FIELD],
)
@classmethod
def _from_dataframe(cls, df: pd.DataFrame, df_questionnaire: pd.DataFrame) -> list[LLMPrompt]:
"""
Internal helper method to process the DataFrame.
"""
questionnaires = df.apply(
lambda row: cls._create_questionnaire(row, df_questionnaire), axis=1
)
return questionnaires.to_list()
if __name__ == "__main__":
pass