Spaces:
Runtime error
Runtime error
# main.py | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import Optional | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline | |
import torch | |
import spacy | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Pydantic model for request body | |
class SloganRequest(BaseModel): | |
brand: str | |
description: str | |
industry: str | |
tone: Optional[str] = "playful" | |
num: Optional[int] = 5 | |
liked_slogan: Optional[str] = None | |
# Load models | |
nlp = spacy.load("en_core_web_sm") | |
model = GPT2LMHeadModel.from_pretrained("./") #slogan_generator_medium") | |
tokenizer = GPT2Tokenizer.from_pretrained("./") #slogan_generator_medium") | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
# Tone presets | |
TONE_PRESETS = { | |
"playful": {"temperature": 0.95, "top_p": 0.95, "repetition_penalty": 1.2}, | |
"bold": {"temperature": 0.8, "top_p": 0.9, "repetition_penalty": 1.45}, | |
"minimalist": {"temperature": 0.6, "top_p": 0.8, "repetition_penalty": 1.5}, | |
"luxury": {"temperature": 0.7, "top_p": 0.85, "repetition_penalty": 1.35}, | |
"classic": {"temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.25} | |
} | |
def summarize_description(text: str) -> str: | |
"""Extract key words from description using spaCy""" | |
doc = nlp(text) | |
keywords = [token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"]] | |
return " ".join(keywords[:12]) | |
def read_root(): | |
return {"message": "Welcome to Slogan Generator API. Use POST / to generate slogans."} | |
# @app.post("/generate-slogans") | |
async def generate_slogans(request: SloganRequest): | |
try: | |
# Process description | |
processed_desc = summarize_description(request.description) | |
# Generate prompts based on presence of liked slogan | |
if request.liked_slogan: | |
prompt1 = ( | |
f"Create {request.industry} brand slogans similar to: '{request.liked_slogan}'\n" | |
f"Brand: {request.brand}\n" | |
f"Key Attributes: {processed_desc}\n" | |
"Slogan:" | |
) | |
prompt2 = ( | |
f"Generate slogans in the style of: '{request.liked_slogan}'\n" | |
f"For: {request.brand}\n" | |
f"Details: {processed_desc}\n" | |
"Slogan:" | |
) | |
else: | |
prompt1 = ( | |
f"Create a {request.industry} brand slogan that's {request.tone} and unique.\n" | |
f"Brand: {request.brand}\n" | |
f"Attributes: {processed_desc}\n" | |
"Slogan:" | |
) | |
prompt2 = ( | |
f"Write {request.tone} marketing slogans for this {request.industry} brand:\n" | |
f"Name: {request.brand}\n" | |
f"About: {processed_desc}\n" | |
"Slogan:" | |
) | |
# Generation parameters | |
gen_params = { | |
**TONE_PRESETS[request.tone], | |
"max_new_tokens": 25, | |
"num_return_sequences": request.num, | |
"do_sample": True, | |
"pad_token_id": tokenizer.eos_token_id | |
} | |
# Generate from both prompts | |
outputs1 = generator(prompt1, **gen_params) | |
outputs2 = generator(prompt2, **gen_params) | |
# Process and deduplicate slogans | |
slogans = [] | |
for output_group in [outputs1, outputs2]: | |
for o in output_group: | |
raw = o['generated_text'].split("Slogan:")[-1].strip() | |
clean = raw.split("\n")[0].replace('"', '').replace('(', '').split(".")[0].strip() | |
if len(clean) > 4 and clean not in slogans: | |
slogans.append(clean) | |
return {"slogans": slogans[:request.num * 2]} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |