ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

data_loader读取器

2021-08-06 18:03:11  阅读:170  来源: 互联网

标签:df batch loader train test import data 读取器


import random
import numpy as np
import pandas as pd
import cv2
def date_loader(image_dir, file_name, batch_size=1, mode='train'):
    train_dir_list = []
    train_label = []
    test_dir_list = []
    test_label = []
    val_dir_list = []
    val_label = []
    df = pd.read_csv(file_name)
    
    # 生成训练和测试数据集 0.8 /0.2
    df = df.sample(frac=1)
    for i in range(len(df)):
        if i <= (len(df)*0.8-1):
            dir =  image_dir+ '/' + df.iloc[i][0] + '.jpg'
            train_dir_list.append(dir)
            train_label.append(int(df.iloc[i][1]-1))
        else: 
            dir =  image_dir+ '/' + df.iloc[i][0] + '.jpg'
            test_dir_list.append(dir)
            test_label.append(int(df.iloc[i][1]-1))
    
    # 生成随机验证集,比列0.2
    df1 = df.sample(frac=0.2)
    for i in range(len(df1)):
        dir =  image_dir+ '/' + df1.iloc[i][0] + '.jpg'
        val_dir_list.append(dir)
        val_label.append(int(df.iloc[i][1]-1))
    
    def reader():
        batch_img = []
        batch_label = []
        if mode == 'train':
            count = 0
            for i in range(len(train_dir_list)):
                img = cv2.imread(train_dir_list[i])
                img = cv2.resize(img, (224,224), interpolation=cv2.INTER_CUBIC)/255
                img = np.transpose(img, (2,0,1))
                batch_img.append(img)
                batch_label.append(train_label[i])
                count +=1
                if (count %batch_size==0):
                    # print(len(train_label))
                    yield np.array(batch_img).astype('float32'), np.asarray(batch_label).astype('int64').reshape(batch_size,1)
                    batch_img = []
                    batch_label = []
        elif mode == 'test':
            count = 0
            for i in range(len(test_dir_list)):
                img = cv2.imread(test_dir_list[i])
                img = cv2.resize(img, (224,224), interpolation=cv2.INTER_CUBIC)/255
                img = np.transpose(img, (2,0,1))
                batch_img.append(img)
                batch_label.append(test_label[i])
                count +=1
                if (count %batch_size==0):
                    # print(len(test_label))
                    yield np.array(batch_img).astype('float32'), np.asarray(batch_label).astype('int64').reshape(batch_size,1)
                    batch_img = []
                    batch_label = []
        elif mode == 'val':
            count = 0
            for i in range(len(val_dir_list)):
                img = cv2.imread(val_dir_list[i])
                img = cv2.resize(img, (224,224), interpolation=cv2.INTER_CUBIC)/255
                img = np.transpose(img, (2,0,1))
                batch_img.append(img)
                batch_label.append(val_label[i])
                count +=1
                if (count %batch_size==0):
                    # print(len(val_dir_list))
                    yield np.array(batch_img).astype('float32'), np.asarray(batch_label).astype('int64').reshape(batch_size,1)
                    batch_img = []
                    batch_label = []
    return reader

a = date_loader('image2_100','a_100_drop_p.csv',mode='test')
for n , data in enumerate(a()):
    images, label = data
    # print(label)
    break

train_reader = paddle.batch(date_loader('image2_100','a_100_drop_p.csv',mode='train'), batch_size=10)
test_reader = paddle.batch(date_loader('image2_100','a_100_drop_p.csv',mode='test'), batch_size=10)

标签:df,batch,loader,train,test,import,data,读取器
来源: https://www.cnblogs.com/mumuzifeng/p/15109802.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有