File size: 5,178 Bytes
43c5292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9ad1cd
43c5292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import re
import loguru
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import cpu_offload_with_hook

"""
English translation of the System prompt:
----------------------------------------
You are an expert in writing image generation prompts. Please rewrite the user's prompt according to the following requirements:
1. The main subject/action/quantity/style/layout/relationship/attribute/text in the rewritten prompt must be consistent with the original intention;
2. The rewritten prompt should follow the "overall-detail-conclusion" structure, ensuring the clarity of information hierarchy;
3. The rewritten prompt should be objective and neutral, avoiding subjective judgment and emotional evaluation;
4. The rewritten prompt should be from the main to the secondary, always describing the most important elements first, and then the secondary and background elements;
5. The rewritten prompt should be logically clear, strictly follow the spatial logic or main-secondary logic, allowing the reader to reconstruct the image in the brain;
6. The rewritten prompt should end with a summary sentence, summarizing the overall style or type of the image.
"""

SYSTEM_PROMPT = (
    "你是一位图像生成提示词撰写专家,请根据用户输入的提示词,改写生成新的提示词,改写后的提示词要求:"
    "1 改写后提示词包含的主体/动作/数量/风格/布局/关系/属性/文字等 必须和改写前的意图一致; "
    "2 在宏观上遵循“总-分-总”的结构,确保信息的层次清晰;"
    "3 客观中立,避免主观臆断和情感评价;"
    "4 由主到次,始终先描述最重要的元素,再描述次要和背景元素;"
    "5 逻辑清晰,严格遵循空间逻辑或主次逻辑,使读者能在大脑中重建画面;"
    "6 结尾点题,必须用一句话总结图像的整体风格或类型。"
)


def replace_single_quotes(text):
    """
    Replace single quotes within words with double quotes, and convert
    curly single quotes to curly double quotes for consistency.
    """
    pattern = r"\B'([^']*)'\B"
    replaced_text = re.sub(pattern, r'"\1"', text)
    replaced_text = replaced_text.replace("’", "”")
    replaced_text = replaced_text.replace("‘", "“")
    return replaced_text


class RePrompt:

    def __init__(self, models_root_path, device_map="auto", enable_offloading=True):
        """
        Initialize the RePrompt class with model and processor.

        Args:
            models_root_path (str): Path to the pretrained model.
            device_map (str): Device mapping for model loading.
        """
        if enable_offloading:
            device_map = None
        self.model = AutoModelForCausalLM.from_pretrained(models_root_path, device_map=device_map, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(models_root_path, trust_remote_code=True)
        self.enable_offloading = enable_offloading

        if enable_offloading:
            _, self.offload_hook = cpu_offload_with_hook(self.model, execution_device=torch.device('cuda'))
        self.device_map = device_map
        self.original_device_map = getattr(self.model, 'hf_device_map', None)

    @torch.inference_mode()
    def predict(
        self,
        prompt_cot,
        sys_prompt=SYSTEM_PROMPT,
    ):
        """
        Generate a rewritten prompt using the model.

        Args:
            prompt_cot (str): The original prompt to be rewritten.
            sys_prompt (str): System prompt to guide the rewriting.
            temperature (float): Sampling temperature.
            device (str): Device for inference.

        Returns:
            str: The rewritten prompt, or the original if generation fails.
        """
        org_prompt_cot = prompt_cot
        try:
            messages = [
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": org_prompt_cot},
            ]
            tokenized_chat = self.tokenizer.apply_chat_template(
                messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", enable_thinking=False  # Toggle thinking mode (default: True)
            )
            if self.model.device != torch.device('meta'):
                tokenized_chat = tokenized_chat.to(self.model.device)
            outputs = self.model.generate(tokenized_chat, max_new_tokens=2048)
            if self.enable_offloading:
                self.offload_hook.offload()
            output_res = self.tokenizer.decode(outputs[0])
            answer_pattern = r'<answer>(.*?)</answer>'
            answer_matches = re.findall(answer_pattern, output_res, re.DOTALL)
            prompt_cot = [match.strip() for match in answer_matches][0]
            prompt_cot = replace_single_quotes(prompt_cot)
        except Exception as e:
            prompt_cot = org_prompt_cot
            loguru.logger.error(f"✗ Re-prompting failed, fall back to generate prompt. Cause: {e}")

        return prompt_cot

    def to(self, device, *args, **kwargs):
        self.model = self.model.to(device, *args, **kwargs)
        return self