jamino30 commited on
Commit
7340bf8
·
verified ·
1 Parent(s): 4d7ae60

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +3 -2
inference.py CHANGED
@@ -29,7 +29,8 @@ def inference(
29
  iterations=3,
30
  optim_caller=optim.LBFGS,
31
  alpha=1,
32
- beta=1
 
33
  ):
34
  generated_image = content_image.clone().requires_grad_(True)
35
  optimizer = optim_caller([generated_image], lr=lr)
@@ -42,7 +43,7 @@ def inference(
42
  generated_features = model(generated_image)
43
  total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
44
  total_loss.backward()
45
- torch.nn.utils.clip_grad_norm_([generated_image], max_norm=1.0) # clip gradients
46
  return total_loss
47
 
48
  for _ in tqdm(range(iterations), desc='The magic is happening ✨'):
 
29
  iterations=3,
30
  optim_caller=optim.LBFGS,
31
  alpha=1,
32
+ beta=1,
33
+ clip_grad_norm=5.0
34
  ):
35
  generated_image = content_image.clone().requires_grad_(True)
36
  optimizer = optim_caller([generated_image], lr=lr)
 
43
  generated_features = model(generated_image)
44
  total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
45
  total_loss.backward()
46
+ torch.nn.utils.clip_grad_norm_([generated_image], max_norm=clip_grad_norm) # clip gradients
47
  return total_loss
48
 
49
  for _ in tqdm(range(iterations), desc='The magic is happening ✨'):