luck210 commited on
Commit
c5c5df6
·
verified ·
1 Parent(s): dd9de16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -12
app.py CHANGED
@@ -31,7 +31,7 @@ logging.basicConfig(
31
  logger = logging.getLogger("cosmic_ai")
32
 
33
  # Set a custom NLTK data directory
34
- nltk_data_dir = os.getenv('NLTK_DATA_DIR', '/tmp/nltk_data')
35
  os.makedirs(nltk_data_dir, exist_ok=True)
36
  nltk.data.path.append(nltk_data_dir)
37
 
@@ -131,10 +131,14 @@ def load_model(task: str, model_name: str = None):
131
  return vqa_function
132
 
133
  # Use pipeline for summarization, image-to-text, and file-qa
134
- return pipeline(task if task != "file-qa" else "question-answering", model=model_to_load)
 
 
 
 
135
 
136
  except Exception as e:
137
- logger.error(f"Model load failed: {str(e)}")
138
  raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
139
 
140
  def get_gemini_response(user_input: str, is_generation: bool = False):
@@ -170,8 +174,21 @@ def translate_text(text: str, target_language: str):
170
 
171
  lang_code = SUPPORTED_LANGUAGES[target_lang]
172
 
 
173
  if translation_model is None or translation_tokenizer is None:
174
- raise Exception("Translation model not initialized")
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
177
  if match:
@@ -829,29 +846,44 @@ async def list_models():
829
 
830
  @app.on_event("startup")
831
  async def startup_event():
832
- """Pre-load models at startup with timeout"""
833
  global translation_model, translation_tokenizer
834
  logger.info("Starting model pre-loading...")
835
 
836
- async def load_model_with_timeout(task):
837
  try:
838
- await asyncio.wait_for(asyncio.to_thread(load_model, task), timeout=60.0)
839
- logger.info(f"Successfully loaded {task} model")
 
 
 
840
  except asyncio.TimeoutError:
841
  logger.warning(f"Timeout loading {task} model - will load on demand")
842
  except Exception as e:
843
  logger.error(f"Error pre-loading {task}: {str(e)}")
844
 
 
845
  try:
846
  model_name = MODELS["translation"]
847
- translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name)
848
- translation_tokenizer = M2M100Tokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
849
  device = "cuda" if torch.cuda.is_available() else "cpu"
850
  translation_model.to(device)
851
  logger.info("Translation model pre-loaded successfully")
852
  except Exception as e:
853
  logger.error(f"Error pre-loading translation model: {str(e)}")
 
 
 
854
 
 
855
  await asyncio.gather(
856
  load_model_with_timeout("summarization"),
857
  load_model_with_timeout("image-to-text"),
@@ -862,5 +894,4 @@ async def startup_event():
862
 
863
  if __name__ == "__main__":
864
  import uvicorn
865
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
866
-
 
31
  logger = logging.getLogger("cosmic_ai")
32
 
33
  # Set a custom NLTK data directory
34
+ nltk_data_dir = os.getenv('NLTK_DATA', '/tmp/nltk_data')
35
  os.makedirs(nltk_data_dir, exist_ok=True)
36
  nltk.data.path.append(nltk_data_dir)
37
 
 
131
  return vqa_function
132
 
133
  # Use pipeline for summarization, image-to-text, and file-qa
134
+ return pipeline(
135
+ task if task != "file-qa" else "question-answering",
136
+ model=model_to_load,
137
+ tokenizer_kwargs={"clean_up_tokenization_spaces": True} # Suppress warning
138
+ )
139
 
140
  except Exception as e:
141
+ logger.error(f"Model load failed for {task}: {str(e)}")
142
  raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
143
 
144
  def get_gemini_response(user_input: str, is_generation: bool = False):
 
174
 
175
  lang_code = SUPPORTED_LANGUAGES[target_lang]
176
 
177
+ # Load translation model on demand if not pre-loaded
178
  if translation_model is None or translation_tokenizer is None:
179
+ logger.info("Translation model not pre-loaded, loading on demand...")
180
+ model_name = MODELS["translation"]
181
+ translation_model = M2M100ForConditionalGeneration.from_pretrained(
182
+ model_name,
183
+ cache_dir=os.getenv("HF_HOME", "/app/cache")
184
+ )
185
+ translation_tokenizer = M2M100Tokenizer.from_pretrained(
186
+ model_name,
187
+ cache_dir=os.getenv("HF_HOME", "/app/cache")
188
+ )
189
+ device = "cuda" if torch.cuda.is_available() else "cpu"
190
+ translation_model.to(device)
191
+ logger.info("Translation model loaded on demand successfully")
192
 
193
  match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
194
  if match:
 
846
 
847
  @app.on_event("startup")
848
  async def startup_event():
849
+ """Pre-load models at startup with timeout and fallback"""
850
  global translation_model, translation_tokenizer
851
  logger.info("Starting model pre-loading...")
852
 
853
+ async def load_model_with_timeout(task, model_name=None):
854
  try:
855
+ await asyncio.wait_for(
856
+ asyncio.to_thread(load_model, task, model_name),
857
+ timeout=60.0
858
+ )
859
+ logger.info(f"Successfully pre-loaded {task} model")
860
  except asyncio.TimeoutError:
861
  logger.warning(f"Timeout loading {task} model - will load on demand")
862
  except Exception as e:
863
  logger.error(f"Error pre-loading {task}: {str(e)}")
864
 
865
+ # Load translation model separately with retry mechanism
866
  try:
867
  model_name = MODELS["translation"]
868
+ logger.info(f"Attempting to load translation model: {model_name}")
869
+ translation_model = M2M100ForConditionalGeneration.from_pretrained(
870
+ model_name,
871
+ cache_dir=os.getenv("HF_HOME", "/app/cache")
872
+ )
873
+ translation_tokenizer = M2M100Tokenizer.from_pretrained(
874
+ model_name,
875
+ cache_dir=os.getenv("HF_HOME", "/app/cache")
876
+ )
877
  device = "cuda" if torch.cuda.is_available() else "cpu"
878
  translation_model.to(device)
879
  logger.info("Translation model pre-loaded successfully")
880
  except Exception as e:
881
  logger.error(f"Error pre-loading translation model: {str(e)}")
882
+ # Fallback: Set to None and load on demand
883
+ translation_model = None
884
+ translation_tokenizer = None
885
 
886
+ # Pre-load other models concurrently
887
  await asyncio.gather(
888
  load_model_with_timeout("summarization"),
889
  load_model_with_timeout("image-to-text"),
 
894
 
895
  if __name__ == "__main__":
896
  import uvicorn
897
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)