Spaces:
Sleeping
Sleeping
Commit
·
af49af1
1
Parent(s):
b40aac1
cleanup
Browse files- app.py +73 -81
- mpl_data_plotter.py +3 -5
app.py
CHANGED
|
@@ -11,29 +11,9 @@ from mpl_data_plotter import MatplotlibDataPlotter
|
|
| 11 |
def convert_int64_to_int32(df):
|
| 12 |
for col in df.columns:
|
| 13 |
if df[col].dtype == 'int64':
|
| 14 |
-
print(col)
|
| 15 |
df[col] = df[col].astype('int32')
|
| 16 |
return df
|
| 17 |
|
| 18 |
-
print(f"Loading domains data...")
|
| 19 |
-
single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
|
| 20 |
-
single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
|
| 21 |
-
single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
|
| 22 |
-
single_df = convert_int64_to_int32(single_df)
|
| 23 |
-
|
| 24 |
-
pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
|
| 25 |
-
pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
|
| 26 |
-
pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
|
| 27 |
-
pair_df = convert_int64_to_int32(pair_df)
|
| 28 |
-
|
| 29 |
-
num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
|
| 30 |
-
columns={'as_domain_id': 'num_domains'})
|
| 31 |
-
|
| 32 |
-
unique_domain_lengths = num_domains_in_region_df.num_domains.unique()
|
| 33 |
-
|
| 34 |
-
print(f"Initializing data plotter...")
|
| 35 |
-
data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
|
| 36 |
-
|
| 37 |
|
| 38 |
def create_color_legend(class_to_color):
|
| 39 |
# Create HTML for the color legend
|
|
@@ -86,66 +66,78 @@ def update_all_plots(frequency, split_name):
|
|
| 86 |
return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
|
| 87 |
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
single_domains_plot
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
# height: 100% !important;
|
| 122 |
-
# width: 100% !important;
|
| 123 |
-
# }
|
| 124 |
-
# </style>
|
| 125 |
-
# """)
|
| 126 |
-
with gr.Column():
|
| 127 |
-
pair_domains_plot = gr.Plot(label="Pair domains")
|
| 128 |
-
# with gr.Column():
|
| 129 |
-
# combined_plot = gr.Plot(label="Combined Wave")
|
| 130 |
-
|
| 131 |
-
frequency_slider.release(
|
| 132 |
-
fn=update_all_plots,
|
| 133 |
-
inputs=[frequency_slider, model_selector],
|
| 134 |
-
outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
|
| 135 |
-
)
|
| 136 |
-
demo.load(
|
| 137 |
-
fn=update_all_plots,
|
| 138 |
-
inputs=[frequency_slider, model_selector],
|
| 139 |
-
outputs=[single_domains_plot, pair_domains_plot]
|
| 140 |
-
)
|
| 141 |
-
model_selector.input(
|
| 142 |
-
fn=update_all_plots,
|
| 143 |
-
inputs=[frequency_slider, model_selector],
|
| 144 |
-
outputs=[single_domains_plot, pair_domains_plot]
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
print(f"Launching!...")
|
| 149 |
-
demo.launch()
|
| 150 |
-
|
| 151 |
-
# demo.load(filter_map, [min_price, max_price, boroughs], map)
|
|
|
|
| 11 |
def convert_int64_to_int32(df):
|
| 12 |
for col in df.columns:
|
| 13 |
if df[col].dtype == 'int64':
|
|
|
|
| 14 |
df[col] = df[col].astype('int32')
|
| 15 |
return df
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def create_color_legend(class_to_color):
|
| 19 |
# Create HTML for the color legend
|
|
|
|
| 66 |
return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
|
| 67 |
|
| 68 |
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
print(f"Loading domains data...")
|
| 71 |
+
single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
|
| 72 |
+
single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
|
| 73 |
+
single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
|
| 74 |
+
single_df = convert_int64_to_int32(single_df)
|
| 75 |
+
|
| 76 |
+
pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
|
| 77 |
+
pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
|
| 78 |
+
pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
|
| 79 |
+
pair_df = convert_int64_to_int32(pair_df)
|
| 80 |
+
|
| 81 |
+
num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
|
| 82 |
+
columns={'as_domain_id': 'num_domains'})
|
| 83 |
+
|
| 84 |
+
unique_domain_lengths = num_domains_in_region_df.num_domains.unique()
|
| 85 |
+
|
| 86 |
+
print(f"Initializing data plotter...")
|
| 87 |
+
data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
print(f"Defining blocks...")
|
| 91 |
+
|
| 92 |
+
# Create Gradio interface
|
| 93 |
+
with gr.Blocks(title="BGC Keyword Plotter") as demo:
|
| 94 |
+
gr.Markdown("## BGC Keyword Plotter")
|
| 95 |
+
gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
|
| 96 |
+
|
| 97 |
+
color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS)
|
| 98 |
+
|
| 99 |
+
with gr.Row():
|
| 100 |
+
frequency_slider = gr.Slider(
|
| 101 |
+
minimum=int(unique_domain_lengths.min()),
|
| 102 |
+
maximum=int(unique_domain_lengths.max()),
|
| 103 |
+
step=1,
|
| 104 |
+
value=int(unique_domain_lengths.min()),
|
| 105 |
+
label="Min number of domains"
|
| 106 |
+
)
|
| 107 |
+
model_selector = gr.Radio(
|
| 108 |
+
choices=["stratified"] + BIOSYN_CLASS_NAMES,
|
| 109 |
+
value="stratified",
|
| 110 |
+
label="Model name"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
with gr.Row():
|
| 114 |
+
with gr.Column():
|
| 115 |
+
single_domains_plot = gr.Plot(
|
| 116 |
+
label="Single domains",
|
| 117 |
+
container=True,
|
| 118 |
+
elem_id="single_domains_plot"
|
| 119 |
+
)
|
| 120 |
+
with gr.Column():
|
| 121 |
+
pair_domains_plot = gr.Plot(label="Pair domains")
|
| 122 |
+
|
| 123 |
+
frequency_slider.release(
|
| 124 |
+
fn=update_all_plots,
|
| 125 |
+
inputs=[frequency_slider, model_selector],
|
| 126 |
+
outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
|
| 127 |
)
|
| 128 |
+
demo.load(
|
| 129 |
+
fn=update_all_plots,
|
| 130 |
+
inputs=[frequency_slider, model_selector],
|
| 131 |
+
outputs=[single_domains_plot, pair_domains_plot]
|
| 132 |
)
|
| 133 |
+
model_selector.input(
|
| 134 |
+
fn=update_all_plots,
|
| 135 |
+
inputs=[frequency_slider, model_selector],
|
| 136 |
+
outputs=[single_domains_plot, pair_domains_plot]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
print(f"Launching!...")
|
| 141 |
+
demo.launch()
|
| 142 |
+
|
| 143 |
+
# demo.load(filter_map, [min_price, max_price, boroughs], map)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mpl_data_plotter.py
CHANGED
|
@@ -17,7 +17,7 @@ class MatplotlibDataPlotter:
|
|
| 17 |
self.single_domains_fig = plt.figure(figsize=(5, 10))
|
| 18 |
self.pair_domains_fig = plt.figure(figsize=(5, 10))
|
| 19 |
|
| 20 |
-
def plot_single_domains(self, num_domains, split_name):
|
| 21 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
| 22 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
| 23 |
'cds_region_id'].values
|
|
@@ -39,7 +39,6 @@ class MatplotlibDataPlotter:
|
|
| 39 |
top_n=5
|
| 40 |
bin_width=1
|
| 41 |
hue_group_offset=0.5
|
| 42 |
-
# hue_order=BIOSYN_CLASS_NAMES
|
| 43 |
width=0.9
|
| 44 |
|
| 45 |
fig = self.single_domains_fig
|
|
@@ -62,7 +61,7 @@ class MatplotlibDataPlotter:
|
|
| 62 |
fig.tight_layout()
|
| 63 |
return fig
|
| 64 |
|
| 65 |
-
def plot_pair_domains(self, num_domains, split_name):
|
| 66 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
| 67 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
| 68 |
'cds_region_id'].values
|
|
@@ -72,9 +71,8 @@ class MatplotlibDataPlotter:
|
|
| 72 |
biosyn_counts_pairs = pair_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
|
| 73 |
hue2count_pairs = dict(biosyn_counts_pairs.values)
|
| 74 |
|
| 75 |
-
# split_name = 'stratified'
|
| 76 |
column_name = f'cosine_similarity_{split_name}'
|
| 77 |
-
|
| 78 |
selected_keyword_index = pair_df_subset.groupby('cds_region_id').agg(
|
| 79 |
{column_name: 'idxmax'}
|
| 80 |
).values.flatten()
|
|
|
|
| 17 |
self.single_domains_fig = plt.figure(figsize=(5, 10))
|
| 18 |
self.pair_domains_fig = plt.figure(figsize=(5, 10))
|
| 19 |
|
| 20 |
+
def plot_single_domains(self, num_domains, split_name="stratified"):
|
| 21 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
| 22 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
| 23 |
'cds_region_id'].values
|
|
|
|
| 39 |
top_n=5
|
| 40 |
bin_width=1
|
| 41 |
hue_group_offset=0.5
|
|
|
|
| 42 |
width=0.9
|
| 43 |
|
| 44 |
fig = self.single_domains_fig
|
|
|
|
| 61 |
fig.tight_layout()
|
| 62 |
return fig
|
| 63 |
|
| 64 |
+
def plot_pair_domains(self, num_domains, split_name="stratified"):
|
| 65 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
| 66 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
| 67 |
'cds_region_id'].values
|
|
|
|
| 71 |
biosyn_counts_pairs = pair_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
|
| 72 |
hue2count_pairs = dict(biosyn_counts_pairs.values)
|
| 73 |
|
|
|
|
| 74 |
column_name = f'cosine_similarity_{split_name}'
|
| 75 |
+
|
| 76 |
selected_keyword_index = pair_df_subset.groupby('cds_region_id').agg(
|
| 77 |
{column_name: 'idxmax'}
|
| 78 |
).values.flatten()
|