ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

基于SGD、ASGD算法的SVM分类器(OpenCV案例源码train_svmsgd.cpp解读)

2020-03-06 19:02:51  阅读:333  来源: 互联网

标签:SVM ASGD shift TermCriteria 分类器 SVMSGD 源码 weights svmsgd


此案例用于二分类问题(鼠标左键、右键点出两类点,会实时画出分界线),最终得到一条分界线(直线):f(x)=weights*x+shift

源码不再贴出,只讲解最核心的doTrain()里的内容。参数含义翻译自ml.hpp文件。

与SVM不同,SVMSGD不需要设置核函数。

【参数】默认值见下述代码

模型类型:SGD、ASGD(推荐)。随机梯度下降、平均随机梯度下降。
边界类型:HARD_MARGIN、SOFT_MARGIN(推荐),前者用于线性可分,后者用于非线性可分
边界规范化 lambda:推荐设为0.0001(对于SGD),0.00001(对于ASGD)。越小,异类被抛弃的越少。
步长 gamma_0
步长降低力度 c:推荐设置为1(对于SGD),0.75(对于ASGD)
终止条件:TermCriteria::COUNT、TermCriteria::EPS、TermCriteria::COUNT + TermCriteria::EPS

参数设置函数:

setSvmsgdType()
setMarginType()
setMarginRegularization()
setInitialStepSize()
setStepDecreasingPower()

【使用方式】

cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();//创建对象
svmsgd->train(trainData);//训练
svmsgd->save("MySvmsgd.xml");//保存模型
svmsgd->load("MySvmsgd.xml");//加载模型
svmsgd->predict(samples, responses);//预测,结果保存到responses标签中

bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift)
{
    //*创建SVMSGD对象
    cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); //创建SVMSGD对象
    //*设置参数,以下全是默认参数
    //svmsgd->setSvmsgdType(SVMSGD::ASGD); //模型类型
    //svmsgd->setMarginType(SVMSGD::SOFT_MARGIN); //边界类型
    //svmsgd->setMarginRegularization(0.00001); //边界规范化
    //svmsgd->setInitialStepSize(0.05);//步长
    //svmsgd->setStepDecreasingPower(0.75); //步长减弱力度
    //svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT,1000,1e-3));//终止条件,1000次迭代,0.001每次迭代的精度
    //*训练集
    cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
    //*训练
    svmsgd->train(trainData);

    if (svmsgd->isTrained()) //获取分界线的系数,f(x)=weights*x+shift
    {
        weights = svmsgd->getWeights();//x系数
        shift = svmsgd->getShift();//常数项
        //*保存模型
        svmsgd->save("svmsgd.xml"); //保存训练好的模型
        
        return true;
    }
    return false;
}

得到的xml中,weights有两个数,shift有一个数。

 

 f(x)=weights*x+shift,不可以理解为y=kx+b,应该理解为Ax+By+C=0。weights的两个数就是A、B,shift是C。

Mat weights(1, 2, CV_32FC1); weights是一个1*2的向量,x也是1*2的向量(xi,xj)也就是(x,y)坐标。

公式写全了就是:f(x)=weights1*xi+weights2*xj+shift,其实就是weights与x这两个向量的内积(对应相乘在求和)

f(x)如果等于0,说明点在此直线上,大于0就在线的一边,小于0在线的另一边。

标签:SVM,ASGD,shift,TermCriteria,分类器,SVMSGD,源码,weights,svmsgd
来源: https://www.cnblogs.com/xixixing/p/12430202.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有