# -*- coding: utf-8 -*- """Multi_Class_text_classification.ipynb Automatically generated by Colaboratory. Original file is located at https://colab.research.google.com/drive/1h-p9SDt9UO_RZZKoMDhAnHQ_wpAP_CyV """ # Commented out IPython magic to ensure Python compatibility. # %matplotlib inline import pandas as pd import matplotlib import numpy as np import matplotlib.pyplot as plt import jieba as jb import re import gradio as gr from sklearn.feature_extraction.text import TfidfVectorizer zhfont = matplotlib.font_manager.FontProperties(fname='/usr/share/fonts/truetype/liberation/simhei.ttf') plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 df = pd.read_csv('text_cat_all.csv') df=df[['cat','review']] d = {'cat':df['cat'].value_counts().index, 'count': df['cat'].value_counts()} df_cat = pd.DataFrame(data=d).reset_index(drop=True) # fig = plt.figure(figsize=(8,6)) # df.groupby('cat').review.count().plot.bar(ylim=0) # plt.title("类目数量分布") # plt.ylabel('数量', fontsize=18) # plt.xlabel('类目', fontsize=18) # plt.show() df['cat_id'] = df['cat'].factorize()[0] cat_id_df = df[['cat', 'cat_id']].drop_duplicates().sort_values('cat_id').reset_index(drop=True) cat_to_id = dict(cat_id_df.values) id_to_cat = dict(cat_id_df[['cat_id', 'cat']].values) #定义删除除字母,数字,汉字以外的所有符号的函数 def remove_punctuation(line): line = str(line) if line.strip()=='': return '' rule = re.compile(u"[^a-zA-Z0-9\u4E00-\u9FA5]") line = rule.sub('',line) return line #停用词列表 def stopwordslist(filepath): stopwords = [line.strip() for line in open(filepath, 'r', encoding='utf-8').readlines()] return stopwords #加载停用词 stopwords = stopwordslist("chineseStopWords.txt") #删除除字母,数字,汉字以外的所有符号 df['clean_review'] = df['review'].apply(remove_punctuation) #分词,并过滤停用词 df['cut_review'] = df['clean_review'].apply(lambda x: " ".join([w for w in list(jb.cut(x)) if w not in stopwords])) #生成词云 ''' from collections import Counter from wordcloud import WordCloud def generate_wordcloud(tup): wordcloud = WordCloud(background_color='white', font_path='/usr/share/fonts/truetype/liberation/simhei.ttf', max_words=50, max_font_size=40, random_state=45).generate(str(tup)) return wordcloud cat_desc = dict() for cat in cat_id_df.cat.values: text = df.loc[df['cat']==cat, 'cut_review'] text = (' '.join(map(str,text))).split(' ') cat_desc[cat]=text fig,axes = plt.subplots(2, 3, figsize=(30, 38)) k=0 for i in range(2): for j in range(3): cat = id_to_cat[k] most100=Counter(cat_desc[cat]).most_common(100) ax = axes[i, j] ax.imshow(generate_wordcloud(most100), interpolation="bilinear") ax.axis('off') # ax.set_title("{} Top 100".format(cat), fontsize=30) k+=1 ''' #生成TF-IDF词向量 tfidf = TfidfVectorizer(norm='l2', ngram_range=(1, 2)) features = tfidf.fit_transform(df.cut_review) labels = df.cat_id #from sklearn.utils import shuffle # def data_split_by_cat(df,cats,test_size=.2): # train = pd.DataFrame() # test = pd.DataFrame() # for cat in cats: # cat_df = df[df.cat==cat][['cut_review','cat_id']] # cat_train = cat_df.sample(int(len(cat_df)*(1-test_size))) # cat_test = cat_df[~cat_df.index.isin(cat_train.index)] # train = train.append(cat_train) # test = test.append(cat_test) # train = shuffle(train) # test = shuffle(test) # return train['cut_review'],test['cut_review'],train['cat_id'],test['cat_id'] from sklearn.model_selection import train_test_split from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extraction.text import TfidfTransformer from sklearn.svm import LinearSVC X_train, X_test, y_train, y_test = train_test_split(df['cut_review'], df['cat_id'], random_state = 0,stratify=df['cat_id']) count_vect = CountVectorizer() X_train_counts = count_vect.fit_transform(X_train) tfidf_transformer = TfidfTransformer() X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts) clf = LinearSVC().fit(X_train_tfidf, y_train) def myPredict(sec): format_sec=" ".join([w for w in list(jb.cut(remove_punctuation(sec))) if w not in stopwords]) pred_cat_id=clf.predict(count_vect.transform([format_sec])) return (id_to_cat[pred_cat_id[0]]) intext = gr.inputs.Textbox(lines=3, label="Text") label = gr.outputs.Label(label="Category",num_top_classes=6) gr.Interface(fn=myPredict, inputs=intext, outputs=label, capture_session=True).launch()