Spaces:
Runtime error
Runtime error
File size: 2,312 Bytes
534e265 34eb6c0 d6c2823 fd0734b 34eb6c0 82f856d 34eb6c0 d6c2823 34eb6c0 fd0734b 34eb6c0 fd0734b 34eb6c0 d6c2823 34eb6c0 edb1d95 34eb6c0 534e265 edb1d95 d6c2823 edb1d95 d6c2823 edb1d95 34eb6c0 d6c2823 edb1d95 a36040c edb1d95 82f856d 34eb6c0 d6c2823 edb1d95 34eb6c0 edb1d95 34eb6c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os,shutil
import argparse
from comet_ml import Experiment
from src.utils.config_loader import Config,constants,set_seed
from src.utils import config_loader
from src.utils.data_utils import print_title
from src.utils.script_utils import validate_config
import importlib
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
def train(args):
config_file_path = args.config_file
config = Config(config_file_path)
# validate config
validate_config(config)
# set config globally & set seed
config_loader.config = config
set_seed(config.seed)
# now load the model
Model = importlib.import_module(f"src.{config.task}.model.models.{config.model}").Model
model_dir = constants.ARTIFACT_MODEL_DIR
os.makedirs(model_dir,exist_ok=True)
model_save_path = os.path.join(model_dir,"model.weights.h5")
# save config to exported model folder
shutil.copy(config_file_path,model_dir)
# rename it to config.yaml
shutil.move(os.path.join(model_dir,Path(config_file_path).name),os.path.join(model_dir,"config.yaml"))
experiment = None
if args.log:
experiment = Experiment(
api_key=os.environ["COMET_API_KEY"],
project_name="image-colorization",
workspace="anujpanthri",
auto_histogram_activation_logging=True,
auto_histogram_epoch_rate=True,
auto_histogram_gradient_logging=True,
auto_histogram_weight_logging=True,
auto_param_logging=True,
)
model = Model(experiment=experiment)
print_title("\nTraining Model")
model.train()
model.save(model_save_path)
# log model to comet
if "LOCAL_SYSTEM" not in os.environ:
if experiment:
experiment.log_model(f"model",model_dir)
# evaluate model
print_title("\nEvaluating Model")
metrics = model.evaluate()
print("Model Evaluation Metrics:",metrics)
if experiment:
experiment.end()
def main():
parser = argparse.ArgumentParser(description="train model based on config yaml file")
parser.add_argument("config_file",type=str)
parser.add_argument("--log",action="store_true",default=False)
args = parser.parse_args()
train(args)
if __name__=="__main__":
main() |