ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

多类分类及误差分析

2020-03-03 11:53:38  阅读:324  来源: 互联网

标签:分析 误差 digit clf 分类 some 分类器 train sgd


1、多类分类

二分类器只能区分两个类别,多分类器则可以区分多余两个类别

一些算法(比如随机森林分类器或者朴素贝叶斯分类器)可以直接处理多分类问题,而其他的一些算法(比如SVM分类器或者线性分类器)择时严格的二分类器。当然也有许多策略让二分类器去执行多分类问题

"一对所有"(OvA)策略:创建一个将图片分为10类(0到9)的系统的一个方法:训练10个二分类器,每一个对应一个数字(探测器0,探测器1,探测器2,以此类推),然后当你想对某张图片进行分类的时候,让每一个分类器对这张图片进行分类,选出决策分数最高的那个分类器。

“一对一”(OvO)策略:对每一对数字都训练一个二分类器,一个分类器用来处理数字0和数字1,一个用来处理数字0和数字2,以此类推,如果有N个类,那么就需要N*(N-1)/2个分类器。让这张图片在这些分类器上都跑一遍,看哪个类胜出。OvO策略的主要优点是:每个分类器都只需要在训练集的部分数据上面进行训练。这部分数据是它所需要区分的那两个类对应的数据。

Sklearn可以探测出你想用一个二分类器去实现多分类的任务,他会自动执行OvA(除了SVM分类器,它使用OvO)。现在让试下SGDCLassifier

sgd_clf.fit(X_train,y_train)
sgd_clf.predict([some_digit])

 

 

 SGDClassifier分类器将对数字0~9产生10个探测器,在训练集上训练10个二分类器,每个分类器都产生这张图片的决策数值,选择数值最高的那个类。

为证明这一点,可以调用decision_function()方法,会返回10个数值,每个数值对一个一个类

some_digit_scores = sgd_clf.decision_function([some_digit])
some_digit_scores

 

 

 最高值是对应类别5:

np.argmax(some_digit_scores)

一个分类器训练好之后,它会保存目标类别列表到属性classes_中去,按照值排序。在本例子当中,在 classes_ 数组当中的每个类的索引方便地匹配了类本身,比如,索引为 5 的类恰好是类别 5 本身。但通常不会这么幸运

sgd_clf.classes_

sgd_clf.classes_[5]

 

 

现在来看下如何强制Sklearn 使用OvO策略或者OvA策略。可以使用OneVsOneClassifier类或者OneVsRestClassifer类。创建一个样例,传递一个二分类器给他的构造函数。

from sklearn.multiclass import OneVsOneClassifier
ovo_clf = OneVsOneClassifier(SGDClassifier(random_state = 42))
ovo_clf.fit(X_train,y_train)
ovo_clf.predict([some_digit])

len(ovo_clf.estimators_)

 

 可以看到OvO策略下对数字0~9分类会产生10*(10 -1 )/2 = 45个检测器

使用RandomForestClassifier分类器试试:

forest_clf.fit(X_train,y_train)
forest_clf.predict([some_digit])

 

 由于RandomForestClassifier可以直接进行多分类,因此没必要去执行OvO或者OvA。可以调用predict_proba()可以得到样例对应的类别的概率值的列表

forest_clf.predict_proba([some_digit])

 

 在数组的索引 5 上的 0.8,意味着这个模型以80% 的概率估算这张图片代表数字 5。它也认为这个图片可能是数字 0 或者数字 3,分别都是 10% 的几率。

现在使用交叉验证来对SGDCLassifier进行精度评估

cross_val_score(sgd_clf,X_train,y_train,cv = 3,scoring = 'accuracy')

 

 在所有的测试者上,他有86%的精度。如何将精度提高到90%以上呢?

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
cross_val_score(sgd_clf,X_train_scaled,y_train,cv = 3,scoring = 'accuracy')

 

 2、误差分析

使用混淆矩阵,首先要使用cross_val_predict进行预测,然后调用confusion_matrix()函数

y_train_pred = cross_val_predict(sgd_clf,X_train_scaled,y_train,cv = 3)
conf_mx = confusion_matrix(y_train,y_train_pred)
conf_mx

plt.matshow(conf_mx,cmap = plt.cm.gray)
plt.show()

 

 关注下包含错误数据的图像呈现,首先需要将混淆矩阵的每一个值除以相应类别图片的总数目,并用0来填充对角线,这样就只保留了被错误分类的数据。

row_sums = conf_mx.sum(axis = 1,keepdims = True)
norm_conf_mx = conf_mx/row_sums
np.fill_diagonal(norm_conf_mx,0)
plt.matshow(norm_conf_mx,cmap= plt.cm.gray)
plt.show()

 

标签:分析,误差,digit,clf,分类,some,分类器,train,sgd
来源: https://www.cnblogs.com/whiteBear/p/12401415.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有