import os,shutil import argparse from comet_ml import Experiment from src.utils.config_loader import Config,constants,set_seed from src.utils import config_loader from src.utils.data_utils import print_title from src.utils.script_utils import validate_config import importlib from pathlib import Path from dotenv import load_dotenv load_dotenv() def train(args): config_file_path = args.config_file config = Config(config_file_path) # validate config validate_config(config) # set config globally & set seed config_loader.config = config set_seed(config.seed) # now load the model Model = importlib.import_module(f"src.{config.task}.model.models.{config.model}").Model model_dir = constants.ARTIFACT_MODEL_DIR os.makedirs(model_dir,exist_ok=True) model_save_path = os.path.join(model_dir,"model.weights.h5") # save config to exported model folder shutil.copy(config_file_path,model_dir) # rename it to config.yaml shutil.move(os.path.join(model_dir,Path(config_file_path).name),os.path.join(model_dir,"config.yaml")) experiment = None if args.log: experiment = Experiment( api_key=os.environ["COMET_API_KEY"], project_name="image-colorization", workspace="anujpanthri", auto_histogram_activation_logging=True, auto_histogram_epoch_rate=True, auto_histogram_gradient_logging=True, auto_histogram_weight_logging=True, auto_param_logging=True, ) model = Model(experiment=experiment) print_title("\nTraining Model") model.train() model.save(model_save_path) # log model to comet if "LOCAL_SYSTEM" not in os.environ: if experiment: experiment.log_model(f"model",model_dir) # evaluate model print_title("\nEvaluating Model") metrics = model.evaluate() print("Model Evaluation Metrics:",metrics) if experiment: experiment.end() def main(): parser = argparse.ArgumentParser(description="train model based on config yaml file") parser.add_argument("config_file",type=str) parser.add_argument("--log",action="store_true",default=False) args = parser.parse_args() train(args) if __name__=="__main__": main()