jamino30 commited on
Commit
0064e4b
·
verified ·
1 Parent(s): 7340bf8

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +11 -4
inference.py CHANGED
@@ -32,13 +32,15 @@ def inference(
32
  beta=1,
33
  clip_grad_norm=5.0
34
  ):
 
 
35
  generated_image = content_image.clone().requires_grad_(True)
36
- optimizer = optim_caller([generated_image], lr=lr)
37
 
38
  with torch.no_grad():
39
  content_features = model(content_image)
40
 
41
- def closure():
42
  optimizer.zero_grad()
43
  generated_features = model(generated_image)
44
  total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
@@ -46,7 +48,12 @@ def inference(
46
  torch.nn.utils.clip_grad_norm_([generated_image], max_norm=clip_grad_norm) # clip gradients
47
  return total_loss
48
 
49
- for _ in tqdm(range(iterations), desc='The magic is happening ✨'):
50
- optimizer.step(closure)
 
 
 
 
 
51
 
52
  return generated_image
 
32
  beta=1,
33
  clip_grad_norm=5.0
34
  ):
35
+ torch.manual_seed(42)
36
+
37
  generated_image = content_image.clone().requires_grad_(True)
38
+ adam_optimizer = optim.AdamW([generated_image], lr=lr)
39
 
40
  with torch.no_grad():
41
  content_features = model(content_image)
42
 
43
+ def closure(optimizer):
44
  optimizer.zero_grad()
45
  generated_features = model(generated_image)
46
  total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
 
48
  torch.nn.utils.clip_grad_norm_([generated_image], max_norm=clip_grad_norm) # clip gradients
49
  return total_loss
50
 
51
+ for _ in tqdm(range(iterations), desc='The magic is happening (1/2) ✨'):
52
+ adam_optimizer.step(lambda: closure(adam_optimizer))
53
+
54
+ lbfgs_optimizer = optim.LBFGS([generated_image], lr=lr)
55
+
56
+ for _ in tqdm(range(iterations), desc='The magic is happening (2/2) ✨'):
57
+ lbfgs_optimizer.step(lambda: closure(lbfgs_optimizer))
58
 
59
  return generated_image