ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

使用DSFD检测DarkFace数据集过程

2021-04-11 19:57:30  阅读:294  来源: 互联网

标签:img DarkFace 检测 det DSFD im max net shrink


1.下载Dark Face数据集,使用track2.2_test_sample文件中图片进行人脸检测测试。

2.修改DSFD源码中demo.py部分:

test_oneimage():
def test_oneimage():
    # load net

    # 影响网络的自动求导机制,使网络前向传播后不进行求导和反向传播(仅测试时使用)
    torch.set_grad_enabled(False)

    # 加载config配置参数
    cfg = widerface_640
    # 分类的类别数目---widerface.py
    num_classes = len(WIDERFace_CLASSES) + 1 # +1 background
    # 加载SSD网络模型,返回一个SSD实例
    net = build_ssd('test', cfg['min_dim'], num_classes) # initialize SSD
    # 加载预训练模型train_model
    net.load_state_dict(torch.load(args.trained_model))
    net.cuda()
    # 表示进入评估模式,神经网络中有train(),eval()两种模式,使用eval()可关闭dropout
    net.eval()
    print('Finished loading model!')

    # evaluation
    cuda = args.cuda
    transform = TestBaseTransform((104, 117, 123))
    thresh=cfg['conf_thresh']
    #save_path = args.save_folder
    #num_images = len(testset)

    # load data,从指定路径加载待测图像

    ''' 
    以此为界,前半部分为网络模型加载和初始化,后半部分为单张图片的人脸检测。此处为测试图片的路            
    径设置过程。

    修改部分,不适用arg.imag_root作为测试图像路径.
    遍历darkface数据集100张图片,依次读取并进行测试
    '''

    folder = './data/'

    #img_id = 'face'
    for i in range(100):

        img = cv2.imread(folder + str(i) + '_fake_B.jpg', cv2.IMREAD_COLOR)
        img_id = 'test' + str(i)
        

        # 单张图片的测试过程
        max_im_shrink = ( (2000.0*2000.0) / (img.shape[0] * img.shape[1])) ** 0.5
        shrink = max_im_shrink if max_im_shrink < 1 else 1

        det0 = infer(net , img , transform , thresh , cuda , shrink)
        det1 = infer_flip(net , img , transform , thresh , cuda , shrink)
        # shrink detecting and shrink only detect big face
        st = 0.5 if max_im_shrink >= 0.75 else 0.5 * max_im_shrink
        det_s = infer(net , img , transform , thresh , cuda , st)
        index = np.where(np.maximum(det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1) > 30)[0]
        det_s = det_s[index, :]
        # enlarge one times
        factor = 2
        bt = min(factor, max_im_shrink) if max_im_shrink > 1 else (st + max_im_shrink) / 2
        det_b = infer(net , img , transform , thresh , cuda , bt)
        # enlarge small iamge x times for small face
        if max_im_shrink > factor:
            bt *= factor
            while bt < max_im_shrink:
                det_b = np.row_stack((det_b, infer(net , img , transform , thresh , cuda , bt)))
                bt *= factor
            det_b = np.row_stack((det_b, infer(net , img , transform , thresh , cuda , max_im_shrink) ))
        # enlarge only detect small face
        if bt > 1:
            index = np.where(np.minimum(det_b[:, 2] - det_b[:, 0] + 1, det_b[:, 3] - det_b[:, 1] + 1) < 100)[0]
            det_b = det_b[index, :]
        else:
            index = np.where(np.maximum(det_b[:, 2] - det_b[:, 0] + 1, det_b[:, 3] - det_b[:, 1] + 1) > 30)[0]
            det_b = det_b[index, :]
        det = np.row_stack((det0, det1, det_s, det_b))
        det = bbox_vote(det)
        vis_detections(img , det , img_id, args.visual_threshold)

3.运行demo.py即可。能在arg.save_folder处得到100张dark_face人脸检测结果。

补充:

DSFD只接收输入格式为jpg的图片,因此对darkface数据集进行批量转换。

import os
from PIL import Image

dirname_read="/home/...DSFD/darkface_png/"   # png格式图片的输入路径
dirname_write="/home/...DSFD/data/"    # jpg图片的输出路径
names=os.listdir(dirname_read)
count=0
for name in names:
    img=Image.open(dirname_read+name)
    name=name.split(".")
    if name[-1] == "png":
        name[-1] = "jpg"
        name = str.join(".", name)
        #r,g,b,a=img.split()
        #img=Image.merge("RGB",(r,g,b))
        to_save_path = dirname_write + name
        img.save(to_save_path)
        count+=1
        print(to_save_path, "------conut:", count)
    else:
        continue

 

标签:img,DarkFace,检测,det,DSFD,im,max,net,shrink
来源: https://blog.csdn.net/weixin_45075887/article/details/115604706

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有