PixNerd / src /models /conditioner /qwen3_text_encoder.py
wangshuai6
init
79ac8fc
import torch
import torch.nn as nn
from src.models.conditioner.base import BaseConditioner
from transformers import Qwen3Model, Qwen2Tokenizer
class Qwen3TextEncoder(BaseConditioner):
def __init__(self, weight_path: str, embed_dim:int=None, max_length=128):
super().__init__()
self.tokenizer = Qwen2Tokenizer.from_pretrained(weight_path, max_length=max_length, padding_side="right")
# self.model = Qwen3Model.from_pretrained(weight_path, attn_implementation="flex_attention").to(torch.bfloat16)
self.model = Qwen3Model.from_pretrained(weight_path).to(torch.bfloat16)
# self.model.compile()
self.uncondition_embedding = None
self.embed_dim = embed_dim
self.max_length = max_length
# torch._dynamo.config.optimize_ddp = False
def _impl_condition(self, y, metadata:dict={}):
tokenized = self.tokenizer(y, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
input_ids = tokenized.input_ids.cuda()
attention_mask = tokenized.attention_mask.cuda()
metadata["valid_length_y"] = torch.sum(attention_mask, dim=-1)
y = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
if y.shape[2] < self.embed_dim:
y = torch.cat([y, torch.zeros(y.shape[0], y.shape[1], self.embed_dim - y.shape[2]).to(y.device, y.dtype)], dim=-1)
if y.shape[2] > self.embed_dim:
y = y[:, :, :self.embed_dim]
return y
def _impl_uncondition(self, y, metadata:dict=None):
if self.uncondition_embedding is not None:
return self.uncondition_embedding.repeat(len(y), 1, 1)
self.uncondition_embedding = self._impl_condition(["",])
return self.uncondition_embedding.repeat(len(y), 1, 1)