File size: 28,328 Bytes
dc80a97
 
a02a090
 
dc80a97
6ddb8bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc80a97
 
 
a02a090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc80a97
a02a090
 
 
 
dc80a97
 
 
 
a02a090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e68ca85
 
 
 
 
 
 
 
 
 
 
a02a090
1f33751
a02a090
 
1f33751
a02a090
 
1f33751
 
a02a090
 
 
1f33751
a02a090
dc80a97
a02a090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2e886e
 
 
 
6ffc9b1
 
 
 
 
 
 
e2e886e
a02a090
e2e886e
a02a090
 
 
e68ca85
a02a090
 
e2e886e
e68ca85
e2e886e
 
a02a090
 
e2e886e
 
 
 
 
 
 
 
 
 
6ffc9b1
 
 
 
 
 
 
 
 
e2e886e
 
 
 
 
 
6ffc9b1
e2e886e
6ffc9b1
e2e886e
6ffc9b1
 
e2e886e
6ffc9b1
 
 
 
 
 
e2e886e
 
6ffc9b1
e2e886e
 
a02a090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e68ca85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a02a090
 
 
 
 
 
 
 
 
 
 
 
 
4bdf408
a02a090
4bdf408
a02a090
 
 
4bdf408
 
a02a090
 
 
 
 
 
dc80a97
 
a02a090
dc80a97
 
a02a090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc80a97
a02a090
 
dc80a97
a02a090
 
 
 
 
 
 
 
 
e68ca85
 
 
a02a090
 
 
 
 
 
dc80a97
 
 
a02a090
 
dc80a97
4bdf408
a02a090
 
 
 
 
dc80a97
 
a02a090
dc80a97
a02a090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc80a97
 
a02a090
 
 
 
 
 
 
 
 
1f33751
a02a090
e68ca85
 
 
 
 
a02a090
 
 
 
e68ca85
1f33751
a02a090
 
 
 
 
 
 
 
 
 
dc80a97
a02a090
 
 
1f33751
dc80a97
1f33751
a02a090
dc80a97
a02a090
dc80a97
a02a090
1f33751
a02a090
 
1f33751
a02a090
 
1f33751
dc80a97
 
a02a090
 
dc80a97
a02a090
dc80a97
 
a02a090
 
 
 
 
 
 
 
 
1f33751
a02a090
 
 
dc80a97
 
a02a090
 
dc80a97
a02a090
 
 
 
 
 
e68ca85
1f33751
a02a090
 
 
 
 
 
 
1f33751
a02a090
1f33751
dc80a97
1f33751
dc80a97
a02a090
 
1f33751
a02a090
 
 
 
 
 
 
 
 
 
 
 
 
1f33751
 
 
 
a02a090
 
 
dc80a97
a02a090
dc80a97
 
 
a02a090
 
 
dc80a97
 
 
 
 
 
a02a090
 
 
 
 
 
 
dc80a97
a02a090
 
 
8febf87
a02a090
8febf87
 
 
 
dc80a97
e68ca85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc80a97
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
#!/usr/bin/env python3
"""
🎥 Video Content Safety Analysis - MiniGPT4-Video + 巨量引擎规则集成版
基于MiniGPT4-Video的真实视频内容分析 + 巨量引擎299条禁投规则检测
"""

# ZeroGPU装饰器 - 必须在torch等包之前导入!
try:
    import spaces
    GPU_AVAILABLE = True
    print("✅ ZeroGPU spaces 可用")
except ImportError:
    print("⚠️ ZeroGPU spaces 不可用,使用CPU模式")
    GPU_AVAILABLE = False
    # 创建一个空的装饰器
    class spaces:
        @staticmethod
        def GPU(duration=60):
            def decorator(func):
                return func
            return decorator

import os
import gradio as gr
import torch
import gc
import whisper
import argparse
import yaml
import random
import numpy as np
import torch.backends.cudnn as cudnn
from minigpt4.common.eval_utils import init_model
from minigpt4.conversation.conversation import CONV_VISION
import tempfile
import shutil
import cv2
import webvtt
import moviepy.editor as mp
from torchvision import transforms
from datetime import timedelta
from moviepy.editor import VideoFileClip

# 导入巨量引擎禁投规则引擎
from prohibited_rules import ProhibitedRulesEngine

# 设置中国镜像
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

# 全局变量
model = None
vis_processor = None
whisper_model = None
args = None
seed = 42

# 初始化巨量引擎规则引擎
rules_engine = ProhibitedRulesEngine()
print("✅ 巨量引擎299条禁投规则引擎初始化完成")

# ======================== MiniGPT4-Video 核心函数 ========================

def format_timestamp(seconds):
    """格式化时间戳为VTT格式"""
    td = timedelta(seconds=seconds)
    total_seconds = int(td.total_seconds())
    milliseconds = int(td.microseconds / 1000)
    hours, remainder = divmod(total_seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    return f"{hours:02}:{minutes:02}:{seconds:02}.{milliseconds:03}"

def extract_video_info(video_path, max_images_length):
    """提取视频信息"""
    clip = VideoFileClip(video_path)
    total_num_frames = int(clip.duration * clip.fps)
    clip.close()
    sampling_interval = int(total_num_frames / max_images_length)
    if sampling_interval == 0:
        sampling_interval = 1
    return sampling_interval, clip.fps

def time_to_milliseconds(time_str):
    """将时间格式转换为毫秒"""
    h, m, s = map(float, time_str.split(':'))
    return int((h * 3600 + m * 60 + s) * 1000)

def extract_subtitles(subtitle_path):
    """提取字幕"""
    if not subtitle_path or not os.path.exists(subtitle_path):
        return []
    
    subtitles = []
    try:
        for caption in webvtt.read(subtitle_path):
            start_ms = time_to_milliseconds(caption.start)
            end_ms = time_to_milliseconds(caption.end)
            text = caption.text.strip().replace('\n', ' ')
            subtitles.append((start_ms, end_ms, text))
    except:
        return []
    return subtitles

def find_subtitle(subtitles, frame_count, fps):
    """查找对应帧的字幕"""
    if not subtitles:
        return None
        
    frame_time = (frame_count / fps) * 1000
    left, right = 0, len(subtitles) - 1
    
    while left <= right:
        mid = (left + right) // 2
        start, end, subtitle_text = subtitles[mid]
        if start <= frame_time <= end:
            return subtitle_text
        elif frame_time < start:
            right = mid - 1
        else:
            left = mid + 1
    
    return None

def match_frames_and_subtitles(video_path, subtitles, sampling_interval, max_sub_len, fps, max_frames):
    """匹配视频帧和字幕"""
    global vis_processor
    
    cap = cv2.VideoCapture(video_path)
    images = []
    frame_count = 0
    img_placeholder = ""
    subtitle_text_in_interval = ""
    history_subtitles = {}
    number_of_words = 0
    
    transform = transforms.Compose([
        transforms.ToPILImage(),
    ])
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
            
        if len(subtitles) > 0:
            frame_subtitle = find_subtitle(subtitles, frame_count, fps)
            if frame_subtitle and not history_subtitles.get(frame_subtitle, False):
                subtitle_text_in_interval += frame_subtitle + " "
                history_subtitles[frame_subtitle] = True
                
        if frame_count % sampling_interval == 0:
            frame = transform(frame[:,:,::-1])  # 转换为RGB
            frame = vis_processor(frame)
            images.append(frame)
            img_placeholder += '<Img><ImageHere>'
            
            if subtitle_text_in_interval != "" and number_of_words < max_sub_len:
                img_placeholder += f'<Cap>{subtitle_text_in_interval}'
                number_of_words += len(subtitle_text_in_interval.split(' '))
                subtitle_text_in_interval = ""
                
        frame_count += 1
        if len(images) >= max_frames:
            break
            
    cap.release()
    cv2.destroyAllWindows()
    
    if len(images) == 0:
        return None, None
        
    images = torch.stack(images)
    return images, img_placeholder

def extract_audio(video_path, audio_path):
    """提取音频"""
    video_clip = mp.VideoFileClip(video_path)
    audio_clip = video_clip.audio
    audio_clip.write_audiofile(audio_path, codec="libmp3lame", bitrate="320k", verbose=False, logger=None)
    video_clip.close()

def get_subtitles(video_path):
    """生成字幕"""
    global whisper_model
    
    if whisper_model is None:
        return None
        
    audio_dir = "workspace/inference_subtitles/mp3"
    subtitle_dir = "workspace/inference_subtitles"
    os.makedirs(subtitle_dir, exist_ok=True)
    os.makedirs(audio_dir, exist_ok=True)
    
    video_id = video_path.split('/')[-1].split('.')[0]
    audio_path = f"{audio_dir}/{video_id}.mp3"
    subtitle_path = f"{subtitle_dir}/{video_id}.vtt"
    
    # 如果字幕已存在,直接返回
    if os.path.exists(subtitle_path):
        return subtitle_path
        
    try:
        extract_audio(video_path, audio_path)
        # 🔧 优化中文语音识别
        result = whisper_model.transcribe(
            audio_path, 
            language="zh",  # 明确指定中文
            task="transcribe",  # 明确指定转录任务
            temperature=0.0,  # 降低随机性
            best_of=5,  # 使用最佳结果
            beam_size=5,  # 增加beam搜索
            patience=2.0,  # 增加耐心参数
            initial_prompt="以下是一段中文视频的语音内容:"  # 中文提示
        )
        
        # 创建VTT文件
        with open(subtitle_path, "w", encoding="utf-8") as vtt_file:
            vtt_file.write("WEBVTT\n\n")
            for segment in result['segments']:
                start = format_timestamp(segment['start'])
                end = format_timestamp(segment['end'])
                text = segment['text']
                vtt_file.write(f"{start} --> {end}\n{text}\n\n")
                
        return subtitle_path
    except Exception as e:
        print(f"字幕生成错误: {e}")
        return None

def prepare_input(video_path, subtitle_path, instruction):
    """准备输入"""
    global args
    
    # 根据模型设置参数
    if args and "mistral" in args.ckpt:
        max_frames = 90
        max_sub_len = 800
    else:
        max_frames = 45
        max_sub_len = 400
    
    sampling_interval, fps = extract_video_info(video_path, max_frames)
    subtitles = extract_subtitles(subtitle_path)
    frames_features, input_placeholder = match_frames_and_subtitles(
        video_path, subtitles, sampling_interval, max_sub_len, fps, max_frames
    )
    
    if input_placeholder:
        input_placeholder += "\n" + instruction
    else:
        input_placeholder = instruction
        
    return frames_features, input_placeholder

def model_generate(*model_args, **kwargs):
    """模型生成函数"""
    global model
    
    with model.maybe_autocast():
        output = model.llama_model.generate(*model_args, **kwargs)
    return output

def generate_prediction(video_path, instruction, gen_subtitles=True, stream=False):
    """生成预测结果"""
    global model, args, seed
    
    if gen_subtitles:
        subtitle_path = get_subtitles(video_path)
    else:
        subtitle_path = None
        
    prepared_images, prepared_instruction = prepare_input(video_path, subtitle_path, instruction)
    
    if prepared_images is None:
        return "视频无法打开,请检查视频路径"
        
    length = len(prepared_images)
    prepared_images = prepared_images.unsqueeze(0)
    
    conv = CONV_VISION.copy()
    conv.system = ""
    conv.append_message(conv.roles[0], prepared_instruction)
    conv.append_message(conv.roles[1], None)
    prompt = [conv.get_prompt()]
    
    # 设置随机种子
    setup_seeds(seed)
    
    # 🔧 GPU内存优化和cuBLAS错误处理
    if torch.cuda.is_available():
        torch.cuda.empty_cache()  # 清理缓存
        torch.cuda.synchronize()  # 同步GPU操作
        
        # 🚀 H200特定优化
        gpu_name = torch.cuda.get_device_name(0)
        if "H200" in gpu_name:
            # H200额外内存清理
            gc.collect()
            torch.cuda.reset_peak_memory_stats()
    
    try:
        # 🔧 使用更保守的生成参数避免cuBLAS错误
        answers = model.generate(
            prepared_images, 
            prompt, 
            max_new_tokens=512,  # 增加token数以获得更详细的分析
            do_sample=True, 
            lengths=[length],
            num_beams=1,  # 保持beam=1减少计算
            temperature=0.7,  # 稍微降低温度获得更稳定输出
            top_p=0.9,     # 添加top_p参数
            repetition_penalty=1.1  # 避免重复
        )
        return answers[0]
    except RuntimeError as e:
        if "cublasLt" in str(e) or "cuBLAS" in str(e):
            # 🚨 cuBLAS错误特殊处理
            print(f"⚠️ 检测到cuBLAS错误,尝试降级处理: {e}")
            
            # 强制清理GPU内存
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
                gc.collect()
                
                # 🚀 H200特定恢复策略
                gpu_name = torch.cuda.get_device_name(0)
                if "H200" in gpu_name:
                    print("🔧 应用H200特定恢复策略...")
                    torch.cuda.reset_peak_memory_stats()
                    # 临时禁用TF32以避免H200精度问题
                    torch.backends.cuda.matmul.allow_tf32 = False
                    torch.backends.cudnn.allow_tf32 = False
            
            try:
                # 🔧 使用更小的参数重试
                answers = model.generate(
                    prepared_images, 
                    prompt, 
                    max_new_tokens=256,  # 减少token数
                    do_sample=False,     # 关闭采样减少计算
                    lengths=[min(length, 24)],  # 增加一点长度,但不要太多
                    num_beams=1,
                    temperature=1.0,
                    use_cache=False  # H200上禁用缓存
                )
                
                # 🚀 H200恢复TF32设置
                if torch.cuda.is_available() and "H200" in torch.cuda.get_device_name(0):
                    torch.backends.cuda.matmul.allow_tf32 = True
                    torch.backends.cudnn.allow_tf32 = True
                
                return answers[0]
            except Exception as e2:
                return f"GPU运算错误,请重试。H200特定优化已应用。错误信息: {str(e2)}"
        else:
            return f"生成预测时出错: {str(e)}"
    except Exception as e:
        return f"生成预测时出错: {str(e)}"

# ======================== 巨量引擎规则检测函数 ========================

def format_violations_report(violations_result):
    """格式化违规检测报告"""
    if not violations_result["has_violations"]:
        return """
🛡️ **巨量引擎规则检测结果**: ✅ 无违规内容
- 已检测规则: 299条巨量引擎禁投规则
- 检测维度: 低危(P1) + 中危(P2) + 高危(P3)
- 检测结果: 内容符合平台规范
        """
    
    report = f"""
🚨 **巨量引擎规则检测结果**: ⚠️ 发现 {violations_result["total_violations"]} 项违规

📊 **违规统计**:
- 🔴 高危违规(P3): {violations_result["high_risk"]["count"]}
- 🟡 中危违规(P2): {violations_result["medium_risk"]["count"]}
- 🟠 低危违规(P1): {violations_result["low_risk"]["count"]}

📋 **详细违规列表**:
    """
    
    # 按风险等级排序显示违规
    for violation in sorted(violations_result["all_violations"], 
                          key=lambda x: {"P3": 3, "P2": 2, "P1": 1}[x["risk_level"]], 
                          reverse=True):
        risk_icon = {"P3": "🚨", "P2": "⚠️", "P1": "💭"}[violation["risk_level"]]
        report += f"""
{risk_icon} **{violation["risk_level"]} - {violation["category"]}**
   规则: {violation["description"]}
   匹配词: "{violation["matched_keyword"]}"
   规则ID: {violation["rule_id"]}
        """
    
    return report

def get_overall_risk_level(violations_result):
    """获取综合风险等级"""
    if not violations_result["has_violations"]:
        return "✅ P3 (安全) - 内容健康,符合平台规范"
    
    if violations_result["high_risk"]["count"] > 0:
        return f"🚨 P0 (极高危) - 发现 {violations_result['high_risk']['count']} 项高危违规,禁止投放"
    elif violations_result["medium_risk"]["count"] > 2:
        return f"⚠️ P1 (高危) - 发现 {violations_result['medium_risk']['count']} 项中危违规,需严格审核"
    elif violations_result["medium_risk"]["count"] > 0:
        return f"⚠️ P1 (中危) - 发现 {violations_result['medium_risk']['count']} 项中危违规,需要审核"
    else:
        return f"⚡ P2 (低危) - 发现 {violations_result['low_risk']['count']} 项低危违规,建议关注"

# ======================== 应用主要函数 ========================

def setup_seeds(seed):
    """设置随机种子"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    cudnn.benchmark = False
    cudnn.deterministic = True

def optimize_gpu_memory():
    """GPU内存优化"""
    print("🔍 开始GPU内存优化...")
    
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        print(f"🔍 GPU: {gpu_name}")
        
        # 🔧 H200特定优化
        if "H200" in gpu_name:
            print("🚀 检测到H200显卡,应用特定优化...")
            # H200优化设置
            os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,garbage_collection_threshold:0.8,expandable_segments:True'
            os.environ['CUDA_LAUNCH_BLOCKING'] = '0'  # H200上设置为0
            os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # H200 cuBLAS优化
            os.environ['NCCL_AVOID_RECORD_STREAMS'] = '1'  # 避免H200内存问题
            
            # 设置混合精度优化
            torch.backends.cudnn.allow_tf32 = True  # 启用TF32提升H200性能
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.benchmark = True  # H200上启用benchmark
            
        else:
            # 标准设置(A100等)
            os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256,garbage_collection_threshold:0.6'
            os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
        
        print(f"💾 总显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        
        # 强制清理所有GPU缓存
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        gc.collect()
        
        print(f"💾 清理后可用显存: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / 1024**3:.1f} GB")

def get_arguments():
    """获取参数配置"""
    parser = argparse.ArgumentParser(description="MiniGPT4-Video参数")
    parser.add_argument("--cfg-path", help="配置文件路径", 
                       default="test_configs/mistral_test_config.yaml")  # 使用mistral配置
    parser.add_argument("--ckpt", type=str, 
                       default='checkpoints/video_mistral_checkpoint_last.pth',  # 使用mistral checkpoint
                       help="模型检查点路径")
    parser.add_argument("--max_new_tokens", type=int, default=512, 
                       help="最大生成token数")
    parser.add_argument("--lora_r", type=int, default=64, help="LoRA rank")  # 修改为64匹配checkpoint
    parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha")  # 修改为16匹配checkpoint
    parser.add_argument("--options", nargs="+", help="覆盖配置选项")
    return parser.parse_args()

def load_minigpt4_model():
    """加载MiniGPT4-Video模型"""
    global model, vis_processor, whisper_model, args, seed
    
    if model is not None:
        return model, vis_processor, whisper_model
    
    try:
        print("🔄 正在加载MiniGPT4-Video模型...")
        
        # 获取参数
        args = get_arguments()
        
        # 加载配置
        config_path = args.cfg_path
        if not os.path.exists(config_path):
            config_path = "test_configs/llama2_test_config.yaml"  # 回退到默认配置
            
        with open(config_path) as file:
            config = yaml.load(file, Loader=yaml.FullLoader)
        
        seed = config['run']['seed']
        setup_seeds(seed)
        
        # GPU内存优化
        optimize_gpu_memory()
        
        print("🚀 开始初始化MiniGPT4-Video模型...")
        model, vis_processor, whisper_gpu_id, minigpt4_gpu_id, answer_module_gpu_id = init_model(args)
        
        # 清理缓存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"💾 模型加载后显存使用: {torch.cuda.memory_allocated(0) / 1024**3:.1f} GB")
        
        print("🚀 开始初始化Whisper模型...")
        # 🔧 使用更强的Whisper模型以提升中文识别
        whisper_model = whisper.load_model("medium").to(f"cuda:{whisper_gpu_id}" if torch.cuda.is_available() else "cpu")
        print("✅ Whisper模型加载完成 (medium版本,优化中文识别)")
        
        if torch.cuda.is_available():
            print(f"💾 全部加载后显存使用: {torch.cuda.memory_allocated(0) / 1024**3:.1f} GB")
        
        print("✅ 所有模型加载完成!")
        return model, vis_processor, whisper_model
        
    except Exception as e:
        print(f"❌ 模型加载失败: {e}")
        print("🔄 回退到模拟模式...")
        return None, None, None

@spaces.GPU(duration=600)  # 增加到10分钟以支持模型下载
def analyze_video_with_minigpt4(video_file, instruction):
    """使用MiniGPT4-Video分析视频内容并进行巨量引擎规则检测"""
    if video_file is None:
        return "❌ 请上传视频文件", "无法评估"
    
    try:
        # 加载模型
        model_loaded, vis_proc, whisper_loaded = load_minigpt4_model()
        
        if model_loaded is None:
            # 模拟模式
            return f"""
🎬 **视频内容分析结果 (模拟模式)**

📋 **基本信息**:
- 视频文件: {video_file}
- 分析指令: {instruction}

⚠️ **注意**: 当前运行在模拟模式,真实模型加载失败
请检查模型文件和配置是否正确

🛡️ **巨量引擎规则检测**: 仅在真实模式下可用
            """, "⚠️ 模拟模式"
        
        print(f"🔄 开始分析视频: {video_file}")
        print(f"📝 分析指令: {instruction}")
        
        # 复制视频到临时路径(如果需要)
        temp_video_path = video_file
        if not os.path.exists(video_file):
            # 如果是Gradio的临时文件,复制到工作目录
            temp_dir = "workspace/tmp"
            os.makedirs(temp_dir, exist_ok=True)
            temp_video_path = os.path.join(temp_dir, "analysis_video.mp4")
            shutil.copy2(video_file, temp_video_path)
        
        # 使用MiniGPT4-Video进行真实分析
        if not instruction or instruction.strip() == "":
            instruction = "请详细分析这个视频的内容,包括场景、人物、动作、对话等。请用中文输出,并详细记录视频中谁说了什么话。"
        
        # 🧠 使用智能规则感知指令
        intelligent_instruction = create_intelligent_instruction(instruction)
        print(f"🧠 使用智能规则感知指令进行分析...")
        
        # 调用MiniGPT4-Video的生成函数
        prediction = generate_prediction(
            video_path=temp_video_path,
            instruction=intelligent_instruction,  # 使用智能指令
            gen_subtitles=True,  # 生成字幕
            stream=False
        )
        
        # 🚨 巨量引擎规则检测 🚨
        print("🔍 开始巨量引擎299条规则检测...")
        violations_result = rules_engine.check_all_content(prediction, instruction)
        
        # 格式化完整分析报告
        enhanced_result = f"""
🎬 **MiniGPT4-Video 视频内容分析 + 巨量引擎规则检测报告**

📋 **基本信息**:
- 视频文件: {os.path.basename(video_file)}
- 分析设备: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU模式'}
- 分析指令: {instruction}

🔍 **视频内容描述**:
{prediction}

{format_violations_report(violations_result)}

📊 **技术信息**:
- 内容理解: MiniGPT4-Video + Whisper  
- 规则引擎: 巨量引擎299条禁投规则
- 检测等级: P1(低危) + P2(中危) + P3(高危)
- 分析模式: 多模态理解 (视觉+语音+文本)

💡 **说明**: 
基于MiniGPT4-Video的深度内容理解,结合巨量引擎完整禁投规则库进行专业违规检测。
        """
        
        # 获取综合风险等级
        safety_score = get_overall_risk_level(violations_result)
        
        return enhanced_result, safety_score
        
    except Exception as e:
        error_msg = f"""
❌ **分析过程中出错**

错误信息: {str(e)}

🔄 **可能的解决方案**:
1. 检查视频文件格式 (建议MP4)
2. 确认模型文件是否正确加载
3. 检查GPU内存是否充足
4. 验证配置文件路径

💡 **提示**: 如果问题持续,请检查模型和依赖项安装
        """
        return error_msg, "⚠️ 错误"

def create_app():
    """创建Gradio应用"""
    
    interface = gr.Interface(
        fn=analyze_video_with_minigpt4,
        inputs=[
            gr.Video(label="上传视频文件"),
            gr.Textbox(
                label="分析指令", 
                value="请详细分析这个视频的内容,包括场景、人物、动作、对话等。请用中文输出,并详细记录视频中谁说了什么话。",
                placeholder="输入您希望AI如何分析这个视频...",
                lines=3
            )
        ],
        outputs=[
            gr.Textbox(label="MiniGPT4-Video 内容分析 + 巨量引擎规则检测", lines=20),
            gr.Textbox(label="巨量引擎风险评级")
        ],
        title="🎥 智能视频内容安全分析 - MiniGPT4-Video + 巨量引擎",
        description="""
        ## 🎬 基于MiniGPT4-Video + 巨量引擎299条禁投规则的专业视频安全检测系统
        
        ⚡ **ZeroGPU加速** | 🎬 **MiniGPT4-Video** | 🎙️ **Whisper语音** | 🛡️ **巨量引擎299条规则**
        
        **🔥 核心功能:**
        - 🎞️ **深度视频理解**: MiniGPT4-Video多模态分析
        - 🎙️ **语音转文字**: Whisper自动生成字幕
        - 🛡️ **专业违规检测**: 巨量引擎完整禁投规则库
        - 📊 **智能风险评级**: P0-P3四级风险等级
        
        **🎯 检测维度:**
        - **高危(P3)**: 违法出版物、烟草、医疗等严重违规 
        - **中危(P2)**: 赌博周边、房地产、金融等中等风险
        - **低危(P1)**: 化妆品、汽车、游戏等轻微风险
        
        **📋 规则覆盖:**
        涵盖化妆品类、汽车类、游戏类、赌博类、房地产类、工具软件类、教育培训类、
        金融类、医疗类、烟草类等全部299条巨量引擎禁投规则
        """,
        examples=[
            [None, "分析这个视频是否包含禁投内容"],
            [None, "检测视频中是否有巨量引擎禁止的产品或服务"],
            [None, "评估视频内容的投放风险等级"],
            [None, "详细描述视频内容并进行合规检查"]
        ],
        cache_examples=False
    )
    
    return interface

def main():
    """主函数"""
    print("🚀 启动MiniGPT4-Video + 巨量引擎视频安全分析应用")
    print("🎬 MiniGPT4-Video: 深度视频内容理解")
    print("🛡️ 巨量引擎: 299条禁投规则检测")
    
    if torch.cuda.is_available():
        print(f"✅ GPU可用: {torch.cuda.get_device_name(0)}")
    else:
        print("⚠️ 使用CPU模式")
    
    # 创建必要的目录
    os.makedirs("workspace/tmp", exist_ok=True)
    os.makedirs("workspace/inference_subtitles", exist_ok=True)
    os.makedirs("workspace/inference_subtitles/mp3", exist_ok=True)
    
    print("📁 工作目录准备完成")
    print("🚀 正在启动Gradio应用...")
    
    app = create_app()
    
    # 启动应用
    app.launch(
        share=True,
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True
    )

def create_intelligent_instruction(original_instruction):
    """创建具备规则理解能力的智能分析指令"""
    
    # 核心禁投规则摘要 - 让AI知道需要检测什么
    rules_summary = """
请特别注意以下巨量引擎禁投内容(如发现请在描述中明确指出):

🚨 **高危违规内容 (P3)**:
- 医疗器械、药品、保健品、医美服务
- 烟草制品、电子烟相关产品  
- 虚拟货币、区块链、NFT、数字藏品
- 违法出版物、政治敏感内容
- 贷款、信贷、金融投资、股票
- 赌博、博彩、棋牌游戏

⚠️ **中危违规内容 (P2)**:
- 房地产买卖、租赁、中介服务
- 工具软件、刷机、破解软件
- 教育培训、学历提升、考试代办
- 翡翠、玉石、文玩、珠宝盲盒
- 黄金回收、贵金属投资

💭 **低危违规内容 (P1)**:
- 化妆品中的特殊功效产品
- 汽车修复、代办服务
- 游戏账号交易、代练
- 特殊食品、减肥产品
"""

    intelligent_instruction = f"""
你是专业的巨量引擎广告内容审核专家。请用中文详细分析这个视频,包括:

📹 **视频内容详细描述**:
- 场景环境:描述视频拍摄场所、背景环境
- 人物信息:谁在视频中出现,年龄、性别、穿着特征  
- 关键动作:详细描述人物的具体动作和行为
- 产品展示:如有产品展示,请详细描述产品外观、材质、用途
- 文字信息:视频中出现的任何文字、标识、品牌名称

🎙️ **语音对话内容**:
- 详细记录视频中的所有对话内容
- 明确标注"谁说了什么话"
- 记录任何产品介绍、价格信息、功效宣传
- 注意推销话术、营销用语

🔍 **潜在违规风险分析**:
{rules_summary}

🎯 **分析要求**:
1. 用中文输出所有内容
2. 对于任何可能涉及上述违规内容的元素,请明确指出
3. 重点关注翡翠、玉石、珠宝等文玩制品
4. 注意医疗、金融、房产、教育等敏感行业
5. 记录所有营销宣传语句

原始指令:{original_instruction}
"""
    
    return intelligent_instruction

if __name__ == "__main__":
    main()