File size: 3,347 Bytes
56238f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os.path
import random
import re
import unicodedata
import torch
from torch.utils.data import Dataset
from PIL import Image

from typing import List, Union

def clean_filename(s):
    # 去除首尾空格和点号
    s = s.strip().strip('.')
    # 转换 Unicode 字符为 ASCII 形式
    s = unicodedata.normalize('NFKD', s).encode('ASCII', 'ignore').decode('ASCII')
    illegal_chars = r'[/]'
    reserved_names = set()
    # 替换非法字符为下划线
    s = re.sub(illegal_chars, '_', s)
    # 合并连续的下划线
    s = re.sub(r'_{2,}', '_', s)
    # 转换为小写
    s = s.lower()
    # 检查是否为保留文件名
    if s.upper() in reserved_names:
        s = s + '_'
    # 限制文件名长度
    max_length = 200
    s = s[:max_length]
    if not s:
        return 'untitled'
    return s

def save_fn(image, metadata, root_path):
    image_path = os.path.join(root_path, str(metadata['filename'])+".png")
    Image.fromarray(image).save(image_path)

class RandomNDataset(Dataset):
    def __init__(self, latent_shape=(4, 64, 64), conditions:Union[int, List, str]=None, seeds=None, max_num_instances=50000, num_samples_per_instance=-1):
        if isinstance(conditions, int):
            conditions = list(range(conditions)) # class labels
        elif isinstance(conditions, str):
            if os.path.exists(conditions):
                conditions = open(conditions, "r").read().splitlines()
            else:
                raise FileNotFoundError(conditions)
        elif isinstance(conditions, list):
            conditions = conditions
        self.conditions = conditions
        self.num_conditons = len(conditions)
        self.seeds = seeds

        if num_samples_per_instance > 0:
            max_num_instances = num_samples_per_instance*self.num_conditons
        else:
            max_num_instances = max_num_instances

        if seeds is not None:
            self.max_num_instances = len(seeds)*self.num_conditons
            self.num_seeds = len(seeds)
        else:
            self.num_seeds = (max_num_instances + self.num_conditons - 1)  // self.num_conditons
            self.max_num_instances = self.num_seeds*self.num_conditons
        self.latent_shape = latent_shape

    def __getitem__(self, idx):
        condition = self.conditions[idx//self.num_seeds]

        seed = random.randint(0, 1<<31) #idx % self.num_seeds
        if self.seeds is not None:
            seed = self.seeds[idx % self.num_seeds]

        filename = f"{clean_filename(str(condition))}_{seed}"
        generator = torch.Generator().manual_seed(seed)
        latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)

        metadata = dict(
            filename=filename,
            seed=seed,
            condition=condition,
            save_fn=save_fn,
        )
        return latent, condition, metadata
    def __len__(self):
        return self.max_num_instances

class ClassLabelRandomNDataset(RandomNDataset):
    def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, conditions:Union[int, List, str]=None, seeds=None, max_num_instances=50000, num_samples_per_instance=-1):
        if conditions is None:
            conditions = list(range(num_classes))
        super().__init__(latent_shape, conditions, seeds, max_num_instances, num_samples_per_instance)