Sort Sequences for PackedSequence in PyTorch

By Yong-Siang Shih / Tue 15 January 2019 / In categories Notes

PackedSequence, PyTorch

Translations: ZH


Although using PackedSequence in PyTorch allows faster processing for sequential data, there is something inconvenient: the sequences must be sorted according to lengths in a batch. If we are doing Seq2Seq such as machine translation, there exist input and output sequences, and their lengths might not match. We could simply sort the sequences according to lengths of input and only use PackedSequence in encoder while not using it in decoder. But there is another approach: sort sequences again for decoder and unsort the output later. This note is about the second approach.

P.s., the newer version of PyTorch is going to have built-in support for this, once it’s released, it’s no longer needed to implement it by ourselves.


Sort the Sequences

Firstly, we define a function sort_sequences to sort inputs according to lengths. Three outputs are returned: the sorted data, inputs, sorted lengths, lengths_sorted, and a unsorted_idx, which could be used to unsort data later.

def sort_sequences(inputs, lengths):
    Sort sequences according to lengths descendingly.

    :param inputs (Tensor): input sequences, size [B, T, D]
    :param lengths (Tensor): length of each sequence, size [B]
    lengths_sorted, sorted_idx = lengths.sort(descending=True)
    _, unsorted_idx = sorted_idx.sort()
    return inputs[sorted_idx], lengths_sorted, unsorted_idx

Using sort_sequences in RNNs

Before feeding the data into RNNs, use sort_sequences to sort the data, and then use pack_padded_sequence to convert the data into PackedSequence.

Afterwards, we could use unsorted_idx to unsort both output and output hidden vectors.

class RNN(nn.Module):
  def forward(self, inputs, lengths=None, hidden=None):
      if lengths is not None:
          inputs, sorted_lengths, unsorted_idx = sort_sequences(
              inputs, lengths)
          inputs = torch.nn.utils.rnn.pack_padded_sequence(
              inputs, sorted_lengths, batch_first=True)

      outputs, (ht, ct) = self.lstm(inputs, hidden)

      if lengths is not None:
          outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
              outputs, batch_first=True)
          outputs = outputs.index_select(0, unsorted_idx)
          ht = ht.index_select(1, unsorted_idx)
          ct = ct.index_select(1, unsorted_idx)

The resulting tensors could still do backpropagation as usual, so the unsorted tensors could be returned directly.

Yong-Siang Shih


Yong-Siang Shih

Software Engineer, Machine Learning Scientist, Open Source Enthusiast. Worked at Appier building machine learning systems, and interned at Google, IBM, and Microsoft as software engineering intern. Love to learn and build things.
* Follow me on GitHub

Load Disqus Comments