Source code for autorag.nodes.queryexpansion.base
import abc
import logging
from pathlib import Path
from typing import List, Union
import pandas as pd
from autorag.nodes.util import make_generator_callable_param
from autorag.schema import BaseModule
from autorag.utils import validate_qa_dataset
logger = logging.getLogger("AutoRAG")
[docs]
class BaseQueryExpansion(BaseModule, metaclass=abc.ABCMeta):
def __init__(self, project_dir: Union[str, Path], *args, **kwargs):
logger.info(
f"Initialize query expansion node - {self.__class__.__name__} module..."
)
# set generator module for query expansion
generator_class, generator_param = make_generator_callable_param(kwargs)
self.generator = generator_class(project_dir, **generator_param)
def __del__(self):
del self.generator
logger.info(
f"Delete query expansion node - {self.__class__.__name__} module..."
)
[docs]
def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
logger.info(
f"Running query expansion node - {self.__class__.__name__} module..."
)
validate_qa_dataset(previous_result)
# find queries columns
assert "query" in previous_result.columns, (
"previous_result must have query column."
)
queries = previous_result["query"].tolist()
return queries
@staticmethod
def _check_expanded_query(queries: List[str], expanded_queries: List[List[str]]):
return list(
map(
lambda query, expanded_query_list: check_expanded_query(
query, expanded_query_list
),
queries,
expanded_queries,
)
)
[docs]
def check_expanded_query(query: str, expanded_query_list: List[str]):
# check if the expanded query is the same as the original query
expanded_query_list = list(map(lambda x: x.strip(), expanded_query_list))
return [
expanded_query if expanded_query else query
for expanded_query in expanded_query_list
]