|
import torch |
|
|
|
orig_t1d = 22 |
|
|
|
orig_t2d = 44 |
|
orig_dtor = 30 |
|
|
|
new_t1d = orig_t1d + 2 + 4 + 1 |
|
new_t2d = orig_t2d + 0 |
|
|
|
|
|
|
|
ckpt = torch.load('/net/scratch/lisanza/diffuse_3track_fullcon/models/BFF_last.pt', map_location=torch.device('cpu')) |
|
|
|
weights = ckpt['model_state_dict'] |
|
|
|
print("original weights") |
|
print('templ_emb.emb.weight', weights['templ_emb.emb.weight'].shape) |
|
print('templ_emb.emb_t1d.weight', weights['templ_emb.emb_t1d.weight'].shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if True: |
|
|
|
pt2_add_dim = new_t1d - orig_t1d |
|
pt3_add_dim = new_t1d - orig_t1d |
|
|
|
|
|
pt2_emb_zeros = torch.zeros(64, pt2_add_dim) |
|
pt3_emb_zeros = torch.zeros(64, pt3_add_dim) |
|
|
|
''' |
|
The way that the t2d input to embedding is created is not straightforward |
|
It looks like this: |
|
|
|
# Prepare 2D template features |
|
left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1) |
|
right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1) |
|
|
|
templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 109) |
|
templ = self.emb(templ) # Template templures (B, T, L, L, d_templ) |
|
''' |
|
|
|
|
|
|
|
new_emb_weights = torch.cat( (weights['templ_emb.emb.weight'][:,:orig_t2d+orig_t1d], pt2_emb_zeros), dim=-1 ) |
|
new_emb_weights = torch.cat( (new_emb_weights, weights['templ_emb.emb.weight'][:,orig_t2d+orig_t1d:], pt3_emb_zeros), dim=-1 ) |
|
|
|
|
|
|
|
|
|
|
|
if True: |
|
t1d_weights_dim = new_t1d + orig_dtor |
|
t1d_add_dim = t1d_weights_dim - weights['templ_emb.emb_t1d.weight'].shape[1] |
|
|
|
t1d_zeros = torch.zeros(64, t1d_add_dim) |
|
new_t1d_weights = torch.cat( (weights['templ_emb.emb_t1d.weight'][:,:orig_t1d], t1d_zeros), dim=-1 ) |
|
new_t1d_weights = torch.cat( (new_t1d_weights, weights['templ_emb.emb_t1d.weight'][:,orig_t1d:]), dim=-1 ) |
|
|
|
weights['templ_emb.emb.weight'] = new_emb_weights |
|
weights['templ_emb.emb_t1d.weight'] = new_t1d_weights |
|
print("new t1d weights dim") |
|
print(new_t1d_weights.shape) |
|
|
|
ckpt['model_state_dict'] = weights |
|
|
|
|
|
torch.save(ckpt, '/net/scratch/lisanza/projects/diffusion/models/t1d_29_t2d_44_BFF_SE3big_2.pt') |
|
|
|
|