import os import gradio as gr import torch from peft import LoraConfig, get_peft_model import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel, PeftConfig from PIL import Image import clip import spaces device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class MultimodalPhi(nn.Module): def __init__(self, phi_model): super().__init__() self.phi_model = phi_model self.embedding_projection = nn.Linear(512, phi_model.config.hidden_size) def forward(self, image_embeddings, input_ids, attention_mask): projected_embeddings = self.embedding_projection(image_embeddings).unsqueeze(1) inputs_embeds = self.phi_model.get_input_embeddings()(input_ids) combined_embeds = torch.cat([projected_embeddings, inputs_embeds], dim=1) extended_attention_mask = torch.cat([torch.ones(attention_mask.shape[0], 1).to(attention_mask.device), attention_mask], dim=1) outputs = self.phi_model(inputs_embeds=combined_embeds, attention_mask=extended_attention_mask) return outputs.logits[:, 1:, :] # Exclude the image token from output def load_models(): try: print("Loading models...") peft_model_name = "sagar007/phi-1_5-finetuned" # Manually load and create LoraConfig, ignoring unknown arguments config_dict = LoraConfig.from_pretrained(peft_model_name).to_dict() # Remove 'layer_replication' if present config_dict.pop('layer_replication', None) lora_config = LoraConfig(**config_dict) print("PEFT config loaded") base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) print("Base model loaded") phi_model = get_peft_model(base_model, lora_config) phi_model.load_state_dict(torch.load(peft_model_name + '/adapter_model.bin', map_location=device), strict=False) print("PEFT model loaded") multimodal_model = MultimodalPhi(phi_model) multimodal_model.load_state_dict(torch.load('multimodal_phi_small_gpu.pth', map_location=device)) multimodal_model.to(device) multimodal_model.eval() print("Multimodal model loaded") tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5") tokenizer.pad_token = tokenizer.eos_token print("Tokenizer loaded") audio_model = whisper.load_model("base").to(device) print("Audio model loaded") clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) print("CLIP model loaded") return multimodal_model, tokenizer, audio_model, clip_model, clip_preprocess except Exception as e: print(f"Error in load_models: {str(e)}") raise model, tokenizer, audio_model, clip_model, clip_preprocess = load_models() @spaces.GPU def get_clip_embedding(image): image = clip_preprocess(Image.open(image)).unsqueeze(0).to(device) with torch.no_grad(): image_features = clip_model.encode_image(image) return image_features.squeeze(0) @spaces.GPU def process_text(text): try: inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding='max_length').to(device) dummy_image_embedding = torch.zeros(512).to(device) # Dummy image embedding for text-only input with torch.no_grad(): outputs = model(dummy_image_embedding.unsqueeze(0), inputs.input_ids, inputs.attention_mask) return tokenizer.decode(outputs[0].argmax(dim=-1), skip_special_tokens=True) except Exception as e: return f"Error in process_text: {str(e)}" @spaces.GPU def process_image(image): try: clip_embedding = get_clip_embedding(image) prompt = "Describe this image:" inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128, padding='max_length').to(device) with torch.no_grad(): outputs = model(clip_embedding.unsqueeze(0), inputs.input_ids, inputs.attention_mask) return tokenizer.decode(outputs[0].argmax(dim=-1), skip_special_tokens=True) except Exception as e: return f"Error in process_image: {str(e)}" @spaces.GPU def process_audio(audio): try: result = audio_model.transcribe(audio) transcription = result["text"] return process_text(f"Transcription: {transcription}\nPlease respond to this:") except Exception as e: return f"Error in process_audio: {str(e)}" def chat(message, image, audio): if audio is not None: return process_audio(audio) elif image is not None: return process_image(image) else: return process_text(message) iface = gr.Interface( fn=chat, inputs=[ gr.Textbox(placeholder="Enter text here..."), gr.Image(type="pil"), gr.Audio(type="filepath") ], outputs="text", title="Multi-Modal Assistant", description="Chat with an AI using text, images, or audio!" ) if __name__ == "__main__": print("Starting Gradio interface...") iface.launch(share=True)