Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import csv | |
import datetime | |
import io | |
import json | |
import os | |
import uuid | |
from abc import ABC, abstractmethod | |
from pathlib import Path | |
from typing import TYPE_CHECKING, Any, List | |
import fo_utils as fou | |
import gradio as gr | |
from gradio import utils | |
if TYPE_CHECKING: | |
from gradio.components import IOComponent | |
def _get_dataset_features_info(is_new, components): | |
""" | |
Takes in a list of components and returns a dataset features info | |
Parameters: | |
is_new: boolean, whether the dataset is new or not | |
components: list of components | |
Returns: | |
infos: a dictionary of the dataset features | |
file_preview_types: dictionary mapping of gradio components to appropriate string. | |
header: list of header strings | |
""" | |
infos = {"flagged": {"features": {}}} | |
# File previews for certain input and output types | |
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"} | |
headers = [] | |
# Generate the headers and dataset_infos | |
if is_new: | |
for component in components: | |
headers.append(component.label) | |
infos["flagged"]["features"][component.label] = { | |
"dtype": "string", | |
"_type": "Value", | |
} | |
if isinstance(component, tuple(file_preview_types)): | |
headers.append(component.label + " file") | |
for _component, _type in file_preview_types.items(): | |
if isinstance(component, _component): | |
infos["flagged"]["features"][ | |
(component.label or "") + " file" | |
] = {"_type": _type} | |
break | |
headers.append("flag") | |
infos["flagged"]["features"]["flag"] = { | |
"dtype": "string", | |
"_type": "Value", | |
} | |
return infos, file_preview_types, headers | |
class FlaggingCallback(ABC): | |
""" | |
An abstract class for defining the methods that any FlaggingCallback should have. | |
""" | |
def setup(self, components: List[IOComponent], flagging_dir: str): | |
""" | |
This method should be overridden and ensure that everything is set up correctly for flag(). | |
This method gets called once at the beginning of the Interface.launch() method. | |
Parameters: | |
components: Set of components that will provide flagged data. | |
flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()). | |
""" | |
pass | |
def flag( | |
self, | |
flag_data: List[Any], | |
flag_option: str | None = None, | |
flag_index: int | None = None, | |
username: str | None = None, | |
) -> int: | |
""" | |
This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments. | |
This gets called every time the <flag> button is pressed. | |
Parameters: | |
interface: The Interface object that is being used to launch the flagging interface. | |
flag_data: The data to be flagged. | |
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used. | |
flag_index (optional): The index of the sample that is being flagged. | |
username (optional): The username of the user that is flagging the data, if logged in. | |
Returns: | |
(int) The total number of samples that have been flagged. | |
""" | |
pass | |
class SimpleCSVLogger(FlaggingCallback): | |
""" | |
A simplified implementation of the FlaggingCallback abstract class | |
provided for illustrative purposes. Each flagged sample (both the input and output data) | |
is logged to a CSV file on the machine running the gradio app. | |
Example: | |
import gradio as gr | |
def image_classifier(inp): | |
return {'cat': 0.3, 'dog': 0.7} | |
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label", | |
flagging_callback=SimpleCSVLogger()) | |
""" | |
def __init__(self): | |
pass | |
def setup(self, components: List[IOComponent], flagging_dir: str | Path): | |
self.components = components | |
self.flagging_dir = flagging_dir | |
os.makedirs(flagging_dir, exist_ok=True) | |
def flag( | |
self, | |
flag_data: List[Any], | |
flag_option: str | None = None, | |
flag_index: int | None = None, | |
username: str | None = None, | |
) -> int: | |
flagging_dir = self.flagging_dir | |
log_filepath = Path(flagging_dir) / "log.csv" | |
csv_data = [] | |
for component, sample in zip(self.components, flag_data): | |
save_dir = Path(flagging_dir) | |
# / utils.strip_invalid_filename_characters( | |
# component.label or "" | |
# ) | |
csv_data.append( | |
component.deserialize( | |
sample, | |
save_dir, | |
None, | |
) | |
) | |
with open(log_filepath, "a", newline="") as csvfile: | |
writer = csv.writer(csvfile) | |
writer.writerow(utils.sanitize_list_for_csv(csv_data)) | |
with open(log_filepath, "r") as csvfile: | |
line_count = len([None for row in csv.reader(csvfile)]) - 1 | |
# if flag_option == "Bad": | |
# #get the image path | |
# image_path = csv_data | |
# #get the image name | |
# print(image_path) | |
# fou.upload_image_to_cvat(image_path[0]) | |
return line_count | |
class FlagMethod: | |
""" | |
Helper class that contains the flagging button option and callback | |
""" | |
def __init__(self, flagging_callback: FlaggingCallback, flag_option=None): | |
self.flagging_callback = flagging_callback | |
self.flag_option = flag_option | |
self.__name__ = "Flag" | |
def __call__(self, *flag_data): | |
self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option) | |