aiqtech commited on
Commit
d5f7879
·
verified ·
1 Parent(s): bfd7809

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -30
app.py CHANGED
@@ -8,39 +8,57 @@ from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
10
  from compel import Compel, ReturnedEmbeddingsType
 
 
 
 
 
 
 
 
 
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- # Make sure to use torch.float16 consistently throughout the pipeline
15
- pipe = StableDiffusionXLPipeline.from_pretrained(
16
- "votepurchase/waiREALCN_v14",
17
- torch_dtype=torch.float16,
18
- variant="fp16", # Explicitly use fp16 variant
19
- use_safetensors=True # Use safetensors if available
20
- )
21
-
22
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
23
- pipe.to(device)
24
-
25
- # Force all components to use the same dtype
26
- pipe.text_encoder.to(torch.float16)
27
- pipe.text_encoder_2.to(torch.float16)
28
- pipe.vae.to(torch.float16)
29
- pipe.unet.to(torch.float16)
30
-
31
- # 追加: Initialize Compel for long prompt processing
32
- compel = Compel(
33
- tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
34
- text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
35
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
36
- requires_pooled=[False, True],
37
- truncate_long_prompts=False
38
- )
 
 
 
 
 
 
 
 
 
39
 
40
  MAX_SEED = np.iinfo(np.int32).max
41
  MAX_IMAGE_SIZE = 1216
42
 
43
- # 追加: Simple long prompt processing function
44
  def process_long_prompt(prompt, negative_prompt=""):
45
  """Simple long prompt processing using Compel"""
46
  try:
@@ -52,7 +70,11 @@ def process_long_prompt(prompt, negative_prompt=""):
52
 
53
  @spaces.GPU
54
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
55
- # 変更: Remove the 60-word limit warning and add long prompt check
 
 
 
 
56
  use_long_prompt = len(prompt.split()) > 60 or len(prompt) > 300
57
 
58
  if randomize_seed:
@@ -61,7 +83,7 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
61
  generator = torch.Generator(device=device).manual_seed(seed)
62
 
63
  try:
64
- # 追加: Try long prompt processing first if prompt is long
65
  if use_long_prompt:
66
  print("Using long prompt processing...")
67
  conditioning, pooled = process_long_prompt(prompt, negative_prompt)
@@ -109,13 +131,15 @@ css = """
109
  with gr.Blocks(css=css) as demo:
110
 
111
  with gr.Column(elem_id="col-container"):
 
 
112
 
113
  with gr.Row():
114
  prompt = gr.Text(
115
  label="Prompt",
116
  show_label=False,
117
  max_lines=1,
118
- placeholder="Enter your prompt (long prompts are automatically supported)", # 変更: Updated placeholder
119
  container=False,
120
  )
121
 
@@ -182,4 +206,4 @@ with gr.Blocks(css=css) as demo:
182
  outputs=[result]
183
  )
184
 
185
- demo.queue().launch()
 
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
10
  from compel import Compel, ReturnedEmbeddingsType
11
+ from huggingface_hub import login
12
+ import os
13
+
14
+ # Add your Hugging Face token here or set it as an environment variable
15
+ HF_TOKEN = os.getenv("HF_TOKEN") # Get from environment variable
16
+ # Or directly: HF_TOKEN = "hf_your_token_here"
17
+
18
+ if HF_TOKEN:
19
+ login(token=HF_TOKEN)
20
 
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
 
23
+ try:
24
+ # Make sure to use torch.float16 consistently throughout the pipeline
25
+ pipe = StableDiffusionXLPipeline.from_pretrained(
26
+ "votepurchase/waiREALCN_v14",
27
+ torch_dtype=torch.float16,
28
+ variant="fp16", # Explicitly use fp16 variant
29
+ use_safetensors=True, # Use safetensors if available
30
+ use_auth_token=HF_TOKEN # Pass token to download
31
+ )
32
+
33
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
34
+ pipe.to(device)
35
+
36
+ # Force all components to use the same dtype
37
+ pipe.text_encoder.to(torch.float16)
38
+ pipe.text_encoder_2.to(torch.float16)
39
+ pipe.vae.to(torch.float16)
40
+ pipe.unet.to(torch.float16)
41
+
42
+ # Initialize Compel for long prompt processing
43
+ compel = Compel(
44
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
45
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
46
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
47
+ requires_pooled=[False, True],
48
+ truncate_long_prompts=False
49
+ )
50
+
51
+ model_loaded = True
52
+ except Exception as e:
53
+ print(f"Failed to load model: {e}")
54
+ model_loaded = False
55
+ pipe = None
56
+ compel = None
57
 
58
  MAX_SEED = np.iinfo(np.int32).max
59
  MAX_IMAGE_SIZE = 1216
60
 
61
+ # Simple long prompt processing function
62
  def process_long_prompt(prompt, negative_prompt=""):
63
  """Simple long prompt processing using Compel"""
64
  try:
 
70
 
71
  @spaces.GPU
72
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
73
+ if not model_loaded:
74
+ error_img = Image.new('RGB', (width, height), color=(50, 50, 50))
75
+ return error_img
76
+
77
+ # Remove the 60-word limit warning and add long prompt check
78
  use_long_prompt = len(prompt.split()) > 60 or len(prompt) > 300
79
 
80
  if randomize_seed:
 
83
  generator = torch.Generator(device=device).manual_seed(seed)
84
 
85
  try:
86
+ # Try long prompt processing first if prompt is long
87
  if use_long_prompt:
88
  print("Using long prompt processing...")
89
  conditioning, pooled = process_long_prompt(prompt, negative_prompt)
 
131
  with gr.Blocks(css=css) as demo:
132
 
133
  with gr.Column(elem_id="col-container"):
134
+ if not model_loaded:
135
+ gr.Markdown("⚠️ **Model failed to load. Please check your Hugging Face token.**")
136
 
137
  with gr.Row():
138
  prompt = gr.Text(
139
  label="Prompt",
140
  show_label=False,
141
  max_lines=1,
142
+ placeholder="Enter your prompt (long prompts are automatically supported)",
143
  container=False,
144
  )
145
 
 
206
  outputs=[result]
207
  )
208
 
209
+ demo.queue().launch()