Johannes Kolbe
commited on
Commit
·
7db1e87
1
Parent(s):
ed6b6d6
added better functionality
Browse files- README.md +2 -2
- app.py +14 -4
- interface.py +17 -4
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
title: Sefa
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: streamlit
|
| 7 |
sdk_version: 1.2.0
|
| 8 |
app_file: app.py
|
|
|
|
| 1 |
---
|
| 2 |
title: Sefa
|
| 3 |
+
emoji: 🔮
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: streamlit
|
| 7 |
sdk_version: 1.2.0
|
| 8 |
app_file: app.py
|
app.py
CHANGED
|
@@ -54,11 +54,16 @@ def synthesize(model, gan_type, code):
|
|
| 54 |
image = postprocess(image)[0]
|
| 55 |
return image
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
"""Main function (loop for StreamLit)."""
|
| 59 |
st.title('Closed-Form Factorization of Latent Semantics in GANs')
|
| 60 |
st.sidebar.title('Options')
|
| 61 |
-
|
| 62 |
|
| 63 |
model_name = st.sidebar.selectbox(
|
| 64 |
'Model to Interpret',
|
|
@@ -72,7 +77,7 @@ layer_idx = st.sidebar.selectbox(
|
|
| 72 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
| 73 |
|
| 74 |
num_semantics = st.sidebar.number_input(
|
| 75 |
-
'Number of semantics', value=5, min_value=0, max_value=None, step=1)
|
| 76 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
| 77 |
if gan_type == 'pggan':
|
| 78 |
max_step = 5.0
|
|
@@ -87,10 +92,12 @@ for sem_idx in steps:
|
|
| 87 |
value=0.0,
|
| 88 |
min_value=-max_step,
|
| 89 |
max_value=max_step,
|
| 90 |
-
step=0.04 * max_step
|
|
|
|
| 91 |
|
| 92 |
image_placeholder = st.empty()
|
| 93 |
button_placeholder = st.empty()
|
|
|
|
| 94 |
|
| 95 |
try:
|
| 96 |
base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
|
|
@@ -105,13 +112,16 @@ if state.model_name != model_name:
|
|
| 105 |
state.code_idx = 0
|
| 106 |
state.codes = base_codes[0:1]
|
| 107 |
|
| 108 |
-
if button_placeholder.button('
|
| 109 |
state.code_idx += 1
|
| 110 |
if state.code_idx < base_codes.shape[0]:
|
| 111 |
state.codes = base_codes[state.code_idx][np.newaxis]
|
| 112 |
else:
|
| 113 |
state.codes = sample(model, gan_type)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
| 115 |
code = state.codes.copy()
|
| 116 |
for sem_idx, step in steps.items():
|
| 117 |
if gan_type == 'pggan':
|
|
|
|
| 54 |
image = postprocess(image)[0]
|
| 55 |
return image
|
| 56 |
|
| 57 |
+
def _update_slider():
|
| 58 |
+
num_semantics = st.session_state["num_semantics"]
|
| 59 |
+
for sem_idx in range(num_semantics):
|
| 60 |
+
st.session_state[f"semantic_slider_{sem_idx}"] = 0
|
| 61 |
+
|
| 62 |
|
| 63 |
"""Main function (loop for StreamLit)."""
|
| 64 |
st.title('Closed-Form Factorization of Latent Semantics in GANs')
|
| 65 |
st.sidebar.title('Options')
|
| 66 |
+
st.sidebar.button('Reset', on_click=_update_slider, kwargs={})
|
| 67 |
|
| 68 |
model_name = st.sidebar.selectbox(
|
| 69 |
'Model to Interpret',
|
|
|
|
| 77 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
| 78 |
|
| 79 |
num_semantics = st.sidebar.number_input(
|
| 80 |
+
'Number of semantics', value=5, min_value=0, max_value=None, step=1, key="num_semantics")
|
| 81 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
| 82 |
if gan_type == 'pggan':
|
| 83 |
max_step = 5.0
|
|
|
|
| 92 |
value=0.0,
|
| 93 |
min_value=-max_step,
|
| 94 |
max_value=max_step,
|
| 95 |
+
step=0.04 * max_step,
|
| 96 |
+
key=f"semantic_slider_{sem_idx}")
|
| 97 |
|
| 98 |
image_placeholder = st.empty()
|
| 99 |
button_placeholder = st.empty()
|
| 100 |
+
button_totally_random = st.empty()
|
| 101 |
|
| 102 |
try:
|
| 103 |
base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
|
|
|
|
| 112 |
state.code_idx = 0
|
| 113 |
state.codes = base_codes[0:1]
|
| 114 |
|
| 115 |
+
if button_placeholder.button('Next Sample'):
|
| 116 |
state.code_idx += 1
|
| 117 |
if state.code_idx < base_codes.shape[0]:
|
| 118 |
state.codes = base_codes[state.code_idx][np.newaxis]
|
| 119 |
else:
|
| 120 |
state.codes = sample(model, gan_type)
|
| 121 |
|
| 122 |
+
if button_totally_random.button('Totally Random'):
|
| 123 |
+
state.codes = sample(model, gan_type)
|
| 124 |
+
|
| 125 |
code = state.codes.copy()
|
| 126 |
for sem_idx, step in steps.items():
|
| 127 |
if gan_type == 'pggan':
|
interface.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
# python 3.7
|
| 2 |
"""Demo."""
|
|
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
|
@@ -55,11 +56,18 @@ def synthesize(model, gan_type, code):
|
|
| 55 |
return image
|
| 56 |
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
def main():
|
| 59 |
"""Main function (loop for StreamLit)."""
|
|
|
|
| 60 |
st.title('Closed-Form Factorization of Latent Semantics in GANs')
|
| 61 |
st.sidebar.title('Options')
|
| 62 |
-
|
| 63 |
|
| 64 |
model_name = st.sidebar.selectbox(
|
| 65 |
'Model to Interpret',
|
|
@@ -73,7 +81,7 @@ def main():
|
|
| 73 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
| 74 |
|
| 75 |
num_semantics = st.sidebar.number_input(
|
| 76 |
-
'Number of semantics', value=5, min_value=0, max_value=None, step=1)
|
| 77 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
| 78 |
if gan_type == 'pggan':
|
| 79 |
max_step = 5.0
|
|
@@ -88,10 +96,12 @@ def main():
|
|
| 88 |
value=0.0,
|
| 89 |
min_value=-max_step,
|
| 90 |
max_value=max_step,
|
| 91 |
-
step=0.04 * max_step
|
|
|
|
| 92 |
|
| 93 |
image_placeholder = st.empty()
|
| 94 |
button_placeholder = st.empty()
|
|
|
|
| 95 |
|
| 96 |
try:
|
| 97 |
base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
|
|
@@ -106,13 +116,16 @@ def main():
|
|
| 106 |
state.code_idx = 0
|
| 107 |
state.codes = base_codes[0:1]
|
| 108 |
|
| 109 |
-
if button_placeholder.button('
|
| 110 |
state.code_idx += 1
|
| 111 |
if state.code_idx < base_codes.shape[0]:
|
| 112 |
state.codes = base_codes[state.code_idx][np.newaxis]
|
| 113 |
else:
|
| 114 |
state.codes = sample(model, gan_type)
|
| 115 |
|
|
|
|
|
|
|
|
|
|
| 116 |
code = state.codes.copy()
|
| 117 |
for sem_idx, step in steps.items():
|
| 118 |
if gan_type == 'pggan':
|
|
|
|
| 1 |
# python 3.7
|
| 2 |
"""Demo."""
|
| 3 |
+
import random
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
|
|
|
| 56 |
return image
|
| 57 |
|
| 58 |
|
| 59 |
+
def _update_slider():
|
| 60 |
+
num_semantics = st.session_state["num_semantics"]
|
| 61 |
+
for sem_idx in range(num_semantics):
|
| 62 |
+
st.session_state[f"semantic_slider_{sem_idx}"] = 0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
def main():
|
| 66 |
"""Main function (loop for StreamLit)."""
|
| 67 |
+
|
| 68 |
st.title('Closed-Form Factorization of Latent Semantics in GANs')
|
| 69 |
st.sidebar.title('Options')
|
| 70 |
+
st.sidebar.button('Reset', on_click=_update_slider, kwargs={})
|
| 71 |
|
| 72 |
model_name = st.sidebar.selectbox(
|
| 73 |
'Model to Interpret',
|
|
|
|
| 81 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
| 82 |
|
| 83 |
num_semantics = st.sidebar.number_input(
|
| 84 |
+
'Number of semantics', value=5, min_value=0, max_value=None, step=1, key="num_semantics")
|
| 85 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
| 86 |
if gan_type == 'pggan':
|
| 87 |
max_step = 5.0
|
|
|
|
| 96 |
value=0.0,
|
| 97 |
min_value=-max_step,
|
| 98 |
max_value=max_step,
|
| 99 |
+
step=0.04 * max_step,
|
| 100 |
+
key=f"semantic_slider_{sem_idx}")
|
| 101 |
|
| 102 |
image_placeholder = st.empty()
|
| 103 |
button_placeholder = st.empty()
|
| 104 |
+
button_totally_random = st.empty()
|
| 105 |
|
| 106 |
try:
|
| 107 |
base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
|
|
|
|
| 116 |
state.code_idx = 0
|
| 117 |
state.codes = base_codes[0:1]
|
| 118 |
|
| 119 |
+
if button_placeholder.button('Next Sample'):
|
| 120 |
state.code_idx += 1
|
| 121 |
if state.code_idx < base_codes.shape[0]:
|
| 122 |
state.codes = base_codes[state.code_idx][np.newaxis]
|
| 123 |
else:
|
| 124 |
state.codes = sample(model, gan_type)
|
| 125 |
|
| 126 |
+
if button_totally_random.button('Totally Random'):
|
| 127 |
+
state.codes = sample(model, gan_type)
|
| 128 |
+
|
| 129 |
code = state.codes.copy()
|
| 130 |
for sem_idx, step in steps.items():
|
| 131 |
if gan_type == 'pggan':
|