import os import gradio as gr from gradio import FlaggingCallback from gradio.components import IOComponent from datasets import load_dataset from typing import List, Optional, Any import argilla as rg import os def load_data(): ds = load_dataset("merve/turkish_instructions", split="train", streaming=True) sample = next(iter(ds)) instruction = sample["talimat"] input = sample["giriş"] response = sample["çıktı"] return instruction, input, response def create_record(instruction, input, response, feedback): status = "Validated" if feedback == "Doğru" else "Default" #sample = next(iter(ds)) fields = { "talimat": instruction, "input": input, "response": response } # the label will come from the flag object in Gradio label = "True" record = rg.TextClassificationRecord( inputs=fields, annotation=label, status=status, metadata={"feedback": feedback} ) print(record) return record class ArgillaLogger(FlaggingCallback): def __init__(self, api_url, api_key, dataset_name): rg.init(api_url=api_url, api_key=api_key) self.dataset_name = dataset_name def setup(self, components: List[IOComponent], flagging_dir: str): pass def flag( self, flag_data: List[Any], flag_option: Optional[str] = None, flag_index: Optional[int] = None, username: Optional[str] = None, ) -> int: text = flag_data[0] inference = flag_data[1] rg.log(name=self.dataset_name, records=create_record(text, flag_option)) gr.Interface( load_data, title = "ALPACA Veriseti Düzeltme Arayüzü", description = "", allow_flagging="manual", flagging_callback=ArgillaLogger( api_url="https://sandbox.argilla.io", api_key=os.getenv("TEAM_API_KEY"), dataset_name="alpaca-flags" ), outputs=["text", "text", "text"] flagging_options=["Doğru", "Yanlış", "Belirsiz"] ).launch()