yjernite commited on
Commit
8df4211
·
1 Parent(s): 427730e

template for examplars

Browse files
Files changed (1) hide show
  1. app.py +41 -13
app.py CHANGED
@@ -133,7 +133,10 @@ def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
133
  .to_html()
134
  )
135
 
136
-
 
 
 
137
 
138
  with gr.Blocks() as demo:
139
  gr.Markdown("# 🤗 Diffusion Cluster Explorer")
@@ -183,31 +186,56 @@ with gr.Blocks() as demo:
183
  # with gr.Accordion("Tag Frequencies", open=False):
184
 
185
  with gr.Tab("Profession Focus"):
186
- with gr.Row():
187
- num_clusters = gr.Radio(
188
- [12, 24, 48],
189
- value=12,
190
- label="How many clusters do you want to use to represent identities?",
191
- )
192
  with gr.Row():
193
  with gr.Column():
 
 
 
 
 
 
194
  profession_choice_focus = gr.Dropdown(
195
  choices=professions,
196
  value="social worker",
197
  label="Select profession:",
198
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  with gr.Column():
200
  plot = gr.Plot(
201
  label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
202
  )
203
- profession_choice_focus.change(
204
- make_profession_plot,
205
- [num_clusters, profession_choice_focus],
206
- plot,
 
 
 
 
 
 
 
 
 
 
207
  queue=False,
208
  )
209
- with gr.Row():
210
- gr.Markdown("TODO: show examplars for cluster")
211
 
212
 
213
  demo.launch()
 
133
  .to_html()
134
  )
135
 
136
+ def show_examplars(num_clusters, prof_name, mod_name, cl_id):
137
+ # TODO: show the actual images
138
+ examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name]["cluster_examplars"][str(cl_id)]
139
+ return json.dumps(examplars_dict)
140
 
141
  with gr.Blocks() as demo:
142
  gr.Markdown("# 🤗 Diffusion Cluster Explorer")
 
186
  # with gr.Accordion("Tag Frequencies", open=False):
187
 
188
  with gr.Tab("Profession Focus"):
 
 
 
 
 
 
189
  with gr.Row():
190
  with gr.Column():
191
+ gr.Markdown("Select profession to visualize here:")
192
+ num_clusters_focus = gr.Radio(
193
+ [12, 24, 48],
194
+ value=12,
195
+ label="How many clusters do you want to use to represent identities?",
196
+ )
197
  profession_choice_focus = gr.Dropdown(
198
  choices=professions,
199
  value="social worker",
200
  label="Select profession:",
201
  )
202
+ gr.Markdown("You can show examples of profession images assigned to each cluster:")
203
+ model_choices_focus = gr.Dropdown(
204
+ [
205
+ "All Models",
206
+ "Stable Diffusion 1.4",
207
+ "Stable Diffusion 2",
208
+ "Dall-E 2",
209
+ ],
210
+ value="All Models",
211
+ label="Select generation model:",
212
+ interactive=True,
213
+ )
214
+ cluster_id_focus = gr.Dropdown(
215
+ choices=[i for i in range(num_clusters_focus.value)],
216
+ value=0,
217
+ label="Select cluster to visualize:",
218
+ )
219
  with gr.Column():
220
  plot = gr.Plot(
221
  label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
222
  )
223
+ for var in [num_clusters_focus, profession_choice_focus]:
224
+ var.change(
225
+ make_profession_plot,
226
+ [num_clusters_focus, profession_choice_focus],
227
+ plot,
228
+ queue=False,
229
+ )
230
+ with gr.Row():
231
+ examplars_plot = gr.JSON() # TODO: turn this into a plot with the actual images
232
+ for var in [model_choices_focus, cluster_id_focus]:
233
+ var.change(
234
+ show_examplars,
235
+ [num_clusters_focus, profession_choice_focus, model_choices_focus, cluster_id_focus],
236
+ examplars_plot,
237
  queue=False,
238
  )
 
 
239
 
240
 
241
  demo.launch()