sagar007 commited on
Commit
8d741e2
·
verified ·
1 Parent(s): 0612b41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -117
app.py CHANGED
@@ -1,134 +1,155 @@
1
- import os
2
  import gradio as gr
 
3
  import torch
4
- from peft import LoraConfig, get_peft_model
5
- import torch.nn as nn
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
- from peft import PeftModel, PeftConfig
8
-
9
  from PIL import Image
10
- import clip
11
- import spaces
12
-
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
- class MultimodalPhi(nn.Module):
16
- def __init__(self, phi_model):
17
- super().__init__()
18
- self.phi_model = phi_model
19
- self.embedding_projection = nn.Linear(512, phi_model.config.hidden_size)
 
 
 
 
 
 
 
 
20
 
21
- def forward(self, image_embeddings, input_ids, attention_mask):
22
- projected_embeddings = self.embedding_projection(image_embeddings).unsqueeze(1)
23
- inputs_embeds = self.phi_model.get_input_embeddings()(input_ids)
24
- combined_embeds = torch.cat([projected_embeddings, inputs_embeds], dim=1)
25
-
26
- extended_attention_mask = torch.cat([torch.ones(attention_mask.shape[0], 1).to(attention_mask.device), attention_mask], dim=1)
27
-
28
- outputs = self.phi_model(inputs_embeds=combined_embeds, attention_mask=extended_attention_mask)
29
- return outputs.logits[:, 1:, :] # Exclude the image token from output
30
-
31
- def load_models():
32
- try:
33
- print("Loading models...")
34
- peft_model_name = "sagar007/phi-1_5-finetuned"
35
-
36
- # Manually load and create LoraConfig, ignoring unknown arguments
37
- config_dict = LoraConfig.from_pretrained(peft_model_name).to_dict()
38
- # Remove 'layer_replication' if present
39
- config_dict.pop('layer_replication', None)
40
- lora_config = LoraConfig(**config_dict)
41
- print("PEFT config loaded")
42
-
43
- base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
44
- print("Base model loaded")
45
-
46
- phi_model = get_peft_model(base_model, lora_config)
47
- phi_model.load_state_dict(torch.load(peft_model_name + '/adapter_model.bin', map_location=device), strict=False)
48
- print("PEFT model loaded")
49
-
50
- multimodal_model = MultimodalPhi(phi_model)
51
- multimodal_model.load_state_dict(torch.load('multimodal_phi_small_gpu.pth', map_location=device))
52
- multimodal_model.to(device)
53
- multimodal_model.eval()
54
- print("Multimodal model loaded")
55
-
56
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
57
- tokenizer.pad_token = tokenizer.eos_token
58
- print("Tokenizer loaded")
59
-
60
- audio_model = whisper.load_model("base").to(device)
61
- print("Audio model loaded")
62
-
63
- clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
64
- print("CLIP model loaded")
65
-
66
- return multimodal_model, tokenizer, audio_model, clip_model, clip_preprocess
67
- except Exception as e:
68
- print(f"Error in load_models: {str(e)}")
69
- raise
70
-
71
- model, tokenizer, audio_model, clip_model, clip_preprocess = load_models()
72
 
73
- @spaces.GPU
74
- def get_clip_embedding(image):
75
- image = clip_preprocess(Image.open(image)).unsqueeze(0).to(device)
76
- with torch.no_grad():
77
- image_features = clip_model.encode_image(image)
78
- return image_features.squeeze(0)
79
 
80
- @spaces.GPU
81
- def process_text(text):
82
  try:
83
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding='max_length').to(device)
84
- dummy_image_embedding = torch.zeros(512).to(device) # Dummy image embedding for text-only input
 
85
  with torch.no_grad():
86
- outputs = model(dummy_image_embedding.unsqueeze(0), inputs.input_ids, inputs.attention_mask)
87
- return tokenizer.decode(outputs[0].argmax(dim=-1), skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  except Exception as e:
89
- return f"Error in process_text: {str(e)}"
90
 
91
- @spaces.GPU
92
- def process_image(image):
93
  try:
94
- clip_embedding = get_clip_embedding(image)
95
- prompt = "Describe this image:"
96
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128, padding='max_length').to(device)
 
 
 
 
 
 
 
97
  with torch.no_grad():
98
- outputs = model(clip_embedding.unsqueeze(0), inputs.input_ids, inputs.attention_mask)
99
- return tokenizer.decode(outputs[0].argmax(dim=-1), skip_special_tokens=True)
100
- except Exception as e:
101
- return f"Error in process_image: {str(e)}"
102
-
103
- @spaces.GPU
104
- def process_audio(audio):
105
- try:
106
- result = audio_model.transcribe(audio)
107
- transcription = result["text"]
108
- return process_text(f"Transcription: {transcription}\nPlease respond to this:")
 
 
 
 
 
 
109
  except Exception as e:
110
- return f"Error in process_audio: {str(e)}"
111
 
112
- def chat(message, image, audio):
113
- if audio is not None:
114
- return process_audio(audio)
115
- elif image is not None:
116
- return process_image(image)
117
- else:
118
- return process_text(message)
119
-
120
- iface = gr.Interface(
121
- fn=chat,
122
- inputs=[
123
- gr.Textbox(placeholder="Enter text here..."),
124
- gr.Image(type="pil"),
125
- gr.Audio(type="filepath")
126
- ],
127
- outputs="text",
128
- title="Multi-Modal Assistant",
129
- description="Chat with an AI using text, images, or audio!"
130
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
 
132
  if __name__ == "__main__":
133
- print("Starting Gradio interface...")
134
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
3
  import torch
 
 
 
 
 
4
  from PIL import Image
5
+ import os
 
 
 
6
 
7
+ # Check if CUDA is available, otherwise use CPU
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"Using device: {device}")
10
+
11
+ # Load model and tokenizer with optimizations for CPU deployment
12
+ def load_model():
13
+ print("Loading model and tokenizer...")
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ "sagar007/Lava_phi",
16
+ torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
17
+ low_cpu_mem_usage=True,
18
+ )
19
+ model = model.to(device)
20
 
21
+ tokenizer = AutoTokenizer.from_pretrained("sagar007/Lava_phi")
22
+ processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
23
+
24
+ print("Model and tokenizer loaded successfully!")
25
+ return model, tokenizer, processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Load models
28
+ model, tokenizer, processor = load_model()
 
 
 
 
29
 
30
+ # For text-only generation
31
+ def generate_text(prompt, max_length=128):
32
  try:
33
+ inputs = tokenizer(f"human: {prompt}\ngpt:", return_tensors="pt").to(device)
34
+
35
+ # Generate with low memory footprint settings
36
  with torch.no_grad():
37
+ outputs = model.generate(
38
+ **inputs,
39
+ max_new_tokens=max_length,
40
+ do_sample=True,
41
+ temperature=0.7,
42
+ top_p=0.9,
43
+ )
44
+
45
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+
47
+ # Extract only the model's response
48
+ if "gpt:" in generated_text:
49
+ generated_text = generated_text.split("gpt:", 1)[1].strip()
50
+
51
+ return generated_text
52
  except Exception as e:
53
+ return f"Error generating text: {str(e)}"
54
 
55
+ # For image and text processing
56
+ def process_image_and_prompt(image, prompt, max_length=128):
57
  try:
58
+ if image is None:
59
+ return "No image provided. Please upload an image."
60
+
61
+ # Process image
62
+ image_tensor = processor(images=image, return_tensors="pt").pixel_values.to(device)
63
+
64
+ # Tokenize input with image token
65
+ inputs = tokenizer(f"human: <image>\n{prompt}\ngpt:", return_tensors="pt").to(device)
66
+
67
+ # Generate with memory optimizations
68
  with torch.no_grad():
69
+ outputs = model.generate(
70
+ input_ids=inputs["input_ids"],
71
+ attention_mask=inputs["attention_mask"],
72
+ images=image_tensor,
73
+ max_new_tokens=max_length,
74
+ do_sample=True,
75
+ temperature=0.7,
76
+ top_p=0.9,
77
+ )
78
+
79
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
+
81
+ # Extract only the model's response
82
+ if "gpt:" in generated_text:
83
+ generated_text = generated_text.split("gpt:", 1)[1].strip()
84
+
85
+ return generated_text
86
  except Exception as e:
87
+ return f"Error processing image: {str(e)}"
88
 
89
+ # Create Gradio Interface
90
+ with gr.Blocks(title="LLaVA-Phi: Vision-Language Model") as demo:
91
+ gr.Markdown("# LLaVA-Phi: Vision-Language Model")
92
+ gr.Markdown("This model can generate text responses from text prompts or analyze images with text prompts.")
93
+
94
+ with gr.Tab("Text Generation"):
95
+ with gr.Row():
96
+ with gr.Column():
97
+ text_input = gr.Textbox(label="Enter your prompt", lines=3, placeholder="What is artificial intelligence?")
98
+ text_max_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Maximum response length")
99
+ text_button = gr.Button("Generate")
100
+
101
+ text_output = gr.Textbox(label="Generated response", lines=8)
102
+
103
+ text_button.click(
104
+ fn=generate_text,
105
+ inputs=[text_input, text_max_length],
106
+ outputs=text_output
107
+ )
108
+
109
+ with gr.Tab("Image + Text Analysis"):
110
+ with gr.Row():
111
+ with gr.Column():
112
+ image_input = gr.Image(type="pil", label="Upload an image")
113
+ image_text_input = gr.Textbox(label="Enter your prompt about the image",
114
+ lines=2,
115
+ placeholder="Describe this image in detail.")
116
+ image_max_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Maximum response length")
117
+ image_button = gr.Button("Analyze")
118
+
119
+ image_output = gr.Textbox(label="Model response", lines=8)
120
+
121
+ image_button.click(
122
+ fn=process_image_and_prompt,
123
+ inputs=[image_input, image_text_input, image_max_length],
124
+ outputs=image_output
125
+ )
126
+
127
+ # Example inputs for each tab
128
+ gr.Examples(
129
+ examples=["What is the advantage of vision-language models?",
130
+ "Explain how multimodal AI models work.",
131
+ "Tell me a short story about robots."],
132
+ inputs=text_input
133
+ )
134
+
135
+ # Add examples for image tab if you have example images
136
+ # gr.Examples(
137
+ # examples=[["example1.jpg", "What's in this image?"]],
138
+ # inputs=[image_input, image_text_input]
139
+ # )
140
 
141
+ # Launch the app with memory optimizations
142
  if __name__ == "__main__":
143
+ # Memory cleanup before launch
144
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
145
+
146
+ # Set low CPU thread usage to reduce memory
147
+ os.environ["OMP_NUM_THREADS"] = "4"
148
+
149
+ # Launch with minimal resource usage
150
+ demo.launch(
151
+ share=True, # Set to False in production
152
+ enable_queue=True,
153
+ max_threads=4,
154
+ show_error=True
155
+ )