wifix199 commited on
Commit
b32f8d8
·
verified ·
1 Parent(s): b6bd576

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -21
app.py CHANGED
@@ -6,32 +6,16 @@ model_id = "SG161222/RealVisXL_V4.0"
6
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
7
  pipe.to("cpu") # Use "cuda" if GPU is available
8
 
9
- unet = pipe.unet
10
-
11
- def generate_image(prompt, unet, pipe):
12
- # Tokenize the prompt
13
- tokens = pipe.tokenizer(prompt, padding=True, truncation=True, max_length=50, return_tensors="pt").to(unet.device)
14
-
15
- # Generate the image
16
- text_embeds = pipe.text_encoder(tokens.input_ids)
17
- image = unet(text_embeds=text_embeds).images[0]
18
  return image
19
 
20
  def chatbot(prompt):
21
  # Generate the image based on the user's input
22
- image = generate_image(prompt, unet, pipe)
23
  return image
24
 
25
- def get_aug_embed(self, text_embeds, image):
26
- if text_embeds is None:
27
- text_embeds = self.text_encoder(
28
- text_embeds=text_embeds,
29
- image=image,
30
- height=self.unet.config.sample_size,
31
- width=self.unet.config.sample_size,
32
- )
33
- return text_embeds
34
-
35
  # Create the Gradio interface
36
  interface = gr.Interface(
37
  fn=chatbot,
@@ -42,4 +26,4 @@ interface = gr.Interface(
42
  )
43
 
44
  # Launch the interface
45
- interface.launch()
 
6
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
7
  pipe.to("cpu") # Use "cuda" if GPU is available
8
 
9
+ def generate_image(prompt, pipe):
10
+ # Generate the image using the pipeline
11
+ image = pipe(prompt).images[0]
 
 
 
 
 
 
12
  return image
13
 
14
  def chatbot(prompt):
15
  # Generate the image based on the user's input
16
+ image = generate_image(prompt, pipe)
17
  return image
18
 
 
 
 
 
 
 
 
 
 
 
19
  # Create the Gradio interface
20
  interface = gr.Interface(
21
  fn=chatbot,
 
26
  )
27
 
28
  # Launch the interface
29
+ interface.launch()