import gradio as gr import os import uuid import time from datetime import datetime from threading import Thread from google.cloud import storage, bigquery from transformers import AutoModel,AutoModelForImageClassification, AutoConfig import torch from torchvision import transforms from PIL import Image from pathlib import Path from collections import deque # Setup GCP credentials credentials_content = os.environ['gcp_cam'] with open('gcp_key.json', 'w') as f: f.write(credentials_content) os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'gcp_key.json' # GCP config bucket_name = os.environ['gcp_bucket'] pkl_blob = os.environ['pretrained_model'] upload_folder = os.environ['user_data_gcp'] bq_dataset = os.environ['bq_dataset'] bq_table = os.environ['bq_table'] # Load transformer model model = AutoModel.from_pretrained("paulinusjua/cameroon-meals", trust_remote_code=True) model.eval() test_input = torch.randn(1, 3, 224, 224) # Assuming standard input with torch.no_grad(): out = model(test_input) print("Sample output:", out) config = AutoConfig.from_pretrained("paulinusjua/cameroon-meals", trust_remote_code=True) labels = config.labels transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) bq_client = bigquery.Client() bucket = storage.Client().bucket(bucket_name) classifier = None chat_state = {"meal": None} deferred_feedback = deque(maxlen=100) def classify_intent(user_input): global classifier if classifier is None: from transformers import pipeline classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") result = classifier(user_input, labels) return result['labels'][0] def get_meal_info_from_bq(meal_name): query = f""" SELECT ingredients, nutrients FROM `{bq_client.project}.{bq_dataset}.cameroon_meals_info` WHERE LOWER(meal) = LOWER(@meal_name) LIMIT 1 """ job_config = bigquery.QueryJobConfig( query_parameters=[bigquery.ScalarQueryParameter("meal_name", "STRING", meal_name)] ) try: query_job = bq_client.query(query, job_config=job_config) result = list(query_job.result()) if not result: return "No extra info found for this meal." row = result[0] return f"🍽️ *Ingredients:* {row.ingredients}\n🥗 *Nutrients:* {row.nutrients}" except Exception as e: print("BQ Fetch Error:", e) return "❌ Could not retrieve meal info." def handle_chat(user_input, last_pred_meal): if not last_pred_meal: return "Please upload a meal image first." intent = classify_intent(user_input) info = get_meal_info_from_bq(last_pred_meal).split("\n") if intent == "ingredients": return info[0] elif intent == "nutrients": return info[1] elif intent == "restaurants": return f"📍 Restaurants for {last_pred_meal} coming soon." else: return "❓ I didn’t understand. Ask about ingredients, nutrients, or restaurants." def upload_image_to_gcs(local_path, dest_folder, dest_filename): blob = bucket.blob(f"{upload_folder}/{dest_folder}{dest_filename}") blob.upload_from_filename(local_path) return f"gs://{bucket_name}/{upload_folder}/{dest_folder}{dest_filename}" def log_to_bigquery(record): table_id = f"{bq_client.project}.{bq_dataset}.{bq_table}" try: errors = bq_client.insert_rows_json(table_id, [record]) if errors: print("BigQuery insert errors:", errors) except Exception as e: print("Logging error:", e) def async_log(record): Thread(target=log_to_bigquery, args=(record,), daemon=True).start() def predict(image_path, threshold=0.275, user_feedback=None): start_time = time.time() unique_id = str(uuid.uuid4()) timestamp = datetime.utcnow().isoformat() try: img = Image.open(image_path).convert("RGB") img_tensor = transform(img).unsqueeze(0) except Exception as e: print("Image processing error:", e) return "Image could not be processed." with torch.no_grad(): logits = model(img_tensor) if isinstance(logits, tuple): logits = logits[0] print("Logits shape:", logits.shape) probs = torch.nn.functional.softmax(logits[0], dim=0) print("Probabilities:", probs.tolist()) pred_idx = torch.argmax(probs).item() print("Predicted index:", pred_idx) pred_class = labels[pred_idx] prob = probs[pred_idx].item() dest_folder = f"user_data/{pred_class}/" if prob >= threshold else "user_data/unknown/" uploaded_gcs_path = upload_image_to_gcs(image_path, dest_folder, f"{unique_id}.jpg") async_log({ "id": unique_id, "timestamp": timestamp, "image_gcs_path": uploaded_gcs_path, "predicted_class": pred_class, "confidence": prob, "threshold": threshold, "user_feedback": user_feedback or "" }) deferred_feedback.append((time.time(), unique_id)) chat_state["meal"] = pred_class return ( f"❓ Unknown Meal: Provide Name. Thanks" if prob <= threshold else f"⚠️ Meal: {pred_class}, Low Confidence" if 0.275 <= prob <= 0.5 else f"✅ Meal: {pred_class}" ) def submit_feedback_only(feedback_text): if not feedback_text.strip(): return "⚠️ No feedback provided." now = time.time() for ts, uid in reversed(deferred_feedback): if now - ts <= 120: async_log({ "id": uid, "timestamp": datetime.utcnow().isoformat(), "image_gcs_path": "feedback_only", "predicted_class": "feedback_update", "confidence": 0.1, "threshold": 0.0, "user_feedback": feedback_text }) return "✅ Feedback Submitted. Thank you!" return "⚠️ Feedback not linked: time expired." def unified_predict(upload_files, webcam_img, clipboard_img, feedback): files = [] if upload_files: files = [file.name for file in upload_files] elif webcam_img: files = [webcam_img] elif clipboard_img: files = [clipboard_img] else: return "No image provided." return "\n\n".join([predict(f, user_feedback=feedback) for f in files]) with gr.Blocks(theme="peach", analytics_enabled=False) as demo: gr.Markdown("""# Cameroonian Meal Recognizer
Welcome to Version 1: Identify traditional Cameroonian dishes from a photo.
This tool offers a friendly playground to learn about our diverse dishes. Therefore multiple image upload is encouraged for improvement in subsequent versions predictions.
Choose an input source below, and our AI will recognize the meal.
""") with gr.Tabs(): with gr.Tab("Upload"): upload_input = gr.File(file_types=["image"], file_count="multiple", label="Upload Meal Images") with gr.Tab("Webcam"): webcam_input = gr.Image(type="filepath", sources=["webcam"], label="Capture from Webcam") with gr.Tab("Clipboard"): clipboard_input = gr.Image(type="filepath", sources=["clipboard"], label="Paste from Clipboard") submit_btn = gr.Button("Identify Meal") output_box = gr.Textbox(label="Prediction Result", lines=6) gr.Markdown("### Feedback") with gr.Row(): feedback_input = gr.Textbox(label=None, placeholder="If prediction is wrong, enter correct meal name...", lines=1, scale=4) feedback_btn = gr.Button("Submit Feedback", scale=1) feedback_ack = gr.HTML("") submit_btn.click(fn=unified_predict, inputs=[upload_input, webcam_input, clipboard_input, feedback_input], outputs=output_box) def styled_feedback_msg(feedback_text): msg = submit_feedback_only(feedback_text) if msg.startswith("✅"): return f"{msg}" elif msg.startswith("⚠️"): return f"{msg}" return msg feedback_btn.click(fn=styled_feedback_msg, inputs=feedback_input, outputs=feedback_ack) #gr.Markdown("### Ask About the Meal") #with gr.Row(): # user_msg = gr.Textbox(label="Ask about ingredients, nutrients or where to find the meal", placeholder="e.g. What are the ingredients?", lines=1, scale=4) # chat_btn = gr.Button("Ask", scale=1) #chat_out = gr.Textbox(label="Bot Reply") #chat_btn.click(fn=lambda x: handle_chat(x, chat_state["meal"]), inputs=user_msg, outputs=chat_out) gr.Markdown("""Future updates will include:
Learn more on Paulinus Jua's LinkedIn.
© 2025 Paulinus Jua. All rights reserved.
""") if __name__ == "__main__": print("App setup complete — launching Gradio...") demo.launch(share=True) print("Launched.")