Remove flagging
Browse files
app.py
CHANGED
|
@@ -11,35 +11,6 @@ model_name = "ellenhp/query2osm-bert-v1"
|
|
| 11 |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True)
|
| 12 |
model = Hydra.from_pretrained(model_name).to('cpu')
|
| 13 |
|
| 14 |
-
|
| 15 |
-
class DatasetSaver(gr.FlaggingCallback):
|
| 16 |
-
inner: Optional[gr.HuggingFaceDatasetSaver] = None
|
| 17 |
-
|
| 18 |
-
def __init__(self, inner):
|
| 19 |
-
self.inner = inner
|
| 20 |
-
|
| 21 |
-
def setup(self, components: list[Component], flagging_dir: str):
|
| 22 |
-
self.inner.setup(components, flagging_dir)
|
| 23 |
-
|
| 24 |
-
def flag(self,
|
| 25 |
-
flag_data: list[Any],
|
| 26 |
-
flag_option: str = "",
|
| 27 |
-
username: str | None = None):
|
| 28 |
-
flag_data = [flag_data[0], {"label": flag_data[1]['label']}]
|
| 29 |
-
self.inner.flag(flag_data, flag_option, None)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
HF_TOKEN = os.getenv('HF_TOKEN')
|
| 33 |
-
if HF_TOKEN is not None:
|
| 34 |
-
hf_writer = gr.HuggingFaceDatasetSaver(
|
| 35 |
-
HF_TOKEN, "osm-queries-crowdsourced", True, "data.csv", False)
|
| 36 |
-
else:
|
| 37 |
-
hf_writer = None
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
flag_callback = DatasetSaver(hf_writer)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
def predict(input_query):
|
| 44 |
with torch.no_grad():
|
| 45 |
print(input_query)
|
|
@@ -58,9 +29,6 @@ gradio_app = gr.Interface(
|
|
| 58 |
inputs=[textbox],
|
| 59 |
outputs=[label],
|
| 60 |
title="Query Classification",
|
| 61 |
-
allow_flagging="manual",
|
| 62 |
-
flagging_options=["correct classification", "incorrect classification"],
|
| 63 |
-
flagging_callback=flag_callback,
|
| 64 |
)
|
| 65 |
|
| 66 |
if __name__ == "__main__":
|
|
|
|
| 11 |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True)
|
| 12 |
model = Hydra.from_pretrained(model_name).to('cpu')
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def predict(input_query):
|
| 15 |
with torch.no_grad():
|
| 16 |
print(input_query)
|
|
|
|
| 29 |
inputs=[textbox],
|
| 30 |
outputs=[label],
|
| 31 |
title="Query Classification",
|
|
|
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
if __name__ == "__main__":
|