ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Tensorflow的数据处理中的Dataset和Iterator

2021-11-12 16:30:56  阅读:162  来源: 互联网

标签:Iterator iterator dataset tf Tensorflow Dataset data


完整机器学习实现代码GitHub
欢迎转载,转载请注明出处https://www.cnblogs.com/huangyc/p/10339433.html


0. 目录

 


1. Tensorflow高效流水线Pipeline

2. Tensorflow的数据处理中的Dataset和Iterator

3. Tensorflow生成TFRecord

4. Tensorflow的Estimator实践原理

回到顶部

1. 前言

我们在训练模型的时候,必须经过的第一个步骤是数据处理。在机器学习领域有一个说法,数据处理的好坏直接影响了模型结果的好坏。数据处理是至关重要的一步。

我们今天关注数据处理的另一个问题:假设我们做深度学习,数据的量随随便便就到GB的级别,那数据处理的速度对于模型的训练也很重要。经常遇到的一个情况是,数据处理的时间占了训练整个模型的大部分。

今天介绍的是Tensorflow官方推荐的数据处理方式是用Dataset API同时支持从内存和硬盘的读取,相比之前的两种方法在语法上更加简洁易懂

回到顶部

2. Dataset原理

Google官方给出的Dataset API中的类图如下所示:

image

2.1 Dataset创建方法

Dataset API还提供了四种创建Dataset的方式:

  • tf.data.Dataset.from_tensor_slices():这个函数直接从内存中读取数据,数据的形式可以是数组、矩阵、dict等。
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
#实例化make_one_shot_iterator对象,该对象只能读取一次
iterator = dataset.make_one_shot_iterator()
# 从iterator里取出一个元素
one_element = iterator.get_next()
with tf.Session() as sess:
    for i in range(5):
        print(sess.run(one_element))
  • tf.data.TFRecordDataset():顾名思义,这个函数是用来读TFRecord文件的,dataset中的每一个元素就是一个TFExample。
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
  • tf.data.TextLineDataset():这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
  • tf.data.FixedLengthRecordDataset():这个函数的输入是一个文件的列表和一个record_bytes,之后dataset的每一个元素就是文件中固定字节数record_bytes的内容。通常用来读取以二进制形式保存的文件,如CIFAR10数据集就是这种形式。

2.2 Dataset数据进行转换(Transformation)

一个Dataset通过Transformation变成一个新的Dataset。通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作,常用的Transformation有:

  • map:接收一个函数对象,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1。
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
  • apply:应用一个转换函数到dataset。
dataset = dataset.apply(group_by_window(key_func, reduce_func, window_size))
  • batch:根据接收的整数值将该数个元素组合成batch,如下面的程序将dataset中的元素组成了大小为32的batch。
dataset = dataset.batch(32)
  • shuffle:打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小。
dataset = dataset.shuffle(buffer_size=10000)
  • repeat:整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch。
dataset = dataset.repeat(5)
# 如果repeat没有参数,则一直重复循环数据
dataset = dataset.repeat()
  • padded_batch:对dataset中的数据进行padding到一定的长度。
dataset.padded_batch(
    batch_size,
    padded_shapes=(
        tf.TensorShape([None]),  # src
        tf.TensorShape([]),  # tgt_output
        tf.TensorShape([]),
        tf.TensorShape([src_max_len])),  # src_len
    padding_values=(
        src_eos_id,  # src
        0,  # tgt_len -- unused
        0,  # src_len -- unused
        0)) # mask
  • shard:根据多GPU进行分片操作。
dataset.shard(num_shards, shard_index)

比较完整的生成dataset的代码。

def parse_fn(example):
  "Parse TFExample records and perform simple data augmentation."
  example_fmt = {
    "image": tf.FixedLengthFeature((), tf.string, ""),
    "label": tf.FixedLengthFeature((), tf.int64, -1)
  }
  parsed = tf.parse_single_example(example, example_fmt)
  image = tf.image.decode_image(parsed["image"])
  image = _augment_helper(image)  # augments image using slice, reshape, resize_bilinear
  return image, parsed["label"]

#简单的生成input_fn
def input_fn():
files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
dataset = files.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
dataset = dataset.map(map_func=parse_fn)
dataset = dataset.batch(batch_size=FLAGS.batch_size)
return dataset

回到顶部

3. Iterator原理

3.1 Iterator Init初始化

生成Iterator一共有4种,复杂程度递增,个人觉得掌握前两种应该够用了,Iterator还有一个优势,目前,单次迭代器是唯一易于与 Estimator 搭配使用的类型

  • one shot Iterator:one shot Iterator是最简单的一种Iterator,仅支持对整个数据集访问一遍,不需要显式的初始化。one-shot Iterator不支参数化。
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
value = sess.run(next_element)
assert i == value

  • initializable Iterator:Initializable Iterator 要求在使用之前显式的通过调用Iterator.initializer操作初始化,这使得在定义数据集时可以结合tf.placeholder传入参数。
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value

  • reinitializable Iterator:可以被不同的dataset对象初始化,比如对于训练集进行了shuffle的操作,对于验证集则没有处理,通常这种情况会使用两个具有相同结构的dataset对象。
  • feedable Iterator:可以通过和tf.placeholder结合在一起,同通过feed_dict机制来选择在每次调用tf.Session.run的时候选择哪种Iterator。

3.2 Iterator get_next遍历数据

Iterator.get_next() 方法tf.Tensor 对象,每次tf.Session.run(Iterator.get_next())都会获取底层数据集中下一个元素的值。

如果迭代器到达数据集的末尾,则执行 Iterator.get_next() 操作会产生 tf.errors.OutOfRangeError。在此之后,迭代器将处于不可用状态;如果需要继续使用,则必须对其重新初始化。

sess.run(iterator.initializer)
while True:
  try:
    sess.run(getNextTensor)
  except tf.errors.OutOfRangeError:
    sess.run(iterator.initializer)

3.3 Iterator Save保存

tf.contrib.data.make_saveable_from_iterator 函数通过迭代器创建一个 SaveableObject,该对象可用于保存和恢复迭代器(实际上是整个输入管道)的当前状态。

# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
with tf.Session() as sess:
  if should_checkpoint:
    saver.save(path_to_checkpoint)
# Restore the iterator state.
with tf.Session() as sess:
  saver.restore(sess, path_to_checkpoint)
回到顶部

4. 总结

本文介绍了创建不同种类的Dataset和Iterator对象的基础知识,熟悉这个数据处理的步骤后,不仅复用性比较强,而且效率也能成倍的提升。

标签:Iterator,iterator,dataset,tf,Tensorflow,Dataset,data
来源: https://blog.csdn.net/u014311125/article/details/121290743

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有