ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

手写数字识别 KNN

2020-02-03 21:09:37  阅读:233  来源: 互联网

标签:KNN csv train file test 手写 识别 array def


from numpy import*
import csv
import operator
from sklearn.neighbors import KNeighborsClassifier

def toInt(array):
    array = mat(array)
    m, n = shape(array)
    Array=zeros((m,n))
    for i in range(m):
        for j in range(n):
            try:
                Array[i, j]=int(array[i, j])
            except ValueError:
                continue
    return Array

def nomalizing(array):
    m,n=shape(array)
    for i in range(m):
        for j in range(n):
            if array[i, j] != 0:
                array[i, j] = 1
    return array

def loadTrainData():
    l = []
    with open("train.csv") as file:
        lines = csv.reader(file)
        for line in lines:
            l.append(line)
    file.close()
    l.remove(l[0])
    l = array(l)
    label = l[:, 0]
    data = l[:, 1:]
    return nomalizing(toInt(data)),toInt(label)

def loadTestData():
    l = []
    with open("test.csv") as file:
        lines = csv.reader(file)
        for line in lines:
            l.append(line)
    file.close()
    l.remove(l[0])
    data = array(l)
    return nomalizing(toInt(data))

def loadTest_result():
    l = []
    with open("test_result.csv") as file:
        lines = csv.reader(file)
        for line in lines:
            l.append(line)
    file.close()
    l.remove(l[0])
    label = array(l)
    return toInt(label[:, 1])

def saveResult(result):
    l = []
    with open("my_result.csv","w")as myFile:
        myWriter = csv.writer(myFile)
        for i in result:
            l.append(i)
        myWriter.writerow(l)
    myFile.close()
    return;

def knnClassify(x_train, y_train, x_test):
    estimator = KNeighborsClassifier()
    estimator.fit(x_train, ravel(y_train))
    y_test = estimator.predict(x_test)
    saveResult(y_test)
    return y_test

def digitRecognition():
    x_train, y_train = loadTrainData()
    x_test = loadTestData()
    predict = knnClassify(x_train, y_train, x_test)
    y_test = loadTest_result()
    m, n = shape(x_test)
    wrong = 0
    for i in range(m):
        # print("predict: %d, answer: %d" %(predict[i], y_test[0, i]))
        if predict[i] != y_test[0, i]:
            wrong += 1
    print("wrong = %d" % wrong)                                       #819
    print("right rate = %f%%" % (100.0 * (m - wrong) / float(m)))     #97.075%

if __name__ == "__main__":
    digitRecognition()
九克拉 发布了3 篇原创文章 · 获赞 0 · 访问量 29 私信 关注

标签:KNN,csv,train,file,test,手写,识别,array,def
来源: https://blog.csdn.net/qq_45807398/article/details/104161642

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有