ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

java手撕KMeans算法实现手写数字聚类(失败案例)

2022-03-19 21:02:22  阅读:171  来源: 互联网

标签:10 java int 28 KMeans ++ 聚类 new


最近几天刚刚接触机器学习,学完K-Means聚类算法。正好又赶上一个课程项目是识别“手写数字”,因为KMeans能够实现聚类,因此自然而然地想要通过KMeans来实现。

前排提示:这是kmeans聚类的一个失败案例,没有成功聚类,仅供参考。

一,什么是KMeans聚类算法??

非常传统的聚类算法,目的是将一堆数据进行分类。

它的思想很朴素:假设这里有一群点,要将这些点分成两类。要是分成的类很合理的话,那不同类之间的中心点相聚是不是应该足够大,中心点附近的同一类的点是不是应该足够多?

举个例子:

a表示的是一堆原始点,没有处理。要将a聚类成两类,先随便找到两个点,计算所有点到这两个点的距离(欧式距离,曼哈顿距离,闵式距离等等都可以),根据距离最近的原则分配成两类。这时候是不是就能够得到两类的中心点,然后再次重复操作,直到最后聚出来的类不会发生变化。

so easy 是不是

二,使用的手写数字测试集??

我们在这里使用的是mnist测试集。这家伙的知名程度在机器学习中相当于是hello world了。不知道的小伙伴可以去查查。

但是一定有人会问到,mnist测试集应该怎么通过java使用呢?

不用担心,我用Python通过TensorFlow将mnist测试集打包成了txt文件,用java的文件操作直接调用就可以了。

具体效果像这样:

 这是28 * 28的二维int数组,每个值介于0到255之间,熟悉图像处理的小伙伴一定知道这是灰度值,0表示最黑,255表示最亮,因此这是黑纸白字的测试集,大家要是自己写测试数据的使用要记着对图片进行预处理,要不然可能会出错。

我将txt命名为:数字名-标号的形式,方便之后训练和测试。

 三,java手撕KMeans算法

先摆上一个算法流程图

 1.首先定义:

           训练图片(50000 * 28 * 28 的三维数组)

           聚类中心(10 * 28 * 28的三维数组)

           每张图片到聚类中心的距离(50000 * 10 的二维数组)

           旧的类和新的类(ArrayList[] 数组,因为不知道一个类中到底会有多少个图片)

    static float[][][] num = new float[50000][28][28];
    static float[][][] center = new float[10][28][28];// 聚类中心
    static long[][] distance = new long[num.length][10];
    static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类
    static ArrayList<Integer>[] newKinds = new ArrayList[10];

 2.定义方法:

        从Txt文件导入测试数据的方法

public static void getTXT(String path,int img,int x,int y) throws IOException {
        File file = new File(path);
        FileInputStream fis = new FileInputStream(file);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        String line;
        while((line = br.readLine()) != null){
            boolean isNum = false;
            for(int i = 0;i < line.length();i ++){
                if(line.charAt(i) != ' ' && !isNum){
                    // 如果遇到数字
                    isNum = true;
                    float tempNum = 0;
                    // 取数字
                    while(i < line.length() && line.charAt(i) != ' '){
                        tempNum = tempNum * 10 + line.charAt(i) - '0';
                        i++;
                    }
                    isNum = false;
                    if(y < 28){
                    }
                    else{
                        y = 0;
                        x ++;
                    }
                    num[img][x][y] = tempNum;
                    y++;
                }
            }
        }
        br.close();
    }

        获得图片到聚类中心距离的方法

    // 得到距离
    public static long getDistance(float[][] n,float[][] k){
        long ret = 0;
        for (int i = 0; i < 28;i ++){
            for (int j = 0; j < 28; j ++){
                ret += Math.pow((n[i][j] - k[i][j]),2);
            }
        }
        return ret;
    }

        得到图片距离最近聚类中心索引的方法

    // 获得数组元素最小值对应的下标
    public static int getMinIndex(long dis[]){
        int index = -1;
        long min = Integer.MAX_VALUE;
        for(int i = 0; i < 10;i ++){
            if(dis[i] < min){
                index = i;
                min = dis[i];
            }
        }
        return index;
    }

        比较旧的聚类和新的聚类是否相同的方法

    public static boolean isSame(){
        for(int i = 0; i < 10 ;i ++){
            for(int j = 0; j < newKinds[i].size();j ++){
                if(newKinds[i].size() != oldKinds[i].size()) return false;
                if (newKinds[i].get(j).intValue() != oldKinds[i].get(j).intValue() ) {
                    return false;
                }
            }
        }
        return true;
    }

需要注意的是!!!

两个Integer的比较需要通过.intValue()的方法先转换成为int!!!再进行比较,否则会因为内存什么什么奇奇怪怪的原因导致出现130 != 130这种很天真的错误。

我在这里被坑了一次,希望看到这片文章的人能够避一下坑。

3.开始while(true)死循环,直到旧类和新类相等不发生改变

        int kindTime = 0;
        while(true){
            // 3.计算每个文件和当前类中心之间的距离
            for (int i = 0; i < num.length; i++){
                for (int j = 0; j < 10; j++){
                    distance[i][j] = getDistance(num[i],center[j]);
                }
            }
            // 更新旧类
            for(int i = 0;i < 10;i ++){
                oldKinds[i].clear();
                for(int j = 0 ; j < newKinds[i].size();j ++){
                    oldKinds[i].add(newKinds[i].get(j));
                }
            }
            // 更新新类
            for (int i = 0; i < 10 ; i ++){
                newKinds[i].clear();
            }
            for (int i = 0; i < num.length; i ++){
                // 获得距离最小值,将其放到对应的类中
                newKinds[getMinIndex(distance[i])].add(i);
            }
            // 4.更新聚类中心
            for(int i = 0; i < 10; i ++){
                for(int x = 0; x < 28; x++){
                    for(int y = 0; y < 28;y ++){
                        center[i][x][y] = getAverage(newKinds[i],x,y);
                    }
                }
            }
            // 5.重复步骤,直到类不再发生改变
            if(isSame()){
                break;
            }
            System.out.println("第"+kindTime+"次聚类");
            kindTime++;
        }

4.保存类中心点

因为如果训练数据不变的话,聚类聚出的中心是不会变化的,所以为了避免之后聚类的重复操作,我们还是将得到的聚类中心点保存成为txt文件放到电脑上比较好。

    // 保存聚类中心点
    public static void saveKind(int index){
        FileWriter out = null;
        String path = "D:\\java\\workSpace\\KMeans\\" + index + "kinds.txt";
        File file = new File(path);
        try {
            out = new FileWriter(file);
            //二维数组按行存入到文件中
            for (int i = 0; i < center[index].length; i++) {
                for (int j = 0; j < center[index][i].length; j++) {
                    //将每个元素转换为字符串
                    String content = String.valueOf(center[index][i][j]) + " ";
                    out.write(content + "\t");
                }
                out.write("\r\n");
            }
            out.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

到现在,所有kmeans要求的操作我们都已经实现了。我们看看效果怎么样吧

1.我从test测试集(刚刚是train训练集)中导入了8000张图片,0到9每个数字各800张。

导入的方式和上文中的相同,这里就不在赘述了。

然后通过刚刚聚出来的类中心对测试数据进行聚类。(因为kmeans是无监督聚类吗,所以我也不知道每个类中心代表的哪个数字)

这是最后聚出来的结果:

发现大问题!!!我将每个类聚到的数字分别列出来。比如第0类,聚到4个数字0,3个数字1……

最后得到的结果,很!不!理!想!

 通过分析可以看到,数字1的聚类效果最好,800张图片中有787张被聚到第7类中了,但是第7类也混入了不少其他数字,还有129张2是什么鬼?!

其他的类就更不用说了,混杂了很多数字。

经过缜密思考之后,我认为是k的数值设置的问题,因为我们想要聚类出10个数字,所以很主观地将k设置成为了10,没有思考相同数字,因为书写原因而出现的数字内部聚类的问题。

就像数字0,分别被聚到了第1类和第4类中,这两类很少有其他数字。因此是将数字0进行了分类,把高的0矮的0胖的0瘦的0分开了!而不是将0之外的数字分开。

或许可以通过改变k的值进行改进呢!

这片文章才差不多就是这样了。最后贴上代码。

如果有朋友想要mnist手写数字数据集的txt文件,可以给我留言邮箱信息哦,我抽时间会发送的。

欢迎大佬们批评指正!

// 首先是kmeans聚类的代码
import java.io.*;
import java.util.ArrayList;

public class KMeans {
    // KMeans算法实现手写数字聚类
    static float[][][] num = new float[50000][28][28];
    static float[][][] center = new float[10][28][28];// 聚类中心
    static long[][] distance = new long[num.length][10];
    static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类
    static ArrayList<Integer>[] newKinds = new ArrayList[10];

    public static void main(String[] args) throws IOException {
        // 1.读取文件
        System.out.println("导入文件中……");
        for (int i = 0;i < num.length;i ++){
            getTXT("D:\\Python\\jupyter\\trains2\\" + Integer.toString(i/5000) + "-" + Integer.toString(i%5000 + 1) + ".txt",i,0,0);
            if(i % 1000 == 0) System.out.println("已导入文件:" + i);
        }
        System.out.println("导入文件成功!!!");
        // 随机选择聚类中心
        for(int i = 0; i < 10; i ++){
            oldKinds[i] = new ArrayList<>();
        }
        for(int i = 0 ; i < 10;i ++) {
            transTwoArray(num[i], center[i]);
            newKinds[i] = new ArrayList<>();
            newKinds[i].add(i);
        }

        int kindTime = 0;
        while(true){
            // 3.计算每个文件和当前类中心之间的距离
            for (int i = 0; i < num.length; i++){
                for (int j = 0; j < 10; j++){
                    distance[i][j] = getDistance(num[i],center[j]);
                }
            }
            // 更新旧类
            for(int i = 0;i < 10;i ++){
                oldKinds[i].clear();
                for(int j = 0 ; j < newKinds[i].size();j ++){
                    oldKinds[i].add(newKinds[i].get(j));
                }
            }
            // 更新新类
            for (int i = 0; i < 10 ; i ++){
                newKinds[i].clear();
            }
            for (int i = 0; i < num.length; i ++){
                // 获得距离最小值,将其放到对应的类中
                newKinds[getMinIndex(distance[i])].add(i);
            }
            // 4.更新聚类中心
            for(int i = 0; i < 10; i ++){
                for(int x = 0; x < 28; x++){
                    for(int y = 0; y < 28;y ++){
                        center[i][x][y] = getAverage(newKinds[i],x,y);
                    }
                }
            }
            // 5.重复步骤,直到类不再发生改变
            if(isSame()){
                break;
            }
            System.out.println("第"+kindTime+"次聚类");
            kindTime++;
        }
        // 保存聚类中心
        System.out.println("聚类成功!!!");
        System.out.println("-------------------------");
        System.out.println("保存类中心点中……");
        for(int i = 0; i < 10;i ++){
            saveKind(i);
        }
        System.out.println("保存类中心点成功!!!");
    }


    // 读取文件
    public static void getTXT(String path,int img,int x,int y) throws IOException {
        File file = new File(path);
        FileInputStream fis = new FileInputStream(file);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        String line;
        while((line = br.readLine()) != null){
            boolean isNum = false;
            for(int i = 0;i < line.length();i ++){
                if(line.charAt(i) != ' ' && !isNum){
                    // 如果遇到数字
                    isNum = true;
                    float tempNum = 0;
                    // 取数字
                    while(i < line.length() && line.charAt(i) != ' '){
                        tempNum = tempNum * 10 + line.charAt(i) - '0';
                        i++;
                    }
                    isNum = false;
                    if(y < 28){
                    }
                    else{
                        y = 0;
                        x ++;
                    }
                    num[img][x][y] = tempNum;
                    y++;
                }
            }
        }
        br.close();
    }
    // 转移两个数组
    public static void transTwoArray(float[][] array1,float[][] array2){
        for(int i = 0; i < 28;i ++){
            for (int j = 0; j < 28;j ++){
                array2[i][j] = array1[i][j];
            }
        }
    }
    // 得到距离
    public static long getDistance(float[][] n,float[][] k){
        long ret = 0;
        for (int i = 0; i < 28;i ++){
            for (int j = 0; j < 28; j ++){
                ret += Math.pow((n[i][j] - k[i][j]),2);
            }
        }
        return ret;
    }
    // 获得数组元素最小值对应的下标
    public static int getMinIndex(long dis[]){
        int index = -1;
        long min = Integer.MAX_VALUE;
        for(int i = 0; i < 10;i ++){
            if(dis[i] < min){
                index = i;
                min = dis[i];
            }
        }
        return index;
    }
    // 计算均值
    public static float getAverage(ArrayList<Integer> arr,int x,int y){
        float ret = 0;
        for(int i = 0; i < arr.size(); i ++){
            ret += num[arr.get(i)][x][y];// 将同一类中所有相同位置元素相加
        }
        return ret / arr.size();
    }
    // 保存聚类中心点
    public static void saveKind(int index){
        FileWriter out = null;
        String path = "D:\\java\\workSpace\\KMeans\\" + index + "kinds.txt";
        File file = new File(path);
        try {
            out = new FileWriter(file);
            //二维数组按行存入到文件中
            for (int i = 0; i < center[index].length; i++) {
                for (int j = 0; j < center[index][i].length; j++) {
                    //将每个元素转换为字符串
                    String content = String.valueOf(center[index][i][j]) + " ";
                    out.write(content + "\t");
                }
                out.write("\r\n");
            }
            out.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    // 是否相等
    public static boolean isSame(){
        for(int i = 0; i < 10 ;i ++){
            for(int j = 0; j < newKinds[i].size();j ++){
                if(newKinds[i].size() != oldKinds[i].size()) return false;
                if (newKinds[i].get(j).intValue() != oldKinds[i].get(j).intValue() ) {
                    return false;
                }
            }
        }
        return true;
    }
}

测试聚类中心的代码

import java.io.*;
import java.util.ArrayList;

public class myKMeansTest {
    static float[][][] kMeans = new float[10][28][28];
    static float[][][] test = new float[8000][28][28];// 测试数据,每个数字有800张
    static long[][] distance = new long[8000][10];// 每张图片聚类类中心的距离
    static ArrayList<Integer>[] kinds = new ArrayList[10];// 每个类中包含的图片索引

    public static void main(String[] args) throws IOException {
        System.out.println("-----获取文件中-----");
        // 读取聚类中心文件
        for(int i = 0; i < 10;i ++){
            String img = "D:\\java\\workSpace\\KMeans\\" + i + "kinds.txt";
            getKMeansTxt(img,i);
        }
        // 读取测试文件
        for(int i = 0;i < 8000;i ++){
            String img = "D:\\Python\\jupyter\\test\\" + i/800 + "-" + (i%800 + 1) + ".txt";
            getTestTxt(img,i,0,0);
            if(i % 800 == 0) System.out.println("已导入数据:"+i);
        }
        System.out.println("获取文件成功!!");
        // 进行测试
        System.out.println("开始聚类……");
        for(int i = 0; i < 10;i ++){
            kinds[i] = new ArrayList<>();
        }
        for(int i = 0; i < 8000;i ++){
            for (int j = 0; j < 10;j ++){
                distance[i][j] = GoodKMeans.getDistance(kMeans[j],test[i]);// 获得每张图片对应聚类中心的距离
            }
        }
        for(int i= 0;i< 8000;i++){
            kinds[GoodKMeans.getMinIndex(distance[i])].add(i);// 将图片归为最小距离的类中
        }
        System.out.println("聚类成功!!");

        int[][] ans = new int[10][10];
        for(int i = 0; i < 10;i ++){
            for(int j = 0; j < kinds[i].size();j ++){
                if(kinds[i].get(j) < 800) ans[i][0]++;
                else if(kinds[i].get(j) >= 800 && kinds[i].get(j) < 1600) ans[i][1]++;
                else if(kinds[i].get(j) >= 1600 && kinds[i].get(j)< 2400) ans[i][2]++;
                else if(kinds[i].get(j) >= 2400 && kinds[i].get(j)< 3200) ans[i][3]++;
                else if(kinds[i].get(j) >= 3200 && kinds[i].get(j)< 4000) ans[i][4]++;
                else if(kinds[i].get(j) >= 4000 && kinds[i].get(j)< 4800) ans[i][5]++;
                else if(kinds[i].get(j) >= 4800 && kinds[i].get(j)< 5600) ans[i][6]++;
                else if(kinds[i].get(j) >= 5600 && kinds[i].get(j)< 6400) ans[i][7]++;
                else if(kinds[i].get(j) >= 6400 && kinds[i].get(j)< 7200) ans[i][8]++;
                else if(kinds[i].get(j) >= 7200 && kinds[i].get(j)< 8000) ans[i][9]++;
            }
        }
        for (int i = 0; i < 10;i ++){
            System.out.print("第"+i+"类中:");
            for (int j = 0; j < 10;j ++){
                System.out.print(j+":");
                System.out.printf("%3d",ans[i][j]);
                System.out.print("\t");
            }
            System.out.println();
        }
    }

    // 获得聚类中心文件
    public static void getKMeansTxt(String img,int index) throws IOException {
        File file = new File(img);
        FileInputStream fis = new FileInputStream(file);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        int x = 0;
        int y = 0;
        String line;
        while((line = br.readLine()) != null){
            boolean isNum = false;
            for(int i = 0;i < line.length();i ++){
                if(line.charAt(i)-'0' <10 && line.charAt(i)-'0' >=0 && !isNum){
                    // 如果遇到数字
                    isNum = true;
                    // 取数字
                    int j = i + 1;
                    while(j < line.length() && line.charAt(j) != ' '){
                        j++;
                    }
                    isNum = false;
                    if(y < 28){
                    }
                    else{
                        y = 0;
                        x ++;
                    }
                    kMeans[index][x][y] = Float.valueOf(line.substring(i,j)).floatValue();
                    i = j;
                    y++;
                }
            }
        }
        br.close();
    }
    // 获得测试文件
    public static void getTestTxt(String path,int img,int x,int y) throws IOException {
        File file = new File(path);
        FileInputStream fis = new FileInputStream(file);
        InputStreamReader isr = new InputStreamReader(fis);
        BufferedReader br = new BufferedReader(isr);

        String line;
        while((line = br.readLine()) != null){
            boolean isNum = false;
            for(int i = 0;i < line.length();i ++){
                if(line.charAt(i) != ' ' && !isNum){
                    // 如果遇到数字
                    isNum = true;
                    float tempNum = 0;
                    // 取数字
                    while(i < line.length() && line.charAt(i) != ' '){
                        tempNum = tempNum * 10 + line.charAt(i) - '0';
                        i++;
                    }
                    isNum = false;
                    if(y < 28){
                    }
                    else{
                        y = 0;
                        x ++;
                    }
                    test[img][x][y] = tempNum;
                    y++;
                }
            }
        }
        br.close();
    }
}

标签:10,java,int,28,KMeans,++,聚类,new
来源: https://blog.csdn.net/m0_51418456/article/details/123601187

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有