Spaces:
Runtime error
Runtime error
File size: 3,545 Bytes
0eb032f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import html
import os
import string
import ftfy
import regex as re
import torch
from transformers import AutoTokenizer
from ..models.wan_video_text_encoder import WanTextEncoder
from .base_prompter import BasePrompter
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans("", "", string.punctuation))
for part in text.split(keep_punctuation_exact_string)
)
else:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower()
text = re.sub(r"\s+", " ", text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, "whitespace", "lower", "canonicalize")
self.name = name
self.seq_len = seq_len
self.clean = clean
# init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
self.vocab_size = self.tokenizer.vocab_size
def __call__(self, sequence, **kwargs):
return_mask = kwargs.pop("return_mask", False)
# arguments
_kwargs = {"return_tensors": "pt"}
if self.seq_len is not None:
_kwargs.update(
{
"padding": "max_length",
"truncation": True,
"max_length": self.seq_len,
}
)
_kwargs.update(**kwargs)
# tokenization
if isinstance(sequence, str):
sequence = [sequence]
if self.clean:
sequence = [self._clean(u) for u in sequence]
ids = self.tokenizer(sequence, **_kwargs)
# output
if return_mask:
return ids.input_ids, ids.attention_mask
else:
return ids.input_ids
def _clean(self, text):
if self.clean == "whitespace":
text = whitespace_clean(basic_clean(text))
elif self.clean == "lower":
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == "canonicalize":
text = canonicalize(basic_clean(text))
return text
class WanPrompter(BasePrompter):
def __init__(self, tokenizer_path=None, text_len=512):
super().__init__()
self.text_len = text_len
self.text_encoder = None
self.fetch_tokenizer(tokenizer_path)
def fetch_tokenizer(self, tokenizer_path=None):
if tokenizer_path is not None:
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path, seq_len=self.text_len, clean="whitespace"
)
def fetch_models(self, text_encoder: WanTextEncoder = None):
self.text_encoder = text_encoder
def encode_prompt(self, prompt, positive=True, device="cuda"):
prompt = self.process_prompt(prompt, positive=positive)
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
ids = ids.to(device)
mask = mask.to(device)
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_emb = self.text_encoder(ids, mask)
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
return prompt_emb
|