Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from src.models.conditioner.base import BaseConditioner | |
from transformers import Qwen3Model, Qwen2Tokenizer | |
class Qwen3TextEncoder(BaseConditioner): | |
def __init__(self, weight_path: str, embed_dim:int=None, max_length=128): | |
super().__init__() | |
self.tokenizer = Qwen2Tokenizer.from_pretrained(weight_path, max_length=max_length, padding_side="right") | |
# self.model = Qwen3Model.from_pretrained(weight_path, attn_implementation="flex_attention").to(torch.bfloat16) | |
self.model = Qwen3Model.from_pretrained(weight_path).to(torch.bfloat16) | |
# self.model.compile() | |
self.uncondition_embedding = None | |
self.embed_dim = embed_dim | |
self.max_length = max_length | |
# torch._dynamo.config.optimize_ddp = False | |
def _impl_condition(self, y, metadata:dict={}): | |
tokenized = self.tokenizer(y, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt") | |
input_ids = tokenized.input_ids.cuda() | |
attention_mask = tokenized.attention_mask.cuda() | |
metadata["valid_length_y"] = torch.sum(attention_mask, dim=-1) | |
y = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] | |
if y.shape[2] < self.embed_dim: | |
y = torch.cat([y, torch.zeros(y.shape[0], y.shape[1], self.embed_dim - y.shape[2]).to(y.device, y.dtype)], dim=-1) | |
if y.shape[2] > self.embed_dim: | |
y = y[:, :, :self.embed_dim] | |
return y | |
def _impl_uncondition(self, y, metadata:dict=None): | |
if self.uncondition_embedding is not None: | |
return self.uncondition_embedding.repeat(len(y), 1, 1) | |
self.uncondition_embedding = self._impl_condition(["",]) | |
return self.uncondition_embedding.repeat(len(y), 1, 1) |