KevinNg99's picture
update randomness
b9ad1cd
import re
import loguru
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import cpu_offload_with_hook
"""
English translation of the System prompt:
----------------------------------------
You are an expert in writing image generation prompts. Please rewrite the user's prompt according to the following requirements:
1. The main subject/action/quantity/style/layout/relationship/attribute/text in the rewritten prompt must be consistent with the original intention;
2. The rewritten prompt should follow the "overall-detail-conclusion" structure, ensuring the clarity of information hierarchy;
3. The rewritten prompt should be objective and neutral, avoiding subjective judgment and emotional evaluation;
4. The rewritten prompt should be from the main to the secondary, always describing the most important elements first, and then the secondary and background elements;
5. The rewritten prompt should be logically clear, strictly follow the spatial logic or main-secondary logic, allowing the reader to reconstruct the image in the brain;
6. The rewritten prompt should end with a summary sentence, summarizing the overall style or type of the image.
"""
SYSTEM_PROMPT = (
"你是一位图像生成提示词撰写专家,请根据用户输入的提示词,改写生成新的提示词,改写后的提示词要求:"
"1 改写后提示词包含的主体/动作/数量/风格/布局/关系/属性/文字等 必须和改写前的意图一致; "
"2 在宏观上遵循“总-分-总”的结构,确保信息的层次清晰;"
"3 客观中立,避免主观臆断和情感评价;"
"4 由主到次,始终先描述最重要的元素,再描述次要和背景元素;"
"5 逻辑清晰,严格遵循空间逻辑或主次逻辑,使读者能在大脑中重建画面;"
"6 结尾点题,必须用一句话总结图像的整体风格或类型。"
)
def replace_single_quotes(text):
"""
Replace single quotes within words with double quotes, and convert
curly single quotes to curly double quotes for consistency.
"""
pattern = r"\B'([^']*)'\B"
replaced_text = re.sub(pattern, r'"\1"', text)
replaced_text = replaced_text.replace("’", "”")
replaced_text = replaced_text.replace("‘", "“")
return replaced_text
class RePrompt:
def __init__(self, models_root_path, device_map="auto", enable_offloading=True):
"""
Initialize the RePrompt class with model and processor.
Args:
models_root_path (str): Path to the pretrained model.
device_map (str): Device mapping for model loading.
"""
if enable_offloading:
device_map = None
self.model = AutoModelForCausalLM.from_pretrained(models_root_path, device_map=device_map, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(models_root_path, trust_remote_code=True)
self.enable_offloading = enable_offloading
if enable_offloading:
_, self.offload_hook = cpu_offload_with_hook(self.model, execution_device=torch.device('cuda'))
self.device_map = device_map
self.original_device_map = getattr(self.model, 'hf_device_map', None)
@torch.inference_mode()
def predict(
self,
prompt_cot,
sys_prompt=SYSTEM_PROMPT,
):
"""
Generate a rewritten prompt using the model.
Args:
prompt_cot (str): The original prompt to be rewritten.
sys_prompt (str): System prompt to guide the rewriting.
temperature (float): Sampling temperature.
device (str): Device for inference.
Returns:
str: The rewritten prompt, or the original if generation fails.
"""
org_prompt_cot = prompt_cot
try:
messages = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": org_prompt_cot},
]
tokenized_chat = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", enable_thinking=False # Toggle thinking mode (default: True)
)
if self.model.device != torch.device('meta'):
tokenized_chat = tokenized_chat.to(self.model.device)
outputs = self.model.generate(tokenized_chat, max_new_tokens=2048)
if self.enable_offloading:
self.offload_hook.offload()
output_res = self.tokenizer.decode(outputs[0])
answer_pattern = r'<answer>(.*?)</answer>'
answer_matches = re.findall(answer_pattern, output_res, re.DOTALL)
prompt_cot = [match.strip() for match in answer_matches][0]
prompt_cot = replace_single_quotes(prompt_cot)
except Exception as e:
prompt_cot = org_prompt_cot
loguru.logger.error(f"✗ Re-prompting failed, fall back to generate prompt. Cause: {e}")
return prompt_cot
def to(self, device, *args, **kwargs):
self.model = self.model.to(device, *args, **kwargs)
return self