|
import argparse |
|
import json |
|
from tqdm.auto import tqdm |
|
|
|
import pandas as pd |
|
import numpy as np |
|
|
|
from app.recommendations import RecommenderSystem |
|
|
|
|
|
def precision_at_k(recommended_items, relevant_items, k): |
|
recommended_at_k = set(recommended_items[:k]) |
|
return len(recommended_at_k & relevant_items) / k |
|
|
|
|
|
def average_precision_at_k(recommended_items, relevant_items, k): |
|
relevant_items = set(relevant_items) |
|
|
|
apk_sum = 0.0 |
|
for m, item in enumerate(recommended_items): |
|
if item in relevant_items: |
|
apk_sum += precision_at_k(recommended_items, relevant_items, m+1) |
|
|
|
return apk_sum / min(k, len(relevant_items)) |
|
|
|
|
|
def evaluate_recsys( |
|
val_ratings_path, |
|
faiss_index_path, |
|
db_path, |
|
n_recommend_items=10, |
|
metrics_savepath=None |
|
): |
|
recsys = RecommenderSystem( |
|
faiss_index_path=faiss_index_path, |
|
db_path=db_path) |
|
|
|
val_ratings = pd.read_csv(val_ratings_path) |
|
grouped_items = val_ratings.groupby("user_id")["item_id"].apply(list).reset_index() |
|
grouped_items = grouped_items["item_id"].tolist() |
|
|
|
|
|
metric_arrays = { |
|
"ap@5": [], |
|
} |
|
|
|
for item_group in tqdm(grouped_items): |
|
if len(item_group) == 1: |
|
continue |
|
|
|
|
|
|
|
|
|
user_metric_arrays = dict() |
|
for metric in metric_arrays.keys(): |
|
user_metric_arrays[metric] = [] |
|
|
|
for item in item_group: |
|
recommend_items = list(recsys.recommend_items(item, n_recommend_items)) |
|
relevant_items = set(item_group) - {item} |
|
|
|
user_metric_arrays["ap@5"].append( |
|
average_precision_at_k(recommend_items, relevant_items, k=5)) |
|
|
|
for metric in metric_arrays.keys(): |
|
user_metric = np.mean(user_metric_arrays[metric]) |
|
metric_arrays[metric].append(user_metric) |
|
|
|
metrics = dict() |
|
for metric, array in metric_arrays.items(): |
|
metrics[metric] = np.mean(array) |
|
|
|
if metrics_savepath is not None: |
|
with open(metrics_savepath, "w") as f: |
|
json.dump(metrics, f) |
|
print(f"Saved metrics to {metrics_savepath}") |
|
return metrics |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Evaluate a recommendation system.") |
|
parser.add_argument("--metrics_savepath", required=True, type=str, help="Path to save the evaluation metrics.") |
|
parser.add_argument("--val_ratings_path", required=True, type=str, help="Path to the csv file with validation ratings.") |
|
parser.add_argument("--faiss_index_path", required=True, type=str, help="Path to the FAISS index.") |
|
parser.add_argument("--db_path", required=True, type=str, help="Path to the database file.") |
|
parser.add_argument("--n_recommend_items", type=int, default=10, help="Number of items to recommend.") |
|
args = parser.parse_args() |
|
evaluate_recsys(**vars(args)) |