Source code for autorag.nodes.queryexpansion.hyde
from typing import List
import pandas as pd
from autorag.nodes.queryexpansion.base import BaseQueryExpansion
from autorag.utils import result_to_dataframe
hyde_prompt = "Please write a passage to answer the question"
[docs]
class HyDE(BaseQueryExpansion):
[docs]
@result_to_dataframe(["queries"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
queries = self.cast_to_run(previous_result, *args, **kwargs)
# pop prompt from kwargs
prompt = kwargs.pop("prompt", hyde_prompt)
kwargs.pop("generator_module_type", None)
expanded_queries = self._pure(queries, prompt, **kwargs)
return self._check_expanded_query(queries, expanded_queries)
def _pure(self, queries: List[str], prompt: str = hyde_prompt, **generator_params):
"""
HyDE, which inspired by "Precise Zero-shot Dense Retrieval without Relevance Labels" (https://arxiv.org/pdf/2212.10496.pdf)
LLM model creates a hypothetical passage.
And then, retrieve passages using hypothetical passage as a query.
:param queries: List[str], queries to retrieve.
:param prompt: Prompt to use when generating hypothetical passage
:return: List[List[str]], List of hyde results.
"""
full_prompts = list(
map(
lambda x: (prompt if not bool(prompt) else hyde_prompt)
+ f"\nQuestion: {x}\nPassage:",
queries,
)
)
input_df = pd.DataFrame({"prompts": full_prompts})
result_df = self.generator.pure(previous_result=input_df, **generator_params)
answers = result_df["generated_texts"].tolist()
results = list(map(lambda x: [x], answers))
return results