import gradio as gr import numpy as np import onnxruntime as ort import json from cryptography.fernet import Fernet import os from dotenv import load_dotenv load_dotenv() # Model load key = os.getenv("ONNX_KEY") cipher = Fernet(key) with open("species_bag.onnx.encrypted", "rb") as f: encrypted = f.read() decrypted = cipher.decrypt(encrypted) ort_session = ort.InferenceSession(decrypted) # Initialize ONNX session input_name = ort_session.get_inputs()[0].name output_name = ort_session.get_outputs()[0].name # Load mappings with open("idx2spec.json", "r") as f: idx2spec = json.load(f) with open("spec2idx.json", "r") as f: spec2idx = json.load(f) with open("spec2key.json", "r") as f: spec2key = json.load(f) baseurl = "https://www.gbif.org/species/" def predict_species(selected_species, n_hits=10): if not selected_species: return "", "" # Convert species names to indices using spec2idx input_indices = [int(spec2idx[name]) for name in selected_species] # Model inference input_np = np.array(input_indices, dtype=np.int64).reshape(1, -1) output = ort_session.run([output_name], {input_name: input_np})[0][0] # Get top predictions top_indices = output.argsort()[-n_hits:][::-1] top_scores = output[top_indices] # Format selected species with links selected_html = ["{}".format( baseurl, spec2key[species], species ) for species in selected_species] # Format predictions with species names, links and scores predictions_html = ["{} ({:.1f}%)".format( baseurl, spec2key[idx2spec[str(idx)]], idx2spec[str(idx)], 100*score ) for idx, score in zip(top_indices, top_scores)] return "
".join(selected_html), "
".join(predictions_html) # Gradio interface with gr.Blocks() as demo: gr.Markdown("## Danmarks planter - hvem mangler?") gr.Markdown("*Sammensæt et plantesamfund og få forslag til andre arter der passer ind. Vælg mellem 3199 danske terrestriske og akvatiske planter.*") with gr.Row(): species_dropdown = gr.Dropdown( choices=sorted(spec2idx.keys()), multiselect=True, label="Find arter", ) with gr.Row(): with gr.Column(scale=5, min_width=200): selected_output = gr.HTML( label="Arter", show_label=True ) with gr.Column(scale=5, min_width=200): predictions_output = gr.HTML( label="Top hits", show_label=True ) with gr.Column(scale=1, min_width=100): n_hits = gr.Number(10, label="Antal hits", minimum=1, maximum=100) add_button = gr.Button("Tilføj top hit", scale=8) gr.Markdown("Forslag er baseret på et neuralt netværk trænet til at forudsige de mest sandsynlige arter som mangler i et plantesamfund. Trænet på stort datasæt af plantesamfund registreret i Danmark (**4.3 millioner registreringer af 3199 arter/slægter/varianter i mere end 180.000 undersøgelser**).") gr.Markdown("App og model af Kenneth Thorø Martinsen (kenneth2810@gmail.com).") def add_top_prediction(selected_species): if not selected_species: return selected_species top_prediction = predict_species(selected_species)[1].split("
")[0].split(" (")[0] top_prediction = top_prediction.split(">")[1].split("<")[0] # Extract species name from HTML if top_prediction not in selected_species: selected_species.append(top_prediction) return selected_species species_dropdown.change( predict_species, inputs=[species_dropdown, n_hits], outputs=[selected_output, predictions_output] ) add_button.click( add_top_prediction, inputs=[species_dropdown], outputs=[species_dropdown] ) if __name__ == "__main__": demo.launch()