Spaces:
Sleeping
Sleeping
File size: 6,304 Bytes
cbcacad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# ============================================================================
# model_utils.py - أدوات مساعدة لإدارة نموذج Interfuser
# ============================================================================
# هذا الملف مسؤول عن كل العمليات المتعلقة بنموذج PyTorch:
# 1. العثور على النماذج المتاحة.
# 2. تحميل نموذج محدد إلى الذاكرة (CPU/GPU).
# 3. توفير وصول سهل إلى النموذج المحمل حاليًا.
# هذا يفصل منطق النموذج بشكل كامل عن منطق واجهة المستخدم.
# ============================================================================
import os
import torch
import logging
# استيراد الأدوات اللازمة من ملف تعريف النموذج
try:
from model_definition import load_and_prepare_model, create_model_config
except ImportError as e:
print(f"خطأ في الاستيراد: تأكد من وجود ملف model_definition.py. الخطأ: {e}")
exit()
# --- المتغيرات العامة الخاصة بالنموذج ---
# الدليل الذي يحتوي على ملفات النماذج
MODEL_DIR = "model"
# تحديد الجهاز تلقائيًا (سيستخدم GPU إذا كان متاحًا)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# هذه المتغيرات ستحتفظ بالنموذج المحمل حاليًا في الذاكرة لتجنب إعادة التحميل
CURRENTLY_LOADED_MODEL: torch.nn.Module = None
CURRENT_MODEL_NAME: str = None
# إعداد نظام التسجيل لمتابعة عمليات التحميل
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def get_available_models():
"""
يبحث في مجلد 'model' ويعيد قائمة بأسماء النماذج المتاحة.
Returns:
list[str]: قائمة بأسماء الملفات للنماذج المتاحة.
"""
if not os.path.isdir(MODEL_DIR):
logging.warning(f"مجلد النماذج '{MODEL_DIR}' غير موجود.")
return []
try:
models = [f for f in os.listdir(MODEL_DIR) if f.endswith(('.pth', '.pt'))]
logging.info(f"تم العثور على النماذج التالية: {models}")
return models
except Exception as e:
logging.error(f"حدث خطأ أثناء قراءة مجلد النماذج: {e}")
return []
def load_model_by_name(model_name: str):
"""
يحمل نموذجًا محددًا بالاسم. إذا كان النموذج المطلوب محملًا بالفعل،
فإنه يتخطى عملية التحميل.
Args:
model_name (str): اسم ملف النموذج المراد تحميله (e.g., 'best_model.pth').
Returns:
str: رسالة نصية تشير إلى حالة عملية التحميل.
"""
global CURRENTLY_LOADED_MODEL, CURRENT_MODEL_NAME
if not model_name:
return "لم يتم اختيار نموذج."
# إذا كان النموذج المطلوب هو نفسه المحمل حاليًا، فلا داعي لفعل أي شيء
if model_name == CURRENT_MODEL_NAME and CURRENTLY_LOADED_MODEL is not None:
message = f"النموذج '{model_name}' محمل بالفعل."
logging.info(message)
return message
logging.info(f"بدء تحميل النموذج: '{model_name}' على الجهاز {DEVICE}...")
model_path = os.path.join(MODEL_DIR, model_name)
if not os.path.exists(model_path):
error_message = f"ملف النموذج '{model_path}' غير موجود."
logging.error(error_message)
# تفريغ النموذج الحالي إذا كان المسار خاطئًا
CURRENTLY_LOADED_MODEL = None
CURRENT_MODEL_NAME = None
raise FileNotFoundError(error_message)
try:
# استخدام الدوال من model_definition.py لإنشاء وتحميل النموذج
model_config = create_model_config(model_path=model_path)
model = load_and_prepare_model(model_config, DEVICE)
# تحديث المتغيرات العامة بالنموذج الجديد
CURRENTLY_LOADED_MODEL = model
CURRENT_MODEL_NAME = model_name
success_message = f"✅ تم تحميل النموذج بنجاح: {model_name}"
logging.info(success_message)
return success_message
except Exception as e:
logging.error(f"❌ حدث خطأ فادح أثناء تحميل النموذج '{model_name}': {e}", exc_info=True)
# إعادة تعيين المتغيرات العامة في حالة الفشل
CURRENTLY_LOADED_MODEL = None
CURRENT_MODEL_NAME = None
# إرسال الخطأ للأعلى ليتم عرضه في واجهة Gradio
raise e
def get_current_model() -> torch.nn.Module:
"""
يعيد كائن النموذج المحمل حاليًا.
إذا لم يكن هناك نموذج محمل، يحاول تحميل أول نموذج متاح كخيار افتراضي.
Returns:
torch.nn.Module or None: كائن النموذج المحمل أو None إذا فشل التحميل.
"""
if CURRENTLY_LOADED_MODEL is None:
logging.info("لا يوجد نموذج محمل حاليًا. محاولة تحميل النموذج الافتراضي...")
available_models = get_available_models()
if available_models:
# محاولة تحميل أول نموذج في القائمة
try:
load_model_by_name(available_models[0])
except Exception as e:
logging.error(f"فشل تحميل النموذج الافتراضي '{available_models[0]}': {e}")
return None
else:
logging.warning("لا توجد نماذج متاحة في مجلد 'model' لتحميلها.")
return None
return CURRENTLY_LOADED_MODEL |