Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import uvicorn | |
from faicons import icon_svg as icon | |
from fire import Fire | |
from shiny import App, reactive, render, ui | |
from quickmt import Translator | |
from quickmt.hub import hf_download, hf_list | |
t = None | |
port: int = 7860, | |
host: str = "0.0.0.0" | |
ui.navbar_options( | |
bg="red", | |
) | |
app_ui = ui.page_navbar( | |
ui.nav_panel( | |
None, | |
ui.layout_columns( | |
ui.card( | |
ui.h4("Input Text"), | |
ui.input_text_area( | |
"input_text", | |
"", | |
value="", | |
width="100%", | |
height="600px", | |
), | |
ui.input_action_button( | |
"translate_button", "Translate!", class_="btn-primary" | |
), | |
), | |
ui.card(ui.h4("Translation"), ui.output_ui("translate")), | |
), | |
), | |
ui.nav_spacer(), | |
ui.nav_control( | |
ui.input_dark_mode( | |
id="darkmode_toggle", mode="dark", style="padding-top: 10px;" | |
), | |
), | |
ui.nav_control( | |
ui.a( | |
icon("github", height="30px", width="30px", fill="#17a2b8"), | |
href="https://github.com/quickmt/quickmt", | |
target="_blank", | |
class_="btn btn-link", | |
), | |
), | |
sidebar=ui.sidebar( | |
ui.tooltip( | |
ui.input_selectize( | |
"model", | |
"Select model", | |
choices=[i.split("/")[1] for i in hf_list()], | |
), | |
"QuickMT model to use. quickmt-fr-en will translate from French (fr) to English (en)", | |
), | |
ui.tooltip( | |
ui.input_slider( | |
"beam_size", "Beam size", min=1, max=8, step=1, value=2 | |
), | |
"Balances speed and quality. 1 for fastest speed, 8 for highest quality, in between for a balance.", | |
), | |
width="350px", | |
), | |
title=ui.h2("QuickMT Machine Translation Demo"), | |
window_title="QuickMT", | |
theme=ui.Theme.from_brand(__file__), | |
navbar_options=ui.navbar_options(underline=False, theme="auto"), | |
) | |
def server(input, output, session): | |
# Take a dependency on the button | |
def model_download_output(): | |
#print(f"Downloading {input.model()} to {input.model_folder()}") | |
hf_download( | |
model_name="quickmt/" + input.model(), | |
output_dir=Path("/code/models") / input.model(), | |
) | |
return "Model downloaded" | |
# Take a dependency on the button | |
def translate(): | |
global t | |
model_path = Path("/code/models") / input.model() | |
if not model_path.exists(): | |
ui.notification_show( | |
f"Downloading model {input.model()}...", | |
type="message", | |
duration=3, | |
) | |
hf_download( | |
model_name="quickmt/" + input.model(), | |
output_dir=Path("/code/models") / input.model(), | |
) | |
try: | |
if t is None or str(input.model()) != str(Path(t.model_path).name): | |
print(f"Loading model {input.model()}") | |
t = Translator( | |
str(model_path), | |
device="cpu", | |
inter_threads=2, | |
) | |
if len(input.input_text()) == 0: | |
return "" | |
return [ | |
ui.p(i) | |
for i in t( | |
input.input_text().splitlines(), beam_size=input.beam_size() | |
) | |
] | |
except: | |
return [ | |
ui.value_box( | |
title=f"Unexpected error", | |
value="Failed to load model", | |
showcase=icon("bug"), | |
), | |
] | |
app = App(app_ui, server) | |
if __name__=="__main__": | |
uvicorn.run(app, port=port, host=host) | |