File size: 4,304 Bytes
1f630db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import json
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from loguru import logger
from typing import Dict, List, Tuple

# ---- あなたの定義済み Dataset, transforms があるなら import ----
# from your_dataset import BlendShapeDataset, image_transform

def unify_index(idx: int, group_size: int) -> int:
    """
    BlendShape の選択 index が group_size と同じ場合、-1 (none) に変換して返す。
    """
    if idx == group_size:
        return -1
    return idx

def build_combination_to_filename(
    teacher_data_file: str,
    meta_file: str,
) -> Tuple[Dict[Tuple[int, ...], str], List[int]]:
    """
    BlendShapeData.json を読み込み、
    例:
       { (g0_idx, g1_idx, g2_idx): "000001.png", ... }
    のような辞書を作成して返す。

    さらに各グループのサイズ (blendShapeNames数 + 1) のリスト group_sizes も返す。
    """
    with open(meta_file, "r", encoding="utf-8") as f:
        meta = json.load(f)
    blend_shape_groups = meta["blendShapeGroupsMeta"]
    # group_sizes[i] = len(そのグループの blendShapeNames) + 1(none枠)
    group_sizes = [len(g["blendShapeNames"]) + 1 for g in blend_shape_groups]

    # teacher_data_file 読み込み
    with open(teacher_data_file, "r", encoding="utf-8") as f:
        teacher_data = json.load(f)
    data_list = teacher_data["dataList"]

    combination_to_filename = {}
    for data in data_list:
        photo_filename = data["photoFileName"]
        blendShapeSelections = data["blendShapeSelectionsPerGroup"]

        # グループ順に selectedBlendShapeIndex を取得しつつ、-1の場合は-1のまま、
        # group_sizeと一致していたら-1へ変換(理想的には -1 しか登場しない想定だが一応対応)
        combo = []
        for group_idx, selection in enumerate(blendShapeSelections):
            sel_idx = selection["selectedBlendShapeIndex"]
            # group_sizes[group_idx] と同じなら none として -1
            sel_idx = unify_index(sel_idx, group_sizes[group_idx])
            combo.append(sel_idx)

        combo = tuple(combo)  # dictのキーにするのでtuple化
        combination_to_filename[combo] = photo_filename

    return combination_to_filename, group_sizes



def main_offline_precompute():
    """
    1. Dataset(JSON)から (組み合わせ -> filename) を作る
    2. CLIP モデルロード
    3. filename -> embedding
    4. ペアワイズ類似度
    5. 保存
    """
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_dir", type=str, default="lapwing/images")
    parser.add_argument("--meta_file", type=str, default="lapwing/texts/BlendShapeGroupsMeta.json")
    parser.add_argument("--teacher_data_file", type=str, default="lapwing/texts/BlendShapeData.json")
    parser.add_argument("--clip_model_name", type=str, default="ViT-L-14")
    parser.add_argument("--clip_pretrained", type=str, default="openai")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--out_comb2fn", type=str, default="combination_to_filename.json")
    parser.add_argument("--out_sims", type=str, default="pairwise_clip_sims.json")

    args = parser.parse_args()

    # 1. 組み合わせ->filename辞書の作成
    combination_to_filename, group_sizes = build_combination_to_filename(
        teacher_data_file=args.teacher_data_file,
        meta_file=args.meta_file,
    )


    # 5. 保存 (JSON形式)
    #   5.1 (combination -> filename)
    #       group_sizes も保存しておくと後段のオンライン時に参照しやすい
    #       tupleは文字列化する必要あり
    comb2fn_dict = {
        "group_sizes": group_sizes,
        "mapping": {
            ",".join(map(str, comb)): fn
            for comb, fn in combination_to_filename.items()
        }
    }
    with open(args.out_comb2fn, "w", encoding="utf-8") as f:
        json.dump(comb2fn_dict, f, ensure_ascii=False, indent=2)

    # データ出力
    for combo, filename in combination_to_filename.items():
        logger.info(f"Combination: {combo}, Filename: {filename}")


if __name__ == "__main__":
    main_offline_precompute()