File size: 2,312 Bytes
534e265
34eb6c0
d6c2823
fd0734b
34eb6c0
82f856d
34eb6c0
 
 
d6c2823
 
34eb6c0
 
 
 
 
 
 
 
fd0734b
34eb6c0
fd0734b
 
34eb6c0
d6c2823
34eb6c0
 
 
edb1d95
34eb6c0
 
 
534e265
 
edb1d95
 
d6c2823
edb1d95
 
 
 
 
 
 
 
 
 
 
 
d6c2823
 
edb1d95
 
34eb6c0
 
d6c2823
 
 
edb1d95
a36040c
edb1d95
82f856d
 
34eb6c0
 
d6c2823
edb1d95
 
34eb6c0
 
 
 
edb1d95
34eb6c0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os,shutil
import argparse
from comet_ml import Experiment
from src.utils.config_loader import Config,constants,set_seed
from src.utils import config_loader
from src.utils.data_utils import print_title
from src.utils.script_utils import validate_config
import importlib
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()

def train(args):
    config_file_path = args.config_file
    config = Config(config_file_path)

    # validate config
    validate_config(config)

    # set config globally & set seed
    config_loader.config = config
    set_seed(config.seed)


    # now load the model
    Model = importlib.import_module(f"src.{config.task}.model.models.{config.model}").Model


    model_dir = constants.ARTIFACT_MODEL_DIR
    os.makedirs(model_dir,exist_ok=True)
    model_save_path = os.path.join(model_dir,"model.weights.h5")
    
    # save config to exported model folder
    shutil.copy(config_file_path,model_dir)
    # rename it to config.yaml
    shutil.move(os.path.join(model_dir,Path(config_file_path).name),os.path.join(model_dir,"config.yaml"))

    experiment = None
    if args.log:
        experiment = Experiment(
            api_key=os.environ["COMET_API_KEY"],
            project_name="image-colorization",
            workspace="anujpanthri",
            auto_histogram_activation_logging=True,
            auto_histogram_epoch_rate=True,
            auto_histogram_gradient_logging=True,
            auto_histogram_weight_logging=True,
            auto_param_logging=True,
        )

    model = Model(experiment=experiment)

    print_title("\nTraining Model")
    model.train()
    model.save(model_save_path)
    
    # log model to comet
    if "LOCAL_SYSTEM" not in os.environ:
        if experiment:
            experiment.log_model(f"model",model_dir)
        
    # evaluate model
    print_title("\nEvaluating Model")
    metrics = model.evaluate()
    print("Model Evaluation Metrics:",metrics)
    
    if experiment:
        experiment.end()

def main():
    parser = argparse.ArgumentParser(description="train model based on config yaml file")
    parser.add_argument("config_file",type=str)
    parser.add_argument("--log",action="store_true",default=False)
    args = parser.parse_args()
    train(args)

if __name__=="__main__":
    main()