用 TensorBox 製作簡易貓貓辨識器

作者: Yong-Siang Shih / Sat 05 November 2016 / 分類: Notes

cats, object detection, TensorBox, TensorFlow

最近恰好有個需求是要收集大量特定物體的圖片,直覺的想法就是訓練一個該物體的偵測器,然後再用這個偵測器從大量圖片中找出符合需求的區塊。經過一番搜尋,發現 TensorBox 似乎是個用來訓練單一物體偵測器的簡單套件,於是便利用貓咪來進行了簡單的嘗試。

Detect a Cat

本文使用的套件有的是只支援 Python 2,此時執行時將會以 python 程式執行。但筆者寫的 scripts 大多是 Python 3,本文中就以 python3 來表示。

本文將會用到不少套件,實際放置套件的資料夾結構如下:

./
./cats
./TensorBox
./labelImg
./labels
./data
./data/train
./data/test

實際的程式可在 shaform/experiments/cat 裡找到,本文只會顯示重點部份。

收集資料

首先我們使用 maxogden/cats 收集好的小規模貓貓照片來進行實驗,我們將利用 catmapper 資料夾中的照片做為訓練資料,最後再用 cat_photos 裡的照片來觀看結果。

值得注意的是,這樣的訓練資料只有大約三百多筆,而且每張照片裡幾乎都有貓。實際上若是真的要訓練高品質的偵測器,應當收集更多的訓練資料,而且應該要加入一些沒有貓,或者有跟貓很像但不是貓的物體,來加強模型的能力。不過本文純屬示範工具用法,未免麻煩就不做深入探討了。

由於 TensorBox 比較支援圖片大小為 32 倍數的照片,不符時須縮放或擷取。為了方便起見我就先把圖片都轉成適當大小。首先安裝 Pillow 套件:

pip3 install Pillow

緊接著寫一個程式可以將圖片裁成指定大小 608x608,然後轉成 png 檔案格式儲存:

import os

from PIL import Image

image_extensions = {'png', 'jpg', 'jpeg'}


def crop_center(img, width, height):
    x = img.size[0] // 2
    y = img.size[1] // 2

    x1 = x - (width // 2)
    x2 = x + (width - width // 2)

    y1 = y - (height // 2)
    y2 = y + (height - height // 2)

    try:
        return img.crop((x1, y1, x2, y2))
    except OSError as e:
        print(e)
        return None


def main(indir, outdir, width, height):
    for fname in os.listdir(indir):
        parts = fname.rsplit('.', 1)
        if len(parts) == 2 and parts[1] in image_extensions:
            name, ext = parts
            img = Image.open(os.path.join(indir, fname))

            if img.size[0] >= width and img.size[1] >= height:
                img = crop_center(img, width, height)
                if img:
                    img.save(os.path.join(outdir, name + '.png'))

並且將訓練和測試資料進行處理:

mkdir data/train data/test
python3 crop.py --indir cats/catmapper --outdir data/train
python3 crop.py --indir cats/cat_photos --outdir data/test

標記貓貓

緊接著,我們使用 labelImg 來針對訓練資料 cats/catmapper 做標記,並將標記存放在 labels 資料夾裡。

注意到 labelImg 只支援 Python 2,同時要先按照官方安裝步驟進行初始化:

cd labelImg
sudo apt-get install pyqt4-dev-tools
sudo pip install lxml
make all
./labelImg.py

可以利用 Ctrl-N, N, P 等快速鍵,建立框框來標記貓咪,以及切換上一張、下一張照片。

緊接著,撰寫一個程式將標記好的資料轉成 TensorBox 能夠讀取的格式:

import json
import os
import random
import xml.etree.ElementTree as ET


def main(indir, outdir, name, seed):
    random.seed(seed)
    if name is None:
        name = os.path.basename(indir)

    entries = []

    for path in sorted(os.listdir(indir)):
        if path.endswith('.xml'):
            tree = ET.parse(os.path.join(indir, path))
            root = tree.getroot()

            img_path = root.findtext('path')
            rects = []

            for obj in root.iter('object'):
                x1 = float(obj.find('bndbox').findtext('xmin'))
                y1 = float(obj.find('bndbox').findtext('ymin'))
                x2 = float(obj.find('bndbox').findtext('xmax'))
                y2 = float(obj.find('bndbox').findtext('ymax'))

                rects.append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2})

            entry = {'image_path': img_path, 'rects': rects}
            entries.append(entry)

    random.shuffle(entries)

    train_offset = int(len(entries) * 0.8)

    train = entries[:train_offset]
    val = entries[train_offset:]

    datasets = [('train', train), ('val', val)]

    for prefix_name, dataset in datasets:
        with open(
                os.path.join(outdir, '{}_{}_boxes.json'.format(
                    name, prefix_name)), 'w') as f:
            json.dump(dataset, f, indent=4, separators=(',', ': '))

將標記轉換到 data 資料夾之下:

python3 extract_label.py --indir labels --outdir data

安裝 TensorBox 並進行訓練

TensorBox 似乎是只支援 Python 2,所以首先設定好 TensorFlow 的 Python 2 版本。接著還要安裝一些相依套件(因為 OpenCV 似乎比較難設定,所以我安裝系統套件):

sudo apt install python-opencv
pip install Cython numpy scikit-image jupyter

接下來下載 TensorBox:

git clone http://github.com/russell91/Tensorbox
cd Tensorbox
./download_data.sh
cd utils && make && cd ..

緊接著修改 TensorBox/hypes/overfeat_rezoom.json 對應的項目:

注意到,train_idltest_idl 須使用絕對路徑。

"train_idl": "/path/to/data/labels_train_boxes.json",
"test_idl": "/path/to/data/labels_val_boxes.json",
"image_width": 608, 
"image_height": 608,
"grid_height": 19, 
"grid_width": 19,

最後就是實際進行訓練了!這個會跑很久,並且把訓練的不同階段的模型存在 TensorBox/output/overfeat_rezoom_{DATETIME}/save.ckpt-{iteration}。由於訓練資料實在很少,所以不須訓練到太後面,可以拿前面的版本就好。

cd TensorBox
python train.py --hypes hypes/overfeat_rezoom.json --gpu 0 --logdir output

測試結果

最後就用 Jupyter Notebook 來檢視我們的結果,先把之前提到的 save.ckpt-10000 以及 save.ckpt-10000.meta 存到 project 根目錄。然後在根目錄用跟 TensorBox 一樣的 Python 2 環境執行:

jupyter notebook

進行初始化:

%matplotlib inline

import json
import os
import random

import tensorflow as tf

import matplotlib.pyplot as plt

from scipy.misc import imread

方便起見將 TensorBox 直接加到路徑,平時請勿模仿:

import sys
sys.path.append('./TensorBox')

from train import build_forward
from evaluate import add_rectangles

最後照抄 TensorBox 的示範程式碼,更改一些檔案路徑:

model_path = './save.ckpt-10000'
image_dir = './data/test'
hypes_file = './TensorBox/hypes/overfeat_rezoom.json'

with open(hypes_file, 'r') as f:
    H = json.load(f)

tf.reset_default_graph()
x_in = tf.placeholder(
    tf.float32, name='x_in', shape=[H['image_height'], H['image_width'], 3])

if H['use_rezoom']:
    pred_boxes, pred_logits, pred_confidences, pred_confs_deltas, pred_boxes_deltas = build_forward(
        H, tf.expand_dims(x_in, 0), 'test', reuse=None)
    grid_area = H['grid_height'] * H['grid_width']
    pred_confidences = tf.reshape(
        tf.nn.softmax(
            tf.reshape(pred_confs_deltas, [grid_area * H['rnn_len'], 2])),
        [grid_area, H['rnn_len'], 2])
    if H['reregress']:
        pred_boxes = pred_boxes + pred_boxes_deltas
else:
    pred_boxes, pred_logits, pred_confidences = build_forward(
        H, tf.expand_dims(x_in, 0), 'test', reuse=None)

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    saver.restore(sess, model_path)

    images = os.listdir(image_dir)
    for i in range(5):
        fname = random.choice(images)
        path = os.path.join(image_dir, fname)
        img = imread(path)
        feed = {x_in: img}
        (np_pred_boxes, np_pred_confidences) = sess.run(
            [pred_boxes, pred_confidences], feed_dict=feed)
        new_img, rects = add_rectangles(
            H, [img],
            np_pred_confidences,
            np_pred_boxes,
            use_stitching=True,
            rnn_len=H['rnn_len'],
            min_conf=0.7,
            show_suppressed=False)

        fig = plt.figure(figsize=(12, 12))
        plt.imshow(new_img)

就可以在 Notebook 上看到成果了!!!

以下分別展示一張成功和一張失敗的照片,如果增加更多訓練資料或者使用更容易辨認的物體,效果應該可以更好才是。

Detect a Cat Successfully Failed to Detect a Cat

Yong-Siang Shih

作者

Yong-Siang Shih

軟體工程師,機器學習科學家,開放原始碼愛好者。曾在 Appier 從事機器學習系統開發,也曾在 Google, IBM, Microsoft 擔任軟體實習生。喜好探索學習新科技。* 在 GitHub 上追蹤我

載入 Disqus 評論