File size: 2,987 Bytes
ada534f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
#!/usr/bin/env python3
"""
Model Download Script for Oil Spill Detection System
Downloads pre-trained models from Hugging Face Hub during deployment
"""
import os
import requests
from pathlib import Path
import sys
# Hugging Face model repository
HUGGINGFACE_REPO = os.getenv("HUGGINGFACE_REPO", "sahilvishwa2108/oil-spill-detection-models")
# Model files to download from Hugging Face
MODEL_FILES = {
"deeplab_final_model.h5": "deeplab_final_model.h5",
"unet_final_model.h5": "unet_final_model.h5"
}
def download_model(repo: str, filename: str, models_dir: Path) -> bool:
"""Download a model file from Hugging Face Hub"""
try:
# Construct Hugging Face download URL
url = f"https://huggingface.co/{repo}/resolve/main/{filename}"
print(f"Downloading {filename} from Hugging Face...")
response = requests.get(url, stream=True)
response.raise_for_status()
model_path = models_dir / filename
model_path.parent.mkdir(parents=True, exist_ok=True)
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"✅ Downloaded {filename}")
return True
except Exception as e:
print(f"❌ Failed to download {filename}: {e}")
return False
def main():
"""Download all required models"""
script_dir = Path(__file__).parent
models_dir = script_dir / "models"
print("🤖 Oil Spill Detection - Model Downloader")
print("=" * 50)
# Check if models already exist
existing_models = []
missing_models = []
for local_name, remote_name in MODEL_FILES.items():
model_path = models_dir / local_name
if model_path.exists():
existing_models.append(local_name)
print(f"✅ {local_name} already exists")
else:
missing_models.append((local_name, remote_name))
if not missing_models:
print("\n🎉 All models are already downloaded!")
return True
# Download missing models
print(f"\n📥 Downloading {len(missing_models)} missing models...")
models_dir.mkdir(parents=True, exist_ok=True)
success_count = 0
for local_name, remote_name in missing_models:
if download_model(HUGGINGFACE_REPO, remote_name, models_dir):
success_count += 1
if success_count == len(missing_models):
print(f"\n🎉 Successfully downloaded all {success_count} models!")
return True
else:
print(f"\n⚠️ Downloaded {success_count}/{len(missing_models)} models")
return False
if __name__ == "__main__":
try:
success = main()
sys.exit(0 if success else 1)
except KeyboardInterrupt:
print("\n\n⚠️ Download interrupted by user")
sys.exit(1)
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
sys.exit(1)
|