Spaces:
Sleeping
Sleeping
# ============================================================================ | |
# model_utils.py - أدوات مساعدة لإدارة نموذج Interfuser | |
# ============================================================================ | |
# هذا الملف مسؤول عن كل العمليات المتعلقة بنموذج PyTorch: | |
# 1. العثور على النماذج المتاحة. | |
# 2. تحميل نموذج محدد إلى الذاكرة (CPU/GPU). | |
# 3. توفير وصول سهل إلى النموذج المحمل حاليًا. | |
# هذا يفصل منطق النموذج بشكل كامل عن منطق واجهة المستخدم. | |
# ============================================================================ | |
import os | |
import torch | |
import logging | |
# استيراد الأدوات اللازمة من ملف تعريف النموذج | |
try: | |
from model_definition import load_and_prepare_model, create_model_config | |
except ImportError as e: | |
print(f"خطأ في الاستيراد: تأكد من وجود ملف model_definition.py. الخطأ: {e}") | |
exit() | |
# --- المتغيرات العامة الخاصة بالنموذج --- | |
# الدليل الذي يحتوي على ملفات النماذج | |
MODEL_DIR = "model" | |
# تحديد الجهاز تلقائيًا (سيستخدم GPU إذا كان متاحًا) | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# هذه المتغيرات ستحتفظ بالنموذج المحمل حاليًا في الذاكرة لتجنب إعادة التحميل | |
CURRENTLY_LOADED_MODEL: torch.nn.Module = None | |
CURRENT_MODEL_NAME: str = None | |
# إعداد نظام التسجيل لمتابعة عمليات التحميل | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
def get_available_models(): | |
""" | |
يبحث في مجلد 'model' ويعيد قائمة بأسماء النماذج المتاحة. | |
Returns: | |
list[str]: قائمة بأسماء الملفات للنماذج المتاحة. | |
""" | |
if not os.path.isdir(MODEL_DIR): | |
logging.warning(f"مجلد النماذج '{MODEL_DIR}' غير موجود.") | |
return [] | |
try: | |
models = [f for f in os.listdir(MODEL_DIR) if f.endswith(('.pth', '.pt'))] | |
logging.info(f"تم العثور على النماذج التالية: {models}") | |
return models | |
except Exception as e: | |
logging.error(f"حدث خطأ أثناء قراءة مجلد النماذج: {e}") | |
return [] | |
def load_model_by_name(model_name: str): | |
""" | |
يحمل نموذجًا محددًا بالاسم. إذا كان النموذج المطلوب محملًا بالفعل، | |
فإنه يتخطى عملية التحميل. | |
Args: | |
model_name (str): اسم ملف النموذج المراد تحميله (e.g., 'best_model.pth'). | |
Returns: | |
str: رسالة نصية تشير إلى حالة عملية التحميل. | |
""" | |
global CURRENTLY_LOADED_MODEL, CURRENT_MODEL_NAME | |
if not model_name: | |
return "لم يتم اختيار نموذج." | |
# إذا كان النموذج المطلوب هو نفسه المحمل حاليًا، فلا داعي لفعل أي شيء | |
if model_name == CURRENT_MODEL_NAME and CURRENTLY_LOADED_MODEL is not None: | |
message = f"النموذج '{model_name}' محمل بالفعل." | |
logging.info(message) | |
return message | |
logging.info(f"بدء تحميل النموذج: '{model_name}' على الجهاز {DEVICE}...") | |
model_path = os.path.join(MODEL_DIR, model_name) | |
if not os.path.exists(model_path): | |
error_message = f"ملف النموذج '{model_path}' غير موجود." | |
logging.error(error_message) | |
# تفريغ النموذج الحالي إذا كان المسار خاطئًا | |
CURRENTLY_LOADED_MODEL = None | |
CURRENT_MODEL_NAME = None | |
raise FileNotFoundError(error_message) | |
try: | |
# استخدام الدوال من model_definition.py لإنشاء وتحميل النموذج | |
model_config = create_model_config(model_path=model_path) | |
model = load_and_prepare_model(model_config, DEVICE) | |
# تحديث المتغيرات العامة بالنموذج الجديد | |
CURRENTLY_LOADED_MODEL = model | |
CURRENT_MODEL_NAME = model_name | |
success_message = f"✅ تم تحميل النموذج بنجاح: {model_name}" | |
logging.info(success_message) | |
return success_message | |
except Exception as e: | |
logging.error(f"❌ حدث خطأ فادح أثناء تحميل النموذج '{model_name}': {e}", exc_info=True) | |
# إعادة تعيين المتغيرات العامة في حالة الفشل | |
CURRENTLY_LOADED_MODEL = None | |
CURRENT_MODEL_NAME = None | |
# إرسال الخطأ للأعلى ليتم عرضه في واجهة Gradio | |
raise e | |
def get_current_model() -> torch.nn.Module: | |
""" | |
يعيد كائن النموذج المحمل حاليًا. | |
إذا لم يكن هناك نموذج محمل، يحاول تحميل أول نموذج متاح كخيار افتراضي. | |
Returns: | |
torch.nn.Module or None: كائن النموذج المحمل أو None إذا فشل التحميل. | |
""" | |
if CURRENTLY_LOADED_MODEL is None: | |
logging.info("لا يوجد نموذج محمل حاليًا. محاولة تحميل النموذج الافتراضي...") | |
available_models = get_available_models() | |
if available_models: | |
# محاولة تحميل أول نموذج في القائمة | |
try: | |
load_model_by_name(available_models[0]) | |
except Exception as e: | |
logging.error(f"فشل تحميل النموذج الافتراضي '{available_models[0]}': {e}") | |
return None | |
else: | |
logging.warning("لا توجد نماذج متاحة في مجلد 'model' لتحميلها.") | |
return None | |
return CURRENTLY_LOADED_MODEL |