Spaces:
Runtime error
Runtime error
File size: 6,068 Bytes
2b84d47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from __future__ import annotations
import csv
import datetime
import io
import json
import os
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, List
import fo_utils as fou
import gradio as gr
from gradio import utils
if TYPE_CHECKING:
from gradio.components import IOComponent
def _get_dataset_features_info(is_new, components):
"""
Takes in a list of components and returns a dataset features info
Parameters:
is_new: boolean, whether the dataset is new or not
components: list of components
Returns:
infos: a dictionary of the dataset features
file_preview_types: dictionary mapping of gradio components to appropriate string.
header: list of header strings
"""
infos = {"flagged": {"features": {}}}
# File previews for certain input and output types
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
headers = []
# Generate the headers and dataset_infos
if is_new:
for component in components:
headers.append(component.label)
infos["flagged"]["features"][component.label] = {
"dtype": "string",
"_type": "Value",
}
if isinstance(component, tuple(file_preview_types)):
headers.append(component.label + " file")
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
infos["flagged"]["features"][
(component.label or "") + " file"
] = {"_type": _type}
break
headers.append("flag")
infos["flagged"]["features"]["flag"] = {
"dtype": "string",
"_type": "Value",
}
return infos, file_preview_types, headers
class FlaggingCallback(ABC):
"""
An abstract class for defining the methods that any FlaggingCallback should have.
"""
@abstractmethod
def setup(self, components: List[IOComponent], flagging_dir: str):
"""
This method should be overridden and ensure that everything is set up correctly for flag().
This method gets called once at the beginning of the Interface.launch() method.
Parameters:
components: Set of components that will provide flagged data.
flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()).
"""
pass
@abstractmethod
def flag(
self,
flag_data: List[Any],
flag_option: str | None = None,
flag_index: int | None = None,
username: str | None = None,
) -> int:
"""
This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
This gets called every time the <flag> button is pressed.
Parameters:
interface: The Interface object that is being used to launch the flagging interface.
flag_data: The data to be flagged.
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
flag_index (optional): The index of the sample that is being flagged.
username (optional): The username of the user that is flagging the data, if logged in.
Returns:
(int) The total number of samples that have been flagged.
"""
pass
class SimpleCSVLogger(FlaggingCallback):
"""
A simplified implementation of the FlaggingCallback abstract class
provided for illustrative purposes. Each flagged sample (both the input and output data)
is logged to a CSV file on the machine running the gradio app.
Example:
import gradio as gr
def image_classifier(inp):
return {'cat': 0.3, 'dog': 0.7}
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
flagging_callback=SimpleCSVLogger())
"""
def __init__(self):
pass
def setup(self, components: List[IOComponent], flagging_dir: str | Path):
self.components = components
self.flagging_dir = flagging_dir
os.makedirs(flagging_dir, exist_ok=True)
def flag(
self,
flag_data: List[Any],
flag_option: str | None = None,
flag_index: int | None = None,
username: str | None = None,
) -> int:
flagging_dir = self.flagging_dir
log_filepath = Path(flagging_dir) / "log.csv"
csv_data = []
for component, sample in zip(self.components, flag_data):
save_dir = Path(flagging_dir)
# / utils.strip_invalid_filename_characters(
# component.label or ""
# )
csv_data.append(
component.deserialize(
sample,
save_dir,
None,
)
)
with open(log_filepath, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(utils.sanitize_list_for_csv(csv_data))
with open(log_filepath, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
# if flag_option == "Bad":
# #get the image path
# image_path = csv_data
# #get the image name
# print(image_path)
# fou.upload_image_to_cvat(image_path[0])
return line_count
class FlagMethod:
"""
Helper class that contains the flagging button option and callback
"""
def __init__(self, flagging_callback: FlaggingCallback, flag_option=None):
self.flagging_callback = flagging_callback
self.flag_option = flag_option
self.__name__ = "Flag"
def __call__(self, *flag_data):
self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option)
|