Basee_model / model_utils.py
mohammed-aljafry's picture
Upload model_utils.py with huggingface_hub
cbcacad verified
# ============================================================================
# model_utils.py - أدوات مساعدة لإدارة نموذج Interfuser
# ============================================================================
# هذا الملف مسؤول عن كل العمليات المتعلقة بنموذج PyTorch:
# 1. العثور على النماذج المتاحة.
# 2. تحميل نموذج محدد إلى الذاكرة (CPU/GPU).
# 3. توفير وصول سهل إلى النموذج المحمل حاليًا.
# هذا يفصل منطق النموذج بشكل كامل عن منطق واجهة المستخدم.
# ============================================================================
import os
import torch
import logging
# استيراد الأدوات اللازمة من ملف تعريف النموذج
try:
from model_definition import load_and_prepare_model, create_model_config
except ImportError as e:
print(f"خطأ في الاستيراد: تأكد من وجود ملف model_definition.py. الخطأ: {e}")
exit()
# --- المتغيرات العامة الخاصة بالنموذج ---
# الدليل الذي يحتوي على ملفات النماذج
MODEL_DIR = "model"
# تحديد الجهاز تلقائيًا (سيستخدم GPU إذا كان متاحًا)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# هذه المتغيرات ستحتفظ بالنموذج المحمل حاليًا في الذاكرة لتجنب إعادة التحميل
CURRENTLY_LOADED_MODEL: torch.nn.Module = None
CURRENT_MODEL_NAME: str = None
# إعداد نظام التسجيل لمتابعة عمليات التحميل
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def get_available_models():
"""
يبحث في مجلد 'model' ويعيد قائمة بأسماء النماذج المتاحة.
Returns:
list[str]: قائمة بأسماء الملفات للنماذج المتاحة.
"""
if not os.path.isdir(MODEL_DIR):
logging.warning(f"مجلد النماذج '{MODEL_DIR}' غير موجود.")
return []
try:
models = [f for f in os.listdir(MODEL_DIR) if f.endswith(('.pth', '.pt'))]
logging.info(f"تم العثور على النماذج التالية: {models}")
return models
except Exception as e:
logging.error(f"حدث خطأ أثناء قراءة مجلد النماذج: {e}")
return []
def load_model_by_name(model_name: str):
"""
يحمل نموذجًا محددًا بالاسم. إذا كان النموذج المطلوب محملًا بالفعل،
فإنه يتخطى عملية التحميل.
Args:
model_name (str): اسم ملف النموذج المراد تحميله (e.g., 'best_model.pth').
Returns:
str: رسالة نصية تشير إلى حالة عملية التحميل.
"""
global CURRENTLY_LOADED_MODEL, CURRENT_MODEL_NAME
if not model_name:
return "لم يتم اختيار نموذج."
# إذا كان النموذج المطلوب هو نفسه المحمل حاليًا، فلا داعي لفعل أي شيء
if model_name == CURRENT_MODEL_NAME and CURRENTLY_LOADED_MODEL is not None:
message = f"النموذج '{model_name}' محمل بالفعل."
logging.info(message)
return message
logging.info(f"بدء تحميل النموذج: '{model_name}' على الجهاز {DEVICE}...")
model_path = os.path.join(MODEL_DIR, model_name)
if not os.path.exists(model_path):
error_message = f"ملف النموذج '{model_path}' غير موجود."
logging.error(error_message)
# تفريغ النموذج الحالي إذا كان المسار خاطئًا
CURRENTLY_LOADED_MODEL = None
CURRENT_MODEL_NAME = None
raise FileNotFoundError(error_message)
try:
# استخدام الدوال من model_definition.py لإنشاء وتحميل النموذج
model_config = create_model_config(model_path=model_path)
model = load_and_prepare_model(model_config, DEVICE)
# تحديث المتغيرات العامة بالنموذج الجديد
CURRENTLY_LOADED_MODEL = model
CURRENT_MODEL_NAME = model_name
success_message = f"✅ تم تحميل النموذج بنجاح: {model_name}"
logging.info(success_message)
return success_message
except Exception as e:
logging.error(f"❌ حدث خطأ فادح أثناء تحميل النموذج '{model_name}': {e}", exc_info=True)
# إعادة تعيين المتغيرات العامة في حالة الفشل
CURRENTLY_LOADED_MODEL = None
CURRENT_MODEL_NAME = None
# إرسال الخطأ للأعلى ليتم عرضه في واجهة Gradio
raise e
def get_current_model() -> torch.nn.Module:
"""
يعيد كائن النموذج المحمل حاليًا.
إذا لم يكن هناك نموذج محمل، يحاول تحميل أول نموذج متاح كخيار افتراضي.
Returns:
torch.nn.Module or None: كائن النموذج المحمل أو None إذا فشل التحميل.
"""
if CURRENTLY_LOADED_MODEL is None:
logging.info("لا يوجد نموذج محمل حاليًا. محاولة تحميل النموذج الافتراضي...")
available_models = get_available_models()
if available_models:
# محاولة تحميل أول نموذج في القائمة
try:
load_model_by_name(available_models[0])
except Exception as e:
logging.error(f"فشل تحميل النموذج الافتراضي '{available_models[0]}': {e}")
return None
else:
logging.warning("لا توجد نماذج متاحة في مجلد 'model' لتحميلها.")
return None
return CURRENTLY_LOADED_MODEL