|
""" |
|
OCR Arena - Main Application |
|
A Gradio web application for comparing OCR results from different AI models. |
|
""" |
|
|
|
import gradio as gr |
|
import logging |
|
import os |
|
import datetime |
|
from dotenv import load_dotenv |
|
from storage import upload_file_to_bucket |
|
from db import add_vote, get_all_votes, calculate_elo_ratings_from_votes |
|
from ocr_models import process_model_ocr, initialize_gemini, initialize_mistral, initialize_openai |
|
from ui_helpers import ( |
|
get_model_display_name, select_random_models, format_votes_table, |
|
format_elo_leaderboard |
|
) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
initialize_gemini() |
|
initialize_mistral() |
|
initialize_openai() |
|
|
|
|
|
SUPABASE_URL = os.getenv("SUPABASE_URL") |
|
SUPABASE_KEY = os.getenv("SUPABASE_KEY") |
|
|
|
|
|
current_gemini_output = "" |
|
current_mistral_output = "" |
|
current_openai_output = "" |
|
current_gpt5_output = "" |
|
current_gpt5_output = "" |
|
current_image_url = "" |
|
current_voted_users = set() |
|
current_model_a = "" |
|
current_model_b = "" |
|
|
|
|
|
def get_default_username(profile: gr.OAuthProfile | None) -> str: |
|
"""Returns the username if the user is logged in, or an empty string if not logged in.""" |
|
if profile is None: |
|
return "" |
|
return profile.username |
|
|
|
def get_current_username(profile_or_username) -> str: |
|
"""Returns the username from login or "Anonymous" if not logged in.""" |
|
|
|
if hasattr(profile_or_username, 'username') and profile_or_username.username: |
|
return profile_or_username.username |
|
|
|
elif isinstance(profile_or_username, str) and profile_or_username.strip(): |
|
|
|
if profile_or_username.startswith("Logout (") and profile_or_username.endswith(")"): |
|
return profile_or_username[8:-1] |
|
|
|
elif profile_or_username != "Sign in with Hugging Face": |
|
return profile_or_username.strip() |
|
|
|
|
|
return "Anonymous" |
|
|
|
def process_image(image): |
|
"""Process uploaded image and select random models for comparison.""" |
|
global current_gemini_output, current_mistral_output, current_openai_output, current_image_url, current_voted_users, current_model_a, current_model_b |
|
|
|
if image is None: |
|
return ( |
|
"Please upload an image.", |
|
"Please upload an image.", |
|
gr.update(visible=False), |
|
gr.update(visible=False) |
|
) |
|
|
|
|
|
current_voted_users.clear() |
|
|
|
|
|
model_a, model_b = select_random_models() |
|
current_model_a = model_a |
|
current_model_b = model_b |
|
|
|
logger.info(f"🎲 Randomly selected two models for comparison") |
|
|
|
try: |
|
|
|
temp_filename = f"temp_image_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png" |
|
image.save(temp_filename) |
|
|
|
|
|
logger.info(f"📤 Uploading image to Supabase storage: {temp_filename}") |
|
upload_result = upload_file_to_bucket( |
|
file_path=temp_filename, |
|
bucket_name="images", |
|
storage_path=f"ocr_images/{temp_filename}", |
|
file_options={"cache-control": "3600", "upsert": "false"} |
|
) |
|
|
|
if upload_result["success"]: |
|
logger.info(f"✅ Image uploaded successfully: {upload_result['storage_path']}") |
|
logger.info(f"🔗 Public URL: {upload_result['public_url']}") |
|
|
|
current_image_url = upload_result.get('public_url') or f"{SUPABASE_URL}/storage/v1/object/public/images/ocr_images/{temp_filename}" |
|
else: |
|
logger.error(f"❌ Image upload failed: {upload_result['error']}") |
|
current_image_url = "" |
|
|
|
|
|
try: |
|
os.remove(temp_filename) |
|
logger.info(f"🗑️ Cleaned up temporary file: {temp_filename}") |
|
except Exception as e: |
|
logger.warning(f"⚠️ Could not remove temporary file {temp_filename}: {e}") |
|
|
|
|
|
return ( |
|
"Please click 'Run OCR' to start processing.", |
|
"Please click 'Run OCR' to start processing.", |
|
gr.update(visible=False), |
|
gr.update(visible=False) |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing image: {e}") |
|
return ( |
|
f"Error processing image: {e}", |
|
f"Error processing image: {e}", |
|
gr.update(visible=False), |
|
gr.update(visible=False) |
|
) |
|
|
|
def check_ocr_completion(model_a_output, model_b_output): |
|
"""Check if both OCR results are ready and update UI accordingly.""" |
|
global current_gemini_output, current_mistral_output, current_openai_output, current_gpt5_output, current_model_a, current_model_b |
|
|
|
|
|
model_a_ready = (model_a_output and |
|
model_a_output != "Please upload an image." and |
|
model_a_output != "Processing OCR..." and |
|
model_a_output != "Please click 'Run OCR' to start processing." and |
|
not model_a_output.startswith("OCR error:")) |
|
|
|
model_b_ready = (model_b_output and |
|
model_b_output != "Please upload an image." and |
|
model_b_output != "Processing OCR..." and |
|
model_b_output != "Please click 'Run OCR' to start processing." and |
|
not model_b_output.startswith("OCR error:")) |
|
|
|
|
|
if model_a_ready: |
|
if current_model_a == "gemini": |
|
current_gemini_output = model_a_output |
|
elif current_model_a == "mistral": |
|
current_mistral_output = model_a_output |
|
elif current_model_a == "openai": |
|
current_openai_output = model_a_output |
|
elif current_model_a == "gpt5": |
|
current_gpt5_output = model_a_output |
|
|
|
if model_b_ready: |
|
if current_model_b == "gemini": |
|
current_gemini_output = model_b_output |
|
elif current_model_b == "mistral": |
|
current_mistral_output = model_b_output |
|
elif current_model_b == "openai": |
|
current_openai_output = model_b_output |
|
elif current_model_b == "gpt5": |
|
current_gpt5_output = model_b_output |
|
|
|
|
|
if model_a_ready and model_b_ready: |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=True) |
|
) |
|
else: |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False) |
|
) |
|
|
|
def load_vote_data(): |
|
"""Load and format vote data for display.""" |
|
try: |
|
|
|
votes = get_all_votes() |
|
votes_table_html = format_votes_table(votes) |
|
|
|
return votes_table_html |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading vote data: {e}") |
|
error_html = f"<p style='color: red;'>Error loading data: {e}</p>" |
|
return error_html |
|
|
|
def load_elo_leaderboard(): |
|
"""Load and format ELO leaderboard data.""" |
|
try: |
|
|
|
votes = get_all_votes() |
|
|
|
|
|
elo_ratings = calculate_elo_ratings_from_votes(votes) |
|
|
|
|
|
vote_counts = { |
|
"gemini": 0, |
|
"mistral": 0, |
|
"openai": 0, |
|
"gpt5": 0 |
|
} |
|
|
|
for vote in votes: |
|
model_a = vote.get('model_a') |
|
model_b = vote.get('model_b') |
|
vote_choice = vote.get('vote') |
|
|
|
if vote_choice == 'model_a' and model_a in vote_counts: |
|
vote_counts[model_a] += 1 |
|
elif vote_choice == 'model_b' and model_b in vote_counts: |
|
vote_counts[model_b] += 1 |
|
|
|
|
|
leaderboard_html = format_elo_leaderboard(elo_ratings, vote_counts) |
|
|
|
return leaderboard_html |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading ELO leaderboard: {e}") |
|
error_html = f"<p style='color: red;'>Error loading ELO leaderboard: {e}</p>" |
|
return error_html |
|
|
|
|
|
with gr.Blocks(title="OCR Comparison", css=""" |
|
.output-box { |
|
border: 2px solid #e0e0e0; |
|
border-radius: 8px; |
|
padding: 15px; |
|
margin: 10px 0; |
|
background-color: #f9f9f9; |
|
min-height: 200px; |
|
} |
|
.output-box:hover { |
|
border-color: #007bff; |
|
box-shadow: 0 2px 8px rgba(0,123,255,0.1); |
|
} |
|
.vote-table { |
|
border-collapse: collapse; |
|
width: 100%; |
|
margin: 10px 0; |
|
min-width: 800px; |
|
} |
|
.vote-table th, .vote-table td { |
|
border: 1px solid #ddd; |
|
padding: 6px; |
|
text-align: left; |
|
vertical-align: top; |
|
} |
|
.vote-table th { |
|
background-color: #f2f2f2; |
|
font-weight: bold; |
|
position: sticky; |
|
top: 0; |
|
z-index: 10; |
|
} |
|
.vote-table tr:nth-child(even) { |
|
background-color: #f9f9f9; |
|
} |
|
.vote-table tr:hover { |
|
background-color: #f5f5f5; |
|
} |
|
.vote-table img { |
|
transition: transform 0.2s ease; |
|
max-width: 100%; |
|
height: auto; |
|
} |
|
.vote-table img:hover { |
|
transform: scale(1.1); |
|
box-shadow: 0 4px 8px rgba(0,0,0,0.2); |
|
} |
|
""") as demo: |
|
|
|
with gr.Tabs(): |
|
|
|
with gr.Tab("⚔️ Arena", id=0): |
|
gr.Markdown("# ⚔️ OCR Arena: Random Model Selection") |
|
gr.Markdown("Upload an image to compare two randomly selected OCR models.") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
username_display = gr.Textbox( |
|
label="Current User", |
|
placeholder="Login with Hugging Face to vote (optional) - Anonymous users welcome!", |
|
interactive=False, |
|
show_label=False |
|
) |
|
with gr.Column(scale=1): |
|
login_button = gr.LoginButton() |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gemini_vote_btn = gr.Button("A is better", variant="primary", size="sm", visible=False) |
|
gemini_output = gr.Markdown(label="Model A Output", elem_classes=["output-box"]) |
|
|
|
image_input = gr.Image(type="pil", label="Upload or Paste Image") |
|
|
|
with gr.Column(): |
|
mistral_vote_btn = gr.Button("B is better", variant="primary", size="sm", visible=False) |
|
mistral_output = gr.Markdown(label="Model B Output", elem_classes=["output-box"]) |
|
|
|
|
|
|
|
with gr.Row(): |
|
process_btn = gr.Button("🔍 Run OCR", variant="primary") |
|
|
|
|
|
with gr.Tab("📊 Data", id=1): |
|
gr.Markdown("# 📊 Vote Data") |
|
gr.Markdown("View all votes from the OCR Arena") |
|
|
|
with gr.Row(): |
|
refresh_btn = gr.Button("🔄 Refresh Data", variant="secondary") |
|
|
|
with gr.Row(): |
|
votes_table = gr.HTML( |
|
value="<p>Loading vote data...</p>", |
|
label="📋 All Votes (Latest First)" |
|
) |
|
|
|
|
|
with gr.Tab("🏆 Leaderboard", id=2): |
|
gr.Markdown("# 🏆 ELO Leaderboard") |
|
gr.Markdown("See how the models rank based on their ELO ratings from head-to-head comparisons.") |
|
|
|
with gr.Row(): |
|
refresh_leaderboard_btn = gr.Button("🔄 Refresh Leaderboard", variant="secondary") |
|
|
|
with gr.Row(): |
|
leaderboard_display = gr.HTML( |
|
value="<p>Loading ELO leaderboard...</p>", |
|
label="🏆 Model Rankings" |
|
) |
|
|
|
|
|
def vote_model_a(profile_or_username): |
|
global current_gemini_output, current_mistral_output, current_openai_output, current_gpt5_output, current_image_url, current_voted_users, current_model_a, current_model_b |
|
|
|
|
|
username = get_current_username(profile_or_username) |
|
|
|
if not username: |
|
username = "Anonymous" |
|
|
|
|
|
if username in current_voted_users: |
|
gr.Info(f"You have already voted for this image, {username}!") |
|
return |
|
|
|
try: |
|
|
|
image_url = current_image_url if current_image_url else "no_image" |
|
|
|
|
|
logger.info(f"📊 Adding Model A vote for user: {username}") |
|
def output_for(model: str) -> str: |
|
return { |
|
"gemini": current_gemini_output, |
|
"mistral": current_mistral_output, |
|
"openai": current_openai_output, |
|
"gpt5": current_gpt5_output, |
|
}.get(model, "") |
|
|
|
add_vote( |
|
username=username, |
|
model_a=current_model_a, |
|
model_b=current_model_b, |
|
model_a_output=output_for(current_model_a), |
|
model_b_output=output_for(current_model_b), |
|
vote="model_a", |
|
image_url=image_url |
|
) |
|
|
|
|
|
current_voted_users.add(username) |
|
|
|
model_a_name = get_model_display_name(current_model_a) |
|
model_b_name = get_model_display_name(current_model_b) |
|
info_message = ( |
|
f"<p>You voted for <strong style='color:green;'>{model_a_name}</strong>.</p>" |
|
f"<p><span style='color:green;'>{model_a_name}</span> - " |
|
f"<span style='color:blue;'>{model_b_name}</span></p>" |
|
) |
|
gr.Info(info_message) |
|
|
|
except Exception as e: |
|
logger.error(f"❌ Error adding Model A vote: {e}") |
|
gr.Info(f"Error recording vote: {e}") |
|
|
|
def vote_model_b(profile_or_username): |
|
global current_gemini_output, current_mistral_output, current_openai_output, current_gpt5_output, current_image_url, current_voted_users, current_model_a, current_model_b |
|
|
|
|
|
username = get_current_username(profile_or_username) |
|
|
|
if not username: |
|
username = "Anonymous" |
|
|
|
|
|
if username in current_voted_users: |
|
gr.Info(f"You have already voted for this image, {username}!") |
|
return |
|
|
|
try: |
|
|
|
image_url = current_image_url if current_image_url else "no_image" |
|
|
|
|
|
logger.info(f"📊 Adding Model B vote for user: {username}") |
|
def output_for(model: str) -> str: |
|
return { |
|
"gemini": current_gemini_output, |
|
"mistral": current_mistral_output, |
|
"openai": current_openai_output, |
|
"gpt5": current_gpt5_output, |
|
}.get(model, "") |
|
|
|
add_vote( |
|
username=username, |
|
model_a=current_model_a, |
|
model_b=current_model_b, |
|
model_a_output=output_for(current_model_a), |
|
model_b_output=output_for(current_model_b), |
|
vote="model_b", |
|
image_url=image_url |
|
) |
|
|
|
|
|
current_voted_users.add(username) |
|
|
|
model_a_name = get_model_display_name(current_model_a) |
|
model_b_name = get_model_display_name(current_model_b) |
|
info_message = ( |
|
f"<p>You voted for <strong style='color:blue;'>{model_b_name}</strong>.</p>" |
|
f"<p><span style='color:green;'>{model_a_name}</span> - " |
|
f"<span style='color:blue;'>{model_b_name}</span></p>" |
|
) |
|
gr.Info(info_message) |
|
|
|
except Exception as e: |
|
logger.error(f"❌ Error adding Model B vote: {e}") |
|
gr.Info(f"Error recording vote: {e}") |
|
|
|
|
|
process_btn.click( |
|
process_image, |
|
inputs=[image_input], |
|
outputs=[gemini_output, mistral_output, gemini_vote_btn, mistral_vote_btn], |
|
) |
|
|
|
|
|
def process_model_a_ocr(image): |
|
global current_model_a |
|
return process_model_ocr(image, current_model_a) |
|
|
|
def process_model_b_ocr(image): |
|
global current_model_b |
|
return process_model_ocr(image, current_model_b) |
|
|
|
process_btn.click( |
|
process_model_a_ocr, |
|
inputs=[image_input], |
|
outputs=[gemini_output], |
|
) |
|
|
|
process_btn.click( |
|
process_model_b_ocr, |
|
inputs=[image_input], |
|
outputs=[mistral_output], |
|
) |
|
|
|
|
|
gemini_output.change( |
|
check_ocr_completion, |
|
inputs=[gemini_output, mistral_output], |
|
outputs=[gemini_vote_btn, mistral_vote_btn], |
|
) |
|
|
|
mistral_output.change( |
|
check_ocr_completion, |
|
inputs=[gemini_output, mistral_output], |
|
outputs=[gemini_vote_btn, mistral_vote_btn], |
|
) |
|
|
|
gemini_vote_btn.click( |
|
vote_model_a, |
|
inputs=[login_button] |
|
) |
|
|
|
mistral_vote_btn.click( |
|
vote_model_b, |
|
inputs=[login_button] |
|
) |
|
|
|
|
|
refresh_btn.click( |
|
load_vote_data, |
|
inputs=None, |
|
outputs=[votes_table] |
|
) |
|
|
|
|
|
refresh_leaderboard_btn.click( |
|
load_elo_leaderboard, |
|
inputs=None, |
|
outputs=[leaderboard_display] |
|
) |
|
|
|
|
|
demo.load(fn=get_default_username, inputs=None, outputs=username_display) |
|
|
|
|
|
demo.load(fn=load_vote_data, inputs=None, outputs=[votes_table]) |
|
|
|
|
|
demo.load(fn=load_elo_leaderboard, inputs=None, outputs=[leaderboard_display]) |
|
|
|
if __name__ == "__main__": |
|
logger.info("Starting OCR Comparison App...") |
|
try: |
|
|
|
demo.launch(share=True) |
|
except ValueError as e: |
|
logger.warning(f"Localhost not accessible: {e}") |
|
logger.info("Launching with public URL...") |
|
demo.launch(share=True) |