Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,347 Bytes
56238f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os.path
import random
import re
import unicodedata
import torch
from torch.utils.data import Dataset
from PIL import Image
from typing import List, Union
def clean_filename(s):
# 去除首尾空格和点号
s = s.strip().strip('.')
# 转换 Unicode 字符为 ASCII 形式
s = unicodedata.normalize('NFKD', s).encode('ASCII', 'ignore').decode('ASCII')
illegal_chars = r'[/]'
reserved_names = set()
# 替换非法字符为下划线
s = re.sub(illegal_chars, '_', s)
# 合并连续的下划线
s = re.sub(r'_{2,}', '_', s)
# 转换为小写
s = s.lower()
# 检查是否为保留文件名
if s.upper() in reserved_names:
s = s + '_'
# 限制文件名长度
max_length = 200
s = s[:max_length]
if not s:
return 'untitled'
return s
def save_fn(image, metadata, root_path):
image_path = os.path.join(root_path, str(metadata['filename'])+".png")
Image.fromarray(image).save(image_path)
class RandomNDataset(Dataset):
def __init__(self, latent_shape=(4, 64, 64), conditions:Union[int, List, str]=None, seeds=None, max_num_instances=50000, num_samples_per_instance=-1):
if isinstance(conditions, int):
conditions = list(range(conditions)) # class labels
elif isinstance(conditions, str):
if os.path.exists(conditions):
conditions = open(conditions, "r").read().splitlines()
else:
raise FileNotFoundError(conditions)
elif isinstance(conditions, list):
conditions = conditions
self.conditions = conditions
self.num_conditons = len(conditions)
self.seeds = seeds
if num_samples_per_instance > 0:
max_num_instances = num_samples_per_instance*self.num_conditons
else:
max_num_instances = max_num_instances
if seeds is not None:
self.max_num_instances = len(seeds)*self.num_conditons
self.num_seeds = len(seeds)
else:
self.num_seeds = (max_num_instances + self.num_conditons - 1) // self.num_conditons
self.max_num_instances = self.num_seeds*self.num_conditons
self.latent_shape = latent_shape
def __getitem__(self, idx):
condition = self.conditions[idx//self.num_seeds]
seed = random.randint(0, 1<<31) #idx % self.num_seeds
if self.seeds is not None:
seed = self.seeds[idx % self.num_seeds]
filename = f"{clean_filename(str(condition))}_{seed}"
generator = torch.Generator().manual_seed(seed)
latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)
metadata = dict(
filename=filename,
seed=seed,
condition=condition,
save_fn=save_fn,
)
return latent, condition, metadata
def __len__(self):
return self.max_num_instances
class ClassLabelRandomNDataset(RandomNDataset):
def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, conditions:Union[int, List, str]=None, seeds=None, max_num_instances=50000, num_samples_per_instance=-1):
if conditions is None:
conditions = list(range(num_classes))
super().__init__(latent_shape, conditions, seeds, max_num_instances, num_samples_per_instance)
|