File size: 2,687 Bytes
7c34c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from typing import List
import argparse
import random
import json
import os

from datasets import Dataset

from multi_token.constants import ROLE_ASSISTANT, ROLE_USER


TYPES = ["audio", "image", "text"]

REPLACEMENTS = {
    "image": ["audio", "image", "document"],
    "picture": ["audio file", "picture", "text snippet"],
    "photo": ["sound", "photo", "text"],
    "visual": ["audio", "visual", "textual"],
    "see": ["hear", "see", "read"],
    "look": ["sound", "look", "read"],
    "visible": ["audible", "visible", "readable"],
}

TEMP_TOKEN = "<<<TEMP-TOKEN>>>"

EXCLUDE_WORDS = ["region", "ocr", "color", "right", "left"]


def _convert_convo(convo) -> List:
    type_idx = TYPES.index(random.choice(TYPES))
    msgs = []
    for m in convo:
        content = m["value"].replace("<image>", TEMP_TOKEN)
        for k, v in REPLACEMENTS.items():
            content = content.replace(k, v[type_idx])
        content = content.replace(TEMP_TOKEN, "<imagebind>")
        msgs.append(
            {
                "role": {"gpt": ROLE_ASSISTANT, "human": ROLE_USER}[m["from"]],
                "content": content,
            }
        )
    return msgs


def _fix_path(path):
    parts = path.split("/")
    parts = [parts[0], parts[1], parts[1], *parts[2:]]
    new_path = os.path.join(*parts)
    return new_path


def main(args):
    rows = []
    for json_fn in args.llava_json:
        with open(json_fn) as f:
            rows.extend(json.load(f))

    def gen(rows):
        for row in rows:
            try:
                img_path = row["image"]
            except KeyError:
                continue

            # avoid tasks too image-y
            convo_text = repr(row["conversations"]).lower()

            if "ocr" in img_path or any(w in convo_text for w in EXCLUDE_WORDS):
                continue

            fn = os.path.join(args.image_folder, _fix_path(img_path))
            if not os.path.exists(fn):
                print("Skipping (does not exist)", fn)
                continue
            yield {
                "id": str(row["id"]),
                "imagebinds": [fn],
                "messages": _convert_convo(row["conversations"]),
            }

    ds = Dataset.from_generator(gen, gen_kwargs={"rows": rows}, num_proc=args.num_proc)
    ds.save_to_disk(args.output_folder)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--llava_json", type=str, action="append")
    parser.add_argument("-f", "--image_folder", type=str)
    parser.add_argument("-o", "--output_folder", type=str)
    parser.add_argument("-n", "--num_proc", type=int, default=1)
    args = parser.parse_args()
    main(args)