使用 Zarr 儲存資料集並用 PyTorch Dataset 讀取

作者: Yong-Siang Shih / Sun 11 November 2018 / 分類: Notes

PyTorch, hdf5, zarr

前言

當使用大量的資料進行機器學習系統的訓練時,總是得花一些心力處理資料集的讀取。尤其如果有大量的圖片,則直接用檔案系統存放很多圖片不僅有些傳送和儲存的麻煩,也會讓讀檔變得沒有效率。

雖然也曾接觸過 HDF5 等資料格式,但不知為何存取起來相當緩慢。所以後來就都直接用 binary 檔案,自己寫資料格式的存取。然而,這樣每次不同的資料無法直接通用,且自己寫檔案格式也很難保證速度上達到最好的效果。

直到最近發現有個叫做 Zarr 的檔案存取格式,使用起來似乎相當快速。是以撰寫此文紀錄使用的方式。

本次的實驗程式放在 shaform/experiments/zarr-dataset

環境設定

由於這次要使用 PyTorch,所以我將使用 Conda 來進行環境配置。首先安裝必要套件:

conda create -n zarr python=3.6
conda activate zarr # or source activate zarr or older version of conda
conda install zarr tqdm python-lmdb -c conda-forge

緊接著,我們下載 AnimeGAN動漫人臉資料集,將其解壓到目錄之下,所以目前的目錄看起來如下:

.
├── anime-faces
│   ├── 1boy
│   │   ├── danbooru_2637825_e3d8c4f9d55f25217cf5600874e664be.png
│   │   ├── danbooru_2637834_dceb8d822bd1326cb0865440d23d39b8.png
│   │   ├── ....
│   ├── 1girl
│   │   ├── danbooru_2635679_21741d5772cb3275165be0b68a286155.jpg
│   │   ├── ....

緊接著刪除一些奇怪的檔案:

rm -f anime-faces/aqua_eyes/._danbooru_2559693_dc628b766d7142f2d2d9c75559e36eb5.jpg
rm -f anime-faces/aqua_eyes/._danbooru_2560862_796530ab01cc7bfd8a03c8d05cc6953b.png

如此一來就準備好了。

資料處理

在動漫人臉資料集裡,每個資料夾代表一個分類,共有 126 個分類。而每個資料夾裡,有許多 96x96x3 的人臉。

現在我們要將檔案轉成 training set 和 validation set。首先利用 PyTorch 的 ImageFolder 將資料讀進來,他會自動依照資料夾給標籤。所以每個圖片會變成一對 (image, label),其中 image 是一個 PIL.Imagelabel 則是一個數字。

我們利用 as_arrayPIL.Image 轉成 unit8 的 numpy array,然後把 channels 改成第一個維度,好符合 PyTorch 的慣用格式。

接著就用 torch.utils.data.random_split 幫我們分開資料。

# convert_anime_faces.py

import numpy as np
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split


def as_array(image):
    return np.asarray(image).swapaxes(2, 0)


def main():
    data_set = ImageFolder(root='anime-faces', transform=as_array)

    val_ratio = 0.1
    val_size = int(len(data_set) * val_ratio)
    train_size = len(data_set) - val_size

    train_set, val_set = random_split(data_set, [train_size, val_size])

    confs = [
        ('data/anime_faces/train.lmdb', train_set),
        ('data/anime_faces/val.lmdb', val_set),
    ]
    for path, data_set in confs:
        convert_data_set(path, data_set)

然後就可以開始把他轉成 zarr 資料集了。

Zarr 有支援各種不同的檔案儲存格式,這裡我們用 lmdb。注意到由於 lmdb 預設使用 sparse file,在 Linux 下你可能會看到 *.lmdb 資料夾足足佔了 1 TB 那麼大,但實際上使用 du -hs 檢查卻會發現其實沒有佔那麼大。這是正常的現象。

這邊主要注意的是因為 PyTorch 會任意存取不同 indices,所以 chunk 的第一個維度必須設為 1,這樣才能達到最高的速度。其他維度因為每次都是一次讀出來,所以就設成 Nonechunk 的功用是,如果維度的數值是 n 的話,壓縮時就會把該維度的 n 個元素壓縮在一起,所以如果你常常一次讀 n 個元素的話,使用 chunk 就會比較快。設成 None 就表示把該維度所有資料都放在一起。

至於 u1 指的則是 1 byte 的 uint,也就是 uint8

# convert_anime_faces.py

def convert_data_set(path, data_set, batch_size=1000):
    loader = DataLoader(
        data_set, batch_size=batch_size, shuffle=False, num_workers=4)

    num_examples = len(data_set)

    os.makedirs(path, exist_ok=True)
    with zarr.LMDBStore(path) as store:
        root = zarr.group(store=store, overwrite=True)
        images_set = root.zeros(
            'images',
            shape=(num_examples, 3, 96, 96),
            chunks=(1, None, None, None),
            dtype='u1')
        labels_set = root.zeros(
            'labels', shape=(num_examples, ), chunks=(1, ), dtype='u1')
        current_iter = 0
        for images, labels in tqdm(loader):
            size = images.shape[0]
            images_set[current_iter:current_iter + size] = images
            labels_set[current_iter:current_iter + size] = labels
            current_iter += size

資料讀取

那麼就可以讀取資料來進行訓練了。這裡我們讀出來的同時,也把檔案格式轉成範圍是 [0, 1] 的圖片,以及資料型態是 torch.long 的標籤。

注意到,如果要使用 num_works > 1 的話,必須要在 __getitem__ 裡頭再開啟資料集,否則會有問題。而因為 zarr 支援多個程式同時讀寫檔案,所以同時讀取是沒問題的。

# test_anime_faces.py

import os

import zarr
import torch
from torch.utils.data import Dataset


class FaceDataset(Dataset):
    def __init__(self, path, transforms=None):
        self.path = path
        self.keys = ('images', 'labels')
        assert os.path.exists(path), 'file `{}` not exists!'.format(path)

        with zarr.LMDBStore(path) as store:
            zarr_db = zarr.group(store=store)
            self.num_examples = zarr_db['labels'].shape[0]
        self.datasets = None

        if transforms is None:
            transforms = {
                'labels': lambda v: torch.tensor(v, dtype=torch.long),
                'images': lambda v: torch.tensor((v - 127.5)/127.5, dtype=torch.float32)
            }
        self.transforms = transforms

    def __len__(self):
        return self.num_examples

    def __getitem__(self, idx):
        if self.datasets is None:
            store = zarr.LMDBStore(self.path)
            zarr_db = zarr.group(store=store)
            self.datasets = {key: zarr_db[key] for key in self.keys}

        items = []
        for key in self.keys:
            item = self.datasets[key][idx]
            if key in self.transforms:
                item = self.transforms[key](item)
            items.append(item)
        return items

最後我們寫一個簡單的 CNN 來進行測試:

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm, trange


class Model(nn.Module):
    def __init__(self, input_size=96 * 96 * 3, output_size=126,
                 hidden_size=25):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=6, stride=2, padding=2),
            nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(
                kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=6, stride=2, padding=2),
            nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(
                kernel_size=2, stride=2))
        self.fc = nn.Linear(6 * 6 * 32, output_size)
        self.criteria = nn.CrossEntropyLoss()

    def forward(self, inputs):
        outputs = self.layer1(inputs)
        outputs = self.layer2(outputs)
        outputs = outputs.reshape(outputs.size(0), -1)
        outputs = self.fc(outputs)
        return outputs


def main(batch_size=64, epochs=50):
    data_train = FaceDataset('data/anime_faces/train.lmdb')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    loader = DataLoader(data_train, batch_size=batch_size, num_workers=10)
    model = Model()
    model.to(device)
    model.train()
    optim = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in trange(epochs):
        t = tqdm(loader)
        for i, (images, labels) in enumerate(t):
            images = images.to(device)
            labels = labels.to(device)

            optim.zero_grad()
            logits = model(images)
            loss = model.criteria(logits, labels)
            loss.backward()
            optim.step()

            predicts = torch.argmax(F.softmax(logits, dim=1), dim=1)
            accuracy = (predicts == labels).to(torch.float32).mean()
            t.set_postfix(
                epoch=epoch, i=i, loss=loss.item(), accuracy=accuracy.item())

最後就可以測試訓練完的模型:

    data_val = FaceDataset('data/anime_faces/val.lmdb')
    val_loader = DataLoader(data_val, batch_size=batch_size, num_workers=0)
    total = len(data_val)
    total_correct = 0
    model.eval()
    for images, labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)
        logits = model(images)
        predicts = torch.argmax(F.softmax(logits, dim=1), dim=1)
        correct = (predicts == labels).sum()
        total_correct += correct.item()
    print('Val accuracy = {}'.format(total_correct / total))

假設你真的照上述的程式跑的話,雖然訓練集的準確率可以達到 72% ,但測試集的結果準確率才 9.2 %,看來必須要做更好的處理來避免 over-fitting 才行。

結語

本次的實驗程式放在 shaform/experiments/zarr-dataset

Yong-Siang Shih

作者

Yong-Siang Shih

軟體工程師,機器學習科學家,開放原始碼愛好者。曾在 Appier 從事機器學習系統開發,也曾在 Google, IBM, Microsoft 擔任軟體實習生。喜好探索學習新科技。

載入 Disqus 評論