| from gradio.components import Component | |
| import torch | |
| from hydra import Hydra | |
| from transformers import AutoTokenizer | |
| import gradio as gr | |
| from hydra import Hydra | |
| import os | |
| from typing import Any, Optional | |
| model_name = "ellenhp/query2osm-bert-v1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True) | |
| model = Hydra.from_pretrained(model_name).to('cpu') | |
| class DatasetSaver(gr.FlaggingCallback): | |
| inner: Optional[gr.HuggingFaceDatasetSaver] = None | |
| def __init__(self, inner): | |
| self.inner = inner | |
| def setup(self, components: list[Component], flagging_dir: str): | |
| self.inner.setup(components, flagging_dir) | |
| def flag(self, | |
| flag_data: list[Any], | |
| flag_option: str = "", | |
| username: str | None = None): | |
| flag_data = [flag_data[0], {"label": flag_data[1]['label']}] | |
| self.inner.flag(flag_data, flag_option, None) | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| if HF_TOKEN is not None: | |
| hf_writer = gr.HuggingFaceDatasetSaver( | |
| HF_TOKEN, "osm-queries-crowdsourced", True, "data.csv", False) | |
| else: | |
| hf_writer = None | |
| flag_callback = DatasetSaver(hf_writer) | |
| def predict(input_query): | |
| with torch.no_grad(): | |
| print(input_query) | |
| input_text = input_query.strip().lower() | |
| inputs = tokenizer(input_text, return_tensors="pt") | |
| outputs = model.forward(inputs.input_ids) | |
| return {classification[0]: classification[1] for classification in outputs.classifications[0]} | |
| textbox = gr.Textbox(label="Query", | |
| placeholder="Where can I get a quick bite to eat?") | |
| label = gr.Label(label="Result", num_top_classes=5) | |
| gradio_app = gr.Interface( | |
| predict, | |
| inputs=[textbox], | |
| outputs=[label], | |
| title="Query Classification", | |
| allow_flagging="manual", | |
| flagging_options=["potentially harmful", "wrong classification"], | |
| flagging_callback=flag_callback, | |
| ) | |
| if __name__ == "__main__": | |
| gradio_app.launch() | |