File size: 6,068 Bytes
2b84d47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from __future__ import annotations

import csv
import datetime
import io
import json
import os
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, List
import fo_utils as fou 

import gradio as gr
from gradio import utils

if TYPE_CHECKING:
    from gradio.components import IOComponent



def _get_dataset_features_info(is_new, components):
    """
    Takes in a list of components and returns a dataset features info

    Parameters:
    is_new: boolean, whether the dataset is new or not
    components: list of components

    Returns:
    infos: a dictionary of the dataset features
    file_preview_types: dictionary mapping of gradio components to appropriate string.
    header: list of header strings

    """
    infos = {"flagged": {"features": {}}}
    # File previews for certain input and output types
    file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
    headers = []

    # Generate the headers and dataset_infos
    if is_new:

        for component in components:
            headers.append(component.label)
            infos["flagged"]["features"][component.label] = {
                "dtype": "string",
                "_type": "Value",
            }
            if isinstance(component, tuple(file_preview_types)):
                headers.append(component.label + " file")
                for _component, _type in file_preview_types.items():
                    if isinstance(component, _component):
                        infos["flagged"]["features"][
                            (component.label or "") + " file"
                        ] = {"_type": _type}
                        break

        headers.append("flag")
        infos["flagged"]["features"]["flag"] = {
            "dtype": "string",
            "_type": "Value",
        }

    return infos, file_preview_types, headers


class FlaggingCallback(ABC):
    """
    An abstract class for defining the methods that any FlaggingCallback should have.
    """

    @abstractmethod
    def setup(self, components: List[IOComponent], flagging_dir: str):
        """
        This method should be overridden and ensure that everything is set up correctly for flag().
        This method gets called once at the beginning of the Interface.launch() method.
        Parameters:
        components: Set of components that will provide flagged data.
        flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()).
        """
        pass

    @abstractmethod
    def flag(
        self,
        flag_data: List[Any],
        flag_option: str | None = None,
        flag_index: int | None = None,
        username: str | None = None,
    ) -> int:
        """
        This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
        This gets called every time the <flag> button is pressed.
        Parameters:
        interface: The Interface object that is being used to launch the flagging interface.
        flag_data: The data to be flagged.
        flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
        flag_index (optional): The index of the sample that is being flagged.
        username (optional): The username of the user that is flagging the data, if logged in.
        Returns:
        (int) The total number of samples that have been flagged.
        """
        pass


class SimpleCSVLogger(FlaggingCallback):
    """
    A simplified implementation of the FlaggingCallback abstract class
    provided for illustrative purposes.  Each flagged sample (both the input and output data)
    is logged to a CSV file on the machine running the gradio app.
    Example:
        import gradio as gr
        def image_classifier(inp):
            return {'cat': 0.3, 'dog': 0.7}
        demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
                            flagging_callback=SimpleCSVLogger())
    """

    def __init__(self):
        pass

    def setup(self, components: List[IOComponent], flagging_dir: str | Path):
        self.components = components
        self.flagging_dir = flagging_dir
        os.makedirs(flagging_dir, exist_ok=True)

    def flag(
        self,
        flag_data: List[Any],
        flag_option: str | None = None,
        flag_index: int | None = None,
        username: str | None = None,
    ) -> int:
        flagging_dir = self.flagging_dir
        log_filepath = Path(flagging_dir) / "log.csv"

        csv_data = []
        for component, sample in zip(self.components, flag_data):
            save_dir = Path(flagging_dir) 
            # / utils.strip_invalid_filename_characters(
            #     component.label or ""
            # )
            csv_data.append(
                component.deserialize(
                    sample,
                    save_dir,
                    None,
                )
            )

        with open(log_filepath, "a", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(utils.sanitize_list_for_csv(csv_data))

        with open(log_filepath, "r") as csvfile:
            line_count = len([None for row in csv.reader(csvfile)]) - 1
        
        # if flag_option == "Bad":
        #     #get the image path
        #     image_path = csv_data
        #     #get the image name
        #     print(image_path)
        #     fou.upload_image_to_cvat(image_path[0])
            
        return line_count



class FlagMethod:
    """
    Helper class that contains the flagging button option and callback
    """

    def __init__(self, flagging_callback: FlaggingCallback, flag_option=None):
        self.flagging_callback = flagging_callback
        self.flag_option = flag_option
        self.__name__ = "Flag"

    def __call__(self, *flag_data):
        self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option)