ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

MindSpore易点通·精讲系列--数据集加载之CSVDataset

2022-07-15 15:00:06  阅读:178  来源: 互联网

标签:iris -- 精讲 dataset column file 易点通 csv data


Dive Into MindSpore – CSVDataset For Dataset Load

MindSpore精讲系列 – 数据集加载之CSVDataset

本文开发环境

  • Ubuntu 20.04
  • Python 3.8
  • MindSpore 1.7.0

本文内容摘要

  • 先看API
  • 数据准备
  • 两种试错
  • 正确示例
  • 本文总结
  • 问题改进
  • 本文参考

1. 先看API

老传统,先看看官方文档:

参数解读:

  • dataset_files – 数据集文件路径,可以单文件也可以是文件列表

  • filed_delim – 字段分割符,默认为","

  • column_defaults – 一个巨坑的参数,留待后面解读

  • column_names – 字段名,用于后续数据字段的key

  • num_paraller_workers – 不再解释

  • shuffle – 是否打乱数据,三种选择[False, Shuffle.GLOBAL, Shuffle.FILES]

    • Shuffle.GLOBAL – 混洗文件和文件中的数据,默认
    • Shuffle.FILES – 仅混洗文件
  • num_shards – 不再解释

  • shard_id – 不再解释

  • cache – 不再解释

2. 数据准备

2.1 数据下载

说明:

数据下载地址:UCI Machine Learning Repository: Iris Data Set

使用如下命令下载数据iris.datairis.names到目标目录:

mkdir iris && cd iris
wget -c https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
wget -c https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.names

**备注:**如果受系统限制,无法使用wget命令,可以考虑用浏览器下载,下载地址见说明。

2.2 数据简介

Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

更详细的介绍参见官方说明:

5. Number of Instances: 150 (50 in each of three classes)

6. Number of Attributes: 4 numeric, predictive attributes and the class

7. Attribute Information:
   1. sepal length in cm
   2. sepal width in cm
   3. petal length in cm
   4. petal width in cm
   5. class:
      -- Iris Setosa
      -- Iris Versicolour
      -- Iris Virginica

8. Missing Attribute Values: None

Summary Statistics:
	         Min  Max   Mean    SD   Class Correlation
   sepal length: 4.3  7.9   5.84  0.83    0.7826
    sepal width: 2.0  4.4   3.05  0.43   -0.4194
   petal length: 1.0  6.9   3.76  1.76    0.9490  (high!)
    petal width: 0.1  2.5   1.20  0.76    0.9565  (high!)

9. Class Distribution: 33.3% for each of 3 classes.

2.3 数据分配

这里对数据进行初步分配,分成训练集和测试集,分配比例为4:1。

相关处理代码如下:

from random import shuffle


def preprocess_iris_data(iris_data_file, train_file, test_file, header=True):
    cls_0 = "Iris-setosa"
    cls_1 = "Iris-versicolor"
    cls_2 = "Iris-virginica"

    cls_0_samples = []
    cls_1_samples = []
    cls_2_samples = []

    with open(iris_data_file, "r", encoding="UTF8") as fp:
        lines = fp.readlines()
        for line in lines:
            line = line.strip()
            if not line:
                continue
            if cls_0 in line:
                cls_0_samples.append(line)
                continue
            if cls_1 in line:
                cls_1_samples.append(line)
                continue
            if cls_2 in line:
                cls_2_samples.append(line)

    shuffle(cls_0_samples)
    shuffle(cls_1_samples)
    shuffle(cls_2_samples)

    print("number of class 0: {}".format(len(cls_0_samples)), flush=True)
    print("number of class 1: {}".format(len(cls_1_samples)), flush=True)
    print("number of class 2: {}".format(len(cls_2_samples)), flush=True)

    train_samples = cls_0_samples[:40] + cls_1_samples[:40] + cls_2_samples[:40]
    test_samples = cls_0_samples[40:] + cls_1_samples[40:] + cls_2_samples[40:]

    header_content = "Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Classes"

    with open(train_file, "w", encoding="UTF8") as fp:
        if header:
            fp.write("{}\n".format(header_content))
        for sample in train_samples:
            fp.write("{}\n".format(sample))

    with open(test_file, "w", encoding="UTF8") as fp:
        if header:
            fp.write("{}\n".format(header_content))
        for sample in test_samples:
            fp.write("{}\n".format(sample))


def main():
    iris_data_file = "{your_path}/iris/iris.data"
    iris_train_file = "{your_path}/iris/iris_train.csv"
    iris_test_file = "{your_path}/iris/iris_test.csv"

    preprocess_iris_data(iris_data_file, iris_train_file, iris_test_file)


if __name__ == "__main__":
    main()

将以上代码保存到preprocess.py文件,使用如下命令运行:

注意修改相关数据文件路径

python3 preprocess.py

输出内容如下:

number of class 0: 50
number of class 1: 50
number of class 2: 50

同时在目标目录生成iris_train.csviris_test.csv文件,目录内容如下所示:

.
├── iris.data
├── iris.names
├── iris_test.csv
└── iris_train.csv

3. 两种试错

下面通过几种**错误(带引号)**用法,来初步认识一下CSVDataset

3.1 column_defaults是哪样

首先,先来个简单加载,代码如下:

为方便读者复现,这里将shuffle设置为False。

from mindspore.dataset import CSVDataset


def dataset_load(data_files):
    column_defaults = [float, float, float, float, str]
    column_names = ["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width", "Classes"]

    dataset = CSVDataset(
        dataset_files=data_files,
        field_delim=",",
        column_defaults=column_defaults,
        column_names=column_names,
        num_samples=None,
        shuffle=False)

    data_iter = dataset.create_dict_iterator()
    item = None
    for data in data_iter:
        item = data
        break

    print("====== sample ======\n{}".format(item), flush=True)


def main():
    iris_train_file = "{your_path}/iris/iris_train.csv"

    dataset_load(data_files=iris_train_file)


if __name__ == "__main__":
    main()

将以上代码保存到load.py文件,运行命令:

注意修改数据文件路径

python3 load.py

纳尼,报错,来看看报错内容:

Traceback (most recent call last):
  File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 107, in <module>
    main()
  File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 103, in main
    dataset_load(data_files=iris_train_file)
  File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 75, in dataset_load
    dataset = CSVDataset(
  File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/validators.py", line 1634, in new_method
    raise TypeError("column type in column_defaults is invalid.")
TypeError: column type in column_defaults is invalid.

看看引发报错的源码,mindspore/dataset/engine/validators.py 中1634行,相关代码如下:

        # check column_defaults
        column_defaults = param_dict.get('column_defaults')
        if column_defaults is not None:
            if not isinstance(column_defaults, list):
                raise TypeError("column_defaults should be type of list.")
            for item in column_defaults:
                if not isinstance(item, (str, int, float)):
                    raise TypeError("column type in column_defaults is invalid.")

3.1.1 报错分析

更多关于column_defaults参数的分析请参考6.1节。

还记得官方参数说明吗,不记得没关系,这里再列出来。

  • column_defaults (list, 可选) - 指定每个数据列的数据类型,有效的类型包括float、int或string。默认值:None,不指定。如果未指定该参数,则所有列的数据类型将被视为string。

很显然,官方参数说明是数据类型,但是到mindspore/dataset/engine/validators.py代码里面,却检测的是数据实例类型。明确了这点,将代码:

column_defaults = [float, float, float, float, str]

修改为:

这里的数值取自iris.names文件,详情参考该文件。

column_defaults = [5.84, 3.05, 3.76, 1.20, "Classes"]

再次运行代码,再次报错:

WARNING: Logging before InitGoogleLogging() is written to STDERR
[ERROR] MD(13306,0x70000269b000,Python):2022-06-14-16:51:59.681.109 [mindspore/ccsrc/minddata/dataset/util/task_manager.cc:217] InterruptMaster] Task is terminated with err msg(more detail in info level log):Unexpected error. Invalid csv, csv file: /Users/kaierlong/Downloads/iris/iris_train.csv parse failed at line 1, type does not match.
Line of code : 506
File         : /Users/jenkins/agent-working-dir/workspace/Compile_CPU_X86_MacOS_PY39/mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc

Traceback (most recent call last):
  File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 107, in <module>
    main()
  File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 103, in main
    dataset_load(data_files=iris_train_file)
  File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 90, in dataset_load
    for data in data_iter:
  File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 147, in __next__
    data = self._get_next()
  File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 211, in _get_next
    raise err
  File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 204, in _get_next
    return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()}
RuntimeError: Unexpected error. Invalid csv, csv file: /Users/kaierlong/Downloads/iris/iris_train.csv parse failed at line 1, type does not match.
Line of code : 506
File         : /Users/jenkins/agent-working-dir/workspace/Compile_CPU_X86_MacOS_PY39/mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc

好了,这个错误我们到3.2部分进行分析。

3.2 header要不要

3.1中,我们根据对报错源码的分析,明确了column_defaults的正确用法,但是依然存在一个错误。

3.2.1 报错分析

根据报错信息,发现是mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc中506行的报错,相关源码如下:

Status CsvOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
  CsvParser csv_parser(worker_id, jagged_rows_connector_.get(), field_delim_, column_default_list_, file);
  RETURN_IF_NOT_OK(csv_parser.InitCsvParser());
  csv_parser.SetStartOffset(start_offset);
  csv_parser.SetEndOffset(end_offset);

  auto realpath = FileUtils::GetRealPath(file.c_str());
  if (!realpath.has_value()) {
    MS_LOG(ERROR) << "Invalid file path, " << file << " does not exist.";
    RETURN_STATUS_UNEXPECTED("Invalid file path, " + file + " does not exist.");
  }

  std::ifstream ifs;
  ifs.open(realpath.value(), std::ifstream::in);
  if (!ifs.is_open()) {
    RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + file + ", the file is damaged or permission denied.");
  }
  if (column_name_list_.empty()) {
    std::string tmp;
    getline(ifs, tmp);
  }
  csv_parser.Reset();
  try {
    while (ifs.good()) {
      // when ifstream reaches the end of file, the function get() return std::char_traits<char>::eof()
      // which is a 32-bit -1, it's not equal to the 8-bit -1 on Euler OS. So instead of char, we use
      // int to receive its return value.
      int chr = ifs.get();
      int err = csv_parser.ProcessMessage(chr);
      if (err != 0) {
        // if error code is -2, the returned error is interrupted
        if (err == -2) return Status(kMDInterrupted);
        RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse csv file: " + file + " at line " +
                                 std::to_string(csv_parser.GetTotalRows() + 1) +
                                 ". Error message: " + csv_parser.GetErrorMessage());
      }
    }
  } catch (std::invalid_argument &ia) {
    std::string err_row = std::to_string(csv_parser.GetTotalRows() + 1);
    RETURN_STATUS_UNEXPECTED("Invalid csv, csv file: " + file + " parse failed at line " + err_row +
                             ", type does not match.");
  } catch (std::out_of_range &oor) {
    std::string err_row = std::to_string(csv_parser.GetTotalRows() + 1);
    RETURN_STATUS_UNEXPECTED("Invalid csv, " + file + " parse failed at line " + err_row + " : value out of range.");
  }
  return Status::OK();
}

通过阅读上面的源码,发现源码中没有处理header行的代码,即默认所有行都是数据行。还记得2.3中数据分配部分的代码,我们写入了header信息,而CSVDataset并不提供处理header行的能力。

现在根据报错分析定位,对2.3的数据分配代码进行修改,将代码

preprocess_iris_data(iris_data_file, iris_train_file, iris_test_file)

修改为

preprocess_iris_data(iris_data_file, iris_train_file, iris_test_file, header=False)

再次运行preprocess.py文件,生成新的数据。

然后运行load.py文件(这里并不需要再改代码),输出内容如下:

说明:

  1. 为方便读者查看,这里对格式进行了人为处理,内容不变。
  2. 这里已经能够正确读取数据,数据包含5个字段。
  3. 数据字段名已经根据指定的column_names做了处理。
====== sample ======
{'Sepal.Length': Tensor(shape=[], dtype=Float32, value= 5.5), 'Sepal.Width': Tensor(shape=[], dtype=Float32, value= 4.2), 'Petal.Length': Tensor(shape=[], dtype=Float32, value= 1.4), 'Petal.Width': Tensor(shape=[], dtype=Float32, value= 0.2), 
'Classes': Tensor(shape=[], dtype=String, value= 'Iris-setosa')}

4. 正确示例

通过3中的两种试错,我们对CSVDataset有了初步认识,细心的读者可能会发现,3中依然有一个问题,那就是Classes字段没有进行数值化,下面我们就来介绍一种对其数值化的方法。

源码如下:

from mindspore.dataset import CSVDataset
from mindspore.dataset.text import Lookup, Vocab


def dataset_load(data_files):
    column_defaults = [5.84, 3.05, 3.76, 1.20, "Classes"]
    column_names = ["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width", "Classes"]

    dataset = CSVDataset(
        dataset_files=data_files,
        field_delim=",",
        column_defaults=column_defaults,
        column_names=column_names,
        num_samples=None,
        shuffle=False)

    cls_to_id_dict = {"Iris-setosa": 0, "Iris-versicolor": 1, "Iris-virginica": 2}
    vocab = Vocab.from_dict(word_dict=cls_to_id_dict)
    lookup = Lookup(vocab)
    dataset = dataset.map(input_columns="Classes", operations=lookup)

    data_iter = dataset.create_dict_iterator()
    item = None
    for data in data_iter:
        item = data
        break

    print("====== sample ======\n{}".format(item), flush=True)


def main():
    iris_train_file = "{your_path}/iris/iris_train.csv"

    dataset_load(data_files=iris_train_file)


if __name__ == "__main__":
    main()

将以上代码保存到load.py文件,运行命令:

注意修改数据文件路径

python3 load.py

输出内容如下:

说明:

  1. 数据包含5个字段。
  2. Classes字段已经根据cls_to_id_dict = {"Iris-setosa": 0, "Iris-versicolor": 1, "Iris-virginica": 2}进行了数值化转换。
  3. 数值化转换用到了mindspore.dataset.text的有关方法,读者可以自行查阅,后续会出相关的解读文章。
====== sample ======
{'Sepal.Length': Tensor(shape=[], dtype=Float32, value= 5.5), 'Sepal.Width': Tensor(shape=[], dtype=Float32, value= 4.2), 'Petal.Length': Tensor(shape=[], dtype=Float32, value= 1.4), 'Petal.Width': Tensor(shape=[], dtype=Float32, value= 0.2), 
'Classes': Tensor(shape=[], dtype=Int32, value= 0)}

后续:

  • 这里还存在其他字段的数据归一化,就留待读者去尝试了。
  • 数值化转换部分,也可以通过在数据分配部分增加代码来提前转换,读者也可以进行尝试。

5. 本文总结

本文对MindSpore中的CSVDataset数据集接口进行了探索和示例展示。通过错误试探,发现目前CSVDataset的文档和功能还相对较弱,只能说是可用。

6. 问题改进

6.1 column_defaults文档错误

英文文档

column_defaults (list, optional) – List of default values for the CSV field (default=None). Each item in the list is either a valid type (float, int, or string). If this is not provided, treats all columns as string type.

中文文档

column_defaults (list, 可选) - 指定每个数据列的数据类型,有效的类型包括float、int或string。默认值:None,不指定。如果未指定该参数,则所有列的数据类型将被视为string。

这里中文翻译有误。其实英文API就有一定的歧义性,前面说了是每个字段的默认值(CSV文件中存在字段为空的情况),后面又说如果为空,则按照string类型处理,让人分不清究竟是数据类型实例还是数据类型。

**注意:**其实这里既有数据类型实例的意思,又有数据类型的意思。当指定了column_defaults参数,则字段的默认值为column_defaults中相应位置的值,字段的类型为column_defaults相应位置值的数据类型。例如:某CSV文件包含三个字段,指定column_defaults为[2.0, 1, “x”],则读取该文件时,三个字段的类型会被识别为float、int、str,如果某行中第二个字段为空,则就用默认值1填充。

6.2 不支持文件含有header

如题

6.3 不支持读取指定字段

如题,API层面不显式支持,不过可以通过后续的数据处理来支持。

7. 本文参考

本文为原创文章,版权归作者所有,未经授权不得转载!

标签:iris,--,精讲,dataset,column,file,易点通,csv,data
来源: https://www.cnblogs.com/skytier/p/16481392.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有