paulinusjua's picture
Update app.py
dd7d85d verified
import gradio as gr
import os
import uuid
import time
from datetime import datetime
from threading import Thread
from google.cloud import storage, bigquery
from transformers import AutoModel,AutoModelForImageClassification, AutoConfig
import torch
from torchvision import transforms
from PIL import Image
from pathlib import Path
from collections import deque
# Setup GCP credentials
credentials_content = os.environ['gcp_cam']
with open('gcp_key.json', 'w') as f:
f.write(credentials_content)
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'gcp_key.json'
# GCP config
bucket_name = os.environ['gcp_bucket']
pkl_blob = os.environ['pretrained_model']
upload_folder = os.environ['user_data_gcp']
bq_dataset = os.environ['bq_dataset']
bq_table = os.environ['bq_table']
# Load transformer model
model = AutoModel.from_pretrained("paulinusjua/cameroon-meals", trust_remote_code=True)
model.eval()
test_input = torch.randn(1, 3, 224, 224) # Assuming standard input
with torch.no_grad():
out = model(test_input)
print("Sample output:", out)
config = AutoConfig.from_pretrained("paulinusjua/cameroon-meals", trust_remote_code=True)
labels = config.labels
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
bq_client = bigquery.Client()
bucket = storage.Client().bucket(bucket_name)
classifier = None
chat_state = {"meal": None}
deferred_feedback = deque(maxlen=100)
def classify_intent(user_input):
global classifier
if classifier is None:
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
result = classifier(user_input, labels)
return result['labels'][0]
def get_meal_info_from_bq(meal_name):
query = f"""
SELECT ingredients, nutrients
FROM `{bq_client.project}.{bq_dataset}.cameroon_meals_info`
WHERE LOWER(meal) = LOWER(@meal_name)
LIMIT 1
"""
job_config = bigquery.QueryJobConfig(
query_parameters=[bigquery.ScalarQueryParameter("meal_name", "STRING", meal_name)]
)
try:
query_job = bq_client.query(query, job_config=job_config)
result = list(query_job.result())
if not result:
return "No extra info found for this meal."
row = result[0]
return f"🍽️ *Ingredients:* {row.ingredients}\n🥗 *Nutrients:* {row.nutrients}"
except Exception as e:
print("BQ Fetch Error:", e)
return "❌ Could not retrieve meal info."
def handle_chat(user_input, last_pred_meal):
if not last_pred_meal:
return "Please upload a meal image first."
intent = classify_intent(user_input)
info = get_meal_info_from_bq(last_pred_meal).split("\n")
if intent == "ingredients":
return info[0]
elif intent == "nutrients":
return info[1]
elif intent == "restaurants":
return f"📍 Restaurants for {last_pred_meal} coming soon."
else:
return "❓ I didn’t understand. Ask about ingredients, nutrients, or restaurants."
def upload_image_to_gcs(local_path, dest_folder, dest_filename):
blob = bucket.blob(f"{upload_folder}/{dest_folder}{dest_filename}")
blob.upload_from_filename(local_path)
return f"gs://{bucket_name}/{upload_folder}/{dest_folder}{dest_filename}"
def log_to_bigquery(record):
table_id = f"{bq_client.project}.{bq_dataset}.{bq_table}"
try:
errors = bq_client.insert_rows_json(table_id, [record])
if errors:
print("BigQuery insert errors:", errors)
except Exception as e:
print("Logging error:", e)
def async_log(record):
Thread(target=log_to_bigquery, args=(record,), daemon=True).start()
def predict(image_path, threshold=0.275, user_feedback=None):
start_time = time.time()
unique_id = str(uuid.uuid4())
timestamp = datetime.utcnow().isoformat()
try:
img = Image.open(image_path).convert("RGB")
img_tensor = transform(img).unsqueeze(0)
except Exception as e:
print("Image processing error:", e)
return "Image could not be processed."
with torch.no_grad():
logits = model(img_tensor)
if isinstance(logits, tuple):
logits = logits[0]
print("Logits shape:", logits.shape)
probs = torch.nn.functional.softmax(logits[0], dim=0)
print("Probabilities:", probs.tolist())
pred_idx = torch.argmax(probs).item()
print("Predicted index:", pred_idx)
pred_class = labels[pred_idx]
prob = probs[pred_idx].item()
dest_folder = f"user_data/{pred_class}/" if prob >= threshold else "user_data/unknown/"
uploaded_gcs_path = upload_image_to_gcs(image_path, dest_folder, f"{unique_id}.jpg")
async_log({
"id": unique_id,
"timestamp": timestamp,
"image_gcs_path": uploaded_gcs_path,
"predicted_class": pred_class,
"confidence": prob,
"threshold": threshold,
"user_feedback": user_feedback or ""
})
deferred_feedback.append((time.time(), unique_id))
chat_state["meal"] = pred_class
return (
f"❓ Unknown Meal: Provide Name. Thanks" if prob <= threshold else
f"⚠️ Meal: {pred_class}, Low Confidence" if 0.275 <= prob <= 0.5 else
f"✅ Meal: {pred_class}"
)
def submit_feedback_only(feedback_text):
if not feedback_text.strip():
return "⚠️ No feedback provided."
now = time.time()
for ts, uid in reversed(deferred_feedback):
if now - ts <= 120:
async_log({
"id": uid,
"timestamp": datetime.utcnow().isoformat(),
"image_gcs_path": "feedback_only",
"predicted_class": "feedback_update",
"confidence": 0.1,
"threshold": 0.0,
"user_feedback": feedback_text
})
return "✅ Feedback Submitted. Thank you!"
return "⚠️ Feedback not linked: time expired."
def unified_predict(upload_files, webcam_img, clipboard_img, feedback):
files = []
if upload_files:
files = [file.name for file in upload_files]
elif webcam_img:
files = [webcam_img]
elif clipboard_img:
files = [clipboard_img]
else:
return "No image provided."
return "\n\n".join([predict(f, user_feedback=feedback) for f in files])
with gr.Blocks(theme="peach", analytics_enabled=False) as demo:
gr.Markdown("""# Cameroonian Meal Recognizer
<p><b>Welcome to Version 1:</b> Identify traditional Cameroonian dishes from a photo.</p>
<p style='background-color: #b3e5fc; padding: 5px; border-radius: 4px;'>This tool offers a friendly playground to learn about our diverse dishes. Therefore multiple image upload is encouraged for improvement in subsequent versions predictions.</p>
<p><i>Choose an input source below, and our AI will recognize the meal.</i></p>
""")
with gr.Tabs():
with gr.Tab("Upload"):
upload_input = gr.File(file_types=["image"], file_count="multiple", label="Upload Meal Images")
with gr.Tab("Webcam"):
webcam_input = gr.Image(type="filepath", sources=["webcam"], label="Capture from Webcam")
with gr.Tab("Clipboard"):
clipboard_input = gr.Image(type="filepath", sources=["clipboard"], label="Paste from Clipboard")
submit_btn = gr.Button("Identify Meal")
output_box = gr.Textbox(label="Prediction Result", lines=6)
gr.Markdown("### Feedback")
with gr.Row():
feedback_input = gr.Textbox(label=None, placeholder="If prediction is wrong, enter correct meal name...", lines=1, scale=4)
feedback_btn = gr.Button("Submit Feedback", scale=1)
feedback_ack = gr.HTML("")
submit_btn.click(fn=unified_predict, inputs=[upload_input, webcam_input, clipboard_input, feedback_input], outputs=output_box)
def styled_feedback_msg(feedback_text):
msg = submit_feedback_only(feedback_text)
if msg.startswith("✅"):
return f"<span style='color: green; font-weight: bold;'>{msg}</span>"
elif msg.startswith("⚠️"):
return f"<span style='color: orange; font-weight: bold;'>{msg}</span>"
return msg
feedback_btn.click(fn=styled_feedback_msg, inputs=feedback_input, outputs=feedback_ack)
#gr.Markdown("### Ask About the Meal")
#with gr.Row():
# user_msg = gr.Textbox(label="Ask about ingredients, nutrients or where to find the meal", placeholder="e.g. What are the ingredients?", lines=1, scale=4)
# chat_btn = gr.Button("Ask", scale=1)
#chat_out = gr.Textbox(label="Bot Reply")
#chat_btn.click(fn=lambda x: handle_chat(x, chat_state["meal"]), inputs=user_msg, outputs=chat_out)
gr.Markdown("""
<p>Future updates will include:
<ul>
<li>Ingredient lists</li>
<li>Meal preparation details</li>
<li>Origin (locality) info</li>
<li>Nearby restaurants</li>
</ul></p>
<p>Learn more on <a href="https://www.linkedin.com/in/paulinus-jua-21255116b/" target="_blank">Paulinus Jua's LinkedIn</a>.</p>
<p>© 2025 Paulinus Jua. All rights reserved.</p>
""")
if __name__ == "__main__":
print("App setup complete — launching Gradio...")
demo.launch(share=True)
print("Launched.")