jamino30 commited on
Commit
3e75c58
·
verified ·
1 Parent(s): 0064e4b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +5 -6
inference.py CHANGED
@@ -26,8 +26,8 @@ def inference(
26
  content_image,
27
  style_features,
28
  lr,
29
- iterations=3,
30
- optim_caller=optim.LBFGS,
31
  alpha=1,
32
  beta=1,
33
  clip_grad_norm=5.0
@@ -35,7 +35,6 @@ def inference(
35
  torch.manual_seed(42)
36
 
37
  generated_image = content_image.clone().requires_grad_(True)
38
- adam_optimizer = optim.AdamW([generated_image], lr=lr)
39
 
40
  with torch.no_grad():
41
  content_features = model(content_image)
@@ -48,12 +47,12 @@ def inference(
48
  torch.nn.utils.clip_grad_norm_([generated_image], max_norm=clip_grad_norm) # clip gradients
49
  return total_loss
50
 
51
- for _ in tqdm(range(iterations), desc='The magic is happening (1/2) ✨'):
 
52
  adam_optimizer.step(lambda: closure(adam_optimizer))
53
 
54
  lbfgs_optimizer = optim.LBFGS([generated_image], lr=lr)
55
-
56
- for _ in tqdm(range(iterations), desc='The magic is happening (2/2) ✨'):
57
  lbfgs_optimizer.step(lambda: closure(lbfgs_optimizer))
58
 
59
  return generated_image
 
26
  content_image,
27
  style_features,
28
  lr,
29
+ adam_iterations=1,
30
+ lbfgs_iterations=3,
31
  alpha=1,
32
  beta=1,
33
  clip_grad_norm=5.0
 
35
  torch.manual_seed(42)
36
 
37
  generated_image = content_image.clone().requires_grad_(True)
 
38
 
39
  with torch.no_grad():
40
  content_features = model(content_image)
 
47
  torch.nn.utils.clip_grad_norm_([generated_image], max_norm=clip_grad_norm) # clip gradients
48
  return total_loss
49
 
50
+ adam_optimizer = optim.AdamW([generated_image], lr=lr)
51
+ for _ in tqdm(range(adam_iterations), desc='The magic is happening (1/2) ✨'):
52
  adam_optimizer.step(lambda: closure(adam_optimizer))
53
 
54
  lbfgs_optimizer = optim.LBFGS([generated_image], lr=lr)
55
+ for _ in tqdm(range(lbfgs_iterations), desc='The magic is happening (2/2) ✨'):
 
56
  lbfgs_optimizer.step(lambda: closure(lbfgs_optimizer))
57
 
58
  return generated_image