前言
使用 PyTorch 的 PackedSequence 雖然可以較快速的處理長短不一的序列資料,但是用起來有個不方便的地方。就是同一個 batch 裡的資料,長度必須由長到短排列。但是如果是在做機器翻譯之類的 Seq2Seq 應用,同時有輸入字串和輸出字串,兩者的長度排序不見得會完全一樣。此時簡單的作法是照輸入排序,然後在 encoder 使用 PackedSequence,但是 decoder 就不要使用。但是其實也可以在 decoder 輸入時,先把資料排序,然後等 decoder 輸出後,再重新把資料轉換為原本的順序。本文就紀錄這種作法。
註:新版的 PyTorch 將會內建這種功能,就不用再使用本文的做法了。
排序資料
首先定義一個 sort_sequences
函式,將輸入的 inputs
照長度排序,並且回傳排好的 inputs
, 排好的長度 lengths_sorted
, 以及可以用來把序列轉回原始排序的 unsorted_idx
。
def sort_sequences(inputs, lengths):
"""sort_sequences
Sort sequences according to lengths descendingly.
:param inputs (Tensor): input sequences, size [B, T, D]
:param lengths (Tensor): length of each sequence, size [B]
"""
lengths_sorted, sorted_idx = lengths.sort(descending=True)
_, unsorted_idx = sorted_idx.sort()
return inputs[sorted_idx], lengths_sorted, unsorted_idx
在 RNN 中的實際用法
在輸入 RNN 之前,先用 sort_sequences
把序列排好,然後再使用 pack_padded_sequence
將資料轉成 PackedSequence
。
輸出之後,利用 unsorted_idx
把資料再轉回原本的排序即可。
class RNN(nn.Module):
def forward(self, inputs, lengths=None, hidden=None):
if lengths is not None:
inputs, sorted_lengths, unsorted_idx = sort_sequences(
inputs, lengths)
inputs = torch.nn.utils.rnn.pack_padded_sequence(
inputs, sorted_lengths, batch_first=True)
outputs, (ht, ct) = self.lstm(inputs, hidden)
if lengths is not None:
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True)
outputs = outputs.index_select(0, unsorted_idx)
ht = ht.index_select(1, unsorted_idx)
ct = ct.index_select(1, unsorted_idx)
這種寫法依然還是可以正常的進行 backpropagation,所以轉回原本序列的 tensors 可以直接傳出去,外頭也不會發現在這裡我們曾經重新排序了兩次。