from pathlib import Path import uvicorn from faicons import icon_svg as icon from fire import Fire from shiny import App, reactive, render, ui from quickmt import Translator from quickmt.hub import hf_download, hf_list t = None port: int = 7860, host: str = "0.0.0.0" ui.navbar_options( bg="red", ) app_ui = ui.page_navbar( ui.nav_panel( None, ui.layout_columns( ui.card( ui.h4("Input Text"), ui.input_text_area( "input_text", "", value="", width="100%", height="600px", ), ui.input_action_button( "translate_button", "Translate!", class_="btn-primary" ), ), ui.card(ui.h4("Translation"), ui.output_ui("translate")), ), ), ui.nav_spacer(), ui.nav_control( ui.input_dark_mode( id="darkmode_toggle", mode="dark", style="padding-top: 10px;" ), ), ui.nav_control( ui.a( icon("github", height="30px", width="30px", fill="#17a2b8"), href="https://github.com/quickmt/quickmt", target="_blank", class_="btn btn-link", ), ), sidebar=ui.sidebar( ui.tooltip( ui.input_selectize( "model", "Select model", choices=[i.split("/")[1] for i in hf_list()], ), "QuickMT model to use. quickmt-fr-en will translate from French (fr) to English (en)", ), ui.tooltip( ui.input_slider( "beam_size", "Beam size", min=1, max=8, step=1, value=2 ), "Balances speed and quality. 1 for fastest speed, 8 for highest quality, in between for a balance.", ), width="350px", ), title=ui.h2("QuickMT Machine Translation Demo"), window_title="QuickMT", theme=ui.Theme.from_brand(__file__), navbar_options=ui.navbar_options(underline=False, theme="auto"), ) def server(input, output, session): @render.ui @reactive.event(input.quickmt_model_download) # Take a dependency on the button def model_download_output(): #print(f"Downloading {input.model()} to {input.model_folder()}") hf_download( model_name="quickmt/" + input.model(), output_dir=Path("/code/models") / input.model(), ) return "Model downloaded" @render.ui @reactive.event(input.translate_button) # Take a dependency on the button def translate(): global t model_path = Path("/code/models") / input.model() if not model_path.exists(): ui.notification_show( f"Downloading model {input.model()}...", type="message", duration=3, ) hf_download( model_name="quickmt/" + input.model(), output_dir=Path("/code/models") / input.model(), ) try: if t is None or str(input.model()) != str(Path(t.model_path).name): print(f"Loading model {input.model()}") t = Translator( str(model_path), device="cpu", inter_threads=2, ) if len(input.input_text()) == 0: return "" return [ ui.p(i) for i in t( input.input_text().splitlines(), beam_size=input.beam_size() ) ] except: return [ ui.value_box( title=f"Unexpected error", value="Failed to load model", showcase=icon("bug"), ), ] app = App(app_ui, server) if __name__=="__main__": uvicorn.run(app, port=port, host=host)