Source code for autorag.nodes.passagereranker.base

import abc
import logging
from pathlib import Path
from typing import Union

import pandas as pd

from autorag.schema import BaseModule
from autorag.utils import validate_qa_dataset
from autorag.utils.cast import cast_retrieve_infos

logger = logging.getLogger("AutoRAG")


[docs] class BasePassageReranker(BaseModule, metaclass=abc.ABCMeta): def __init__(self, project_dir: Union[str, Path], *args, **kwargs): logger.info( f"Initialize passage reranker node - {self.__class__.__name__} module..." ) def __del__(self): logger.info( f"Deleting passage reranker node - {self.__class__.__name__} module..." )
[docs] def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs): logger.info( f"Running passage reranker node - {self.__class__.__name__} module..." ) validate_qa_dataset(previous_result) # find queries columns assert ( "query" in previous_result.columns ), "previous_result must have query column." queries = previous_result["query"].tolist() retrieve_infos = cast_retrieve_infos(previous_result) return ( queries, retrieve_infos["retrieved_contents"], retrieve_infos["retrieve_scores"], retrieve_infos["retrieved_ids"], )