Spaces:
Runtime error
Runtime error
File size: 4,115 Bytes
2cc756f c52625e 511e59d 3bf9b00 2cc756f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
import torch
import spacy
# Initialize FastAPI app
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic model for request body
class SloganRequest(BaseModel):
brand: str
description: str
industry: str
tone: Optional[str] = "playful"
num: Optional[int] = 5
liked_slogan: Optional[str] = None
# Load models
nlp = spacy.load("en_core_web_sm")
model = GPT2LMHeadModel.from_pretrained("./") #slogan_generator_medium")
tokenizer = GPT2Tokenizer.from_pretrained("./") #slogan_generator_medium")
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
# Tone presets
TONE_PRESETS = {
"playful": {"temperature": 0.95, "top_p": 0.95, "repetition_penalty": 1.2},
"bold": {"temperature": 0.8, "top_p": 0.9, "repetition_penalty": 1.45},
"minimalist": {"temperature": 0.6, "top_p": 0.8, "repetition_penalty": 1.5},
"luxury": {"temperature": 0.7, "top_p": 0.85, "repetition_penalty": 1.35},
"classic": {"temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.25}
}
def summarize_description(text: str) -> str:
"""Extract key words from description using spaCy"""
doc = nlp(text)
keywords = [token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"]]
return " ".join(keywords[:12])
@app.get("/")
def read_root():
return {"message": "Welcome to Slogan Generator API. Use POST / to generate slogans."}
# @app.post("/generate-slogans")
@app.post("/")
async def generate_slogans(request: SloganRequest):
try:
# Process description
processed_desc = summarize_description(request.description)
# Generate prompts based on presence of liked slogan
if request.liked_slogan:
prompt1 = (
f"Create {request.industry} brand slogans similar to: '{request.liked_slogan}'\n"
f"Brand: {request.brand}\n"
f"Key Attributes: {processed_desc}\n"
"Slogan:"
)
prompt2 = (
f"Generate slogans in the style of: '{request.liked_slogan}'\n"
f"For: {request.brand}\n"
f"Details: {processed_desc}\n"
"Slogan:"
)
else:
prompt1 = (
f"Create a {request.industry} brand slogan that's {request.tone} and unique.\n"
f"Brand: {request.brand}\n"
f"Attributes: {processed_desc}\n"
"Slogan:"
)
prompt2 = (
f"Write {request.tone} marketing slogans for this {request.industry} brand:\n"
f"Name: {request.brand}\n"
f"About: {processed_desc}\n"
"Slogan:"
)
# Generation parameters
gen_params = {
**TONE_PRESETS[request.tone],
"max_new_tokens": 25,
"num_return_sequences": request.num,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id
}
# Generate from both prompts
outputs1 = generator(prompt1, **gen_params)
outputs2 = generator(prompt2, **gen_params)
# Process and deduplicate slogans
slogans = []
for output_group in [outputs1, outputs2]:
for o in output_group:
raw = o['generated_text'].split("Slogan:")[-1].strip()
clean = raw.split("\n")[0].replace('"', '').replace('(', '').split(".")[0].strip()
if len(clean) > 4 and clean not in slogans:
slogans.append(clean)
return {"slogans": slogans[:request.num * 2]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|