saadfarhad commited on
Commit
9cfec01
·
verified ·
1 Parent(s): 95d4486

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -15
app.py CHANGED
@@ -1,43 +1,55 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoConfig, AutoProcessor
4
- # Import the custom model class directly.
5
- from transformers.models.llava.modeling_llava import LlavaQwenForCausalLM
6
 
7
- # --- Diagnostic Print (Optional) ---
8
- config = AutoConfig.from_pretrained(
9
- "lmms-lab/LLaVA-Video-7B-Qwen2",
10
- trust_remote_code=True
11
- )
12
  print("Configuration type:", type(config))
13
  print("Configuration architectures:", config.architectures)
14
- # --- End Diagnostic ---
15
 
16
- # Load the processor and the model using the custom model class.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  processor = AutoProcessor.from_pretrained(
18
  "lmms-lab/LLaVA-Video-7B-Qwen2",
19
  trust_remote_code=True
20
  )
21
- model = LlavaQwenForCausalLM.from_pretrained(
22
  "lmms-lab/LLaVA-Video-7B-Qwen2",
23
  trust_remote_code=True
24
  )
25
 
26
- # Move model to the appropriate device.
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  model.to(device)
29
 
30
  def analyze_video(video_path):
31
  prompt = "Analyze this video of a concert and determine the moment when the crowd is most engaged."
32
- # Process the text and video.
33
  inputs = processor(text=prompt, video=video_path, return_tensors="pt")
34
  inputs = {k: v.to(device) for k, v in inputs.items()}
35
- # Generate output (assumes the custom model has a generate method).
36
  outputs = model.generate(**inputs, max_new_tokens=100)
37
  answer = processor.decode(outputs[0], skip_special_tokens=True)
38
  return answer
39
 
40
- # Create the Gradio Interface.
41
  iface = gr.Interface(
42
  fn=analyze_video,
43
  inputs=gr.Video(label="Upload Concert/Event Video", type="filepath"),
 
1
  import gradio as gr
2
  import torch
3
+ import importlib
4
+ from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM
5
+ from transformers.models.llava.configuration_llava import LlavaConfig
6
 
7
+ # --- Diagnostic: Load the configuration ---
8
+ config = AutoConfig.from_pretrained("lmms-lab/LLaVA-Video-7B-Qwen2", trust_remote_code=True)
 
 
 
9
  print("Configuration type:", type(config))
10
  print("Configuration architectures:", config.architectures)
 
11
 
12
+ # Expecting the architecture name to be "LlavaQwenForCausalLM"
13
+ arch = config.architectures[0] # This should be "LlavaQwenForCausalLM"
14
+
15
+ # --- Dynamic Import: Retrieve the model class by name ---
16
+ # Import the module that (should) contain the custom model class.
17
+ module = importlib.import_module("transformers.models.llava.modeling_llava")
18
+ try:
19
+ model_cls = getattr(module, arch)
20
+ print("Successfully imported model class:", model_cls)
21
+ except AttributeError:
22
+ raise ImportError(f"Cannot find class {arch} in module transformers.models.llava.modeling_llava")
23
+
24
+ # --- Register the Custom Model Class ---
25
+ # This tells the auto loader that for LlavaConfig, use our dynamically imported model class.
26
+ AutoModelForCausalLM.register(LlavaConfig, model_cls)
27
+
28
+ # --- Load Processor and Model ---
29
  processor = AutoProcessor.from_pretrained(
30
  "lmms-lab/LLaVA-Video-7B-Qwen2",
31
  trust_remote_code=True
32
  )
33
+ model = AutoModelForCausalLM.from_pretrained(
34
  "lmms-lab/LLaVA-Video-7B-Qwen2",
35
  trust_remote_code=True
36
  )
37
 
38
+ # Move model to GPU if available
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
  model.to(device)
41
 
42
  def analyze_video(video_path):
43
  prompt = "Analyze this video of a concert and determine the moment when the crowd is most engaged."
44
+ # Process the text and video input
45
  inputs = processor(text=prompt, video=video_path, return_tensors="pt")
46
  inputs = {k: v.to(device) for k, v in inputs.items()}
47
+ # Generate output (assuming the custom model implements generate)
48
  outputs = model.generate(**inputs, max_new_tokens=100)
49
  answer = processor.decode(outputs[0], skip_special_tokens=True)
50
  return answer
51
 
52
+ # Create the Gradio Interface
53
  iface = gr.Interface(
54
  fn=analyze_video,
55
  inputs=gr.Video(label="Upload Concert/Event Video", type="filepath"),