Azzan Dwi Riski commited on
Commit
12c9666
·
1 Parent(s): 2f33391

add models

Browse files
Files changed (1) hide show
  1. app.py +9 -19
app.py CHANGED
@@ -75,14 +75,6 @@ class LateFusionModel(nn.Module):
75
 
76
  return fused_logits, image_logits, text_logits, weights
77
 
78
- # def unwrap_dataparallel(model):
79
- # """Recursively unwrap all DataParallel layers inside a model."""
80
- # if isinstance(model, torch.nn.DataParallel):
81
- # model = model.module
82
- # for name, module in model.named_children():
83
- # setattr(model, name, unwrap_dataparallel(module))
84
- # return model
85
-
86
  # Load model
87
  model_path = "models/best_fusion_model.pt"
88
  if os.path.exists(model_path):
@@ -108,21 +100,19 @@ if os.path.exists(image_model_path):
108
  image_only_model.eval()
109
  print("Image-only model loaded from state_dict successfully!")
110
  else:
111
- raise FileNotFoundError("Image-only model not found in models/ folder.")
 
 
 
 
 
 
 
 
112
 
113
 
114
  # --- Functions ---
115
  def clean_text(text):
116
- # text = re.sub(r"http\S+", "", text)
117
- # text = re.sub('\n', '', text)
118
- # text = re.sub("[^a-zA-Z^']", " ", text)
119
- # text = re.sub(" {2,}", " ", text)
120
- # text = text.strip()
121
- # text = re.sub(r'\s+', ' ', text)
122
- # text = re.sub(r'\b\w{1,2}\b', '', text)
123
- # text = re.sub(r'\b\w{20,}\b', '', text)
124
- # text = text.lower()
125
- # Kata 1–2 huruf yang penting dan tidak boleh dihapus
126
  exceptions = {
127
  "di", "ke", "ya"
128
  }
 
75
 
76
  return fused_logits, image_logits, text_logits, weights
77
 
 
 
 
 
 
 
 
 
78
  # Load model
79
  model_path = "models/best_fusion_model.pt"
80
  if os.path.exists(model_path):
 
100
  image_only_model.eval()
101
  print("Image-only model loaded from state_dict successfully!")
102
  else:
103
+ print("Image-only model not found locally. Downloading from Hugging Face Hub...")
104
+ image_model_path = hf_hub_download(repo_id="azzandr/gambling-iamge-model", filename="best_image_model_Adam_lr0.0001_bs32_state_dict.pt")
105
+ image_only_model = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.DEFAULT)
106
+ num_features = image_only_model.classifier[1].in_features
107
+ image_only_model.classifier = nn.Linear(num_features, 1)
108
+ image_only_model.load_state_dict(torch.load(image_model_path, map_location=device))
109
+ image_only_model.to(device)
110
+ image_only_model.eval()
111
+ print("Image-only model downloaded and loaded successfully!")
112
 
113
 
114
  # --- Functions ---
115
  def clean_text(text):
 
 
 
 
 
 
 
 
 
 
116
  exceptions = {
117
  "di", "ke", "ya"
118
  }