Spaces:
Runtime error
Runtime error
File size: 2,458 Bytes
dc4cc7d 82a5fcc 5f7bf3e 9fbbfae c439932 3484a4e f23d43c c439932 9053bac dc4cc7d 6dd641d b59be17 dc4cc7d b59be17 dc4cc7d 8e5554f dc4cc7d 3820b00 dc4cc7d df494d8 d05d42f dc4cc7d 2551a08 dc4cc7d df494d8 dc4cc7d cc59196 1dac32a dc4cc7d d05d42f 50cb45b cb3447a dc4cc7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import os
import gradio as gr
from gradio import FlaggingCallback
from gradio.components import IOComponent
from datasets import load_dataset
from typing import List, Optional, Any
import argilla as rg
import os
def load_data(idx):
df = load_dataset("merve/turkish_instructions", split="train").to_pandas()
sample = df.iloc[int(idx)]
instruction = sample[1]
if sample[2]:
input_sample = sample[2]
else:
input_sample="-"
response = sample[3]
return instruction, input_sample, response
def create_record(text, feedback):
status = "Validated" if feedback == "DoΔru" else "Default"
instruction, input_sample, response = load_data(int(text))
fields = {
"talimat": instruction,
"girdi": input_sample,
"Γ§Δ±ktΔ±": response}
# the label will come from the flag object in Gradio
label = "True"
record = rg.TextClassificationRecord(
inputs=fields,
annotation=label,
status=status,
metadata={"feedback": feedback}
)
print(record)
return record
class ArgillaLogger(FlaggingCallback):
def __init__(self, api_url, api_key, dataset_name):
rg.init(api_url=api_url, api_key=api_key)
self.dataset_name = dataset_name
def setup(self, components: List[IOComponent], flagging_dir: str):
pass
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
text = flag_data[0]
inference = flag_data[1]
rg.log(name=self.dataset_name, records=create_record(text, flag_option))
idx_input = gr.Slider(minimum=0, maximum=51564, label="SatΔ±r")
instruction = gr.Textbox(label="Talimat")
input_sample = gr.Textbox(label="Girdi")
response = gr.Textbox(label="ΓΔ±ktΔ±")
gr.Interface(
load_data,
title = "ALPACA Veriseti DΓΌzeltme ArayΓΌzΓΌ",
description = "Bir satΔ±r sayΔ±sΔ± verip ΓΆrnek alΔ±n. Γeviride gΓΆzΓΌnΓΌze doΔru gelmeyen bir Εey olursa iΕaretleyin.",
allow_flagging="manual",
flagging_callback=ArgillaLogger(
api_url="https://pro.argilla.io",
api_key=os.getenv("API_KEY"),
dataset_name="alpaca-flags"
),
inputs=[idx_input],
outputs=[instruction, input_sample, response],
flagging_options=["DoΔru", "YanlΔ±Ε", "Belirsiz"],
theme="gradio/soft"
).launch() |