import torch import webvtt import os import cv2 from minigpt4.common.eval_utils import prepare_texts, init_model from minigpt4.conversation.conversation import CONV_VISION from torchvision import transforms import json from tqdm import tqdm import soundfile as sf import argparse import moviepy.editor as mp import gradio as gr from pytubefix import YouTube import shutil from PIL import Image from moviepy.editor import VideoFileClip import torch import random import numpy as np import torch.backends.cudnn as cudnn from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from threading import Thread import cv2 import webvtt from bisect import bisect_left import whisper import time from datetime import timedelta # Function to format timestamps for VTT def format_timestamp(seconds): td = timedelta(seconds=seconds) total_seconds = int(td.total_seconds()) milliseconds = int(td.microseconds / 1000) hours, remainder = divmod(total_seconds, 3600) minutes, seconds = divmod(remainder, 60) return f"{hours:02}:{minutes:02}:{seconds:02}.{milliseconds:03}" def extract_video_info(video_path,max_images_length): clip = VideoFileClip(video_path) total_num_frames = int(clip.duration * clip.fps) clip.close() sampling_interval = int(total_num_frames / max_images_length) if sampling_interval == 0: sampling_interval = 1 return sampling_interval,clip.fps def time_to_milliseconds(time_str): # Convert time format "hh:mm:ss.sss" to milliseconds h, m, s = map(float, time_str.split(':')) return int((h * 3600 + m * 60 + s) * 1000) def extract_subtitles(subtitle_path): # Parse the VTT file into a list of (start_time_ms, end_time_ms, text) subtitles = [] for caption in webvtt.read(subtitle_path): start_ms = time_to_milliseconds(caption.start) end_ms = time_to_milliseconds(caption.end) text = caption.text.strip().replace('\n', ' ') subtitles.append((start_ms, end_ms, text)) return subtitles def find_subtitle(subtitles, frame_count, fps): frame_time = (frame_count / fps)*1000 left, right = 0, len(subtitles) - 1 while left <= right: mid = (left + right) // 2 start, end, subtitle_text = subtitles[mid] # print("Mid start end sub ",mid,start,end,subtitle_text) if start <= frame_time <= end: return subtitle_text elif frame_time < start: right = mid - 1 else: left = mid + 1 return None # If no subtitle is found def match_frames_and_subtitles(video_path,subtitles,sampling_interval,max_sub_len,fps,max_frames): cap = cv2.VideoCapture(video_path) images = [] frame_count = 0 img_placeholder = "" subtitle_text_in_interval = "" history_subtitles = {} number_of_words=0 transform=transforms.Compose([ transforms.ToPILImage(), ]) # first_frame=cap.read()[1] # video_out=cv2.VideoWriter("old_prepare_input.mp4",cv2.VideoWriter_fourcc(*'mp4v'), 1, (first_frame.shape[1],first_frame.shape[0])) while cap.isOpened(): ret, frame = cap.read() if not ret: break if len (subtitles) > 0: # use binary search to find the subtitle for the current frame which the frame time is between the start and end time of the subtitle frame_subtitle=find_subtitle(subtitles, frame_count, fps) if frame_subtitle and not history_subtitles.get(frame_subtitle,False): subtitle_text_in_interval+=frame_subtitle+" " history_subtitles[frame_subtitle]=True if frame_count % sampling_interval == 0: # raw_frame=frame.copy() frame = transform(frame[:,:,::-1]) # convert to RGB frame = vis_processor(frame) images.append(frame) img_placeholder += '' if subtitle_text_in_interval != "" and number_of_words< max_sub_len: img_placeholder+=f'{subtitle_text_in_interval}' # write the subtitle on the frame # cv2.putText(raw_frame,subtitle_text_in_interval,(10,50),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),2) number_of_words+=len(subtitle_text_in_interval.split(' ')) subtitle_text_in_interval = "" # video_out.write(raw_frame) frame_count += 1 if len(images) >= max_frames: break cap.release() cv2.destroyAllWindows() # video_out.release() if len(images) == 0: # skip the video if no frame is extracted return None,None images = torch.stack(images) return images,img_placeholder def prepare_input(video_path, subtitle_path,instruction): if "mistral" in args.ckpt : max_frames=90 max_sub_len = 800 else: max_frames = 45 max_sub_len = 400 sampling_interval,fps = extract_video_info(video_path, max_frames) subtitles = extract_subtitles(subtitle_path) frames_features,input_placeholder = match_frames_and_subtitles(video_path,subtitles,sampling_interval,max_sub_len,fps,max_frames) input_placeholder+="\n"+instruction return frames_features, input_placeholder def extract_audio(video_path, audio_path): video_clip = mp.VideoFileClip(video_path) audio_clip = video_clip.audio audio_clip.write_audiofile(audio_path, codec="libmp3lame", bitrate="320k") def get_subtitles(video_path) : audio_dir="workspace/inference_subtitles/mp3" subtitle_dir="workspace/inference_subtitles" os.makedirs(subtitle_dir, exist_ok=True) os.makedirs(audio_dir, exist_ok=True) video_id=video_path.split('/')[-1].split('.')[0] audio_path = f"workspace/inference_subtitles/mp3/{video_id}"+'.mp3' subtitle_path = f"{subtitle_dir}/{video_id}"+'.vtt' # if the subtitles are already generated, return the path of the subtitles if os.path.exists(subtitle_path): return f"{subtitle_dir}/{video_id}"+'.vtt' audio_path = f"{audio_dir}/{video_id}"+'.mp3' try: extract_audio(video_path, audio_path) result = whisper_model.transcribe(audio_path,language="en") # Create VTT file with open(subtitle_path, "w", encoding="utf-8") as vtt_file: vtt_file.write("WEBVTT\n\n") for segment in result['segments']: start = format_timestamp(segment['start']) end = format_timestamp(segment['end']) text = segment['text'] vtt_file.write(f"{start} --> {end}\n{text}\n\n") return subtitle_path except Exception as e: print(f"Error during subtitle generation for {video_path}: {e}") return None def stream_answer(generation_kwargs): streamer = TextIteratorStreamer(model.llama_tokenizer, skip_special_tokens=True) generation_kwargs['streamer'] = streamer thread = Thread(target=model_generate, kwargs=generation_kwargs) thread.start() return streamer def escape_markdown(text): # List of Markdown special characters that need to be escaped md_chars = ['<', '>'] # Escape each special character for char in md_chars: text = text.replace(char, '\\' + char) return text def model_generate(*args, **kwargs): # for 8 bit and 16 bit compatibility with model.maybe_autocast(): output = model.llama_model.generate(*args, **kwargs) return output def generate_prediction (video_path,instruction,gen_subtitles=True,stream=True): if gen_subtitles: subtitle_path=get_subtitles(video_path) else : subtitle_path=None prepared_images,prepared_instruction=prepare_input(video_path,subtitle_path,instruction) if prepared_images is None: return "Video cann't be open ,check the video path again" length=len(prepared_images) prepared_images=prepared_images.unsqueeze(0) conv = CONV_VISION.copy() conv.system = "" # if you want to make conversation comment the 2 lines above and make the conv is global variable conv.append_message(conv.roles[0], prepared_instruction) conv.append_message(conv.roles[1], None) prompt = [conv.get_prompt()] # print("prompt",prompt) if stream: generation_kwargs = model.answer_prepare_for_streaming(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=True, lengths=[length],num_beams=1) streamer=stream_answer(generation_kwargs) print("Streamed answer:",end='') for a in streamer: print(a,end='') print() else: setup_seeds(50) answers = model.generate(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=True, lengths=[length],num_beams=1) print("Generated_answer :",answers[0]) def get_arguments(): parser = argparse.ArgumentParser(description="Inference parameters") parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml") parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint") parser.add_argument("--add_subtitles",action= 'store_true',help="whether to add subtitles") parser.add_argument("--stream",action= 'store_true',help="whether to stream the answer") parser.add_argument("--question", type=str, help="question to ask") parser.add_argument("--video_path", type=str, help="Path to the video file") parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens") parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") parser.add_argument( "--options", nargs="+", help="override some settings in the used config, the key-value pair " "in xxx=yyy format will be merged into config file (deprecate), " "change to --cfg-options instead.", ) return parser.parse_args() args=get_arguments() def setup_seeds(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) cudnn.benchmark = False cudnn.deterministic = True import yaml with open('test_configs/llama2_test_config.yaml') as file: config = yaml.load(file, Loader=yaml.FullLoader) seed=config['run']['seed'] print("seed",seed) model, vis_processor,whisper_gpu_id,minigpt4_gpu_id,answer_module_gpu_id = init_model(args) whisper_model=whisper.load_model("large").to(f"cuda:{whisper_gpu_id}") conv = CONV_VISION.copy() conv.system = "" if __name__ == "__main__": video_path=args.video_path instruction=args.question add_subtitles=args.add_subtitles stream=args.stream setup_seeds(50) t1=time.time() generate_prediction(video_path,instruction,gen_subtitles=add_subtitles,stream=stream) print("Time taken for inference",time.time()-t1)