Spaces:
Running
Running
from __future__ import annotations | |
import os | |
import gradio as gr | |
import json | |
import random | |
from datetime import datetime | |
from typing import Dict, List, Tuple | |
import hashlib | |
import itertools | |
from datasets import load_dataset, Dataset, DatasetDict | |
from huggingface_hub import HfApi, create_repo, repo_exists, Repository | |
from huggingface_hub import HfFolder | |
import shutil | |
import threading | |
import json | |
from collections.abc import Iterable | |
from gradio.themes.base import Base | |
from gradio.themes.utils import colors, fonts, sizes | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
os.environ['HF_AUTH'] = HF_TOKEN | |
HfApi(token=HF_TOKEN) | |
USER_IDS = set(json.loads(os.environ.get("USER_IDS")) + json.loads(os.environ.get("USER_IDS_2"))) | |
class Soft(Base): | |
def __init__( | |
self, | |
*, | |
primary_hue: colors.Color | str = colors.indigo, | |
secondary_hue: colors.Color | str = colors.indigo, | |
neutral_hue: colors.Color | str = colors.gray, | |
spacing_size: sizes.Size | str = sizes.spacing_md, | |
radius_size: sizes.Size | str = sizes.radius_md, | |
text_size: sizes.Size | str = sizes.text_md, | |
font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
# fonts.LocalFont("Montserrat"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
), | |
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
# fonts.LocalFont("IBM Plex Mono"), | |
"ui-monospace", | |
"Consolas", | |
"monospace", | |
), | |
): | |
super().__init__( | |
primary_hue=primary_hue, | |
secondary_hue=secondary_hue, | |
neutral_hue=neutral_hue, | |
spacing_size=spacing_size, | |
radius_size=radius_size, | |
text_size=text_size, | |
font=font, | |
font_mono=font_mono, | |
) | |
self.name = "soft" | |
super().set( | |
# Colors | |
background_fill_primary="*neutral_50", | |
slider_color="*primary_500", | |
slider_color_dark="*primary_600", | |
# Shadows | |
shadow_drop="0 1px 4px 0 rgb(0 0 0 / 0.1)", | |
shadow_drop_lg="0 2px 5px 0 rgb(0 0 0 / 0.2)", | |
# Block Labels | |
block_background_fill="white", | |
block_label_padding="*spacing_sm *spacing_md", | |
block_label_background_fill="*primary_100", | |
block_label_background_fill_dark="*primary_600", | |
block_label_radius="*radius_md", | |
block_label_text_size="*text_md", | |
block_label_text_weight="600", | |
block_label_text_color="*primary_500", | |
block_label_text_color_dark="white", | |
block_title_radius="*block_label_radius", | |
block_title_padding="*block_label_padding", | |
block_title_background_fill="*block_label_background_fill", | |
block_title_text_weight="600", | |
block_title_text_color="*primary_500", | |
block_title_text_color_dark="white", | |
block_label_margin="*spacing_md", | |
# Inputs | |
input_background_fill="white", | |
input_border_color="*neutral_100", | |
input_shadow="*shadow_drop", | |
input_shadow_focus="*shadow_drop_lg", | |
checkbox_shadow="none", | |
# Buttons | |
shadow_spread="6px", | |
button_primary_shadow="*shadow_drop_lg", | |
button_primary_shadow_hover="*shadow_drop_lg", | |
button_primary_shadow_active="*shadow_inset", | |
button_secondary_shadow="*shadow_drop_lg", | |
button_secondary_shadow_hover="*shadow_drop_lg", | |
button_secondary_shadow_active="*shadow_inset", | |
checkbox_label_shadow="*shadow_drop_lg", | |
button_primary_background_fill="*primary_500", | |
button_primary_background_fill_hover="*primary_400", | |
button_primary_background_fill_hover_dark="*primary_500", | |
button_primary_text_color="white", | |
button_secondary_background_fill="white", | |
button_secondary_background_fill_hover="*neutral_100", | |
button_secondary_background_fill_hover_dark="*primary_500", | |
button_secondary_text_color="*neutral_800", | |
button_cancel_background_fill="*button_secondary_background_fill", | |
button_cancel_background_fill_hover="*button_secondary_background_fill_hover", | |
button_cancel_background_fill_hover_dark="*button_secondary_background_fill_hover", | |
button_cancel_text_color="*button_secondary_text_color", | |
checkbox_label_background_fill_selected="*primary_500", | |
checkbox_label_background_fill_selected_dark="*primary_600", | |
checkbox_border_width="1px", | |
checkbox_border_color="*neutral_100", | |
checkbox_border_color_dark="*neutral_600", | |
checkbox_background_color_selected="*primary_600", | |
checkbox_background_color_selected_dark="*primary_700", | |
checkbox_border_color_focus="*primary_500", | |
checkbox_border_color_focus_dark="*primary_600", | |
checkbox_border_color_selected="*primary_600", | |
checkbox_border_color_selected_dark="*primary_700", | |
checkbox_label_text_color_selected="white", | |
# Borders | |
block_border_width="0px", | |
panel_border_width="0px", | |
) | |
guideline = open("guidelines.md").read().strip() | |
# Configuration for the output dataset | |
ANNOTATIONS_REPO = "ltg/fluency-annotations" # Change to your repo name | |
DATA_DIR = "annotation_data" | |
ANNOTATIONS_FILE = os.path.join(DATA_DIR, "train.jsonl") | |
# Model names for the three responses | |
MODEL_NAMES = ["mistral-Nemo", "translated-SFT", "on-policy-RL"] | |
# Create all pairwise comparisons | |
MODEL_PAIRS = list(itertools.combinations(MODEL_NAMES, 2)) | |
# Initialize repository | |
def init_repository(): | |
"""Initialize or clone the repository""" | |
try: | |
repo = Repository( | |
local_dir=DATA_DIR, | |
clone_from=ANNOTATIONS_REPO, | |
use_auth_token=HF_TOKEN, | |
repo_type="dataset" | |
) | |
repo.git_pull() | |
return repo | |
except Exception as e: | |
print(f"Error initializing repository: {e}") | |
# Create local directory if repo doesn't exist | |
os.makedirs(DATA_DIR, exist_ok=True) | |
return None | |
# Initialize on startup | |
annotation_repo = init_repository() | |
def load_existing_annotations(): | |
"""Load existing annotations from the jsonl file""" | |
annotations = {} | |
if os.path.exists(ANNOTATIONS_FILE): | |
try: | |
with open(ANNOTATIONS_FILE, "r") as f: | |
for line in f: | |
if line.strip(): | |
ann = json.loads(line) | |
user_id = ann.get("user_id") | |
if user_id: | |
if user_id not in annotations: | |
annotations[user_id] = [] | |
annotations[user_id].append(ann) | |
print(f"Loaded {sum(len(v) for v in annotations.values())} existing annotations") | |
except Exception as e: | |
print(f"Error loading annotations: {e}") | |
return annotations | |
def save_annotation_to_file(annotation_data): | |
"""Save a single annotation to the jsonl file and push to hub""" | |
global annotation_repo | |
try: | |
# Pull latest changes | |
if annotation_repo: | |
annotation_repo.git_pull() | |
# Append to jsonl file | |
with open(ANNOTATIONS_FILE, "a") as f: | |
line = json.dumps(annotation_data, ensure_ascii=False) | |
f.write(f"{line}\n") | |
# Push to hub asynchronously | |
if annotation_repo: | |
annotation_repo.push_to_hub(blocking=False) | |
except Exception as e: | |
print(f"Error saving annotation: {e}") | |
# Try to reinitialize repository | |
try: | |
shutil.rmtree(DATA_DIR) | |
annotation_repo = init_repository() | |
# Retry saving | |
with open(ANNOTATIONS_FILE, "a") as f: | |
line = json.dumps(annotation_data, ensure_ascii=False) | |
f.write(f"{line}\n") | |
if annotation_repo: | |
annotation_repo.push_to_hub(blocking=False) | |
except Exception as e2: | |
print(f"Failed to save annotation after retry: {e2}") | |
def load_dataset_samples(): | |
"""Load and prepare dataset samples with pairwise comparisons""" | |
try: | |
# Load the private dataset (requires authentication) | |
dataset = load_dataset("ltg/fluency-generations", split="train", token=HF_TOKEN) | |
# Transform dataset into pairwise comparison format | |
pairwise_samples = [] | |
for item in dataset: | |
sample_id = item["sample_id"] | |
prompt = item["prompt"] | |
responses = item["responses"] | |
# Create pairwise comparisons for this sample | |
for model_a, model_b in MODEL_PAIRS: | |
pairwise_samples.append({ | |
"id": f"{sample_id}_{model_a}_vs_{model_b}", | |
"original_id": sample_id, | |
"prompt": prompt, | |
"response_a": responses[model_a], | |
"response_b": responses[model_b], | |
"model_a": model_a, | |
"model_b": model_b, | |
"dataset": "NTNU" | |
}) | |
extra_dataset = load_dataset("ltg/fluency-generations", split="test", token=HF_TOKEN) | |
extra_pairwise_samples = [] | |
for i, item in enumerate(extra_dataset): | |
sample_id = item["sample_id"] | |
prompt = item["prompt"] | |
responses = item["responses"] | |
model_a, model_b = MODEL_PAIRS[i % len(MODEL_PAIRS)] | |
model_a, model_b = (model_a, model_b) if i % 2 == 0 else (model_b, model_a) | |
extra_pairwise_samples.append({ | |
"id": f"{sample_id}_{model_a}_vs_{model_b}", | |
"original_id": sample_id, | |
"prompt": prompt, | |
"response_a": responses[model_a], | |
"response_b": responses[model_b], | |
"model_a": model_a, | |
"model_b": model_b, | |
"dataset": "training_examples" | |
}) | |
return pairwise_samples, extra_pairwise_samples | |
except Exception as e: | |
print(f"Error loading dataset: {e}") | |
print("Using dummy data for testing...") | |
# Fallback to dummy data for testing | |
return [ | |
{ | |
"id": "dummy_001_modelA_vs_modelB", | |
"original_id": "dummy_001", | |
"prompt": "Test prompt for development", | |
"response_a": "This is response A for testing.", | |
"response_b": "This is response B for testing.", | |
"model_a": "modelA", | |
"model_b": "modelB", | |
"dataset": "test" | |
} | |
], [] | |
def swap_sample(sample): | |
return { | |
"id": str(sample["original_id"]) + '_' + sample["model_b"] + '_vs_' + sample["model_a"], | |
"original_id": sample["original_id"], | |
"prompt": sample["prompt"], | |
"response_a": sample["response_b"], | |
"response_b": sample["response_a"], | |
"model_a": sample["model_b"], | |
"model_b": sample["model_a"], | |
"dataset": sample["dataset"] | |
} | |
# Load dataset on startup | |
DATASET_SAMPLES, EXTRA_DATASET_SAMPLES = load_dataset_samples() | |
class AnnotationManager: | |
def __init__(self): | |
# Load existing annotations from file | |
self.annotations = load_existing_annotations() | |
self.user_states = {} | |
# Rebuild user states from loaded annotations | |
for user_id, user_annotations in self.annotations.items(): | |
annotated_ids = [ann["sample_id"] for ann in user_annotations] | |
self.user_states[user_id] = { | |
"current_index": 0, | |
"annotations": annotated_ids | |
} | |
def get_user_seed(self, user_id: str) -> int: | |
"""Generate consistent seed for user""" | |
return int(hashlib.md5(user_id.encode()).hexdigest(), 16) | |
def get_user_samples(self, user_id: str) -> List[Dict]: | |
"""Get shuffled samples for user based on their ID""" | |
seed = self.get_user_seed(user_id) | |
samples = DATASET_SAMPLES.copy() | |
random.Random(seed).shuffle(samples) | |
samples = [ | |
sample if random.Random(seed + i).randint(0, 1) == 0 else swap_sample(sample) | |
for i, sample in enumerate(samples) | |
] | |
samples = EXTRA_DATASET_SAMPLES.copy() + samples | |
return samples | |
def get_next_sample(self, user_id: str) -> Tuple[Dict, int, int]: | |
"""Get next unannotated sample for user""" | |
if user_id not in self.user_states: | |
# Check if user has existing annotations | |
if user_id in self.annotations: | |
annotated_ids = [ann["sample_id"] for ann in self.annotations[user_id]] | |
self.user_states[user_id] = { | |
"current_index": 0, | |
"annotations": annotated_ids | |
} | |
else: | |
self.user_states[user_id] = { | |
"current_index": 0, | |
"annotations": [] | |
} | |
samples = self.get_user_samples(user_id) | |
state = self.user_states[user_id] | |
# Count total annotations for this user | |
total_annotated = len(state["annotations"]) | |
# Find next unannotated sample | |
for idx, sample in enumerate(samples): | |
if not self.is_annotated(user_id, sample["id"]): | |
return sample, total_annotated + 1, len(samples) | |
# All samples annotated | |
return None, len(samples), len(samples) | |
def is_annotated(self, user_id: str, sample_id: str) -> bool: | |
"""Check if user has annotated this sample""" | |
if user_id not in self.annotations: | |
return False | |
return any(ann["sample_id"] == sample_id for ann in self.annotations[user_id]) | |
def save_annotation(self, user_id: str, sample_id: str, choice: str, | |
model_a: str = None, model_b: str = None, | |
original_id: str = None, dataset_name: str = None): | |
"""Save user's annotation and persist to file""" | |
if user_id not in self.annotations: | |
self.annotations[user_id] = [] | |
annotation = { | |
"user_id": user_id, | |
"sample_id": sample_id, | |
"original_sample_id": original_id, | |
"dataset": dataset_name, | |
"model_a": model_a, | |
"model_b": model_b, | |
"choice": choice, | |
"timestamp": datetime.now().isoformat() | |
} | |
# Save to memory | |
self.annotations[user_id].append(annotation) | |
# Update user state | |
if user_id in self.user_states: | |
self.user_states[user_id]["annotations"].append(sample_id) | |
else: | |
self.user_states[user_id] = { | |
"current_index": 0, | |
"annotations": [sample_id] | |
} | |
# Save to file asynchronously | |
threading.Thread( | |
target=save_annotation_to_file, | |
args=(annotation,) | |
).start() | |
print(f"Saved annotation: {annotation}") | |
def get_user_progress(self, user_id: str) -> Dict: | |
"""Get user's annotation progress""" | |
if user_id not in self.annotations: | |
return {"completed": 0, "total": len(DATASET_SAMPLES)} | |
completed = len(self.annotations[user_id]) | |
return {"completed": completed, "total": len(DATASET_SAMPLES)} | |
# Initialize manager | |
manager = AnnotationManager() | |
def login(user_id: str) -> Tuple: | |
"""Handle user login""" | |
if not user_id or user_id.strip() == "" or user_id.strip() not in USER_IDS: | |
return ( | |
gr.update(visible=True), # login_interface | |
gr.update(visible=False), # annotation_interface | |
"", # user_state | |
gr.update(value="Please enter a valid ID"), # login_status | |
gr.update(), # prompt | |
gr.update(), # response_a | |
gr.update(), # response_b | |
gr.update() # progress | |
) | |
user_id = user_id.strip() | |
sample, current, total = manager.get_next_sample(user_id) | |
if sample is None: | |
return ( | |
gr.update(visible=True), # login_interface | |
gr.update(visible=False), # annotation_interface | |
user_id, # user_state | |
gr.update(value=f"All {total} samples completed for user: {user_id}! 🎉"), # login_status | |
gr.update(), # prompt | |
gr.update(), # response_a | |
gr.update(), # response_b | |
gr.update() # progress | |
) | |
# Show which models are being compared | |
model_info = f" | Comparing: {sample.get('model_a', 'A')} vs {sample.get('model_b', 'B')}" | |
return ( | |
gr.update(visible=False), # login_interface | |
gr.update(visible=True), # annotation_interface | |
user_id, # user_state | |
gr.update(value=""), # login_status | |
gr.update(value=sample["prompt"]), # prompt | |
gr.update(value=sample["response_a"]), # response_a | |
gr.update(value=sample["response_b"]), # response_b | |
gr.update(value=f"Progress: {current}/{total}") # progress | |
) | |
def annotate(choice: str, user_id: str) -> Tuple: | |
"""Handle annotation submission""" | |
if not user_id: | |
return ( | |
gr.update(), # prompt | |
gr.update(), # response_a | |
gr.update(), # response_b | |
gr.update(), # progress | |
gr.update(value="Error: No user logged in", visible=True) # status | |
) | |
# Get current sample to save annotation | |
sample, _, _ = manager.get_next_sample(user_id) | |
if sample: | |
# Map button choice to annotation value | |
choice_map = { | |
"a_better": "A is more fluent", | |
"b_better": "B is more fluent", | |
"equal": "Equally fluent" | |
} | |
# Save with all metadata | |
manager.save_annotation( | |
user_id=user_id, | |
sample_id=sample["id"], | |
choice=choice_map[choice], | |
model_a=sample.get("model_a"), | |
model_b=sample.get("model_b"), | |
original_id=sample.get("original_id"), | |
dataset_name=sample.get("dataset") | |
) | |
# Get next sample | |
next_sample, current, total = manager.get_next_sample(user_id) | |
if next_sample is None: | |
return ( | |
gr.update(value="All samples completed! Thank you for your annotations."), # prompt | |
gr.update(value=""), # response_a | |
gr.update(value=""), # response_b | |
gr.update(value=f"Progress: {total}/{total} - Complete!"), # progress | |
gr.update(value="All annotations complete!", visible=True) # status | |
) | |
# Show which models are being compared | |
model_info = f" | Comparing: {next_sample.get('model_a', 'A')} vs {next_sample.get('model_b', 'B')}" | |
return ( | |
gr.update(value=next_sample["prompt"]), # prompt | |
gr.update(value=next_sample["response_a"]), # response_a | |
gr.update(value=next_sample["response_b"]), # response_b | |
gr.update(value=f"Progress: {current}/{total}"), # progress | |
gr.update(value="Annotation saved!", visible=True) # status | |
) | |
def logout() -> Tuple: | |
"""Handle user logout""" | |
return ( | |
gr.update(visible=True), # login_interface | |
gr.update(visible=False), # annotation_interface | |
"", # user_state | |
gr.update(value=""), # login_status | |
gr.update(value=""), # prompt | |
gr.update(value=""), # response_a | |
gr.update(value=""), # response_b | |
gr.update(value="") # progress | |
) | |
# Create Gradio interface | |
custom_css = """ | |
#login-group { | |
background-color: white !important; | |
} | |
#login-group > * { | |
background-color: white !important; | |
} | |
#login-group .gr-group { | |
background-color: white !important; | |
} | |
#login-group .gr-form { | |
background-color: white !important; | |
} | |
.light-shadow { | |
box-shadow: 0 1px 4px 0 rgb(0 0 0 / 0.1) !important; | |
} | |
/* Target the textbox container */ | |
.no-style-textbox { | |
border: none !important; | |
box-shadow: none !important; | |
} | |
/* Target both input and textarea elements */ | |
.no-style-textbox input, | |
.no-style-textbox textarea { | |
border: none !important; | |
box-shadow: none !important; | |
padding: 0 !important; | |
outline: none !important; | |
} | |
/* Target the Gradio textbox wrapper */ | |
.no-style-textbox .gr-textbox { | |
border: none !important; | |
box-shadow: none !important; | |
} | |
/* Target focus states */ | |
.no-style-textbox input:focus, | |
.no-style-textbox textarea:focus { | |
border: none !important; | |
box-shadow: none !important; | |
outline: none !important; | |
} | |
/* Additional targeting for stubborn Gradio elements */ | |
.no-style-textbox .gr-form, | |
.no-style-textbox .gr-input { | |
border: none !important; | |
box-shadow: none !important; | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(theme=Soft(font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial"]), title="Dataset Annotation Tool", css=custom_css) as app: | |
gr.Markdown("# Norwegian Fluency Annotation") | |
with gr.Accordion("Click here to see the full annotation guidelines:", open=False, elem_classes="light-shadow"): | |
gr.Markdown(guideline, padding=True) | |
user_state = gr.State("") | |
# Login Interface | |
with gr.Column(visible=True) as login_interface: | |
with gr.Column(variant="panel", elem_id="login-group", elem_classes="light-shadow"): | |
gr.Markdown("## Log in", padding=True) | |
user_id_input = gr.Textbox( | |
label="Enter your unique annotator ID to begin", | |
placeholder="Annotator ID" | |
) | |
with gr.Row(): | |
login_btn = gr.Button("Login", variant="primary", scale=0.2, min_width=100) | |
gr.HTML("") | |
login_status = gr.Markdown("", padding=True) | |
# Annotation Interface | |
with gr.Column(visible=False, elem_id="annotation-group") as annotation_interface: | |
progress_label = gr.Markdown("") | |
# Row 1: Prompt | |
with gr.Row(elem_classes="light-shadow"): | |
prompt_display = gr.Textbox( | |
label="Prompt", | |
interactive=False, | |
lines=1, | |
elem_classes="no-style-textbox", | |
autoscroll=False | |
) | |
# Row 2: Responses | |
with gr.Row(elem_classes="light-shadow"): | |
response_a_display = gr.Textbox( | |
label="Response A", | |
interactive=False, | |
lines=1, | |
scale=1, | |
elem_classes="no-style-textbox", | |
autoscroll=False, | |
max_lines=100 | |
) | |
response_b_display = gr.Textbox( | |
label="Response B", | |
interactive=False, | |
lines=1, | |
scale=1, | |
elem_classes="no-style-textbox", | |
autoscroll=False, | |
max_lines=100 | |
) | |
# Row 3: Buttons | |
with gr.Row(): | |
btn_a = gr.Button("A is more fluent", variant="primary") | |
btn_equal = gr.Button("Equally fluent", variant="primary") | |
btn_b = gr.Button("B is more fluent", variant="primary") | |
status_message = gr.Markdown("", visible=False) | |
with gr.Row(visible=False): | |
logout_btn = gr.Button("Logout", variant="stop", size="sm") | |
# Event handlers | |
login_btn.click( | |
fn=login, | |
inputs=[user_id_input], | |
outputs=[ | |
login_interface, | |
annotation_interface, | |
user_state, | |
login_status, | |
prompt_display, | |
response_a_display, | |
response_b_display, | |
progress_label | |
] | |
) | |
user_id_input.submit( | |
fn=login, | |
inputs=[user_id_input], | |
outputs=[ | |
login_interface, | |
annotation_interface, | |
user_state, | |
login_status, | |
prompt_display, | |
response_a_display, | |
response_b_display, | |
progress_label | |
] | |
) | |
btn_a.click( | |
fn=lambda user_id: annotate("a_better", user_id), | |
inputs=[user_state], | |
outputs=[ | |
prompt_display, | |
response_a_display, | |
response_b_display, | |
progress_label, | |
status_message | |
] | |
) | |
btn_b.click( | |
fn=lambda user_id: annotate("b_better", user_id), | |
inputs=[user_state], | |
outputs=[ | |
prompt_display, | |
response_a_display, | |
response_b_display, | |
progress_label, | |
status_message | |
] | |
) | |
btn_equal.click( | |
fn=lambda user_id: annotate("equal", user_id), | |
inputs=[user_state], | |
outputs=[ | |
prompt_display, | |
response_a_display, | |
response_b_display, | |
progress_label, | |
status_message | |
] | |
) | |
logout_btn.click( | |
fn=logout, | |
inputs=[], | |
outputs=[ | |
login_interface, | |
annotation_interface, | |
user_state, | |
login_status, | |
prompt_display, | |
response_a_display, | |
response_b_display, | |
progress_label | |
] | |
) | |
if __name__ == "__main__": | |
app.launch() |