Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
from transformers import BitsAndBytesConfig, LlavaNextForConditionalGeneration, AutoProcessor | |
import gc | |
MODEL_ID = "arjunanand13/gas_pipe_llava_finetunedv3" | |
def clear_memory(): | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
def extract_frames_from_video(video_path, num_frames=4): | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
raise ValueError(f"Cannot open video: {video_path}") | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if total_frames < num_frames: | |
num_frames = min(total_frames, num_frames) | |
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
frames = [] | |
for frame_idx in frame_indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
ret, frame = cap.read() | |
if ret: | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame_pil = Image.fromarray(frame_rgb) | |
frame_resized = frame_pil.resize((112, 112), Image.Resampling.LANCZOS) | |
frames.append(frame_resized) | |
cap.release() | |
while len(frames) < 4: | |
if frames: | |
frames.append(frames[-1].copy()) | |
else: | |
frames.append(Image.new('RGB', (112, 112), color='black')) | |
return frames[:4] | |
def create_frame_grid(frames, grid_size=(2, 2)): | |
cols, rows = grid_size | |
frame_size = 112 | |
grid_width = frame_size * cols | |
grid_height = frame_size * rows | |
grid_image = Image.new('RGB', (grid_width, grid_height)) | |
for i, frame in enumerate(frames): | |
row = i // cols | |
col = i % cols | |
x = col * frame_size | |
y = row * frame_size | |
grid_image.paste(frame, (x, y)) | |
return grid_image | |
def load_model(): | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_storage=torch.uint8 | |
) | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
processor.tokenizer.padding_side = "right" | |
processor.tokenizer.pad_token = processor.tokenizer.eos_token | |
model = LlavaNextForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16, | |
quantization_config=bnb_config, | |
device_map="auto", | |
low_cpu_mem_usage=True, | |
trust_remote_code=True | |
) | |
model.config.use_cache = False | |
model.eval() | |
return model, processor | |
model, processor = load_model() | |
def predict_gas_pipe_quality(video_path): | |
try: | |
frames = extract_frames_from_video(video_path, num_frames=4) | |
grid_image = create_frame_grid(frames, grid_size=(2, 2)) | |
prompt = "[INST] <image>\nGas pipe test result? [/INST]" | |
inputs = processor(text=prompt, images=grid_image, return_tensors="pt") | |
if torch.cuda.is_available(): | |
inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
pixel_values=inputs["pixel_values"], | |
image_sizes=inputs["image_sizes"], | |
max_new_tokens=16, | |
do_sample=False, | |
pad_token_id=processor.tokenizer.eos_token_id | |
) | |
prediction = processor.batch_decode( | |
generated_ids[:, inputs["input_ids"].size(1):], | |
skip_special_tokens=True | |
)[0].strip() | |
clear_memory() | |
return grid_image, prediction | |
except Exception as e: | |
clear_memory() | |
return None, f"Error: {str(e)}" | |
def create_interface(): | |
with gr.Blocks(title="Gas Pipe Quality Control") as iface: | |
gr.Markdown("# Gas Pipe Quality Control") | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="Upload Video") | |
analyze_btn = gr.Button("Analyze", variant="primary") | |
with gr.Column(): | |
frame_grid = gr.Image(label="Extracted Frames") | |
result_output = gr.Textbox(label="Model Output", lines=3) | |
gr.Examples( | |
examples=[ | |
["13.mp4"], | |
["14.mp4"], | |
["04.mp4"], | |
["07_part1.mp4"], | |
["09_part1.mp4"], | |
["29_part1.mp4"] | |
], | |
inputs=video_input, | |
outputs=[frame_grid, result_output], | |
fn=predict_gas_pipe_quality, | |
cache_examples=False | |
) | |
analyze_btn.click( | |
fn=predict_gas_pipe_quality, | |
inputs=video_input, | |
outputs=[frame_grid, result_output] | |
) | |
video_input.change( | |
fn=predict_gas_pipe_quality, | |
inputs=video_input, | |
outputs=[frame_grid, result_output] | |
) | |
return iface | |
if __name__ == "__main__": | |
iface = create_interface() | |
iface.launch(share=True) |