Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""Multi_Class_text_classification.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1h-p9SDt9UO_RZZKoMDhAnHQ_wpAP_CyV | |
""" | |
# Commented out IPython magic to ensure Python compatibility. | |
# %matplotlib inline | |
import pandas as pd | |
import matplotlib | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import jieba as jb | |
import re | |
zhfont = matplotlib.font_manager.FontProperties(fname='/usr/share/fonts/truetype/liberation/simhei.ttf') | |
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 | |
df = pd.read_csv('text_cat_all.csv') | |
df=df[['cat','review']]\ | |
df[df.isnull().values==True] | |
df = df[pd.notnull(df['review'])] | |
d = {'cat':df['cat'].value_counts().index, 'count': df['cat'].value_counts()} | |
df_cat = pd.DataFrame(data=d).reset_index(drop=True) | |
# fig = plt.figure(figsize=(8,6)) | |
# df.groupby('cat').review.count().plot.bar(ylim=0) | |
# plt.title("类目数量分布") | |
# plt.ylabel('数量', fontsize=18) | |
# plt.xlabel('类目', fontsize=18) | |
# plt.show() | |
df['cat_id'] = df['cat'].factorize()[0] | |
cat_id_df = df[['cat', 'cat_id']].drop_duplicates().sort_values('cat_id').reset_index(drop=True) | |
cat_to_id = dict(cat_id_df.values) | |
id_to_cat = dict(cat_id_df[['cat_id', 'cat']].values) | |
#定义删除除字母,数字,汉字以外的所有符号的函数 | |
def remove_punctuation(line): | |
line = str(line) | |
if line.strip()=='': | |
return '' | |
rule = re.compile(u"[^a-zA-Z0-9\u4E00-\u9FA5]") | |
line = rule.sub('',line) | |
return line | |
#停用词列表 | |
def stopwordslist(filepath): | |
stopwords = [line.strip() for line in open(filepath, 'r', encoding='utf-8').readlines()] | |
return stopwords | |
#加载停用词 | |
stopwords = stopwordslist("drive/MyDrive/chineseStopWords.txt") | |
#删除除字母,数字,汉字以外的所有符号 | |
df['clean_review'] = df['review'].apply(remove_punctuation) | |
#分词,并过滤停用词 | |
df['cut_review'] = df['clean_review'].apply(lambda x: " ".join([w for w in list(jb.cut(x)) if w not in stopwords])) | |
#生成词云 | |
''' | |
from collections import Counter | |
from wordcloud import WordCloud | |
def generate_wordcloud(tup): | |
wordcloud = WordCloud(background_color='white', | |
font_path='/usr/share/fonts/truetype/liberation/simhei.ttf', | |
max_words=50, max_font_size=40, | |
random_state=45).generate(str(tup)) | |
return wordcloud | |
cat_desc = dict() | |
for cat in cat_id_df.cat.values: | |
text = df.loc[df['cat']==cat, 'cut_review'] | |
text = (' '.join(map(str,text))).split(' ') | |
cat_desc[cat]=text | |
fig,axes = plt.subplots(2, 3, figsize=(30, 38)) | |
k=0 | |
for i in range(2): | |
for j in range(3): | |
cat = id_to_cat[k] | |
most100=Counter(cat_desc[cat]).most_common(100) | |
ax = axes[i, j] | |
ax.imshow(generate_wordcloud(most100), interpolation="bilinear") | |
ax.axis('off') | |
# ax.set_title("{} Top 100".format(cat), fontsize=30) | |
k+=1 | |
''' | |
#生成TF-IDF词向量 | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
tfidf = TfidfVectorizer(norm='l2', ngram_range=(1, 2)) | |
features = tfidf.fit_transform(df.cut_review) | |
labels = df.cat_id | |
print(features.shape) | |
print('-----------------------------') | |
print(features) | |
#from sklearn.utils import shuffle | |
# def data_split_by_cat(df,cats,test_size=.2): | |
# train = pd.DataFrame() | |
# test = pd.DataFrame() | |
# for cat in cats: | |
# cat_df = df[df.cat==cat][['cut_review','cat_id']] | |
# cat_train = cat_df.sample(int(len(cat_df)*(1-test_size))) | |
# cat_test = cat_df[~cat_df.index.isin(cat_train.index)] | |
# train = train.append(cat_train) | |
# test = test.append(cat_test) | |
# train = shuffle(train) | |
# test = shuffle(test) | |
# return train['cut_review'],test['cut_review'],train['cat_id'],test['cat_id'] | |
from sklearn.model_selection import train_test_split | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sklearn.feature_extraction.text import TfidfTransformer | |
from sklearn.svm import LinearSVC | |
X_train, X_test, y_train, y_test = train_test_split(df['cut_review'], df['cat_id'], random_state = 0,stratify=df['cat_id']) | |
count_vect = CountVectorizer() | |
X_train_counts = count_vect.fit_transform(X_train) | |
tfidf_transformer = TfidfTransformer() | |
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts) | |
clf = LinearSVC().fit(X_train_tfidf, y_train) | |
def myPredict(sec): | |
format_sec=" ".join([w for w in list(jb.cut(remove_punctuation(sec))) if w not in stopwords]) | |
pred_cat_id=clf.predict(count_vect.transform([format_sec])) | |
return (id_to_cat[pred_cat_id[0]]) | |
intext = gr.inputs.Textbox(lines=3, label="Text") | |
label = gr.outputs.Label(label="Category",num_top_classes=6) | |
gr.Interface(fn=myPredict, inputs=intext, outputs=label, capture_session=True).launch() |