File size: 2,458 Bytes
dc4cc7d
 
 
 
 
 
 
 
 
 
82a5fcc
5f7bf3e
9fbbfae
c439932
3484a4e
 
 
 
f23d43c
c439932
9053bac
dc4cc7d
6dd641d
b59be17
dc4cc7d
b59be17
 
 
 
 
 
dc4cc7d
 
 
 
 
8e5554f
dc4cc7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3820b00
dc4cc7d
 
df494d8
d05d42f
 
 
 
dc4cc7d
 
2551a08
dc4cc7d
df494d8
dc4cc7d
 
cc59196
1dac32a
dc4cc7d
 
d05d42f
 
50cb45b
cb3447a
dc4cc7d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import gradio as gr
from gradio import FlaggingCallback
from gradio.components import IOComponent

from datasets import load_dataset
from typing import List, Optional, Any
import argilla as rg
import os

def load_data(idx):
    df = load_dataset("merve/turkish_instructions", split="train").to_pandas()
    sample = df.iloc[int(idx)]
    instruction = sample[1]

    if sample[2]:
        input_sample = sample[2] 
    else:
        input_sample="-"
    response = sample[3]
    return instruction, input_sample, response

def create_record(text, feedback):
    
    status = "Validated" if feedback == "Doğru" else "Default"
    instruction, input_sample, response = load_data(int(text))
    
    fields = {
    "talimat": instruction,
    "girdi": input_sample,
    "Γ§Δ±ktΔ±": response}
    
    # the label will come from the flag object in Gradio
    label = "True"
    
    record = rg.TextClassificationRecord(
        inputs=fields,
        annotation=label,
        status=status,
        metadata={"feedback": feedback}
    )
    
    print(record)
    return record




class ArgillaLogger(FlaggingCallback):
    def __init__(self, api_url, api_key, dataset_name):
        rg.init(api_url=api_url, api_key=api_key)
        self.dataset_name = dataset_name
    def setup(self, components: List[IOComponent], flagging_dir: str):
        pass
    def flag(
        self,
        flag_data: List[Any],
        flag_option: Optional[str] = None,
        flag_index: Optional[int] = None,
        username: Optional[str] = None,
    ) -> int:
        text = flag_data[0]
        inference = flag_data[1]
        rg.log(name=self.dataset_name, records=create_record(text, flag_option))


idx_input = gr.Slider(minimum=0, maximum=51564, label="SatΔ±r")
instruction = gr.Textbox(label="Talimat")
input_sample = gr.Textbox(label="Girdi")
response = gr.Textbox(label="Γ‡Δ±ktΔ±")


gr.Interface(
    load_data,
    title = "ALPACA Veriseti DΓΌzeltme ArayΓΌzΓΌ",
    description = "Bir satΔ±r sayΔ±sΔ± verip ΓΆrnek alΔ±n. Γ‡eviride gΓΆzΓΌnΓΌze doğru gelmeyen bir şey olursa işaretleyin.",
    allow_flagging="manual",
    flagging_callback=ArgillaLogger(
        api_url="https://pro.argilla.io", 
        api_key=os.getenv("API_KEY"), 
        dataset_name="alpaca-flags"
    ),
    inputs=[idx_input],
    outputs=[instruction, input_sample, response],
    flagging_options=["Doğru", "Yanlış", "Belirsiz"],
    theme="gradio/soft"
).launch()