Source code for autorag.nodes.passageaugmenter.base
import abc
import logging
import os
import pandas as pd
from autorag.schema import BaseModule
from autorag.utils import (
validate_qa_dataset,
sort_by_scores,
validate_corpus_dataset,
cast_corpus_dataset,
)
from autorag.utils.cast import cast_retrieved_ids
from autorag.utils.util import select_top_k
logger = logging.getLogger("AutoRAG")
[docs]
class BasePassageAugmenter(BaseModule, metaclass=abc.ABCMeta):
def __init__(self, project_dir: str, *args, **kwargs):
logger.info(
f"Initialize passage augmenter node - {self.__class__.__name__} module..."
)
data_dir = os.path.join(project_dir, "data")
corpus_df = pd.read_parquet(
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
)
validate_corpus_dataset(corpus_df)
corpus_df = cast_corpus_dataset(corpus_df)
self.corpus_df = corpus_df
def __del__(self):
logger.info(
f"Initialize passage augmenter node - {self.__class__.__name__} module..."
)
[docs]
def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
logger.info(
f"Running passage augmenter node - {self.__class__.__name__} module..."
)
validate_qa_dataset(previous_result)
# find ids columns
return cast_retrieved_ids(previous_result)
[docs]
@staticmethod
def sort_by_scores(
augmented_contents,
augmented_ids,
augmented_scores,
top_k: int,
reverse: bool = True,
):
# sort by scores
df = pd.DataFrame(
{
"contents": augmented_contents,
"ids": augmented_ids,
"scores": augmented_scores,
}
)
df[["contents", "ids", "scores"]] = df.apply(
lambda row: sort_by_scores(row, reverse=reverse),
axis=1,
result_type="expand",
)
# select by top_k
results = select_top_k(df, ["contents", "ids", "scores"], top_k)
return (
results["contents"].tolist(),
results["ids"].tolist(),
results["scores"].tolist(),
)