|
|
|
""" |
|
T5 Prompt Enhancer V0.3 Demo Script |
|
Quick test of all four instruction types |
|
""" |
|
|
|
import torch |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
def load_model(): |
|
"""Load the T5 V0.3 model""" |
|
print("π€ Loading T5 Prompt Enhancer V0.3...") |
|
|
|
tokenizer = T5Tokenizer.from_pretrained(".") |
|
model = T5ForConditionalGeneration.from_pretrained(".") |
|
|
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
print("β
Model loaded on GPU") |
|
else: |
|
print("β
Model loaded on CPU") |
|
|
|
return model, tokenizer |
|
|
|
def enhance_prompt(model, tokenizer, text, style="clean"): |
|
"""Generate enhanced prompt with style control""" |
|
|
|
style_prompts = { |
|
"clean": f"Enhance this prompt (no lora): {text}", |
|
"technical": f"Enhance this prompt (with lora): {text}", |
|
"simplify": f"Simplify this prompt: {text}", |
|
"standard": f"Enhance this prompt: {text}" |
|
} |
|
|
|
prompt = style_prompts[style] |
|
inputs = tokenizer(prompt, return_tensors="pt", max_length=256, truncation=True) |
|
|
|
if torch.cuda.is_available(): |
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_length=80, |
|
num_beams=2, |
|
repetition_penalty=2.0, |
|
no_repeat_ngram_size=3, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def main(): |
|
"""Demo all four instruction types""" |
|
|
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
test_prompts = [ |
|
"woman in red dress", |
|
"cat on chair", |
|
"cyberpunk cityscape", |
|
"masterpiece, best quality, ultra-detailed render of a fantasy dragon with golden scales" |
|
] |
|
|
|
styles = ["standard", "clean", "technical", "simplify"] |
|
|
|
print("\nπ¨ T5 Prompt Enhancer V0.3 Demo") |
|
print("="*60) |
|
|
|
for prompt in test_prompts: |
|
print(f"\nπ Input: '{prompt}'") |
|
print("-" * 40) |
|
|
|
for style in styles: |
|
try: |
|
result = enhance_prompt(model, tokenizer, prompt, style) |
|
print(f"{style:>10}: {result}") |
|
except Exception as e: |
|
print(f"{style:>10}: ERROR - {e}") |
|
|
|
print() |
|
|
|
if __name__ == "__main__": |
|
main() |