QuickMT-Demo / app.py
radinplaid's picture
Update app.py
b3dbb08 verified
from pathlib import Path
import uvicorn
from faicons import icon_svg as icon
from fire import Fire
from shiny import App, reactive, render, ui
from quickmt import Translator
from quickmt.hub import hf_download, hf_list
t = None
port: int = 7860,
host: str = "0.0.0.0"
ui.navbar_options(
bg="red",
)
app_ui = ui.page_navbar(
ui.nav_panel(
None,
ui.layout_columns(
ui.card(
ui.h4("Input Text"),
ui.input_text_area(
"input_text",
"",
value="",
width="100%",
height="600px",
),
ui.input_action_button(
"translate_button", "Translate!", class_="btn-primary"
),
),
ui.card(ui.h4("Translation"), ui.output_ui("translate")),
),
),
ui.nav_spacer(),
ui.nav_control(
ui.input_dark_mode(
id="darkmode_toggle", mode="dark", style="padding-top: 10px;"
),
),
ui.nav_control(
ui.a(
icon("github", height="30px", width="30px", fill="#17a2b8"),
href="https://github.com/quickmt/quickmt",
target="_blank",
class_="btn btn-link",
),
),
sidebar=ui.sidebar(
ui.tooltip(
ui.input_selectize(
"model",
"Select model",
choices=[i.split("/")[1] for i in hf_list()],
),
"QuickMT model to use. quickmt-fr-en will translate from French (fr) to English (en)",
),
ui.tooltip(
ui.input_slider(
"beam_size", "Beam size", min=1, max=8, step=1, value=2
),
"Balances speed and quality. 1 for fastest speed, 8 for highest quality, in between for a balance.",
),
width="350px",
),
title=ui.h2("QuickMT Machine Translation Demo"),
window_title="QuickMT",
theme=ui.Theme.from_brand(__file__),
navbar_options=ui.navbar_options(underline=False, theme="auto"),
)
def server(input, output, session):
@render.ui
@reactive.event(input.quickmt_model_download) # Take a dependency on the button
def model_download_output():
#print(f"Downloading {input.model()} to {input.model_folder()}")
hf_download(
model_name="quickmt/" + input.model(),
output_dir=Path("/code/models") / input.model(),
)
return "Model downloaded"
@render.ui
@reactive.event(input.translate_button) # Take a dependency on the button
def translate():
global t
model_path = Path("/code/models") / input.model()
if not model_path.exists():
ui.notification_show(
f"Downloading model {input.model()}...",
type="message",
duration=3,
)
hf_download(
model_name="quickmt/" + input.model(),
output_dir=Path("/code/models") / input.model(),
)
try:
if t is None or str(input.model()) != str(Path(t.model_path).name):
print(f"Loading model {input.model()}")
t = Translator(
str(model_path),
device="cpu",
inter_threads=2,
)
if len(input.input_text()) == 0:
return ""
return [
ui.p(i)
for i in t(
input.input_text().splitlines(), beam_size=input.beam_size()
)
]
except:
return [
ui.value_box(
title=f"Unexpected error",
value="Failed to load model",
showcase=icon("bug"),
),
]
app = App(app_ui, server)
if __name__=="__main__":
uvicorn.run(app, port=port, host=host)