Spaces:
Running
Running
import os | |
import openai | |
import gradio as gr | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from dotenv import load_dotenv | |
import torch | |
from PIL import Image # PILμ μ¬μ©νμ¬ μ΄λ―Έμ§λ₯Ό μ΄κΈ° μν΄ μΆκ° | |
# .env νμΌμμ νκ²½ λ³μλ₯Ό λΆλ¬μ΄ | |
load_dotenv() | |
# API ν€ λ° κ²μ¦ | |
API_KEY = os.getenv("OPENAI_API_KEY") | |
if API_KEY is None: | |
raise ValueError("OPENAI_API_KEY νκ²½ λ³μκ° μ€μ λμ§ μμμ΅λλ€.") | |
openai.api_key = API_KEY # OpenAI API ν€ μ€μ | |
# BLIP λͺ¨λΈ λ‘λ (μ΄λ―Έμ§ μΊ‘μ λ) | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # GPU λλ CPU μ€μ | |
blip_model.to(device) # λͺ¨λΈμ μ νν λλ°μ΄μ€λ‘ μ΄λ | |
# GPT-4 API νΈμΆ ν¨μ (temperature, top_p κ° μΆκ° μ‘°μ ) | |
def call_api(content, system_message, max_tokens=500, temperature=0.6, top_p=1.0): | |
try: | |
response = openai.ChatCompletion.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": content}, | |
], | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
return response.choices[0].message['content'].strip() | |
except openai.OpenAIError as e: | |
return f"OpenAI API Error: {str(e)}" | |
# generate_blog_post_in_korean ν¨μ μμ (temperature, top_p μ‘°μ λ° ν둬ννΈ μμ ) | |
def generate_blog_post_in_korean(image_path, user_input, style): | |
# 1. νμΌ κ²½λ‘μμ μ΄λ―Έμ§λ₯Ό μ΄μ΄ PIL μ΄λ―Έμ§λ‘ λ³ν | |
image = Image.open(image_path) | |
# 2. μ΄λ―Έμ§ μΊ‘μ λ μμ± (BLIP) | |
inputs = blip_processor(image, return_tensors="pt").to(device) | |
out = blip_model.generate(**inputs) | |
image_caption = blip_processor.decode(out[0], skip_special_tokens=True) | |
# 3. μ€νμΌμ λ°λΌ ν둬ννΈ λ° temperature/top_p μ€μ | |
if style == "μ¬μ€μ μΈ": | |
combined_prompt = ( | |
f"μ΄λ―Έμ§ μ€λͺ : {image_caption}\n" | |
f"μ¬μ©μ μ λ ₯: {user_input}\n\n" | |
"μ΄ λ μ€λͺ μ κΈ°λ°μΌλ‘ μλ κ·Έλλ‘μ μ¬μ€λ§ κ°κ²°νκ³ μ ννκ² λ¬μ¬ν΄ μ£ΌμΈμ. " | |
"λΆνμν λ°°κ²½ μ€λͺ μ΄λ μΆλ‘ μ νΌνκ³ , μ₯λ©΄μ λν μ νν μ λ³΄λ§ μ κ³΅ν΄ μ£ΌμΈμ.\n\n" | |
"μμ: 'ν μ΄λΈ μμ μ¬λ¬ κ·Έλ¦μ λμ₯μ°κ°μ λ€μν μμλ€μ΄ λμ¬μ Έ μλ€. " | |
"μ€μμ λλ°°κΈ°μ λ΄κΈ΄ λμ₯μ°κ°κ° μκ³ , κ·Έ μμλ κ°μ’ λ°μ°¬λ€μ΄ λμ¬ μμ΅λλ€.'" | |
) | |
temperature = 0.2 # μ΅λν μ¬μ€μ κΈ°λ° | |
top_p = 0.7 # μμΈ‘μ λ€μμ± μ΅μ | |
elif style == "κ°μ±μ μΈ": | |
combined_prompt = ( | |
f"μ΄λ―Έμ§ μ€λͺ : {image_caption}\n" | |
f"μ¬μ©μ μ λ ₯: {user_input}\n\n" | |
"μ΄ λ μ€λͺ μ μ°Έκ³ ν΄μ μΌμμ μ΄κ³ λ°λ»ν λΆμκΈ°μ κΈλ‘ ννν΄ μ£ΌμΈμ. " | |
"μΆκ°μ μΈ μ€λͺ μ΄λ 배경보λ€λ μ₯λ©΄κ³Ό κ°μ μ μμ°μ€λ½κ² μ λ¬νλ κΈμ μ¨ μ£ΌμΈμ.\n\n" | |
"μμ: 'λμ₯μ°κ°κ° λμΈ ν μ΄λΈμλ λ€μν μμλ€μ΄ μ κ°νκ² μ°¨λ €μ Έ μμ΅λλ€. " | |
"λ¨λν λμ₯μ°κ°μμλ ꡬμν ν₯μ΄ νκΈ°κ³ , κ·Έ μμλ κ³ κΈ°μ μ±μκ° λ¬λΏ λ΄κΈ΄ λ°μ°¬λ€μ΄ λμ¬ μμ΄μ. " | |
"λ°₯κ³Ό ν¨κ» λ¨ΉκΈ° μ’μ μμλ€μ΄ μ€λΉλμ΄ μκ³ , μ§μμ μ μ±μ€λ½κ² λ§λ λ°λ»ν λλμ΄ λλλ€.'" | |
) | |
temperature = 0.7 # λ μ°½μμ μ΄κ³ κ°μ±μ μΈ νν | |
top_p = 0.9 # νλΆν ννμ μν΄ λ€μμ± νμ© | |
# 4. GPT-4λ‘ μ€λͺ μμ± | |
system_message = "You are an AI assistant that generates either factual or emotional descriptions based on image descriptions and user input." | |
translated_caption = call_api(combined_prompt, system_message, temperature=temperature, top_p=top_p) | |
return translated_caption | |
# νλμ μ΄λ―Έμ§λ§ μ²λ¦¬νλ ν¨μ | |
def generate_blog_post_single(image, desc, style): | |
if image is not None and desc.strip() != "": | |
result = generate_blog_post_in_korean(image, desc, style) | |
return result | |
else: | |
return "" # μ΄λ―Έμ§κ° μκ±°λ μ€λͺ μ΄ μμΌλ©΄ λΉ λ¬Έμμ΄ λ°ν | |
# Gradio μΈν°νμ΄μ€ μ€μ (νλμ μ΄λ―Έμ§μ μ€λͺ λ§ λ°μ) | |
iface = gr.Interface( | |
fn=generate_blog_post_single, | |
inputs=[ | |
gr.File(label="μ΄λ―Έμ§ μ λ‘λ"), # gr.Image λμ gr.Fileλ‘ λ³κ²½ | |
gr.Textbox(label="μ¬μ§μ λν μ€λͺ μ λ ₯", placeholder="μ¬μ§ μ€λͺ μ μ λ ₯νμΈμ"), | |
gr.Radio(["μ¬μ€μ μΈ", "κ°μ±μ μΈ"], label="μ€λͺ μ€νμΌ μ ν", value="μ¬μ€μ μΈ") # default -> valueλ‘ λ³κ²½ | |
], | |
outputs=gr.Textbox(label="μ΄λ―Έμ§ μ€λͺ κ²°κ³Ό"), | |
title="μ΄λ―Έμ§ μ€λͺ μμ±κΈ°", | |
description="νλμ μ΄λ―Έμ§μ ν μ€νΈλ₯Ό λ°νμΌλ‘ μ΅μμ νκ΅μ΄λ‘ ννν©λλ€.", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
iface.launch(share=True) |