multimodalart HF Staff commited on
Commit
a8749db
·
verified ·
1 Parent(s): 60b20b3
Files changed (2) hide show
  1. app.py +20 -45
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,38 +1,3 @@
1
- # Add this code at the very beginning of your script, before any other imports
2
-
3
- import sys
4
- import types
5
-
6
- # Create a fake torch._six module with string_classes
7
- torch_six = types.ModuleType('torch._six')
8
- torch_six.string_classes = (str,) # In Python 3, string_classes is just (str,)
9
-
10
- # Create torch module if it doesn't exist in sys.modules
11
- if 'torch' not in sys.modules:
12
- import torch
13
-
14
- # Add the _six submodule to torch
15
- sys.modules['torch._six'] = torch_six
16
- torch._six = torch_six
17
-
18
- # Monkey patch for pytorch_lightning.utilities.distributed
19
- try:
20
- from pytorch_lightning.utilities.rank_zero import rank_zero_only
21
- # Create the old module path
22
- pl_utils_dist = types.ModuleType('pytorch_lightning.utilities.distributed')
23
- pl_utils_dist.rank_zero_only = rank_zero_only
24
- sys.modules['pytorch_lightning.utilities.distributed'] = pl_utils_dist
25
- except ImportError:
26
- # If even the new import fails, create a dummy decorator
27
- def rank_zero_only(fn):
28
- """Dummy decorator that just returns the function as-is"""
29
- return fn
30
-
31
- pl_utils_dist = types.ModuleType('pytorch_lightning.utilities.distributed')
32
- pl_utils_dist.rank_zero_only = rank_zero_only
33
- sys.modules['pytorch_lightning.utilities.distributed'] = pl_utils_dist
34
-
35
- # Now continue with your original imports
36
  from pydoc import describe
37
  import gradio as gr
38
  import torch
@@ -46,9 +11,9 @@ from ldm.util import instantiate_from_config
46
  from huggingface_hub import hf_hub_download
47
  import spaces
48
 
49
- # Rest of your code continues here...
50
  model_path_e = hf_hub_download(repo_id="multimodalart/compvis-latent-diffusion-text2img-large", filename="txt2img-f8-large.ckpt")
51
 
 
52
  import argparse, os, sys, glob
53
  import numpy as np
54
  from PIL import Image
@@ -62,14 +27,22 @@ from ldm.models.diffusion.plms import PLMSSampler
62
  from open_clip import tokenizer
63
  import open_clip
64
 
65
- config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
66
-
67
- pl_sd = torch.load(model_path_e, map_location="cuda")
68
- sd = pl_sd["state_dict"]
69
- model = instantiate_from_config(config.model)
70
- m, u = model.load_state_dict(sd, strict=False)
71
-
72
- model.half().to("cuda")
 
 
 
 
 
 
 
 
73
 
74
  def load_safety_model(clip_model):
75
  """load the safety model"""
@@ -119,8 +92,10 @@ def is_unsafe(safety_model, embeddings, threshold=0.5):
119
  x = np.array([e[0] for e in nsfw_values])
120
  return True if x > threshold else False
121
 
122
- # model = load_model_from_config(config,model_path_e)
 
123
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
 
124
 
125
  #NSFW CLIP Filter
126
  safety_model = load_safety_model("ViT-B/32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pydoc import describe
2
  import gradio as gr
3
  import torch
 
11
  from huggingface_hub import hf_hub_download
12
  import spaces
13
 
 
14
  model_path_e = hf_hub_download(repo_id="multimodalart/compvis-latent-diffusion-text2img-large", filename="txt2img-f8-large.ckpt")
15
 
16
+ #@title Import stuff
17
  import argparse, os, sys, glob
18
  import numpy as np
19
  from PIL import Image
 
27
  from open_clip import tokenizer
28
  import open_clip
29
 
30
+ def load_model_from_config(config, ckpt, verbose=False):
31
+ print(f"Loading model from {ckpt}")
32
+ pl_sd = torch.load(ckpt, map_location="cuda")
33
+ sd = pl_sd["state_dict"]
34
+ model = instantiate_from_config(config.model)
35
+ m, u = model.load_state_dict(sd, strict=False)
36
+ if len(m) > 0 and verbose:
37
+ print("missing keys:")
38
+ print(m)
39
+ if len(u) > 0 and verbose:
40
+ print("unexpected keys:")
41
+ print(u)
42
+
43
+ model = model.half().cuda()
44
+ model.eval()
45
+ return model
46
 
47
  def load_safety_model(clip_model):
48
  """load the safety model"""
 
92
  x = np.array([e[0] for e in nsfw_values])
93
  return True if x > threshold else False
94
 
95
+ config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
96
+ model = load_model_from_config(config,model_path_e)
97
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
98
+ model = model.to(device)
99
 
100
  #NSFW CLIP Filter
101
  safety_model = load_safety_model("ViT-B/32")
requirements.txt CHANGED
@@ -3,11 +3,12 @@ ftfy
3
  regex
4
  tqdm
5
  omegaconf
6
- pytorch-lightning
7
  torch-fidelity
8
  transformers
9
  einops
10
  gradio
 
11
  open_clip_torch
12
  numpy
13
  tqdm
 
3
  regex
4
  tqdm
5
  omegaconf
6
+ pytorch-lightning==1.7.7
7
  torch-fidelity
8
  transformers
9
  einops
10
  gradio
11
+ torch==1.13.1
12
  open_clip_torch
13
  numpy
14
  tqdm