Source code for autorag.data.qa.generation_gt.openai_gen_gt
import itertools
from typing import Dict
from openai import AsyncClient
from pydantic import BaseModel
from autorag.data.qa.generation_gt.base import add_gen_gt
from autorag.data.qa.generation_gt.prompt import GEN_GT_SYSTEM_PROMPT
[docs]
class Response(BaseModel):
answer: str
[docs]
async def make_gen_gt_openai(
row: Dict,
client: AsyncClient,
system_prompt: str,
model_name: str = "gpt-4o-2024-08-06",
):
retrieval_gt_contents = list(
itertools.chain.from_iterable(row["retrieval_gt_contents"])
)
query = row["query"]
passage_str = "\n".join(retrieval_gt_contents)
user_prompt = f"Text:\n<|text_start|>\n{passage_str}\n<|text_end|>\n\nQuestion:\n{query}\n\nAnswer:"
completion = await client.beta.chat.completions.parse(
model=model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.0,
response_format=Response,
)
response: Response = completion.choices[0].message.parsed
return add_gen_gt(row, response.answer)
[docs]
async def make_concise_gen_gt(
row: Dict,
client: AsyncClient,
model_name: str = "gpt-4o-2024-08-06",
lang: str = "en",
):
"""
Generate concise generation_gt using OpenAI Structured Output for preventing errors.
It generates a concise answer, so it is generally a word or just a phrase.
:param row: The input row of the qa dataframe.
:param client: The OpenAI async client.
:param model_name: The model name that supports structured output.
It has to be "gpt-4o-2024-08-06" or "gpt-4o-mini-2024-07-18".
:param lang: The language code of the prompt.
Default is "en".
:return: The output row of the qa dataframe with added "generation_gt" in it.
"""
return await make_gen_gt_openai(
row, client, GEN_GT_SYSTEM_PROMPT["concise"][lang], model_name
)
[docs]
async def make_basic_gen_gt(
row: Dict,
client: AsyncClient,
model_name: str = "gpt-4o-2024-08-06",
lang: str = "en",
):
"""
Generate basic generation_gt using OpenAI Structured Output for preventing errors.
It generates a "basic" answer, and its prompt is simple.
:param row: The input row of the qa dataframe.
:param client: The OpenAI async client.
:param model_name: The model name that supports structured output.
It has to be "gpt-4o-2024-08-06" or "gpt-4o-mini-2024-07-18".
:param lang: The language code of the prompt.
Default is "en".
:return: The output row of the qa dataframe with added "generation_gt" in it.
"""
return await make_gen_gt_openai(
row, client, GEN_GT_SYSTEM_PROMPT["basic"][lang], model_name
)