add_classification / model.py
HuiC's picture
Update model.py
cccc9ee
# -*- coding: utf-8 -*-
"""Multi_Class_text_classification.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1h-p9SDt9UO_RZZKoMDhAnHQ_wpAP_CyV
"""
# Commented out IPython magic to ensure Python compatibility.
# %matplotlib inline
import pandas as pd
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import jieba as jb
import re
zhfont = matplotlib.font_manager.FontProperties(fname='/usr/share/fonts/truetype/liberation/simhei.ttf')
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
df = pd.read_csv('text_cat_all.csv')
df=df[['cat','review']]\
df[df.isnull().values==True]
df = df[pd.notnull(df['review'])]
d = {'cat':df['cat'].value_counts().index, 'count': df['cat'].value_counts()}
df_cat = pd.DataFrame(data=d).reset_index(drop=True)
# fig = plt.figure(figsize=(8,6))
# df.groupby('cat').review.count().plot.bar(ylim=0)
# plt.title("类目数量分布")
# plt.ylabel('数量', fontsize=18)
# plt.xlabel('类目', fontsize=18)
# plt.show()
df['cat_id'] = df['cat'].factorize()[0]
cat_id_df = df[['cat', 'cat_id']].drop_duplicates().sort_values('cat_id').reset_index(drop=True)
cat_to_id = dict(cat_id_df.values)
id_to_cat = dict(cat_id_df[['cat_id', 'cat']].values)
#定义删除除字母,数字,汉字以外的所有符号的函数
def remove_punctuation(line):
line = str(line)
if line.strip()=='':
return ''
rule = re.compile(u"[^a-zA-Z0-9\u4E00-\u9FA5]")
line = rule.sub('',line)
return line
#停用词列表
def stopwordslist(filepath):
stopwords = [line.strip() for line in open(filepath, 'r', encoding='utf-8').readlines()]
return stopwords
#加载停用词
stopwords = stopwordslist("drive/MyDrive/chineseStopWords.txt")
#删除除字母,数字,汉字以外的所有符号
df['clean_review'] = df['review'].apply(remove_punctuation)
#分词,并过滤停用词
df['cut_review'] = df['clean_review'].apply(lambda x: " ".join([w for w in list(jb.cut(x)) if w not in stopwords]))
#生成词云
'''
from collections import Counter
from wordcloud import WordCloud
def generate_wordcloud(tup):
wordcloud = WordCloud(background_color='white',
font_path='/usr/share/fonts/truetype/liberation/simhei.ttf',
max_words=50, max_font_size=40,
random_state=45).generate(str(tup))
return wordcloud
cat_desc = dict()
for cat in cat_id_df.cat.values:
text = df.loc[df['cat']==cat, 'cut_review']
text = (' '.join(map(str,text))).split(' ')
cat_desc[cat]=text
fig,axes = plt.subplots(2, 3, figsize=(30, 38))
k=0
for i in range(2):
for j in range(3):
cat = id_to_cat[k]
most100=Counter(cat_desc[cat]).most_common(100)
ax = axes[i, j]
ax.imshow(generate_wordcloud(most100), interpolation="bilinear")
ax.axis('off')
# ax.set_title("{} Top 100".format(cat), fontsize=30)
k+=1
'''
#生成TF-IDF词向量
from sklearn.feature_extraction.text import TfidfVectorizer
tfidf = TfidfVectorizer(norm='l2', ngram_range=(1, 2))
features = tfidf.fit_transform(df.cut_review)
labels = df.cat_id
print(features.shape)
print('-----------------------------')
print(features)
#from sklearn.utils import shuffle
# def data_split_by_cat(df,cats,test_size=.2):
# train = pd.DataFrame()
# test = pd.DataFrame()
# for cat in cats:
# cat_df = df[df.cat==cat][['cut_review','cat_id']]
# cat_train = cat_df.sample(int(len(cat_df)*(1-test_size)))
# cat_test = cat_df[~cat_df.index.isin(cat_train.index)]
# train = train.append(cat_train)
# test = test.append(cat_test)
# train = shuffle(train)
# test = shuffle(test)
# return train['cut_review'],test['cut_review'],train['cat_id'],test['cat_id']
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.svm import LinearSVC
X_train, X_test, y_train, y_test = train_test_split(df['cut_review'], df['cat_id'], random_state = 0,stratify=df['cat_id'])
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(X_train)
tfidf_transformer = TfidfTransformer()
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
clf = LinearSVC().fit(X_train_tfidf, y_train)
def myPredict(sec):
format_sec=" ".join([w for w in list(jb.cut(remove_punctuation(sec))) if w not in stopwords])
pred_cat_id=clf.predict(count_vect.transform([format_sec]))
return (id_to_cat[pred_cat_id[0]])
intext = gr.inputs.Textbox(lines=3, label="Text")
label = gr.outputs.Label(label="Category",num_top_classes=6)
gr.Interface(fn=myPredict, inputs=intext, outputs=label, capture_session=True).launch()