mduppes's picture
Add dataset selector for examples
cca4f79
raw
history blame
7.89 kB
import json
import logging
import mimetypes
import os
import typing as tp
from io import StringIO
from urllib.parse import unquote
import pandas as pd
import requests
from backend.config import ABS_DATASET_DOMAIN, get_dataset_config, get_datasets
from backend.descriptions import (
DATASET_DESCRIPTIONS,
DESCRIPTIONS,
METRIC_DESCRIPTIONS,
MODEL_DESCRIPTIONS,
)
from backend.examples import get_examples_tab
from flask import Flask, request, Response, send_from_directory
from flask_cors import CORS
from tools import (
get_leaderboard_filters,
get_old_format_dataframe,
) # Import your function
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.warning("Starting the Flask app...")
app = Flask(__name__, static_folder="../frontend/dist", static_url_path="")
CORS(app)
@app.route("/")
def index():
logger.warning("Serving index.html")
return send_from_directory(app.static_folder, "index.html")
@app.route("/datasets")
def datasets():
"""
Returns the dataset configs grouped by audio / image / video.
"""
return Response(json.dumps(get_datasets()), mimetype="application/json")
@app.route("/data/<path:dataset_name>")
def data_files(dataset_name):
"""
Serves csv files from S3 or locally based on config
"""
# Get dataset_type from query params
dataset_type = request.args.get("dataset_type")
if not dataset_type:
logger.error("No dataset_type provided in query parameters.")
return "Dataset type not specified", 400
dataset_config = get_dataset_config(dataset_name)
file_path = (
os.path.join(dataset_config["path"], dataset_name) + f"_{dataset_type}.csv"
)
logger.info(f"Looking for dataset file: {file_path}")
try:
df = pd.read_csv(file_path)
logger.info(f"Processing dataset: {dataset_name}")
config = get_dataset_config(dataset_name)
if dataset_type == "benchmark":
return get_leaderboard(config, df)
elif dataset_type == "attacks_variations":
return get_chart(config, df)
except:
logger.error(f"Failed to fetch file: {file_path}")
return "File not found", 404
@app.route("/files/<path:file_path>")
def serve_file_path(file_path):
"""
Serves files from S3 or locally based on config
"""
# Get the absolute path to the file
abs_path = file_path
logger.info(f"Looking for file: {abs_path}")
try:
with open(abs_path, "rb") as f:
content = f.read()
return Response(content, mimetype="application/octet-stream")
except FileNotFoundError:
logger.error(f"Failed to fetch file: {abs_path}")
return "File not found", 404
@app.route("/examples/<path:type>")
def example_files(type):
"""
Serve example files from S3 or locally based on config
"""
# Get dataset parameter from query string
dataset_name = request.args.get("dataset")
if not dataset_name:
return {"error": "Dataset parameter is required"}, 400
try:
result = get_examples_tab(type, dataset_name)
return Response(json.dumps(result), mimetype="application/json")
except ValueError as e:
return {"error": str(e)}, 400
@app.route("/descriptions")
def descriptions():
"""
Serve descriptions and model descriptions from descriptions.py
"""
return Response(
json.dumps(
{
"descriptions": DESCRIPTIONS,
"metric_descriptions": METRIC_DESCRIPTIONS,
"model_descriptions": MODEL_DESCRIPTIONS,
"dataset_descriptions": DATASET_DESCRIPTIONS,
}
),
mimetype="application/json",
)
# Add a proxy endpoint to bypass CORS issues
@app.route("/proxy/<path:url>")
def proxy(url):
"""
Proxy endpoint to fetch remote files and serve them to the frontend.
This helps bypass CORS restrictions on remote resources.
"""
try:
# Decode the URL parameter
url = unquote(url)
# Make sure we're only proxying from trusted domains for security
if not url.startswith(ABS_DATASET_DOMAIN):
return {"error": "Only proxying from allowed domains is permitted"}, 403
if url.startswith("http://") or url.startswith("https://"):
response = requests.get(url, stream=True)
if response.status_code != 200:
return {"error": f"Failed to fetch from {url}"}, response.status_code
# Create a Flask Response with the same content type as the original
excluded_headers = [
"content-encoding",
"content-length",
"transfer-encoding",
"connection",
]
headers = {
name: value
for name, value in response.headers.items()
if name.lower() not in excluded_headers
}
# Add CORS headers
headers["Access-Control-Allow-Origin"] = "*"
return Response(response.content, response.status_code, headers)
else:
# Serve a local file if the URL is not a network resource
local_path = url
if not os.path.exists(local_path):
return {"error": f"Local file not found: {local_path}"}, 404
with open(local_path, "rb") as f:
content = f.read()
# Guess content type based on file extension
mime_type, _ = mimetypes.guess_type(local_path)
headers = {"Access-Control-Allow-Origin": "*"}
return Response(
content,
mimetype=mime_type or "application/octet-stream",
headers=headers,
)
except Exception as e:
return {"error": str(e)}, 500
def get_leaderboard(config, df):
# Determine file type and handle accordingly
logger.warning(f"Processing dataset with config: {config}")
# This part adds on all the columns
df = get_old_format_dataframe(df, config["first_cols"], config["attack_scores"])
groups, default_selection = get_leaderboard_filters(df, config["categories"])
# Replace NaN values with None for JSON serialization
df = df.fillna(value="NaN")
# Transpose the DataFrame so each column becomes a row and column is the model
df = df.set_index("model").T.reset_index()
df = df.rename(columns={"index": "metric"})
# Convert DataFrame to JSON
result = {
"groups": {group: list(metrics) for group, metrics in groups.items()},
"default_selected_metrics": list(default_selection),
"rows": df.to_dict(orient="records"),
}
return Response(json.dumps(result), mimetype="application/json")
def get_chart(config, df):
# This function should return the chart data based on the DataFrame
# For now, we will just return a placeholder response
# Replace NaN values with None for JSON serialization
attacks_plot_metrics = [
"bit_acc",
"log10_p_value",
"TPR",
"FPR",
"watermark_det_score",
]
df = df.fillna(value="NaN")
chart_data = {
"metrics": attacks_plot_metrics,
"attacks_with_variations": config["attacks_with_variations"],
"all_attacks_df": df.to_dict(orient="records"),
}
return Response(json.dumps(chart_data), mimetype="application/json")
@app.errorhandler(404)
def not_found(e):
# Serve index.html for any 404 (SPA fallback)
return send_from_directory(app.static_folder, "index.html")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=True)