[关闭]
@chenyaofo 2021-09-29T11:29:42.000000Z 字数 7411 阅读 2081

Webdataset 加速深度学习数据加载

导言:在大规模数据上进行深度学习通常会因为IO瓶颈而拖慢训练的速度,本文介绍了webdataset是如何在深度学习中加速大规模数据加载的。

webdataset 简介

webdataset是什么webdataset是一个数据加载的库,其可以从tar文件中直接读取数据样本而无需将tar包中的所有文件释放出来。从某个角度看,webdataset制定了一种基于tar包的大规模数据格式,其实就是翻版的tfrecord,只不过tfrecord是google专门搞出来的格式,而webdataset直接使用tar这种通用的数据格式,没有自己另外再搞一种二进制格式。此外,webdataset是专门为PyTorch写的,可以很容易集成到已有的PyTorch代码中(其实稍微改改应该很容易集成到任何深度学习框架中)。webdataset的主要目的是为了解决传统数据加载方式(就是直接从磁盘中加载大量数据集文件)存在的一些问题。

传统数据加载方式有什么问题?当今的大规模数据集包含了大量的数据样本,例如ImageNet包括约130万图片,OpenImage包括约900万图片,这还只是开胃菜,在大公司里面还有比这些大得多的数据集。如果这些图片样本直接存放在文件系统/对象存储系统上,数据读取会给这些系统带来极大的压力。原因包括以下几点:

webdataset是如何解决上述问题的:webdataset将数据样本文件打包,但是这里注意不是将所有文件打成一个特别大的包,而是将其打成若干个包。以ImageNet为例,我们可以将130万个文件打包为256个tar包,平均每个tar文件包含5k个样本。在读取的时候,webdataset将这256个tar包顺序打乱,然后按照打乱的顺序依次读取tar包。在读取每一个tar包的时候,里面存储的样本将会被顺序读取(因此很快),但是这样的话达不到打乱整个数据集的目的。因此webdataset维护了一个buffer,新读取的样本将会和buffer中的一个随机样本交换,达到打乱数据集的目的:

  1. # read sample from the given tar file
  2. k = rng.randint(0, len(buf) - 1)
  3. sample, buf[k] = buf[k], sample
  4. # return sample here

这样webdataset将把上述传统数据加载方提到的缺点都解决了,需要注意的是webdataset的数据集打乱程度是和这个buffer的大小有关系,在实际中需要设置一个足够大的数值。其实webdataset的工作原理和tfrecord是一模一样的,用tensorflow的同学应该是很容易理解。

性能对比:webdataset vs. 原生数据加载

为了突显webdataset的优秀特性,我们将imagenet打包为webdataset支持的一系列tar包并比较使用webdataset加载和使用pytorch原生的ImageFolder加载的速度。注意这里为了更好的对比IO速度,我们把图片文件的所有字节加载到内存中就够了,并没有进行图片解码和任何的预处理操作。下面我们分别给出了在机械硬盘和固态硬盘上用webdataset和原生数据加载方式的速度对比。

机械硬盘对比结果:在机械硬盘上,webdataset基本上带来了10倍的读取速度提升。如此巨大的性能提升是因为机械硬盘的顺序读取速度比随机读取快太多了,而webdataset这个库很好地利用了这一点,几乎把所有的文件读取都变成了顺序读取。从每秒加载的图片文件大小来看,webdataset已经非常接近这块机械硬盘的读取上限(~170MB/s),基本做到了极致。

每秒加载的图片数量对比:

线程数 1 2 4 6 8
原生加载 83.20 86.19 104.40 112.96 120.42
Webdataset加载 1447.39 1423.57 1215.70 1160.79 1020.50

每秒加载的图片文件大小 (MB/s) 对比:

线程数 1 2 4 6 8
原生加载 9.17 9.36 11.48 12.35 13.31
Webdataset加载 159.77 155.47 134.51 125.99 112.43

固态硬盘对比结果:在固态硬盘上,webdataset带来了从27%到56%不等的读取速度提升,这个提升远没有机械硬盘来的惊艳,但是有提升总好过没有是不是。提升比较小地原因时固态硬盘的随机读写性能相对于机械硬盘已经好了太多太多(见附录中的硬盘读写性能测试)。另外要说一点这块是SATA的固态硬盘,如果是NVME的固态硬盘,这一差距还会继续的缩小。

每秒加载的图片数量对比:

线程数 1 2 4 6 8
原生加载 1936.34 2339.95 3299.90 3515.42 3536.51
Webdataset加载 2567.04 3665.97 4383.44 4539.08 4503.89

每秒加载的图片文件大小 (MB/s) 对比:

线程数 1 2 4 6 8
原生加载 153.54 255.70 361.70 385.44 387.74
Webdataset加载 281.63 403.49 482.24 496.40 495.29

结论:在机械硬盘上强烈推荐使用webdataset作为数据加载方式,在固态硬盘上也十分推荐(其实固态硬盘的原生ImageFolder的加载速度已经非常够用了)。此外,上面我们讨论的主要是本地加载数据的情况,如果在云上进行机器学习模型训练,数据文件往往会直接从分布式文件系统或者对象存储上进行读取。如果数据集文件过多也会导致分布式文件/对象存储处理过多无用的元数据,并且小文件过多也无法一直保持网络带宽的最大化利用,这些问题都会导致数据加载变成训练过程中的瓶颈,而webdataset也能很好处理这个场景(事实上,这个库就是为这类场景发明的,所以叫做webdataset)。

附录

硬件环境:这里列出本次测试的硬件环境:

硬件名称 具体型号
CPU Intel(R) Core(TM) i5-8500 CPU @ 3.00GHz
内存 Kingston 2666MHz 8G x 2
机械硬盘 Western Digital Blue 1T 7200 rpm
固态硬盘 Intel 545s Series 256G

本次对比测试将会在机械硬盘和固态硬盘上进行,下面给出一些fio脚本测试得到的数据,以便更好对后面的实验对比结果进行分析:

  1. |Name | Read(MB/s)| Write(MB/s)|
  2. |--------------|------------|------------|
  3. | SEQ1M Q1 T1 | 484.345| 352.739|
  4. | SEQ1M Q8 T1 | 516.135| 425.775|
  5. | RND4K Q32T16 | 311.884| 282.172|
  6. | . IOPS | 76143.543| 68889.666|
  7. | . latency us | 6.709| 7.422|
  8. | RND4K Q1 T1 | 42.470| 106.964|
  9. | . IOPS | 10368.547| 26114.252|
  10. | . latency us | 0.096| 0.037|
  1. |Name | Read(MB/s)| Write(MB/s)|
  2. |--------------|------------|------------|
  3. | SEQ1M Q1 T1 | 179.546| 79.943|
  4. | SEQ1M Q8 T1 | 172.259| 79.952|
  5. | RND4K Q32T16 | 2.335| 0.929|
  6. | . IOPS | 569.994| 226.896|
  7. | . latency us | 823.919| 1837.834|
  8. | RND4K Q1 T1 | 0.686| 1.094|
  9. | . IOPS | 167.532| 267.147|
  10. | . latency us | 5.962| 3.734|

测试说明:为了测试的公平性,每一次测试之前都会使用命令sync; echo 3 > /proc/sys/vm/drop_caches清空所有缓冲区,包括页面缓存,目录项和inode以保证数据确实是从硬盘加载而不是来自于内存缓存。

webdataset 构建:我们使用下面的代码将imagenet的train部分打包为tar包,这里的打包代码是自己写的,用了多进程,这个库给的代码是单进程的,慢得离谱。。。

  1. import os
  2. import random
  3. import datetime
  4. from multiprocessing import Process
  5. from torchvision import datasets
  6. from torchvision.datasets import ImageNet
  7. from torchvision.datasets.folder import ImageFolder
  8. from webdataset import TarWriter
  9. def make_wds_shards(pattern, num_shards, num_workers, samples, map_func, **kwargs):
  10. random.shuffle(samples)
  11. samples_per_shards = [samples[i::num_shards] for i in range(num_shards)]
  12. shard_ids = list(range(num_shards))
  13. processes = [
  14. Process(
  15. target=write_partial_samples,
  16. args=(
  17. pattern,
  18. shard_ids[i::num_workers],
  19. samples_per_shards[i::num_workers],
  20. map_func,
  21. kwargs
  22. )
  23. )
  24. for i in range(num_workers)]
  25. for p in processes:
  26. p.start()
  27. for p in processes:
  28. p.join()
  29. def write_partial_samples(pattern, shard_ids, samples, map_func, kwargs):
  30. for shard_id, samples in zip(shard_ids, samples):
  31. write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs)
  32. def write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs):
  33. fname = pattern % shard_id
  34. print(f"[{datetime.datetime.now()}] start to write samples to shard {fname}")
  35. stream = TarWriter(fname, **kwargs)
  36. size = 0
  37. for item in samples:
  38. size += stream.write(map_func(item))
  39. stream.close()
  40. print(f"[{datetime.datetime.now()}] complete to write samples to shard {fname}")
  41. return size
  42. if __name__ == "__main__":
  43. root = "/gdata/ImageNet2012/train"
  44. items = []
  45. dataset = ImageFolder(root=root, loader=lambda x:x)
  46. for i in range(len(dataset)):
  47. items.append(dataset[i])
  48. print(dataset[0],os.path.splitext(os.path.basename(dataset[0][0]))[0])
  49. def map_func(item):
  50. name, class_idx = item
  51. with open(os.path.join(name), "rb") as stream:
  52. image = stream.read()
  53. sample = {
  54. "__key__": os.path.splitext(os.path.basename(name))[0],
  55. "jpg": image,
  56. "cls": str(class_idx).encode("ascii")
  57. }
  58. return sample
  59. make_wds_shards(
  60. pattern="/userhome/tars/imagenet-1k-%06d.tar",
  61. num_shards=256, # 设置分片数量
  62. num_workers=8, # 设置创建wds数据集的进程数
  63. samples=items,
  64. map_func=map_func,
  65. )

测试代码:我们测试随机从数据集中读取N张图的耗时(在机械硬盘上N=30000,在固态硬盘上N=300000),根据耗时计算每秒读取图片的数量和吞吐量

  1. import os
  2. import time
  3. import torch
  4. import webdataset as wds
  5. from torchvision.datasets import ImageFolder
  6. from torch.utils.data import DataLoader
  7. def get_ori_loader(disk, num_workers):
  8. def read_bytes(path):
  9. with open(path, "rb") as f:
  10. return f.read()
  11. root = "/mnt/extend/imagenet/train" if disk == "hdd" else "/home/chenyaofo/webdataset-test/train"
  12. dataset = ImageFolder(root, loader=read_bytes, is_valid_file=lambda x: True)
  13. dataloader = DataLoader(dataset, num_workers=num_workers, shuffle=True, batch_size=128)
  14. return dataloader
  15. def get_wds_loader(disk, num_workers):
  16. url = "/mnt/extend/tars/imagenet-1k-{000000..000256}.tar" if disk == "hdd" else "/home/chenyaofo/webdataset-test/tars/imagenet-1k-{000000..000256}.tar"
  17. def my_decoder(key, value):
  18. if not key.endswith(".jpg"):
  19. return None
  20. assert isinstance(value, bytes)
  21. return value
  22. dataset = wds.WebDataset(url).shuffle(1000).decode(my_decoder)
  23. dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=128)
  24. return dataloader
  25. def run_test(loader, disk):
  26. N_stop = 30000 if disk == "hdd" else 300000
  27. start = time.perf_counter()
  28. total_batch_size = 0
  29. total_bytes = 0
  30. for items in loader:
  31. if isinstance(items, dict):
  32. batch_size = len(items['jpeg.cls'])
  33. n_bytes = sum(map(lambda x: len(x), items['jpeg.jpg']))
  34. else:
  35. batch_size = len(items[1])
  36. n_bytes = sum(map(lambda x: len(x), items[0]))
  37. total_batch_size += batch_size
  38. # print(total_batch_size)
  39. total_bytes += n_bytes
  40. if total_batch_size > N_stop:
  41. end = time.perf_counter()
  42. return total_batch_size, total_bytes, end-start
  43. for disk in ["ssd", "hdd"]:
  44. for get_loader in [get_ori_loader, get_wds_loader]:
  45. for num_workers in [1, 2, 4, 6, 8]:
  46. os.system("sync; echo 3 > /proc/sys/vm/drop_caches")
  47. loader = get_loader(disk, num_workers)
  48. total_batch_size, total_bytes, time_cost = run_test(loader, disk)
  49. print(f"{disk}, {get_loader.__name__}, num_workers={num_workers}, fps={total_batch_size/time_cost:.2f}, throughput={total_bytes/(1024)**2/time_cost:.2f} MB/s")
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注