smog / src /utils /get_model_and_data.py
vonexel's picture
add: src
fe64bad verified
from ..datasets.get_dataset import get_datasets
from ..models.get_model import get_model as get_gen_model
import clip
def get_model_and_data(parameters, split="train"):
# clip_model, preprocess = clip.load("ViT-B/32", device=device) # Must set jit=False for training
clip_model, clip_preprocess = clip.load("ViT-B/32", device=parameters['device'], jit=False) # Must set jit=False for training
clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16
for domain in parameters.get('clip_training', '').split('_'):
clip_num_layers = parameters.get('clip_layers', 12)
if domain == 'text':
clip_model.initialize_parameters()
clip_model.transformer.resblocks = clip_model.transformer.resblocks[:clip_num_layers]
if domain == 'image':
clip_model.initialize_parameters()
clip_model.visual.transformer = clip_model.transformer.resblocks[:clip_num_layers]
# NO Clip Training ,Freeze CLIP weights
if parameters.get('clip_training', '') == '':
clip_model.eval()
for p in clip_model.parameters():
p.requires_grad = False
datasets = get_datasets(parameters, clip_preprocess, split)
model = get_gen_model(parameters, clip_model)
return model, datasets