Source code for autorag.nodes.queryexpansion.multi_query_expansion
from typing import List
import pandas as pd
from autorag.nodes.queryexpansion.base import BaseQueryExpansion
from autorag.utils import result_to_dataframe
multi_query_expansion_prompt = """You are an AI language model assistant.
Your task is to generate 3 different versions of the given user
question to retrieve relevant documents from a vector database.
By generating multiple perspectives on the user question,
your goal is to help the user overcome some of the limitations
of distance-based similarity search. Provide these alternative
questions separated by newlines. Original question: {query}"""
[docs]
class MultiQueryExpansion(BaseQueryExpansion):
[docs]
@result_to_dataframe(["queries"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
queries = self.cast_to_run(previous_result, *args, **kwargs)
# pop prompt from kwargs
prompt = kwargs.pop("prompt", multi_query_expansion_prompt)
kwargs.pop("generator_module_type", None)
expanded_queries = self._pure(queries, prompt, **kwargs)
return self._check_expanded_query(queries, expanded_queries)
def _pure(
self, queries, prompt: str = multi_query_expansion_prompt, **kwargs
) -> List[List[str]]:
"""
Expand a list of queries using a multi-query expansion approach.
LLM model generate 3 different versions queries for each input query.
:param queries: List[str], queries to decompose.
:param prompt: str, prompt to use for multi-query expansion.
default prompt comes from langchain MultiQueryRetriever default query prompt.
:return: List[List[str]], list of expansion query.
"""
full_prompts = list(map(lambda x: prompt.format(query=x), queries))
input_df = pd.DataFrame({"prompts": full_prompts})
result_df = self.generator.pure(previous_result=input_df, **kwargs)
answers = result_df["generated_texts"].tolist()
results = list(
map(lambda x: get_multi_query_expansion(x[0], x[1]), zip(queries, answers))
)
return results
[docs]
def get_multi_query_expansion(query: str, answer: str) -> List[str]:
try:
queries = answer.split("\n")
queries.insert(0, query)
return queries
except:
return [query]