Spaces:
Running
Running
File size: 25,426 Bytes
7b74407 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
from tokenxxx import *
from constants import *
from utils import *
import os, json, urllib.request, urllib.parse, torch, hashlib, inspect
from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.models.xtts import Xtts
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def filter_kwargs(cls, kwargs): sig = inspect.signature(cls.__init__); accepted = set(sig.parameters.keys()) - {"self"}; return {k: v for k, v in kwargs.items() if k in accepted}
def sanitize_filename(name, url=None): for c in '<>:"/\\|?*': name = name.replace(c, ''); if not name and url is not None: name = hashlib.md5(url.encode()).hexdigest(); return name
def download_file(url, filepath): d = os.path.dirname(filepath); if d and not os.path.exists(d): os.makedirs(d, exist_ok=True)
while not os.path.exists(filepath):
try:
def prog(t): last = [0]; def inner(n, bs, ts): if ts > 0: t.total = ts; t.update(n * bs - last[0]); last[0] = n * bs; return inner
with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t: urllib.request.urlretrieve(url, filepath, reporthook=prog(t))
except: continue
def download_files(folder, files_spec):
if isinstance(files_spec, dict): for fn, url in files_spec.items(): fn = sanitize_filename(fn, url); fp = os.path.join(folder, fn); download_file(url, fp)
elif isinstance(files_spec, list):
for item in files_spec:
if isinstance(item, str): url = item; parsed = urllib.parse.urlparse(url); fn = os.path.basename(parsed.path); if not fn: fn = hashlib.md5(url.encode()).hexdigest(); fn = sanitize_filename(fn, url)
elif isinstance(item, (list, tuple)) and len(item) == 2: url, fn = item; fn = sanitize_filename(fn, url)
elif isinstance(item, dict) and "filename" in item and "url" in item: fn = sanitize_filename(item["filename"], item["url"]); url = item["url"]
else: raise ValueError("Invalid file specification")
fp = os.path.join(folder, fn); download_file(url, fp)
else: raise ValueError("files_spec must be dict or list")
def read_json(fp): with open(fp, 'r', encoding='utf-8') as f: return json.load(f)
def get_codegen_tokenizer(vocab_path, merges_path):
with open(vocab_path, 'r', encoding='utf-8') as f: vocab = json.load(f)
with open(merges_path, 'r', encoding='utf-8') as f: merges = f.read().splitlines()
merge_ranks = {};
for i, merge in enumerate(merges): parts = merge.strip().split(); if len(parts) == 2: merge_ranks[tuple(parts)] = i
def bpe(token):
word = list(token); pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
while True: candidate = None; candidate_rank = None; candidate_index = None
for i, pair in enumerate(pairs): if pair in merge_ranks: rank = merge_ranks[pair]; if candidate is None or rank < candidate_rank: candidate = pair; candidate_rank = rank; candidate_index = i
if candidate is None: break
first, second = candidate; new_word = []; i = 0
while i < len(word):
try: j = word.index(first, i); new_word.extend(word[i:j]); i = j
except ValueError: new_word.extend(word[i:]); break
if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second); i += 2
else: new_word.append(word[i]); i += 1
word = new_word;
if len(word) == 1: break
pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
return word
def tokenizer(text): tokens = []; for token in text.split(): bpe_tokens = bpe(token); for subtoken in bpe_tokens: tokens.append(vocab.get(subtoken, 0)); return tokens
return tokenizer
def simple_tokenizer(text, vocab, max_length=77): toks = text.split(); ids = [vocab.get(t, 1) for t in toks]; if len(ids) < max_length: ids = ids + [0] * (max_length - len(ids))
else: ids = ids[:max_length]; return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
def load_state_dict_safe(model, loaded_state_dict): model_state = model.state_dict(); new_state = {}; for key, value in model_state.items(): if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape: new_state[key] = loaded_state_dict[key]
else: new_state[key] = value; model.load_state_dict(new_state, strict=False)
class GPT2Config: def __init__(self, vocab_size=50257, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d): return cls(**d)
class MBartConfig: def __init__(self, vocab_size=50265, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d): return cls(**d)
class CodeGenConfig: def __init__(self, vocab_size=50257, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d): return cls(**d)
class BartConfig: def __init__(self, vocab_size=50265, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d): return cls(**d)
class AutoencoderKLConfig:
def __init__(self, **kwargs): self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d): return cls(**d)
class OpenLRMConfig:
def __init__(self, **kwargs): self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d): return cls(**d)
class UNet2DConditionModelConfig:
def __init__(self, **kwargs): self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d): return cls(**d)
class MusicGenConfig:
def __init__(self, **kwargs): self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d): return cls(**d)
class XTTSConfig:
def __init__(self, **kwargs): self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d): return cls(**d)
class GPT2LMHeadModel(nn.Module): def __init__(self, config): super().__init__(); layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.transformer = nn.TransformerEncoder(layer, num_layers=12); self.lm_head = nn.Linear(768, config.vocab_size)
def forward(self, x): return self.lm_head(self.transformer(x))
class MBartForConditionalGeneration(nn.Module): def __init__(self, config): super().__init__(); self.config = config; layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.encoder = nn.TransformerEncoder(layer, num_layers=6)
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12); self.decoder = nn.TransformerDecoder(dlayer, num_layers=6); self.output_layer = nn.Linear(768, config.vocab_size)
def forward(self, src, tgt): return self.output_layer(self.decoder(tgt, self.encoder(src)))
class CodeGenForCausalLM(nn.Module): def __init__(self, config): super().__init__(); d_model = getattr(config, "d_model", 1024); n_head = getattr(config, "n_head", 16)
num_layers = getattr(config, "num_layers", 12); dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head); self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers); self.lm_head = nn.Linear(d_model, config.vocab_size)
def forward(self, tgt, memory=None): if memory is None: memory = torch.zeros_like(tgt); return self.lm_head(self.transformer_decoder(tgt, memory))
class BartForConditionalGeneration(nn.Module): def __init__(self, config): super().__init__(); layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.encoder = nn.TransformerEncoder(layer, num_layers=6)
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12); self.decoder = nn.TransformerDecoder(dlayer, num_layers=6); self.output_layer = nn.Linear(768, config.vocab_size)
def forward(self, src, tgt): return self.output_layer(self.decoder(tgt, self.encoder(src)))
class ResnetBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__(); self.norm1 = nn.GroupNorm(32, in_ch); self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1); self.norm2 = nn.GroupNorm(32, out_ch); self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1); self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x): sc = self.conv_shortcut(x); h = F.silu(self.norm1(x)); h = self.conv1(h); h = F.silu(self.norm2(x)); h = self.conv2(h); return h + sc
class Downsample(nn.Module): def __init__(self, in_ch, out_ch): super().__init__(); self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
def forward(self, x): return self.conv(x)
class DownBlock(nn.Module): def __init__(self, in_ch, out_ch, num_res): super().__init__(); self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)]); self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
def forward(self, x): for r in self.resnets: x = r(x); for ds in self.downsamplers: x = ds(x); return x
class Upsample(nn.Module): def __init__(self, in_ch, out_ch): super().__init__(); self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
def forward(self, x): return self.conv(x)
class UpBlock(nn.Module): def __init__(self, in_ch, out_ch, num_res): super().__init__(); self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)]); self.upsampler = Upsample(out_ch, out_ch)
def forward(self, x): for r in self.resnets: x = r(x); return self.upsampler(x)
class AttentionBlock(nn.Module): def __init__(self, ch): super().__init__(); self.norm = nn.GroupNorm(32, ch); self.query = nn.Conv2d(ch, ch, 1); self.key = nn.Conv2d(ch, ch, 1); self.value = nn.Conv2d(ch, ch, 1); self.proj_attn = nn.Conv2d(ch, ch, 1)
def forward(self, x): b, c, h, w = x.shape; xn = self.norm(x); q = self.query(xn).view(b, c, -1).permute(0, 2, 1); k = self.key(xn).view(b, c, -1); v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1); out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w); return x + self.proj_attn(out)
class Encoder(nn.Module): def __init__(self, in_ch=3, base_ch=128, latent_ch=4): super().__init__(); self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1); self.down_blocks = nn.ModuleList([DownBlock(base_ch, base_ch, 2), DownBlock(base_ch, base_ch * 2, 2), DownBlock(base_ch * 2, base_ch * 4, 2), DownBlock(base_ch * 4, base_ch * 4, 2)]); self.mid_block = nn.ModuleList([ResnetBlock(base_ch * 4, base_ch * 4), AttentionBlock(base_ch * 4), ResnetBlock(base_ch * 4, base_ch * 4)]); self.conv_norm_out = nn.GroupNorm(32, base_ch * 4); self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1); self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
def forward(self, x): x = self.conv_in(x); for blk in self.down_blocks: x = blk(x); for m in self.mid_block: x = m(x); x = self.conv_norm_out(x); x = self.conv_out(x); return self.quant_conv(x)
class Decoder(nn.Module): def __init__(self, out_ch=3, base_ch=128, latent_ch=4): super().__init__(); self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1); self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1); self.mid_block = nn.ModuleList([ResnetBlock(base_ch * 4, base_ch * 4), AttentionBlock(base_ch * 4), ResnetBlock(base_ch * 4, base_ch * 4)]); self.up_blocks = nn.ModuleList([UpBlock(base_ch * 4, base_ch * 4, 3), UpBlock(base_ch * 4, base_ch * 2, 3), UpBlock(base_ch * 2, base_ch, 3), UpBlock(base_ch, base_ch, 3)]); self.conv_norm_out = nn.GroupNorm(32, base_ch); self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
def forward(self, x): x = self.post_quant_conv(x); x = self.conv_in(x); for m in self.mid_block: x = m(x); for up in self.up_blocks: x = up(x); x = self.conv_norm_out(x); return self.conv_out(x)
class AutoencoderKL(nn.Module): def __init__(self, config): super().__init__(); in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3); base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4); self.encoder = Encoder(in_ch, base_ch, latent_ch); self.decoder = Decoder(out_ch, base_ch, latent_ch)
def forward(self, x): return self.decoder(self.encoder(x))
def decode(self, x): return self.decoder(x)
class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__(); self.norm1 = nn.LayerNorm(embed_dim); self.attn = nn.MultiheadAttention(embed_dim, num_heads); self.norm2 = nn.LayerNorm(embed_dim)
hidden_dim = embed_dim * 4; self.mlp = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
def forward(self, x): res = x; x = self.norm1(x); x = x.transpose(0, 1); attn, _ = self.attn(x, x, x); x = attn.transpose(0, 1); x = res + x; return x + self.mlp(self.norm2(x))
class VisionTransformer(nn.Module): def __init__(self, config): super().__init__(); self.img_size = config.get("img_size", 592)
self.patch_size = config.get("patch_size", 16); self.embed_dim = config.get("hidden_size", 768); depth = config.get("depth", 12); num_heads = config.get("num_heads", 12)
num_patches = (self.img_size // self.patch_size) ** 2; self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)); self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size); self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)]); self.norm = nn.LayerNorm(self.embed_dim)
self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim)); self._init_weights()
def _init_weights(self): nn.init.normal_(self.cls_token, std=0.02); nn.init.normal_(self.pos_embed, std=0.02)
def forward(self, x): x = self.patch_embed(x); x = x.flatten(2).transpose(1, 2); cls_tokens = self.cls_token.expand(x.shape[0], -1, -1); x = torch.cat((cls_tokens, x), dim=1); x = x + self.pos_embed
for blk in self.blocks: x = blk(x); return self.norm(x)[:, 0]
class OpenLRM(nn.Module): def __init__(self, config): super().__init__(); self.encoder = nn.ModuleDict({"model": VisionTransformer(config)}); hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768); self.linear = nn.Linear(hidden, hidden)
def forward(self, x): return self.linear(self.encoder["model"](x))
class VideoUNet(nn.Module): def __init__(self, in_ch=4, out_ch=4, features=None): super().__init__();
if features is None: features = [64, 128, 256]
self.encoder = nn.ModuleList(); self.pool = nn.MaxPool3d(2, 2); self.decoder = nn.ModuleList()
for f in features: self.encoder.append(nn.Sequential(nn.Conv3d(in_ch, f, 3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(f, f, 3, padding=1), nn.ReLU(inplace=True))); in_ch = f
for f in reversed(features): self.decoder.append(nn.Sequential(nn.Conv3d(f * 2, f, 3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(f, f, 3, padding=1), nn.ReLU(inplace=True)))
self.final_conv = nn.Conv3d(features[0], out_ch, 1)
def forward(self, x, t, encoder_hidden_states): skips = []; for enc in self.encoder: x = enc(x); skips.append(x); x = self.pool(x)
for dec in self.decoder: skip = skips.pop(); x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False); x = torch.cat([x, skip], dim=1); x = dec(x)
return self.final_conv(x)
class SentimentClassifierModel(nn.Module): def __init__(self, config): super().__init__(); self.classifier = nn.Sequential(nn.Linear(768, 256), nn.ReLU(), nn.Linear(256, 2))
def forward(self, x): return self.classifier(x)
class STTModel(nn.Module): def __init__(self, config): super().__init__(); self.net = nn.Sequential(nn.Linear(768, 512), nn.ReLU(), nn.Linear(512, 768))
def forward(self, x): return self.net(x)
class TTSModel(nn.Module): def __init__(self, config): super().__init__(); self.net = nn.Sequential(nn.Linear(768, 512), nn.ReLU(), nn.Linear(512, 768))
def forward(self, x): return self.net(x)
class MusicGenModel(nn.Module): def __init__(self, config): super().__init__(); layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.transformer = nn.TransformerEncoder(layer, num_layers=12); self.linear = nn.Linear(768, 768)
def forward(self, x): return self.linear(self.transformer(x))
class SimpleTextEncoder(nn.Module): def __init__(self, vocab_size=10000, embed_dim=768, max_length=77): super().__init__(); self.embedding = nn.Embedding(vocab_size, embed_dim); self.max_length = max_length
def forward(self, text_tokens): return self.embedding(text_tokens)
class DiffusionScheduler: def __init__(self, steps): self.steps = steps; self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device); self.alphas = 1 - self.betas; self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def step(self, noise, t, sample): alpha_bar = self.alpha_bars[t]; alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else torch.tensor(1.0, device=sample.device)
x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar); new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise; return new_sample
class VideoOutput: def __init__(self, frames): self.frames = [img_as_ubyte(frame) for frame in frames[0]]
class VideoPipeline(nn.Module): def __init__(self, unet, vae, text_encoder, vocab): super().__init__(); self.unet = unet; self.vae = vae; self.text_encoder = text_encoder; self.vocab = vocab
def forward(self, prompt: str, steps: int = 25, num_frames: int = 24): token_ids = simple_tokenizer(prompt, self.vocab); text_emb = self.text_encoder(token_ids)
latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half(); sched = DiffusionScheduler(steps)
for t in range(steps): noise = self.unet(latent, t, text_emb); latent = sched.step(noise, t, latent)
frames = self.vae.decode(latent / 0.18215); frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy(); return VideoOutput(frames)
class XTTSModelClass(nn.Module):
def __init__(self, config):
super().__init__()
self.xtts = XTTSModel(config, num_speakers=1024, num_languages=25) # Adjust num_speakers, num_languages as needed
def forward(self, text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths):
return self.xtts.forward(text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths)
def inference(self, text, language_id, speaker_id, voice_sample, temperature=0.7, length_penalty=1.0):
return self.xtts.inference(text, language_id, speaker_id, voice_sample, temperature, length_penalty)
def initialize_gpt2_model(folder, files): download_files(folder, files); config = GPT2Config(); model = GPT2LMHeadModel(config).to(device)
sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device); load_state_dict_safe(model, sd); model.eval(); enc = read_json(os.path.join(folder, sanitize_filename("encoder.json"))); return model, enc
def initialize_translation_model(folder, files): download_files(folder, files); config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = MBartForConditionalGeneration(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
vp = os.path.join(folder, "vocab.json");
if os.path.exists(vp): vocab = read_json(vp); model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
else: model.tokenizer = lambda txt: txt
model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}; return model
def initialize_codegen_model(folder, files): download_files(folder, files); config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = CodeGenForCausalLM(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt")); vocab = read_json(os.path.join(folder, "vocab.json")); idx2w = {v: k for k, v in vocab.items()}
model.tokenizer = tok; return model, tok, vocab, idx2w, vocab
def initialize_summarization_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = BartForConditionalGeneration(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
vp = os.path.join(folder, "vocab.json");
if os.path.exists(vp): vocab_json = read_json(vp); vocab = set(vocab_json.keys()); return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
return model, None, None, None
def initialize_imagegen_model(folder, files): download_files(folder, files); config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
vae = AutoencoderKL(config).to(device); sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device); load_state_dict_safe(vae, sd); vae.eval(); return vae
def initialize_image_to_3d_model(folder, files): download_files(folder, files); config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model3d = OpenLRM(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model3d, sd); model3d.eval(); return model3d
def initialize_text_to_video_model(folder, files): download_files(folder, files); unet_cfg = read_json(os.path.join(folder, "config.json"))
unet_cfg = filter_kwargs(VideoUNet, unet_cfg); unet = VideoUNet(**unet_cfg).half().to(device)
sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device); load_state_dict_safe(unet, sd_unet); unet.eval()
vae_cfg = read_json(os.path.join(folder, "config.json")); vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg); vae = AutoencoderKL(vae_cfg).half().to(device)
sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device); load_state_dict_safe(vae, sd_vae); vae.eval()
vp = os.path.join(folder, "vocab.json"); text_vocab = read_json(vp) if os.path.exists(vp) else {}; te_path = os.path.join(folder, "text_encoder.bin")
if os.path.exists(te_path): text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device); sd_te = torch.load(te_path, map_location=device); load_state_dict_safe(text_encoder, sd_te)
else: text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
text_encoder.eval(); return VideoPipeline(unet, vae, text_encoder, text_vocab)
def initialize_sentiment_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = SentimentClassifierModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
vp = os.path.join(folder, "vocab.json");
if os.path.exists(vp): read_json(vp); return model
def initialize_stt_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = STTModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
vp = os.path.join(folder, "vocab.json");
if os.path.exists(vp): read_json(vp); return model
def initialize_tts_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = TTSModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
vp = os.path.join(folder, "vocab.json");
if os.path.exists(vp): read_json(vp); return model
def initialize_musicgen_model(folder, files): download_files(folder, files); config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = MusicGenModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval(); return model
def initialize_xtts_model(folder, files):
download_files(folder, files)
config_xtts = XTTSConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = XTTSModelClass(config_xtts).to(device)
checkpoint = torch.load(os.path.join(folder, "model.pth"), map_location=torch.device(device))
model.load_state_dict(checkpoint["model"], strict=False)
model.eval()
return model.xtts
|