Update app.py
Browse files
app.py
CHANGED
@@ -31,7 +31,7 @@ logging.basicConfig(
|
|
31 |
logger = logging.getLogger("cosmic_ai")
|
32 |
|
33 |
# Set a custom NLTK data directory
|
34 |
-
nltk_data_dir = os.getenv('
|
35 |
os.makedirs(nltk_data_dir, exist_ok=True)
|
36 |
nltk.data.path.append(nltk_data_dir)
|
37 |
|
@@ -131,10 +131,14 @@ def load_model(task: str, model_name: str = None):
|
|
131 |
return vqa_function
|
132 |
|
133 |
# Use pipeline for summarization, image-to-text, and file-qa
|
134 |
-
return pipeline(
|
|
|
|
|
|
|
|
|
135 |
|
136 |
except Exception as e:
|
137 |
-
logger.error(f"Model load failed: {str(e)}")
|
138 |
raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
|
139 |
|
140 |
def get_gemini_response(user_input: str, is_generation: bool = False):
|
@@ -170,8 +174,21 @@ def translate_text(text: str, target_language: str):
|
|
170 |
|
171 |
lang_code = SUPPORTED_LANGUAGES[target_lang]
|
172 |
|
|
|
173 |
if translation_model is None or translation_tokenizer is None:
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
|
177 |
if match:
|
@@ -829,29 +846,44 @@ async def list_models():
|
|
829 |
|
830 |
@app.on_event("startup")
|
831 |
async def startup_event():
|
832 |
-
"""Pre-load models at startup with timeout"""
|
833 |
global translation_model, translation_tokenizer
|
834 |
logger.info("Starting model pre-loading...")
|
835 |
|
836 |
-
async def load_model_with_timeout(task):
|
837 |
try:
|
838 |
-
await asyncio.wait_for(
|
839 |
-
|
|
|
|
|
|
|
840 |
except asyncio.TimeoutError:
|
841 |
logger.warning(f"Timeout loading {task} model - will load on demand")
|
842 |
except Exception as e:
|
843 |
logger.error(f"Error pre-loading {task}: {str(e)}")
|
844 |
|
|
|
845 |
try:
|
846 |
model_name = MODELS["translation"]
|
847 |
-
|
848 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
849 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
850 |
translation_model.to(device)
|
851 |
logger.info("Translation model pre-loaded successfully")
|
852 |
except Exception as e:
|
853 |
logger.error(f"Error pre-loading translation model: {str(e)}")
|
|
|
|
|
|
|
854 |
|
|
|
855 |
await asyncio.gather(
|
856 |
load_model_with_timeout("summarization"),
|
857 |
load_model_with_timeout("image-to-text"),
|
@@ -862,5 +894,4 @@ async def startup_event():
|
|
862 |
|
863 |
if __name__ == "__main__":
|
864 |
import uvicorn
|
865 |
-
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
866 |
-
|
|
|
31 |
logger = logging.getLogger("cosmic_ai")
|
32 |
|
33 |
# Set a custom NLTK data directory
|
34 |
+
nltk_data_dir = os.getenv('NLTK_DATA', '/tmp/nltk_data')
|
35 |
os.makedirs(nltk_data_dir, exist_ok=True)
|
36 |
nltk.data.path.append(nltk_data_dir)
|
37 |
|
|
|
131 |
return vqa_function
|
132 |
|
133 |
# Use pipeline for summarization, image-to-text, and file-qa
|
134 |
+
return pipeline(
|
135 |
+
task if task != "file-qa" else "question-answering",
|
136 |
+
model=model_to_load,
|
137 |
+
tokenizer_kwargs={"clean_up_tokenization_spaces": True} # Suppress warning
|
138 |
+
)
|
139 |
|
140 |
except Exception as e:
|
141 |
+
logger.error(f"Model load failed for {task}: {str(e)}")
|
142 |
raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
|
143 |
|
144 |
def get_gemini_response(user_input: str, is_generation: bool = False):
|
|
|
174 |
|
175 |
lang_code = SUPPORTED_LANGUAGES[target_lang]
|
176 |
|
177 |
+
# Load translation model on demand if not pre-loaded
|
178 |
if translation_model is None or translation_tokenizer is None:
|
179 |
+
logger.info("Translation model not pre-loaded, loading on demand...")
|
180 |
+
model_name = MODELS["translation"]
|
181 |
+
translation_model = M2M100ForConditionalGeneration.from_pretrained(
|
182 |
+
model_name,
|
183 |
+
cache_dir=os.getenv("HF_HOME", "/app/cache")
|
184 |
+
)
|
185 |
+
translation_tokenizer = M2M100Tokenizer.from_pretrained(
|
186 |
+
model_name,
|
187 |
+
cache_dir=os.getenv("HF_HOME", "/app/cache")
|
188 |
+
)
|
189 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
190 |
+
translation_model.to(device)
|
191 |
+
logger.info("Translation model loaded on demand successfully")
|
192 |
|
193 |
match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
|
194 |
if match:
|
|
|
846 |
|
847 |
@app.on_event("startup")
|
848 |
async def startup_event():
|
849 |
+
"""Pre-load models at startup with timeout and fallback"""
|
850 |
global translation_model, translation_tokenizer
|
851 |
logger.info("Starting model pre-loading...")
|
852 |
|
853 |
+
async def load_model_with_timeout(task, model_name=None):
|
854 |
try:
|
855 |
+
await asyncio.wait_for(
|
856 |
+
asyncio.to_thread(load_model, task, model_name),
|
857 |
+
timeout=60.0
|
858 |
+
)
|
859 |
+
logger.info(f"Successfully pre-loaded {task} model")
|
860 |
except asyncio.TimeoutError:
|
861 |
logger.warning(f"Timeout loading {task} model - will load on demand")
|
862 |
except Exception as e:
|
863 |
logger.error(f"Error pre-loading {task}: {str(e)}")
|
864 |
|
865 |
+
# Load translation model separately with retry mechanism
|
866 |
try:
|
867 |
model_name = MODELS["translation"]
|
868 |
+
logger.info(f"Attempting to load translation model: {model_name}")
|
869 |
+
translation_model = M2M100ForConditionalGeneration.from_pretrained(
|
870 |
+
model_name,
|
871 |
+
cache_dir=os.getenv("HF_HOME", "/app/cache")
|
872 |
+
)
|
873 |
+
translation_tokenizer = M2M100Tokenizer.from_pretrained(
|
874 |
+
model_name,
|
875 |
+
cache_dir=os.getenv("HF_HOME", "/app/cache")
|
876 |
+
)
|
877 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
878 |
translation_model.to(device)
|
879 |
logger.info("Translation model pre-loaded successfully")
|
880 |
except Exception as e:
|
881 |
logger.error(f"Error pre-loading translation model: {str(e)}")
|
882 |
+
# Fallback: Set to None and load on demand
|
883 |
+
translation_model = None
|
884 |
+
translation_tokenizer = None
|
885 |
|
886 |
+
# Pre-load other models concurrently
|
887 |
await asyncio.gather(
|
888 |
load_model_with_timeout("summarization"),
|
889 |
load_model_with_timeout("image-to-text"),
|
|
|
894 |
|
895 |
if __name__ == "__main__":
|
896 |
import uvicorn
|
897 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
|