Spaces:
Running
Running
import torch | |
import re | |
import json | |
import gradio as gr | |
from konlpy.tag import Okt | |
from transformers import AutoTokenizer, BertForSequenceClassification | |
# --- 1. ์ค์ ๋ฐ ์ ์ฒ๋ฆฌ ํจ์ --- | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
BASE_TOKENIZER_DIR = 'base' | |
EMOTION_MODEL_DIR = 'kobert_emotion_classifier' | |
GENRE_MODEL_DIR = 'kobert_genre_classifier_archive' | |
okt = Okt() | |
def remove_english(text): | |
return re.sub(r'[A-Za-z]+', '', text) | |
def extract_pos(text): | |
allowed_pos = ['Noun', 'Verb', 'Adjective'] | |
text = remove_english(text) | |
return ' '.join([word for word, pos in okt.pos(text) if pos in allowed_pos]) | |
# --- 2. ๋ ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋ --- | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(BASE_TOKENIZER_DIR, trust_remote_code=True) | |
print("โ ๊ณต์ฉ ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต") | |
# ๊ฐ์ ๋ถ๋ฅ ๋ชจ๋ธ ๋ก๋ (ํ์ผ์์ ๋ ์ด๋ธ ์ฝ๊ธฐ) | |
emotion_model = BertForSequenceClassification.from_pretrained(EMOTION_MODEL_DIR) | |
emotion_model.to(device) | |
emotion_model.eval() | |
with open(f"{EMOTION_MODEL_DIR}/labels_ids.json", "r", encoding="utf-8") as f: | |
emotion_labels_ids = json.load(f) | |
id_to_emotion_label = {v: k for k, v in emotion_labels_ids.items()} | |
print("โ ๊ฐ์ ๋ถ๋ฅ ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต") | |
# ์ฅ๋ฅด ๋ถ๋ฅ ๋ชจ๋ธ ๋ก๋ | |
genre_model = BertForSequenceClassification.from_pretrained(GENRE_MODEL_DIR) | |
genre_model.to(device) | |
genre_model.eval() | |
id_to_genre_label = { | |
0: '๋ก/๋ฉํ', | |
1: '๋์ค', | |
2: 'R&B/Soul', | |
3: '๋ฐ๋ผ๋', | |
4: '๋ฉ/ํํฉ', | |
5: 'ํธ๋กํธ' | |
} | |
print("โ ์ฅ๋ฅด ๋ถ๋ฅ ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต (๋ ์ด๋ธ ์ง์ ์ ์)") | |
except Exception as e: | |
print(f"๋ชจ๋ธ ๋๋ ํ ํฌ๋์ด์ ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
emotion_model, genre_model = None, None | |
# --- 3. ํตํฉ ์์ธก ํจ์ (์ดํ ๋์ผ) --- | |
def predict_emotion_and_genre(text): | |
if not emotion_model or not genre_model: | |
raise gr.Error("๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. Space์ ๋ก๊ทธ๋ฅผ ํ์ธํด์ฃผ์ธ์.") | |
preprocessed_text = extract_pos(text) | |
# ๊ฐ์ ์์ธก | |
emotion_inputs = tokenizer(preprocessed_text, return_tensors='pt', truncation=True, padding=True, max_length=384).to(device) | |
with torch.no_grad(): | |
emotion_logits = emotion_model(**emotion_inputs).logits | |
emotion_probs = torch.softmax(emotion_logits, dim=1).squeeze().cpu().numpy() | |
emotion_confidences = {id_to_emotion_label[i]: float(prob) for i, prob in enumerate(emotion_probs)} | |
# ์ฅ๋ฅด ์์ธก | |
genre_inputs = tokenizer(preprocessed_text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device) | |
with torch.no_grad(): | |
genre_logits = genre_model(**genre_inputs).logits | |
genre_probs = torch.softmax(genre_logits, dim=1).squeeze().cpu().numpy() | |
genre_confidences = {id_to_genre_label[i]: float(prob) for i, prob in enumerate(genre_probs)} | |
return emotion_confidences, genre_confidences | |
# --- 4. Gradio ์ธํฐํ์ด์ค (์ดํ ๋์ผ) --- | |
title = "๐ค ํ๊ตญ์ด ๊ฐ์ฌ ๊ฐ์ ๋ฐ ์ฅ๋ฅด ๋์ ๋ถ์๊ธฐ ๐ถ" | |
description = "KoBERT๋ฅผ ํ์ธํ๋ํ์ฌ ๋ง๋ ๋ชจ๋ธ์ ๋๋ค. ๊ฐ์ฌ๋ฅผ ์ ๋ ฅํ๋ฉด ๊ฐ์ ๊ณผ ์ฅ๋ฅด๋ฅผ ๋์์ ์์ธกํฉ๋๋ค." | |
examples = [ | |
["์ฌํ์ ๋ฐ๋ฐ๋ฅ์์ ๋ ๋๋ฅผ ๋ง๋"], | |
["๊ฐ์ด์ด ์ ์ฅํด์ง๋ค ์ด๊ฑด ๋ชป ์ฐธ์ง"], | |
["๋์ ํจ๊ป๋ผ๋ฉด ์ด๋๋ ๊ฐ ์ ์์ด"], | |
["์ค๋ ๋ฐค ์ฃผ์ธ๊ณต์ ๋์ผ ๋"] | |
] | |
iface = gr.Interface( | |
fn=predict_emotion_and_genre, | |
inputs=gr.Textbox(lines=10, placeholder="์ฌ๊ธฐ์ ๋ ธ๋ ๊ฐ์ฌ๋ฅผ ์ ๋ ฅํ์ธ์...", label="๋ ธ๋ ๊ฐ์ฌ"), | |
outputs=[ | |
gr.Label(num_top_classes=3, label="๊ฐ์ ์์ธก ๊ฒฐ๊ณผ"), | |
gr.Label(num_top_classes=3, label="์ฅ๋ฅด ์์ธก ๊ฒฐ๊ณผ") | |
], | |
title=title, | |
description=description, | |
examples=examples | |
) | |
iface.launch() |