|
from typing import * |
|
import os |
|
os.environ['TOKENIZERS_PARALLELISM'] = 'true' |
|
import torch |
|
from transformers import AutoTokenizer, CLIPTextModel |
|
|
|
from ....utils import dist_utils |
|
|
|
|
|
class TextConditionedMixin: |
|
""" |
|
Mixin for text-conditioned models. |
|
|
|
Args: |
|
text_cond_model: The text conditioning model. |
|
""" |
|
def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.text_cond_model_name = text_cond_model |
|
self.text_cond_model = None |
|
|
|
def _init_text_cond_model(self): |
|
""" |
|
Initialize the text conditioning model. |
|
""" |
|
|
|
with dist_utils.local_master_first(): |
|
model = CLIPTextModel.from_pretrained(self.text_cond_model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name) |
|
model.eval() |
|
model = model.cuda() |
|
self.text_cond_model = { |
|
'model': model, |
|
'tokenizer': tokenizer, |
|
} |
|
self.text_cond_model['null_cond'] = self.encode_text(['']) |
|
|
|
@torch.no_grad() |
|
def encode_text(self, text: List[str]) -> torch.Tensor: |
|
""" |
|
Encode the text. |
|
""" |
|
assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond" |
|
if self.text_cond_model is None: |
|
self._init_text_cond_model() |
|
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt') |
|
tokens = encoding['input_ids'].cuda() |
|
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state |
|
|
|
return embeddings |
|
|
|
def get_cond(self, cond, **kwargs): |
|
""" |
|
Get the conditioning data. |
|
""" |
|
cond = self.encode_text(cond) |
|
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) |
|
cond = super().get_cond(cond, **kwargs) |
|
return cond |
|
|
|
def get_inference_cond(self, cond, **kwargs): |
|
""" |
|
Get the conditioning data for inference. |
|
""" |
|
cond = self.encode_text(cond) |
|
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) |
|
cond = super().get_inference_cond(cond, **kwargs) |
|
return cond |
|
|