Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,35 +1,21 @@
|
|
1 |
-
# app.py
|
2 |
-
|
3 |
import torch
|
4 |
import re
|
|
|
|
|
5 |
from konlpy.tag import Okt
|
6 |
from transformers import AutoTokenizer, BertForSequenceClassification
|
7 |
-
import gradio as gr
|
8 |
|
9 |
# --- 1. ์ค์ ๋ฐ ์ ์ฒ๋ฆฌ ํจ์ ---
|
10 |
|
11 |
# ๋๋ฐ์ด์ค ์ค์
|
12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
|
14 |
-
#
|
15 |
-
|
|
|
16 |
|
17 |
-
#
|
18 |
-
# ์์: labels_ids = {'๋์ค': 0, '๋ฐ๋ผ๋': 1, '๋ก': 2, ...}
|
19 |
-
id_to_label = {
|
20 |
-
0: '๋ก/๋ฉํ', # ์ค์ ์ฅ๋ฅด ์ด๋ฆ์ผ๋ก ๋ณ๊ฒฝ
|
21 |
-
1: '๋์ค',
|
22 |
-
2: 'R&B/Soul',
|
23 |
-
3: '๋ฐ๋ผ๋',
|
24 |
-
4: '๋ฉ/ํํฉํฉ',
|
25 |
-
5: 'ํธ๋กํธ'
|
26 |
-
}
|
27 |
-
# โโโโโ ์ด ๋ถ๋ถ์ ์ค์ ์ฅ๋ฅด๋ช
์ผ๋ก ๊ผญ ์์ ํด์ฃผ์ธ์! โโโโโ
|
28 |
-
|
29 |
-
|
30 |
-
# ๋
ธํธ๋ถ์์ ์ฌ์ฉํ ์ ์ฒ๋ฆฌ ํจ์ (๊ทธ๋๋ก ๋ณต์ฌ)
|
31 |
okt = Okt()
|
32 |
-
|
33 |
def remove_english(text):
|
34 |
return re.sub(r'[A-Za-z]+', '', text)
|
35 |
|
@@ -38,69 +24,82 @@ def extract_pos(text):
|
|
38 |
text = remove_english(text)
|
39 |
return ' '.join([word for word, pos in okt.pos(text) if pos in allowed_pos])
|
40 |
|
41 |
-
# --- 2. ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋ ---
|
42 |
|
43 |
try:
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
except Exception as e:
|
50 |
print(f"๋ชจ๋ธ ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
51 |
-
|
52 |
-
tokenizer, model = None, None
|
53 |
-
|
54 |
-
# --- 3. ์์ธก ํจ์ ---
|
55 |
|
56 |
-
|
57 |
-
|
|
|
58 |
raise gr.Error("๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. Space์ ๋ก๊ทธ๋ฅผ ํ์ธํด์ฃผ์ธ์.")
|
59 |
|
60 |
-
# 1. ์
๋ ฅ๋ ๊ฐ์ฌ ์ ์ฒ๋ฆฌ
|
61 |
preprocessed_text = extract_pos(text)
|
62 |
|
63 |
-
# 2.
|
64 |
-
|
65 |
-
preprocessed_text,
|
66 |
-
return_tensors='pt',
|
67 |
-
truncation=True,
|
68 |
-
padding='max_length',
|
69 |
-
max_length=512 # ๋
ธํธ๋ถ์์ ์ค์ ํ MAX_LENGTH์ ๋์ผํ๊ฒ
|
70 |
).to(device)
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
# 3. ์์ธก
|
|
|
|
|
|
|
73 |
with torch.no_grad():
|
74 |
-
|
75 |
-
|
|
|
76 |
|
77 |
-
|
78 |
-
probabilities = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
|
79 |
-
|
80 |
-
# Gradio์ Label ์ปดํฌ๋ํธ์ ๋ง๊ฒ ์ถ๋ ฅ ํ์ ๋ณ๊ฒฝ
|
81 |
-
confidences = {id_to_label[i]: float(prob) for i, prob in enumerate(probabilities)}
|
82 |
-
|
83 |
-
return confidences
|
84 |
|
85 |
# --- 4. Gradio ์ธํฐํ์ด์ค ์์ฑ ---
|
86 |
-
|
87 |
-
|
88 |
-
description = "KoBERT๋ฅผ ํ์ธํ๋ํ์ฌ ๋ง๋ ๋
ธ๋ ๊ฐ์ฌ ์ฅ๋ฅด ๋ถ๋ฅ ๋ชจ๋ธ์
๋๋ค. ์๋์ ๊ฐ์ฌ๋ฅผ ์
๋ ฅํ๊ณ '๋ถ๋ฅํ๊ธฐ' ๋ฒํผ์ ๋๋ฅด๋ฉด ์ฅ๋ฅด๋ฅผ ์์ธกํด์ค๋๋ค."
|
89 |
examples = [
|
90 |
-
["
|
91 |
-
["
|
92 |
-
["
|
|
|
93 |
]
|
94 |
|
95 |
-
|
96 |
-
# Gradio ์ธํฐํ์ด์ค ์คํ
|
97 |
iface = gr.Interface(
|
98 |
-
fn=
|
99 |
inputs=gr.Textbox(lines=10, placeholder="์ฌ๊ธฐ์ ๋
ธ๋ ๊ฐ์ฌ๋ฅผ ์
๋ ฅํ์ธ์...", label="๋
ธ๋ ๊ฐ์ฌ"),
|
100 |
-
outputs
|
|
|
|
|
|
|
|
|
101 |
title=title,
|
102 |
description=description,
|
103 |
examples=examples
|
104 |
)
|
105 |
|
|
|
106 |
iface.launch()
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import re
|
3 |
+
import json
|
4 |
+
import gradio as gr
|
5 |
from konlpy.tag import Okt
|
6 |
from transformers import AutoTokenizer, BertForSequenceClassification
|
|
|
7 |
|
8 |
# --- 1. ์ค์ ๋ฐ ์ ์ฒ๋ฆฌ ํจ์ ---
|
9 |
|
10 |
# ๋๋ฐ์ด์ค ์ค์
|
11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
|
13 |
+
# ๊ฐ ๋ชจ๋ธ์ด ์ ์ฅ๋ ๊ฒฝ๋ก
|
14 |
+
EMOTION_MODEL_DIR = './kobert_emotion_classifier_archive'
|
15 |
+
GENRE_MODEL_DIR = './kobert_genre_classifier_archive'
|
16 |
|
17 |
+
# ๋
ธํธ๋ถ์์ ์ฌ์ฉํ ์ ์ฒ๋ฆฌ ํจ์ (๊ณตํต ์ฌ์ฉ)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
okt = Okt()
|
|
|
19 |
def remove_english(text):
|
20 |
return re.sub(r'[A-Za-z]+', '', text)
|
21 |
|
|
|
24 |
text = remove_english(text)
|
25 |
return ' '.join([word for word, pos in okt.pos(text) if pos in allowed_pos])
|
26 |
|
27 |
+
# --- 2. ๋ ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋ ---
|
28 |
|
29 |
try:
|
30 |
+
# ๊ฐ์ ๋ถ๋ฅ ๋ชจ๋ธ ๋ก๋
|
31 |
+
emotion_tokenizer = AutoTokenizer.from_pretrained(EMOTION_MODEL_DIR, trust_remote_code=True)
|
32 |
+
emotion_model = BertForSequenceClassification.from_pretrained(EMOTION_MODEL_DIR, trust_remote_code=True)
|
33 |
+
emotion_model.to(device)
|
34 |
+
emotion_model.eval()
|
35 |
+
with open(f"{EMOTION_MODEL_DIR}/labels_ids.json", "r", encoding="utf-8") as f:
|
36 |
+
emotion_labels_ids = json.load(f)
|
37 |
+
id_to_emotion_label = {v: k for k, v in emotion_labels_ids.items()}
|
38 |
+
print("โ
๊ฐ์ ๋ถ๋ฅ ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต")
|
39 |
+
|
40 |
+
# ์ฅ๋ฅด ๋ถ๋ฅ ๋ชจ๋ธ ๋ก๋
|
41 |
+
genre_tokenizer = AutoTokenizer.from_pretrained(GENRE_MODEL_DIR, trust_remote_code=True)
|
42 |
+
genre_model = BertForSequenceClassification.from_pretrained(GENRE_MODEL_DIR, trust_remote_code=True)
|
43 |
+
genre_model.to(device)
|
44 |
+
genre_model.eval()
|
45 |
+
# ์ฅ๋ฅด ๋ ์ด๋ธ ๋งต
|
46 |
+
id_to_genre_label = {0: '๋ก/๋ฉํ', 1: '๋์ค', 2: 'R&B/Soul', 3: '๋ฐ๋ผ๋', 4: '๋ฉ/ํํฉ', 5: 'ํธ๋กํธ'}
|
47 |
+
print("โ
์ฅ๋ฅด ๋ถ๋ฅ ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต")
|
48 |
+
|
49 |
except Exception as e:
|
50 |
print(f"๋ชจ๋ธ ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
51 |
+
emotion_model, genre_model = None, None
|
|
|
|
|
|
|
52 |
|
53 |
+
# --- 3. ํตํฉ ์์ธก ํจ์ ---
|
54 |
+
def predict_emotion_and_genre(text):
|
55 |
+
if not emotion_model or not genre_model:
|
56 |
raise gr.Error("๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. Space์ ๋ก๊ทธ๋ฅผ ํ์ธํด์ฃผ์ธ์.")
|
57 |
|
58 |
+
# 1. ์
๋ ฅ๋ ๊ฐ์ฌ ๊ณตํต ์ ์ฒ๋ฆฌ
|
59 |
preprocessed_text = extract_pos(text)
|
60 |
|
61 |
+
# 2. ๊ฐ์ ์์ธก ์ํ
|
62 |
+
emotion_inputs = emotion_tokenizer(
|
63 |
+
preprocessed_text, return_tensors='pt', truncation=True, padding=True, max_length=384
|
|
|
|
|
|
|
|
|
64 |
).to(device)
|
65 |
+
with torch.no_grad():
|
66 |
+
emotion_logits = emotion_model(**emotion_inputs).logits
|
67 |
+
emotion_probs = torch.softmax(emotion_logits, dim=1).squeeze().cpu().numpy()
|
68 |
+
emotion_confidences = {id_to_emotion_label[i]: float(prob) for i, prob in enumerate(emotion_probs)}
|
69 |
|
70 |
+
# 3. ์ฅ๋ฅด ์์ธก ์ํ
|
71 |
+
genre_inputs = genre_tokenizer(
|
72 |
+
preprocessed_text, return_tensors='pt', truncation=True, padding=True, max_length=512
|
73 |
+
).to(device)
|
74 |
with torch.no_grad():
|
75 |
+
genre_logits = genre_model(**genre_inputs).logits
|
76 |
+
genre_probs = torch.softmax(genre_logits, dim=1).squeeze().cpu().numpy()
|
77 |
+
genre_confidences = {id_to_genre_label[i]: float(prob) for i, prob in enumerate(genre_probs)}
|
78 |
|
79 |
+
return emotion_confidences, genre_confidences
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
# --- 4. Gradio ์ธํฐํ์ด์ค ์์ฑ ---
|
82 |
+
title = "๐ค ํ๊ตญ์ด ๊ฐ์ฌ ๊ฐ์ ๋ฐ ์ฅ๋ฅด ๋์ ๋ถ์๊ธฐ ๐ถ"
|
83 |
+
description = "KoBERT๋ฅผ ํ์ธํ๋ํ์ฌ ๋ง๋ ๋ชจ๋ธ์
๋๋ค. ๊ฐ์ฌ๋ฅผ ์
๋ ฅํ๋ฉด ๊ฐ์ ๊ณผ ์ฅ๋ฅด๋ฅผ ๋์์ ์์ธกํฉ๋๋ค."
|
|
|
84 |
examples = [
|
85 |
+
["์ฌํ์ ๋ฐ๋ฐ๋ฅ์์ ๋ ๋๋ฅผ ๋ง๋"],
|
86 |
+
["๊ฐ์ด์ด ์
์ฅํด์ง๋ค ์ด๊ฑด ๋ชป ์ฐธ์ง"],
|
87 |
+
["๋์ ํจ๊ป๋ผ๋ฉด ์ด๋๋ ๊ฐ ์ ์์ด"],
|
88 |
+
["์ค๋ ๋ฐค ์ฃผ์ธ๊ณต์ ๋์ผ ๋"]
|
89 |
]
|
90 |
|
|
|
|
|
91 |
iface = gr.Interface(
|
92 |
+
fn=predict_emotion_and_genre,
|
93 |
inputs=gr.Textbox(lines=10, placeholder="์ฌ๊ธฐ์ ๋
ธ๋ ๊ฐ์ฌ๋ฅผ ์
๋ ฅํ์ธ์...", label="๋
ธ๋ ๊ฐ์ฌ"),
|
94 |
+
# โ
outputs ๋ถ๋ถ์ ์์ ํ์ฌ ๋ ๊ฒฐ๊ณผ ๋ชจ๋ ์์ 3๊ฐ๋ฅผ ํ์ํ๋๋ก ์ค์ ํฉ๋๋ค.
|
95 |
+
outputs=[
|
96 |
+
gr.Label(num_top_classes=3, label="๊ฐ์ ์์ธก ๊ฒฐ๊ณผ"),
|
97 |
+
gr.Label(num_top_classes=3, label="์ฅ๋ฅด ์์ธก ๊ฒฐ๊ณผ")
|
98 |
+
],
|
99 |
title=title,
|
100 |
description=description,
|
101 |
examples=examples
|
102 |
)
|
103 |
|
104 |
+
# ์ฑ ์คํ
|
105 |
iface.launch()
|