前言
當使用大量的資料進行機器學習系統的訓練時,總是得花一些心力處理資料集的讀取。尤其如果有大量的圖片,則直接用檔案系統存放很多圖片不僅有些傳送和儲存的麻煩,也會讓讀檔變得沒有效率。
雖然也曾接觸過 HDF5 等資料格式,但不知為何存取起來相當緩慢。所以後來就都直接用 binary 檔案,自己寫資料格式的存取。然而,這樣每次不同的資料無法直接通用,且自己寫檔案格式也很難保證速度上達到最好的效果。
直到最近發現有個叫做 Zarr 的檔案存取格式,使用起來似乎相當快速。是以撰寫此文紀錄使用的方式。
本次的實驗程式放在 shaform/experiments/zarr-dataset。
環境設定
由於這次要使用 PyTorch,所以我將使用 Conda 來進行環境配置。首先安裝必要套件:
conda create -n zarr python=3.6
conda activate zarr # or source activate zarr or older version of conda
conda install zarr tqdm python-lmdb -c conda-forge
緊接著,我們下載 AnimeGAN 的動漫人臉資料集,將其解壓到目錄之下,所以目前的目錄看起來如下:
.
├── anime-faces
│ ├── 1boy
│ │ ├── danbooru_2637825_e3d8c4f9d55f25217cf5600874e664be.png
│ │ ├── danbooru_2637834_dceb8d822bd1326cb0865440d23d39b8.png
│ │ ├── ....
│ ├── 1girl
│ │ ├── danbooru_2635679_21741d5772cb3275165be0b68a286155.jpg
│ │ ├── ....
緊接著刪除一些奇怪的檔案:
rm -f anime-faces/aqua_eyes/._danbooru_2559693_dc628b766d7142f2d2d9c75559e36eb5.jpg
rm -f anime-faces/aqua_eyes/._danbooru_2560862_796530ab01cc7bfd8a03c8d05cc6953b.png
如此一來就準備好了。
資料處理
在動漫人臉資料集裡,每個資料夾代表一個分類,共有 126 個分類。而每個資料夾裡,有許多 96x96x3 的人臉。
現在我們要將檔案轉成 training set 和 validation set。首先利用 PyTorch 的 ImageFolder
將資料讀進來,他會自動依照資料夾給標籤。所以每個圖片會變成一對 (image, label)
,其中 image
是一個 PIL.Image
而 label
則是一個數字。
我們利用 as_array
把 PIL.Image
轉成 unit8
的 numpy array,然後把 channels 改成第一個維度,好符合 PyTorch 的慣用格式。
接著就用 torch.utils.data.random_split
幫我們分開資料。
# convert_anime_faces.py
import numpy as np
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split
def as_array(image):
return np.asarray(image).swapaxes(2, 0)
def main():
data_set = ImageFolder(root='anime-faces', transform=as_array)
val_ratio = 0.1
val_size = int(len(data_set) * val_ratio)
train_size = len(data_set) - val_size
train_set, val_set = random_split(data_set, [train_size, val_size])
confs = [
('data/anime_faces/train.lmdb', train_set),
('data/anime_faces/val.lmdb', val_set),
]
for path, data_set in confs:
convert_data_set(path, data_set)
然後就可以開始把他轉成 zarr
資料集了。
Zarr 有支援各種不同的檔案儲存格式,這裡我們用 lmdb。注意到由於 lmdb 預設使用 sparse file,在 Linux 下你可能會看到 *.lmdb
資料夾足足佔了 1 TB 那麼大,但實際上使用 du -hs
檢查卻會發現其實沒有佔那麼大。這是正常的現象。
這邊主要注意的是因為 PyTorch 會任意存取不同 indices,所以 chunk
的第一個維度必須設為 1
,這樣才能達到最高的速度。其他維度因為每次都是一次讀出來,所以就設成 None
。chunk
的功用是,如果維度的數值是 n
的話,壓縮時就會把該維度的 n
個元素壓縮在一起,所以如果你常常一次讀 n
個元素的話,使用 chunk
就會比較快。設成 None
就表示把該維度所有資料都放在一起。
至於 u1
指的則是 1 byte 的 uint,也就是 uint8
。
# convert_anime_faces.py
def convert_data_set(path, data_set, batch_size=1000):
loader = DataLoader(
data_set, batch_size=batch_size, shuffle=False, num_workers=4)
num_examples = len(data_set)
os.makedirs(path, exist_ok=True)
with zarr.LMDBStore(path) as store:
root = zarr.group(store=store, overwrite=True)
images_set = root.zeros(
'images',
shape=(num_examples, 3, 96, 96),
chunks=(1, None, None, None),
dtype='u1')
labels_set = root.zeros(
'labels', shape=(num_examples, ), chunks=(1, ), dtype='u1')
current_iter = 0
for images, labels in tqdm(loader):
size = images.shape[0]
images_set[current_iter:current_iter + size] = images
labels_set[current_iter:current_iter + size] = labels
current_iter += size
資料讀取
那麼就可以讀取資料來進行訓練了。這裡我們讀出來的同時,也把檔案格式轉成範圍是 [0, 1]
的圖片,以及資料型態是 torch.long
的標籤。
注意到,如果要使用 num_works > 1
的話,必須要在 __getitem__
裡頭再開啟資料集,否則會有問題。而因為 zarr 支援多個程式同時讀寫檔案,所以同時讀取是沒問題的。
# test_anime_faces.py
import os
import zarr
import torch
from torch.utils.data import Dataset
class FaceDataset(Dataset):
def __init__(self, path, transforms=None):
self.path = path
self.keys = ('images', 'labels')
assert os.path.exists(path), 'file `{}` not exists!'.format(path)
with zarr.LMDBStore(path) as store:
zarr_db = zarr.group(store=store)
self.num_examples = zarr_db['labels'].shape[0]
self.datasets = None
if transforms is None:
transforms = {
'labels': lambda v: torch.tensor(v, dtype=torch.long),
'images': lambda v: torch.tensor((v - 127.5)/127.5, dtype=torch.float32)
}
self.transforms = transforms
def __len__(self):
return self.num_examples
def __getitem__(self, idx):
if self.datasets is None:
store = zarr.LMDBStore(self.path)
zarr_db = zarr.group(store=store)
self.datasets = {key: zarr_db[key] for key in self.keys}
items = []
for key in self.keys:
item = self.datasets[key][idx]
if key in self.transforms:
item = self.transforms[key](item)
items.append(item)
return items
最後我們寫一個簡單的 CNN 來進行測試:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
class Model(nn.Module):
def __init__(self, input_size=96 * 96 * 3, output_size=126,
hidden_size=25):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=6, stride=2, padding=2),
nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(
kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=6, stride=2, padding=2),
nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(
kernel_size=2, stride=2))
self.fc = nn.Linear(6 * 6 * 32, output_size)
self.criteria = nn.CrossEntropyLoss()
def forward(self, inputs):
outputs = self.layer1(inputs)
outputs = self.layer2(outputs)
outputs = outputs.reshape(outputs.size(0), -1)
outputs = self.fc(outputs)
return outputs
def main(batch_size=64, epochs=50):
data_train = FaceDataset('data/anime_faces/train.lmdb')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loader = DataLoader(data_train, batch_size=batch_size, num_workers=10)
model = Model()
model.to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in trange(epochs):
t = tqdm(loader)
for i, (images, labels) in enumerate(t):
images = images.to(device)
labels = labels.to(device)
optim.zero_grad()
logits = model(images)
loss = model.criteria(logits, labels)
loss.backward()
optim.step()
predicts = torch.argmax(F.softmax(logits, dim=1), dim=1)
accuracy = (predicts == labels).to(torch.float32).mean()
t.set_postfix(
epoch=epoch, i=i, loss=loss.item(), accuracy=accuracy.item())
最後就可以測試訓練完的模型:
data_val = FaceDataset('data/anime_faces/val.lmdb')
val_loader = DataLoader(data_val, batch_size=batch_size, num_workers=0)
total = len(data_val)
total_correct = 0
model.eval()
for images, labels in val_loader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
predicts = torch.argmax(F.softmax(logits, dim=1), dim=1)
correct = (predicts == labels).sum()
total_correct += correct.item()
print('Val accuracy = {}'.format(total_correct / total))
假設你真的照上述的程式跑的話,雖然訓練集的準確率可以達到 72% ,但測試集的結果準確率才 9.2 %,看來必須要做更好的處理來避免 over-fitting 才行。
結語
本次的實驗程式放在 shaform/experiments/zarr-dataset。