Source code for autorag.evaluation.metric.retrieval
import itertools
import math
from autorag.evaluation.metric.util import autorag_metric
from autorag.schema.metricinput import MetricInput
[docs]
@autorag_metric(fields_to_check=["retrieval_gt", "retrieved_ids"])
def retrieval_f1(metric_input: MetricInput):
"""
Compute f1 score for retrieval.
:param metric_input: The MetricInput schema for AutoRAG metric.
:return: The f1 score.
"""
recall_score = retrieval_recall.__wrapped__(metric_input)
precision_score = retrieval_precision.__wrapped__(metric_input)
if recall_score + precision_score == 0:
return 0
else:
return 2 * (recall_score * precision_score) / (recall_score + precision_score)
[docs]
@autorag_metric(fields_to_check=["retrieval_gt", "retrieved_ids"])
def retrieval_recall(metric_input: MetricInput) -> float:
gt, pred = metric_input.retrieval_gt, metric_input.retrieved_ids
gt_sets = [frozenset(g) for g in gt]
pred_set = set(pred)
hits = sum(any(pred_id in gt_set for pred_id in pred_set) for gt_set in gt_sets)
recall = hits / len(gt) if len(gt) > 0 else 0.0
return recall
[docs]
@autorag_metric(fields_to_check=["retrieval_gt", "retrieved_ids"])
def retrieval_precision(metric_input: MetricInput) -> float:
gt, pred = metric_input.retrieval_gt, metric_input.retrieved_ids
gt_sets = [frozenset(g) for g in gt]
pred_set = set(pred)
hits = sum(any(pred_id in gt_set for gt_set in gt_sets) for pred_id in pred_set)
precision = hits / len(pred) if len(pred) > 0 else 0.0
return precision
[docs]
@autorag_metric(fields_to_check=["retrieval_gt", "retrieved_ids"])
def retrieval_ndcg(metric_input: MetricInput) -> float:
gt, pred = metric_input.retrieval_gt, metric_input.retrieved_ids
gt_sets = [frozenset(g) for g in gt]
pred_set = set(pred)
relevance_scores = {
pred_id: 1 if any(pred_id in gt_set for gt_set in gt_sets) else 0
for pred_id in pred_set
}
dcg = sum(
(2 ** relevance_scores[doc_id] - 1) / math.log2(i + 2)
for i, doc_id in enumerate(pred)
)
len_flatten_gt = len(list(itertools.chain.from_iterable(gt)))
len_pred = len(pred)
ideal_pred = [1] * min(len_flatten_gt, len_pred) + [0] * max(
0, len_pred - len_flatten_gt
)
idcg = sum(relevance / math.log2(i + 2) for i, relevance in enumerate(ideal_pred))
ndcg = dcg / idcg if idcg > 0 else 0
return ndcg
[docs]
@autorag_metric(fields_to_check=["retrieval_gt", "retrieved_ids"])
def retrieval_mrr(metric_input: MetricInput) -> float:
"""
Reciprocal Rank (RR) is the reciprocal of the rank of the first relevant item.
Mean of RR in whole queries is MRR.
"""
gt, pred = metric_input.retrieval_gt, metric_input.retrieved_ids
# Flatten the ground truth list of lists into a single set of relevant documents
gt_sets = [frozenset(g) for g in gt]
rr_list = []
for gt_set in gt_sets:
for i, pred_id in enumerate(pred):
if pred_id in gt_set:
rr_list.append(1.0 / (i + 1))
break
return sum(rr_list) / len(gt_sets) if rr_list else 0.0
[docs]
@autorag_metric(fields_to_check=["retrieval_gt", "retrieved_ids"])
def retrieval_map(metric_input: MetricInput) -> float:
"""
Mean Average Precision (MAP) is the mean of Average Precision (AP) for all queries.
"""
gt, pred = metric_input.retrieval_gt, metric_input.retrieved_ids
gt_sets = [frozenset(g) for g in gt]
ap_list = []
for gt_set in gt_sets:
pred_hits = [1 if pred_id in gt_set else 0 for pred_id in pred]
precision_list = [
sum(pred_hits[: i + 1]) / (i + 1)
for i, hit in enumerate(pred_hits)
if hit == 1
]
ap_list.append(
sum(precision_list) / len(precision_list) if precision_list else 0.0
)
return sum(ap_list) / len(gt_sets) if ap_list else 0.0