在 PyTorch 中重新排序資料來使用 PackedSequence

作者: Yong-Siang Shih / Tue 15 January 2019 / 分類: Notes

PackedSequence, PyTorch

翻譯: EN

前言

使用 PyTorchPackedSequence 雖然可以較快速的處理長短不一的序列資料,但是用起來有個不方便的地方。就是同一個 batch 裡的資料,長度必須由長到短排列。但是如果是在做機器翻譯之類的 Seq2Seq 應用,同時有輸入字串和輸出字串,兩者的長度排序不見得會完全一樣。此時簡單的作法是照輸入排序,然後在 encoder 使用 PackedSequence,但是 decoder 就不要使用。但是其實也可以在 decoder 輸入時,先把資料排序,然後等 decoder 輸出後,再重新把資料轉換為原本的順序。本文就紀錄這種作法。

註:新版的 PyTorch 將會內建這種功能,就不用再使用本文的做法了。

Order

排序資料

首先定義一個 sort_sequences 函式,將輸入的 inputs 照長度排序,並且回傳排好的 inputs, 排好的長度 lengths_sorted, 以及可以用來把序列轉回原始排序的 unsorted_idx

def sort_sequences(inputs, lengths):
    """sort_sequences
    Sort sequences according to lengths descendingly.

    :param inputs (Tensor): input sequences, size [B, T, D]
    :param lengths (Tensor): length of each sequence, size [B]
    """
    lengths_sorted, sorted_idx = lengths.sort(descending=True)
    _, unsorted_idx = sorted_idx.sort()
    return inputs[sorted_idx], lengths_sorted, unsorted_idx

在 RNN 中的實際用法

在輸入 RNN 之前,先用 sort_sequences 把序列排好,然後再使用 pack_padded_sequence 將資料轉成 PackedSequence

輸出之後,利用 unsorted_idx 把資料再轉回原本的排序即可。

class RNN(nn.Module):
  def forward(self, inputs, lengths=None, hidden=None):
      if lengths is not None:
          inputs, sorted_lengths, unsorted_idx = sort_sequences(
              inputs, lengths)
          inputs = torch.nn.utils.rnn.pack_padded_sequence(
              inputs, sorted_lengths, batch_first=True)

      outputs, (ht, ct) = self.lstm(inputs, hidden)

      if lengths is not None:
          outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
              outputs, batch_first=True)
          outputs = outputs.index_select(0, unsorted_idx)
          ht = ht.index_select(1, unsorted_idx)
          ct = ct.index_select(1, unsorted_idx)

這種寫法依然還是可以正常的進行 backpropagation,所以轉回原本序列的 tensors 可以直接傳出去,外頭也不會發現在這裡我們曾經重新排序了兩次。

Yong-Siang Shih

作者

Yong-Siang Shih

軟體工程師,機器學習科學家,開放原始碼愛好者。曾在 Appier 從事機器學習系統開發,也曾在 Google, IBM, Microsoft 擔任軟體實習生。喜好探索學習新科技。* 在 GitHub 上追蹤我

載入 Disqus 評論