# Sort Sequences for PackedSequence in PyTorch

By Yong-Siang Shih / Tue 15 January 2019 / In categories Notes

Translations: ZH

## Introduction

Although using PackedSequence in PyTorch allows faster processing for sequential data, there is something inconvenient: the sequences must be sorted according to lengths in a batch. If we are doing Seq2Seq such as machine translation, there exist input and output sequences, and their lengths might not match. We could simply sort the sequences according to lengths of input and only use PackedSequence in encoder while not using it in decoder. But there is another approach: sort sequences again for decoder and unsort the output later. This note is about the second approach.

P.s., the newer version of PyTorch is going to have built-in support for this, once it’s released, it’s no longer needed to implement it by ourselves.

## Sort the Sequences

Firstly, we define a function `sort_sequences` to sort `inputs` according to lengths. Three outputs are returned: the sorted data, `inputs`, sorted lengths, `lengths_sorted`, and a `unsorted_idx`, which could be used to unsort data later.

``````def sort_sequences(inputs, lengths):
"""sort_sequences
Sort sequences according to lengths descendingly.

:param inputs (Tensor): input sequences, size [B, T, D]
:param lengths (Tensor): length of each sequence, size [B]
"""
lengths_sorted, sorted_idx = lengths.sort(descending=True)
_, unsorted_idx = sorted_idx.sort()
return inputs[sorted_idx], lengths_sorted, unsorted_idx
``````

## Using `sort_sequences` in RNNs

Before feeding the data into RNNs, use `sort_sequences` to sort the data, and then use `pack_padded_sequence` to convert the data into `PackedSequence`.

Afterwards, we could use `unsorted_idx` to unsort both output and output hidden vectors.

``````class RNN(nn.Module):
def forward(self, inputs, lengths=None, hidden=None):
if lengths is not None:
inputs, sorted_lengths, unsorted_idx = sort_sequences(
inputs, lengths)
inputs, sorted_lengths, batch_first=True)

outputs, (ht, ct) = self.lstm(inputs, hidden)

if lengths is not None: