Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,134 +1,155 @@
|
|
1 |
-
import os
|
2 |
import gradio as gr
|
|
|
3 |
import torch
|
4 |
-
from peft import LoraConfig, get_peft_model
|
5 |
-
import torch.nn as nn
|
6 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
-
from peft import PeftModel, PeftConfig
|
8 |
-
|
9 |
from PIL import Image
|
10 |
-
import
|
11 |
-
import spaces
|
12 |
-
|
13 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
extended_attention_mask = torch.cat([torch.ones(attention_mask.shape[0], 1).to(attention_mask.device), attention_mask], dim=1)
|
27 |
-
|
28 |
-
outputs = self.phi_model(inputs_embeds=combined_embeds, attention_mask=extended_attention_mask)
|
29 |
-
return outputs.logits[:, 1:, :] # Exclude the image token from output
|
30 |
-
|
31 |
-
def load_models():
|
32 |
-
try:
|
33 |
-
print("Loading models...")
|
34 |
-
peft_model_name = "sagar007/phi-1_5-finetuned"
|
35 |
-
|
36 |
-
# Manually load and create LoraConfig, ignoring unknown arguments
|
37 |
-
config_dict = LoraConfig.from_pretrained(peft_model_name).to_dict()
|
38 |
-
# Remove 'layer_replication' if present
|
39 |
-
config_dict.pop('layer_replication', None)
|
40 |
-
lora_config = LoraConfig(**config_dict)
|
41 |
-
print("PEFT config loaded")
|
42 |
-
|
43 |
-
base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
44 |
-
print("Base model loaded")
|
45 |
-
|
46 |
-
phi_model = get_peft_model(base_model, lora_config)
|
47 |
-
phi_model.load_state_dict(torch.load(peft_model_name + '/adapter_model.bin', map_location=device), strict=False)
|
48 |
-
print("PEFT model loaded")
|
49 |
-
|
50 |
-
multimodal_model = MultimodalPhi(phi_model)
|
51 |
-
multimodal_model.load_state_dict(torch.load('multimodal_phi_small_gpu.pth', map_location=device))
|
52 |
-
multimodal_model.to(device)
|
53 |
-
multimodal_model.eval()
|
54 |
-
print("Multimodal model loaded")
|
55 |
-
|
56 |
-
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
|
57 |
-
tokenizer.pad_token = tokenizer.eos_token
|
58 |
-
print("Tokenizer loaded")
|
59 |
-
|
60 |
-
audio_model = whisper.load_model("base").to(device)
|
61 |
-
print("Audio model loaded")
|
62 |
-
|
63 |
-
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
|
64 |
-
print("CLIP model loaded")
|
65 |
-
|
66 |
-
return multimodal_model, tokenizer, audio_model, clip_model, clip_preprocess
|
67 |
-
except Exception as e:
|
68 |
-
print(f"Error in load_models: {str(e)}")
|
69 |
-
raise
|
70 |
-
|
71 |
-
model, tokenizer, audio_model, clip_model, clip_preprocess = load_models()
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
image = clip_preprocess(Image.open(image)).unsqueeze(0).to(device)
|
76 |
-
with torch.no_grad():
|
77 |
-
image_features = clip_model.encode_image(image)
|
78 |
-
return image_features.squeeze(0)
|
79 |
|
80 |
-
|
81 |
-
def
|
82 |
try:
|
83 |
-
inputs = tokenizer(
|
84 |
-
|
|
|
85 |
with torch.no_grad():
|
86 |
-
outputs = model
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
except Exception as e:
|
89 |
-
return f"Error
|
90 |
|
91 |
-
|
92 |
-
def
|
93 |
try:
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
with torch.no_grad():
|
98 |
-
outputs = model
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
except Exception as e:
|
110 |
-
return f"Error
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
|
|
132 |
if __name__ == "__main__":
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
|
3 |
import torch
|
|
|
|
|
|
|
|
|
|
|
4 |
from PIL import Image
|
5 |
+
import os
|
|
|
|
|
|
|
6 |
|
7 |
+
# Check if CUDA is available, otherwise use CPU
|
8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
print(f"Using device: {device}")
|
10 |
+
|
11 |
+
# Load model and tokenizer with optimizations for CPU deployment
|
12 |
+
def load_model():
|
13 |
+
print("Loading model and tokenizer...")
|
14 |
+
model = AutoModelForCausalLM.from_pretrained(
|
15 |
+
"sagar007/Lava_phi",
|
16 |
+
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
|
17 |
+
low_cpu_mem_usage=True,
|
18 |
+
)
|
19 |
+
model = model.to(device)
|
20 |
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained("sagar007/Lava_phi")
|
22 |
+
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
23 |
+
|
24 |
+
print("Model and tokenizer loaded successfully!")
|
25 |
+
return model, tokenizer, processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
# Load models
|
28 |
+
model, tokenizer, processor = load_model()
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
# For text-only generation
|
31 |
+
def generate_text(prompt, max_length=128):
|
32 |
try:
|
33 |
+
inputs = tokenizer(f"human: {prompt}\ngpt:", return_tensors="pt").to(device)
|
34 |
+
|
35 |
+
# Generate with low memory footprint settings
|
36 |
with torch.no_grad():
|
37 |
+
outputs = model.generate(
|
38 |
+
**inputs,
|
39 |
+
max_new_tokens=max_length,
|
40 |
+
do_sample=True,
|
41 |
+
temperature=0.7,
|
42 |
+
top_p=0.9,
|
43 |
+
)
|
44 |
+
|
45 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
46 |
+
|
47 |
+
# Extract only the model's response
|
48 |
+
if "gpt:" in generated_text:
|
49 |
+
generated_text = generated_text.split("gpt:", 1)[1].strip()
|
50 |
+
|
51 |
+
return generated_text
|
52 |
except Exception as e:
|
53 |
+
return f"Error generating text: {str(e)}"
|
54 |
|
55 |
+
# For image and text processing
|
56 |
+
def process_image_and_prompt(image, prompt, max_length=128):
|
57 |
try:
|
58 |
+
if image is None:
|
59 |
+
return "No image provided. Please upload an image."
|
60 |
+
|
61 |
+
# Process image
|
62 |
+
image_tensor = processor(images=image, return_tensors="pt").pixel_values.to(device)
|
63 |
+
|
64 |
+
# Tokenize input with image token
|
65 |
+
inputs = tokenizer(f"human: <image>\n{prompt}\ngpt:", return_tensors="pt").to(device)
|
66 |
+
|
67 |
+
# Generate with memory optimizations
|
68 |
with torch.no_grad():
|
69 |
+
outputs = model.generate(
|
70 |
+
input_ids=inputs["input_ids"],
|
71 |
+
attention_mask=inputs["attention_mask"],
|
72 |
+
images=image_tensor,
|
73 |
+
max_new_tokens=max_length,
|
74 |
+
do_sample=True,
|
75 |
+
temperature=0.7,
|
76 |
+
top_p=0.9,
|
77 |
+
)
|
78 |
+
|
79 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
80 |
+
|
81 |
+
# Extract only the model's response
|
82 |
+
if "gpt:" in generated_text:
|
83 |
+
generated_text = generated_text.split("gpt:", 1)[1].strip()
|
84 |
+
|
85 |
+
return generated_text
|
86 |
except Exception as e:
|
87 |
+
return f"Error processing image: {str(e)}"
|
88 |
|
89 |
+
# Create Gradio Interface
|
90 |
+
with gr.Blocks(title="LLaVA-Phi: Vision-Language Model") as demo:
|
91 |
+
gr.Markdown("# LLaVA-Phi: Vision-Language Model")
|
92 |
+
gr.Markdown("This model can generate text responses from text prompts or analyze images with text prompts.")
|
93 |
+
|
94 |
+
with gr.Tab("Text Generation"):
|
95 |
+
with gr.Row():
|
96 |
+
with gr.Column():
|
97 |
+
text_input = gr.Textbox(label="Enter your prompt", lines=3, placeholder="What is artificial intelligence?")
|
98 |
+
text_max_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Maximum response length")
|
99 |
+
text_button = gr.Button("Generate")
|
100 |
+
|
101 |
+
text_output = gr.Textbox(label="Generated response", lines=8)
|
102 |
+
|
103 |
+
text_button.click(
|
104 |
+
fn=generate_text,
|
105 |
+
inputs=[text_input, text_max_length],
|
106 |
+
outputs=text_output
|
107 |
+
)
|
108 |
+
|
109 |
+
with gr.Tab("Image + Text Analysis"):
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column():
|
112 |
+
image_input = gr.Image(type="pil", label="Upload an image")
|
113 |
+
image_text_input = gr.Textbox(label="Enter your prompt about the image",
|
114 |
+
lines=2,
|
115 |
+
placeholder="Describe this image in detail.")
|
116 |
+
image_max_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Maximum response length")
|
117 |
+
image_button = gr.Button("Analyze")
|
118 |
+
|
119 |
+
image_output = gr.Textbox(label="Model response", lines=8)
|
120 |
+
|
121 |
+
image_button.click(
|
122 |
+
fn=process_image_and_prompt,
|
123 |
+
inputs=[image_input, image_text_input, image_max_length],
|
124 |
+
outputs=image_output
|
125 |
+
)
|
126 |
+
|
127 |
+
# Example inputs for each tab
|
128 |
+
gr.Examples(
|
129 |
+
examples=["What is the advantage of vision-language models?",
|
130 |
+
"Explain how multimodal AI models work.",
|
131 |
+
"Tell me a short story about robots."],
|
132 |
+
inputs=text_input
|
133 |
+
)
|
134 |
+
|
135 |
+
# Add examples for image tab if you have example images
|
136 |
+
# gr.Examples(
|
137 |
+
# examples=[["example1.jpg", "What's in this image?"]],
|
138 |
+
# inputs=[image_input, image_text_input]
|
139 |
+
# )
|
140 |
|
141 |
+
# Launch the app with memory optimizations
|
142 |
if __name__ == "__main__":
|
143 |
+
# Memory cleanup before launch
|
144 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
145 |
+
|
146 |
+
# Set low CPU thread usage to reduce memory
|
147 |
+
os.environ["OMP_NUM_THREADS"] = "4"
|
148 |
+
|
149 |
+
# Launch with minimal resource usage
|
150 |
+
demo.launch(
|
151 |
+
share=True, # Set to False in production
|
152 |
+
enable_queue=True,
|
153 |
+
max_threads=4,
|
154 |
+
show_error=True
|
155 |
+
)
|