Spaces:
Runtime error
Runtime error
yjernite
commited on
Commit
·
8df4211
1
Parent(s):
427730e
template for examplars
Browse files
app.py
CHANGED
@@ -133,7 +133,10 @@ def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
|
|
133 |
.to_html()
|
134 |
)
|
135 |
|
136 |
-
|
|
|
|
|
|
|
137 |
|
138 |
with gr.Blocks() as demo:
|
139 |
gr.Markdown("# 🤗 Diffusion Cluster Explorer")
|
@@ -183,31 +186,56 @@ with gr.Blocks() as demo:
|
|
183 |
# with gr.Accordion("Tag Frequencies", open=False):
|
184 |
|
185 |
with gr.Tab("Profession Focus"):
|
186 |
-
with gr.Row():
|
187 |
-
num_clusters = gr.Radio(
|
188 |
-
[12, 24, 48],
|
189 |
-
value=12,
|
190 |
-
label="How many clusters do you want to use to represent identities?",
|
191 |
-
)
|
192 |
with gr.Row():
|
193 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
profession_choice_focus = gr.Dropdown(
|
195 |
choices=professions,
|
196 |
value="social worker",
|
197 |
label="Select profession:",
|
198 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
with gr.Column():
|
200 |
plot = gr.Plot(
|
201 |
label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
|
202 |
)
|
203 |
-
profession_choice_focus
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
queue=False,
|
208 |
)
|
209 |
-
with gr.Row():
|
210 |
-
gr.Markdown("TODO: show examplars for cluster")
|
211 |
|
212 |
|
213 |
demo.launch()
|
|
|
133 |
.to_html()
|
134 |
)
|
135 |
|
136 |
+
def show_examplars(num_clusters, prof_name, mod_name, cl_id):
|
137 |
+
# TODO: show the actual images
|
138 |
+
examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name]["cluster_examplars"][str(cl_id)]
|
139 |
+
return json.dumps(examplars_dict)
|
140 |
|
141 |
with gr.Blocks() as demo:
|
142 |
gr.Markdown("# 🤗 Diffusion Cluster Explorer")
|
|
|
186 |
# with gr.Accordion("Tag Frequencies", open=False):
|
187 |
|
188 |
with gr.Tab("Profession Focus"):
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
with gr.Row():
|
190 |
with gr.Column():
|
191 |
+
gr.Markdown("Select profession to visualize here:")
|
192 |
+
num_clusters_focus = gr.Radio(
|
193 |
+
[12, 24, 48],
|
194 |
+
value=12,
|
195 |
+
label="How many clusters do you want to use to represent identities?",
|
196 |
+
)
|
197 |
profession_choice_focus = gr.Dropdown(
|
198 |
choices=professions,
|
199 |
value="social worker",
|
200 |
label="Select profession:",
|
201 |
)
|
202 |
+
gr.Markdown("You can show examples of profession images assigned to each cluster:")
|
203 |
+
model_choices_focus = gr.Dropdown(
|
204 |
+
[
|
205 |
+
"All Models",
|
206 |
+
"Stable Diffusion 1.4",
|
207 |
+
"Stable Diffusion 2",
|
208 |
+
"Dall-E 2",
|
209 |
+
],
|
210 |
+
value="All Models",
|
211 |
+
label="Select generation model:",
|
212 |
+
interactive=True,
|
213 |
+
)
|
214 |
+
cluster_id_focus = gr.Dropdown(
|
215 |
+
choices=[i for i in range(num_clusters_focus.value)],
|
216 |
+
value=0,
|
217 |
+
label="Select cluster to visualize:",
|
218 |
+
)
|
219 |
with gr.Column():
|
220 |
plot = gr.Plot(
|
221 |
label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
|
222 |
)
|
223 |
+
for var in [num_clusters_focus, profession_choice_focus]:
|
224 |
+
var.change(
|
225 |
+
make_profession_plot,
|
226 |
+
[num_clusters_focus, profession_choice_focus],
|
227 |
+
plot,
|
228 |
+
queue=False,
|
229 |
+
)
|
230 |
+
with gr.Row():
|
231 |
+
examplars_plot = gr.JSON() # TODO: turn this into a plot with the actual images
|
232 |
+
for var in [model_choices_focus, cluster_id_focus]:
|
233 |
+
var.change(
|
234 |
+
show_examplars,
|
235 |
+
[num_clusters_focus, profession_choice_focus, model_choices_focus, cluster_id_focus],
|
236 |
+
examplars_plot,
|
237 |
queue=False,
|
238 |
)
|
|
|
|
|
239 |
|
240 |
|
241 |
demo.launch()
|