Spaces:
Sleeping
Sleeping
Commit
·
b40aac1
1
Parent(s):
2d1d8cb
fix avxline in plots, use common legend in gradio, add reaction and loading on launch
Browse files- app.py +70 -9
- constants.py +17 -0
- mpl_data_plotter.py +28 -23
- plot_utils.py +1 -1
app.py
CHANGED
|
@@ -17,11 +17,13 @@ def convert_int64_to_int32(df):
|
|
| 17 |
|
| 18 |
print(f"Loading domains data...")
|
| 19 |
single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
|
| 20 |
-
single_df
|
|
|
|
| 21 |
single_df = convert_int64_to_int32(single_df)
|
| 22 |
|
| 23 |
pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
|
| 24 |
-
pair_df
|
|
|
|
| 25 |
pair_df = convert_int64_to_int32(pair_df)
|
| 26 |
|
| 27 |
num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
|
|
@@ -33,6 +35,53 @@ print(f"Initializing data plotter...")
|
|
| 33 |
data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def update_all_plots(frequency, split_name):
|
| 37 |
return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
|
| 38 |
|
|
@@ -43,6 +92,8 @@ with gr.Blocks(title="BGC Keyword Plotter") as demo:
|
|
| 43 |
gr.Markdown("## BGC Keyword Plotter")
|
| 44 |
gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
|
| 45 |
|
|
|
|
|
|
|
| 46 |
with gr.Row():
|
| 47 |
frequency_slider = gr.Slider(
|
| 48 |
minimum=int(unique_domain_lengths.min()),
|
|
@@ -51,14 +102,13 @@ with gr.Blocks(title="BGC Keyword Plotter") as demo:
|
|
| 51 |
value=int(unique_domain_lengths.min()),
|
| 52 |
label="Min number of domains"
|
| 53 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
with gr.Row():
|
| 56 |
-
with gr.Column():
|
| 57 |
-
split_selector = gr.Dropdown(
|
| 58 |
-
choices=["stratified"] + BIOSYN_CLASS_NAMES,
|
| 59 |
-
value="stratified",
|
| 60 |
-
label="Split name"
|
| 61 |
-
)
|
| 62 |
with gr.Column():
|
| 63 |
single_domains_plot = gr.Plot(
|
| 64 |
label="Single domains",
|
|
@@ -80,11 +130,22 @@ with gr.Blocks(title="BGC Keyword Plotter") as demo:
|
|
| 80 |
|
| 81 |
frequency_slider.release(
|
| 82 |
fn=update_all_plots,
|
| 83 |
-
inputs=[frequency_slider,
|
| 84 |
outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
|
| 85 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
print(f"Launching!...")
|
| 89 |
demo.launch()
|
|
|
|
| 90 |
# demo.load(filter_map, [min_price, max_price, boroughs], map)
|
|
|
|
| 17 |
|
| 18 |
print(f"Loading domains data...")
|
| 19 |
single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
|
| 20 |
+
single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
|
| 21 |
+
single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
|
| 22 |
single_df = convert_int64_to_int32(single_df)
|
| 23 |
|
| 24 |
pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
|
| 25 |
+
pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
|
| 26 |
+
pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
|
| 27 |
pair_df = convert_int64_to_int32(pair_df)
|
| 28 |
|
| 29 |
num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
|
|
|
|
| 35 |
data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
|
| 36 |
|
| 37 |
|
| 38 |
+
def create_color_legend(class_to_color):
|
| 39 |
+
# Create HTML for the color legend
|
| 40 |
+
legend_html = """
|
| 41 |
+
<div style="
|
| 42 |
+
margin: 10px 0;
|
| 43 |
+
padding: 10px;
|
| 44 |
+
border: 1px solid #ddd;
|
| 45 |
+
border-radius: 4px;
|
| 46 |
+
background: white;
|
| 47 |
+
">
|
| 48 |
+
<div style="
|
| 49 |
+
font-weight: bold;
|
| 50 |
+
margin-bottom: 8px;
|
| 51 |
+
">Color Legend:</div>
|
| 52 |
+
<div style="
|
| 53 |
+
display: flex;
|
| 54 |
+
flex-wrap: wrap;
|
| 55 |
+
gap: 15px;
|
| 56 |
+
align-items: center;
|
| 57 |
+
">
|
| 58 |
+
"""
|
| 59 |
+
# Add each class and its color
|
| 60 |
+
for class_name, color in class_to_color.items():
|
| 61 |
+
legend_html += f"""
|
| 62 |
+
<div style="
|
| 63 |
+
display: flex;
|
| 64 |
+
align-items: center;
|
| 65 |
+
gap: 5px;
|
| 66 |
+
">
|
| 67 |
+
<div style="
|
| 68 |
+
width: 20px;
|
| 69 |
+
height: 20px;
|
| 70 |
+
background-color: {color};
|
| 71 |
+
border-radius: 3px;
|
| 72 |
+
"></div>
|
| 73 |
+
<span>{class_name}</span>
|
| 74 |
+
</div>
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
legend_html += """
|
| 78 |
+
</div>
|
| 79 |
+
</div>
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
return gr.HTML(legend_html)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
def update_all_plots(frequency, split_name):
|
| 86 |
return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
|
| 87 |
|
|
|
|
| 92 |
gr.Markdown("## BGC Keyword Plotter")
|
| 93 |
gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
|
| 94 |
|
| 95 |
+
color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS)
|
| 96 |
+
|
| 97 |
with gr.Row():
|
| 98 |
frequency_slider = gr.Slider(
|
| 99 |
minimum=int(unique_domain_lengths.min()),
|
|
|
|
| 102 |
value=int(unique_domain_lengths.min()),
|
| 103 |
label="Min number of domains"
|
| 104 |
)
|
| 105 |
+
model_selector = gr.Radio(
|
| 106 |
+
choices=["stratified"] + BIOSYN_CLASS_NAMES,
|
| 107 |
+
value="stratified",
|
| 108 |
+
label="Model name"
|
| 109 |
+
)
|
| 110 |
|
| 111 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
with gr.Column():
|
| 113 |
single_domains_plot = gr.Plot(
|
| 114 |
label="Single domains",
|
|
|
|
| 130 |
|
| 131 |
frequency_slider.release(
|
| 132 |
fn=update_all_plots,
|
| 133 |
+
inputs=[frequency_slider, model_selector],
|
| 134 |
outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
|
| 135 |
)
|
| 136 |
+
demo.load(
|
| 137 |
+
fn=update_all_plots,
|
| 138 |
+
inputs=[frequency_slider, model_selector],
|
| 139 |
+
outputs=[single_domains_plot, pair_domains_plot]
|
| 140 |
+
)
|
| 141 |
+
model_selector.input(
|
| 142 |
+
fn=update_all_plots,
|
| 143 |
+
inputs=[frequency_slider, model_selector],
|
| 144 |
+
outputs=[single_domains_plot, pair_domains_plot]
|
| 145 |
+
)
|
| 146 |
|
| 147 |
|
| 148 |
print(f"Launching!...")
|
| 149 |
demo.launch()
|
| 150 |
+
|
| 151 |
# demo.load(filter_map, [min_price, max_price, boroughs], map)
|
constants.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
|
|
|
|
|
|
|
| 2 |
POSTER_BLUE = '#01589C'
|
| 3 |
|
| 4 |
BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Terpene', "Other"]
|
|
@@ -6,3 +8,18 @@ BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Te
|
|
| 6 |
SINGLE_DOMAINS_FILE = 'data/single_domains.csv.gz'
|
| 7 |
PAIR_DOMAINS_FILE = 'data/pair_domains.csv.gz'
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
|
| 4 |
POSTER_BLUE = '#01589C'
|
| 5 |
|
| 6 |
BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Terpene', "Other"]
|
|
|
|
| 8 |
SINGLE_DOMAINS_FILE = 'data/single_domains.csv.gz'
|
| 9 |
PAIR_DOMAINS_FILE = 'data/pair_domains.csv.gz'
|
| 10 |
|
| 11 |
+
BIOSYN_CLASS_HEX_COLORS = {
|
| 12 |
+
'Alkaloid': '#a1c9f4',
|
| 13 |
+
'NRP': '#ffb482',
|
| 14 |
+
'Polyketide': '#8de5a1',
|
| 15 |
+
'RiPP': '#ff9f9b',
|
| 16 |
+
'Saccharide': '#d0bbff',
|
| 17 |
+
'Terpene': '#debb9b',
|
| 18 |
+
'Other': '#cfcfcf',
|
| 19 |
+
# 'stratified': '#01589C', # just in case
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
COLOR_PALETTE = sns.color_palette([
|
| 23 |
+
BIOSYN_CLASS_HEX_COLORS[biosyn_class]
|
| 24 |
+
for biosyn_class in BIOSYN_CLASS_NAMES
|
| 25 |
+
])
|
mpl_data_plotter.py
CHANGED
|
@@ -21,7 +21,12 @@ class MatplotlibDataPlotter:
|
|
| 21 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
| 22 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
| 23 |
'cds_region_id'].values
|
|
|
|
| 24 |
single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# split_name = 'stratified'
|
| 26 |
column_name = f'cosine_similarity_{split_name}'
|
| 27 |
# single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
|
|
@@ -35,12 +40,8 @@ class MatplotlibDataPlotter:
|
|
| 35 |
bin_width=1
|
| 36 |
hue_group_offset=0.5
|
| 37 |
# hue_order=BIOSYN_CLASS_NAMES
|
| 38 |
-
hue2count={}
|
| 39 |
width=0.9
|
| 40 |
|
| 41 |
-
show_legend=True
|
| 42 |
-
print(matplotlib.get_backend())
|
| 43 |
-
|
| 44 |
fig = self.single_domains_fig
|
| 45 |
fig.clf()
|
| 46 |
|
|
@@ -48,23 +49,29 @@ class MatplotlibDataPlotter:
|
|
| 48 |
plot_utils.draw_barplots(
|
| 49 |
targets_list,
|
| 50 |
label_list=label_list,
|
| 51 |
-
top_n=
|
| 52 |
-
bin_width=
|
| 53 |
-
hue_group_offset=
|
| 54 |
hue_order=BIOSYN_CLASS_NAMES,
|
| 55 |
-
hue2count=
|
| 56 |
-
width=
|
| 57 |
ax=ax,
|
| 58 |
-
show_legend=
|
|
|
|
| 59 |
)
|
| 60 |
-
|
| 61 |
-
return fig
|
| 62 |
|
| 63 |
def plot_pair_domains(self, num_domains, split_name):
|
| 64 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
| 65 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
| 66 |
'cds_region_id'].values
|
|
|
|
| 67 |
pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# split_name = 'stratified'
|
| 69 |
column_name = f'cosine_similarity_{split_name}'
|
| 70 |
# pair_df_subset = pair_df.loc[pair_df.dom_location_len >= num_domains]
|
|
@@ -83,27 +90,25 @@ class MatplotlibDataPlotter:
|
|
| 83 |
hue2count={}
|
| 84 |
width=0.9
|
| 85 |
|
| 86 |
-
show_legend=
|
| 87 |
-
# fig = plt.figure(figsize=(5, 10))
|
| 88 |
fig = self.pair_domains_fig
|
| 89 |
-
# fig = plt.gcf()
|
| 90 |
fig.clf()
|
| 91 |
-
print(matplotlib.get_backend())
|
| 92 |
|
| 93 |
ax = fig.gca()
|
| 94 |
plot_utils.draw_barplots(
|
| 95 |
targets_list,
|
| 96 |
label_list=label_list,
|
| 97 |
-
top_n=
|
| 98 |
-
bin_width=
|
| 99 |
-
hue_group_offset=
|
| 100 |
hue_order=BIOSYN_CLASS_NAMES,
|
| 101 |
-
hue2count=
|
| 102 |
-
width=
|
| 103 |
ax=ax,
|
| 104 |
-
show_legend=
|
|
|
|
| 105 |
)
|
| 106 |
-
|
| 107 |
return fig #plt.gcf()
|
| 108 |
|
| 109 |
|
|
|
|
| 21 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
| 22 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
| 23 |
'cds_region_id'].values
|
| 24 |
+
|
| 25 |
single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]
|
| 26 |
+
|
| 27 |
+
biosyn_counts_single = single_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
|
| 28 |
+
hue2count_single = dict(biosyn_counts_single.values)
|
| 29 |
+
|
| 30 |
# split_name = 'stratified'
|
| 31 |
column_name = f'cosine_similarity_{split_name}'
|
| 32 |
# single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
|
|
|
|
| 40 |
bin_width=1
|
| 41 |
hue_group_offset=0.5
|
| 42 |
# hue_order=BIOSYN_CLASS_NAMES
|
|
|
|
| 43 |
width=0.9
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
fig = self.single_domains_fig
|
| 46 |
fig.clf()
|
| 47 |
|
|
|
|
| 49 |
plot_utils.draw_barplots(
|
| 50 |
targets_list,
|
| 51 |
label_list=label_list,
|
| 52 |
+
top_n=top_n,
|
| 53 |
+
bin_width=bin_width,
|
| 54 |
+
hue_group_offset=hue_group_offset,
|
| 55 |
hue_order=BIOSYN_CLASS_NAMES,
|
| 56 |
+
hue2count=hue2count_single,
|
| 57 |
+
width=width,
|
| 58 |
ax=ax,
|
| 59 |
+
show_legend=False,
|
| 60 |
+
palette=COLOR_PALETTE
|
| 61 |
)
|
| 62 |
+
fig.tight_layout()
|
| 63 |
+
return fig
|
| 64 |
|
| 65 |
def plot_pair_domains(self, num_domains, split_name):
|
| 66 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
| 67 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
| 68 |
'cds_region_id'].values
|
| 69 |
+
|
| 70 |
pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
|
| 71 |
+
|
| 72 |
+
biosyn_counts_pairs = pair_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
|
| 73 |
+
hue2count_pairs = dict(biosyn_counts_pairs.values)
|
| 74 |
+
|
| 75 |
# split_name = 'stratified'
|
| 76 |
column_name = f'cosine_similarity_{split_name}'
|
| 77 |
# pair_df_subset = pair_df.loc[pair_df.dom_location_len >= num_domains]
|
|
|
|
| 90 |
hue2count={}
|
| 91 |
width=0.9
|
| 92 |
|
| 93 |
+
show_legend=False
|
|
|
|
| 94 |
fig = self.pair_domains_fig
|
|
|
|
| 95 |
fig.clf()
|
|
|
|
| 96 |
|
| 97 |
ax = fig.gca()
|
| 98 |
plot_utils.draw_barplots(
|
| 99 |
targets_list,
|
| 100 |
label_list=label_list,
|
| 101 |
+
top_n=top_n,
|
| 102 |
+
bin_width=bin_width,
|
| 103 |
+
hue_group_offset=hue_group_offset,
|
| 104 |
hue_order=BIOSYN_CLASS_NAMES,
|
| 105 |
+
hue2count=hue2count_pairs,
|
| 106 |
+
width=width,
|
| 107 |
ax=ax,
|
| 108 |
+
show_legend=show_legend,
|
| 109 |
+
palette=COLOR_PALETTE
|
| 110 |
)
|
| 111 |
+
fig.tight_layout()
|
| 112 |
return fig #plt.gcf()
|
| 113 |
|
| 114 |
|
plot_utils.py
CHANGED
|
@@ -76,7 +76,7 @@ def draw_barplots(targets_list, label_list=None, top_n=5, bin_width=1,
|
|
| 76 |
# if not normalize:
|
| 77 |
# bottom[bin_indices] += bar_offset
|
| 78 |
line_pos = bin_indices.max() + width/2 + hue_group_offset/2
|
| 79 |
-
|
| 80 |
if show_legend:
|
| 81 |
ax.legend(
|
| 82 |
loc='upper center', bbox_to_anchor=(0.5, -0.05),
|
|
|
|
| 76 |
# if not normalize:
|
| 77 |
# bottom[bin_indices] += bar_offset
|
| 78 |
line_pos = bin_indices.max() + width/2 + hue_group_offset/2
|
| 79 |
+
ax.axhline(line_pos, linewidth=1, linestyle='dashed', color=POSTER_BLUE)
|
| 80 |
if show_legend:
|
| 81 |
ax.legend(
|
| 82 |
loc='upper center', bbox_to_anchor=(0.5, -0.05),
|