File size: 4,982 Bytes
0eb032f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import shutil
from typing import List

from huggingface_hub import hf_hub_download
from modelscope import snapshot_download
from typing_extensions import Literal, TypeAlias

from ..configs.model_config import (Preset_model_id,
                                    preset_models_on_huggingface,
                                    preset_models_on_modelscope)


def download_from_modelscope(model_id, origin_file_path, local_dir):
    os.makedirs(local_dir, exist_ok=True)
    file_name = os.path.basename(origin_file_path)
    if file_name in os.listdir(local_dir):
        print(f"    {file_name} has been already in {local_dir}.")
    else:
        print(f"    Start downloading {os.path.join(local_dir, file_name)}")
        snapshot_download(
            model_id, allow_file_pattern=origin_file_path, local_dir=local_dir
        )
        downloaded_file_path = os.path.join(local_dir, origin_file_path)
        target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
        if downloaded_file_path != target_file_path:
            shutil.move(downloaded_file_path, target_file_path)
            shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))


def download_from_huggingface(model_id, origin_file_path, local_dir):
    os.makedirs(local_dir, exist_ok=True)
    file_name = os.path.basename(origin_file_path)
    if file_name in os.listdir(local_dir):
        print(f"    {file_name} has been already in {local_dir}.")
    else:
        print(f"    Start downloading {os.path.join(local_dir, file_name)}")
        hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
        downloaded_file_path = os.path.join(local_dir, origin_file_path)
        target_file_path = os.path.join(local_dir, file_name)
        if downloaded_file_path != target_file_path:
            shutil.move(downloaded_file_path, target_file_path)
            shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))


Preset_model_website: TypeAlias = Literal[
    "HuggingFace",
    "ModelScope",
]
website_to_preset_models = {
    "HuggingFace": preset_models_on_huggingface,
    "ModelScope": preset_models_on_modelscope,
}
website_to_download_fn = {
    "HuggingFace": download_from_huggingface,
    "ModelScope": download_from_modelscope,
}


def download_customized_models(
    model_id,
    origin_file_path,
    local_dir,
    downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
    downloaded_files = []
    for website in downloading_priority:
        # Check if the file is downloaded.
        file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
        if file_to_download in downloaded_files:
            continue
        # Download
        website_to_download_fn[website](model_id, origin_file_path, local_dir)
        if os.path.basename(origin_file_path) in os.listdir(local_dir):
            downloaded_files.append(file_to_download)
    return downloaded_files


def download_models(
    model_id_list: List[Preset_model_id] = [],
    downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
    print(f"Downloading models: {model_id_list}")
    downloaded_files = []
    load_files = []

    for model_id in model_id_list:
        for website in downloading_priority:
            if model_id in website_to_preset_models[website]:
                # Parse model metadata
                model_metadata = website_to_preset_models[website][model_id]
                if isinstance(model_metadata, list):
                    file_data = model_metadata
                else:
                    file_data = model_metadata.get("file_list", [])

                # Try downloading the model from this website.
                model_files = []
                for model_id, origin_file_path, local_dir in file_data:
                    # Check if the file is downloaded.
                    file_to_download = os.path.join(
                        local_dir, os.path.basename(origin_file_path)
                    )
                    if file_to_download in downloaded_files:
                        continue
                    # Download
                    website_to_download_fn[website](
                        model_id, origin_file_path, local_dir
                    )
                    if os.path.basename(origin_file_path) in os.listdir(local_dir):
                        downloaded_files.append(file_to_download)
                        model_files.append(file_to_download)

                # If the model is successfully downloaded, break.
                if len(model_files) > 0:
                    if (
                        isinstance(model_metadata, dict)
                        and "load_path" in model_metadata
                    ):
                        model_files = model_metadata["load_path"]
                    load_files.extend(model_files)
                    break

    return load_files