PSNbst commited on
Commit
96db9b0
·
verified ·
1 Parent(s): b9a151f

Create batch-app.py

Browse files
Files changed (1) hide show
  1. batch-app.py +174 -0
batch-app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from PIL import Image, ImageChops, ImageFilter
4
+ from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+
8
+ # 初始化模型
9
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
10
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
11
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
13
+
14
+ # 图像处理函数
15
+ def compute_difference_images(img_a, img_b):
16
+ def extract_sketch(image):
17
+ grayscale = image.convert("L")
18
+ inverted = ImageChops.invert(grayscale)
19
+ sketch = ImageChops.screen(grayscale, inverted)
20
+ return sketch
21
+
22
+ def compute_normal_map(image):
23
+ edges = image.filter(ImageFilter.FIND_EDGES)
24
+ return edges
25
+
26
+ diff_overlay = ImageChops.difference(img_a, img_b)
27
+ return {
28
+ "original_a": img_a,
29
+ "original_b": img_b,
30
+ "sketch_a": extract_sketch(img_a),
31
+ "sketch_b": extract_sketch(img_b),
32
+ "normal_a": compute_normal_map(img_a),
33
+ "normal_b": compute_normal_map(img_b),
34
+ "diff_overlay": diff_overlay
35
+ }
36
+
37
+ # 保存图像到文件
38
+ def save_images(images):
39
+ paths = []
40
+ for key, img in images.items():
41
+ path = f"{key}.png"
42
+ img.save(path)
43
+ paths.append((path, key.replace("_", " ").capitalize()))
44
+ return paths
45
+
46
+ # BLIP生成更详尽描述
47
+ def generate_detailed_caption(image):
48
+ inputs = blip_processor(image, return_tensors="pt")
49
+ caption = blip_model.generate(**inputs, max_length=128, num_beams=5, no_repeat_ngram_size=2)
50
+ return blip_processor.decode(caption[0], skip_special_tokens=True)
51
+
52
+ # 特征差异可视化
53
+ def plot_feature_differences(latent_diff):
54
+ diff_magnitude = [abs(x) for x in latent_diff[0]]
55
+ indices = range(len(diff_magnitude))
56
+
57
+ plt.figure(figsize=(8, 4))
58
+ plt.bar(indices, diff_magnitude, alpha=0.7)
59
+ plt.xlabel("Feature Index")
60
+ plt.ylabel("Magnitude of Difference")
61
+ plt.title("Feature Differences (Bar Chart)")
62
+ bar_chart_path = "bar_chart.png"
63
+ plt.savefig(bar_chart_path)
64
+ plt.close()
65
+
66
+ plt.figure(figsize=(6, 6))
67
+ plt.pie(diff_magnitude[:10], labels=range(10), autopct="%1.1f%%", startangle=140)
68
+ plt.title("Top 10 Feature Differences (Pie Chart)")
69
+ pie_chart_path = "pie_chart.png"
70
+ plt.savefig(pie_chart_path)
71
+ plt.close()
72
+
73
+ return bar_chart_path, pie_chart_path
74
+
75
+ # 生成详细分析
76
+ def generate_text_analysis(api_key, api_type, caption_a, caption_b):
77
+ import openai
78
+
79
+ if api_type == "DeepSeek":
80
+ from openai import OpenAI
81
+ client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
82
+ else:
83
+ client = openai
84
+
85
+ response = client.ChatCompletion.create(
86
+ model="gpt-4" if api_type == "GPT" else "deepseek-chat",
87
+ messages=[
88
+ {"role": "system", "content": "You are a helpful assistant."},
89
+ {"role": "user", "content": f"图片A的描述为:{caption_a}。图片B的描述为:{caption_b}。\n请对两张图片的内容和潜在特征区别进行详细分析,并输出一个简洁但富有条理的总结。"}
90
+ ]
91
+ )
92
+ return response['choices'][0]['message']['content'].strip()
93
+
94
+ # 分析函数
95
+ def analyze_images(img_a, img_b, api_key, api_type):
96
+ images_diff = compute_difference_images(img_a, img_b)
97
+ saved_images = save_images(images_diff)
98
+
99
+ caption_a = generate_detailed_caption(img_a)
100
+ caption_b = generate_detailed_caption(img_b)
101
+
102
+ inputs = clip_processor(images=img_a, return_tensors="pt")
103
+ features_a = clip_model.get_image_features(**inputs).detach().numpy()
104
+
105
+ inputs = clip_processor(images=img_b, return_tensors="pt")
106
+ features_b = clip_model.get_image_features(**inputs).detach().numpy()
107
+
108
+ latent_diff = np.abs(features_a - features_b).tolist()
109
+
110
+ bar_chart, pie_chart = plot_feature_differences(latent_diff)
111
+ text_analysis = generate_text_analysis(api_key, api_type, caption_a, caption_b)
112
+
113
+ return {
114
+ "saved_images": saved_images,
115
+ "caption_a": caption_a,
116
+ "caption_b": caption_b,
117
+ "text_analysis": text_analysis,
118
+ "bar_chart": bar_chart,
119
+ "pie_chart": pie_chart
120
+ }
121
+
122
+ # 批量分析
123
+ def batch_analyze(folder_a, folder_b, api_key, api_type):
124
+ def load_images(folder_path):
125
+ files = sorted([os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
126
+ return [Image.open(f).convert("RGB") for f in files]
127
+
128
+ images_a = load_images(folder_a)
129
+ images_b = load_images(folder_b)
130
+ num_pairs = min(len(images_a), len(images_b))
131
+
132
+ results = []
133
+ for i in range(num_pairs):
134
+ result = analyze_images(images_a[i], images_b[i], api_key, api_type)
135
+ results.append({
136
+ "pair": (f"Image A-{i+1}", f"Image B-{i+1}"),
137
+ **result
138
+ })
139
+ return results
140
+
141
+ # Gradio界面
142
+ with gr.Blocks() as demo:
143
+ gr.Markdown("# 批量图像对比分析工具")
144
+
145
+ api_key_input = gr.Textbox(label="API Key", placeholder="输入您的 API Key", type="password")
146
+ api_type_input = gr.Dropdown(label="API 类型", choices=["GPT", "DeepSeek"], value="GPT")
147
+ folder_a_input = gr.Textbox(label="文件夹A路径", placeholder="输入包含图片A的文件夹路径")
148
+ folder_b_input = gr.Textbox(label="文件夹B路径", placeholder="输入包含图片B的文件夹路径")
149
+ analyze_button = gr.Button("开始批量分析")
150
+
151
+ with gr.Row():
152
+ result_gallery = gr.Gallery(label="差异图像").style(grid=3)
153
+ result_text_analysis = gr.Textbox(label="详细分析", interactive=False, lines=5)
154
+
155
+ def process_batch_analysis(folder_a, folder_b, api_key, api_type):
156
+ results = batch_analyze(folder_a, folder_b, api_key, api_type)
157
+ all_images = []
158
+ all_texts = []
159
+
160
+ for result in results:
161
+ all_images.extend(result["saved_images"])
162
+ all_images.append((result["bar_chart"], "Bar Chart"))
163
+ all_images.append((result["pie_chart"], "Pie Chart"))
164
+ all_texts.append(f"{result['pair'][0]} vs {result['pair'][1]}:\n{result['text_analysis']}")
165
+
166
+ return all_images, "\n\n".join(all_texts)
167
+
168
+ analyze_button.click(
169
+ fn=process_batch_analysis,
170
+ inputs=[folder_a_input, folder_b_input, api_key_input, api_type_input],
171
+ outputs=[result_gallery, result_text_analysis]
172
+ )
173
+
174
+ demo.launch()