Introduction
Although using PackedSequence in PyTorch allows faster processing for sequential data, there is something inconvenient: the sequences must be sorted according to lengths in a batch. If we are doing Seq2Seq such as machine translation, there exist input and output sequences, and their lengths might not match. We could simply sort the sequences according to lengths of input and only use PackedSequence in encoder while not using it in decoder. But there is another approach: sort sequences again for decoder and unsort the output later. This note is about the second approach.
P.s., the newer version of PyTorch is going to have built-in support for this, once it’s released, it’s no longer needed to implement it by ourselves.
Sort the Sequences
Firstly, we define a function sort_sequences
to sort inputs
according to lengths. Three outputs are returned: the sorted data, inputs
, sorted lengths, lengths_sorted
, and a unsorted_idx
, which could be used to unsort data later.
def sort_sequences(inputs, lengths):
"""sort_sequences
Sort sequences according to lengths descendingly.
:param inputs (Tensor): input sequences, size [B, T, D]
:param lengths (Tensor): length of each sequence, size [B]
"""
lengths_sorted, sorted_idx = lengths.sort(descending=True)
_, unsorted_idx = sorted_idx.sort()
return inputs[sorted_idx], lengths_sorted, unsorted_idx
Using sort_sequences
in RNNs
Before feeding the data into RNNs, use sort_sequences
to sort the data, and then use pack_padded_sequence
to convert the data into PackedSequence
.
Afterwards, we could use unsorted_idx
to unsort both output and output hidden vectors.
class RNN(nn.Module):
def forward(self, inputs, lengths=None, hidden=None):
if lengths is not None:
inputs, sorted_lengths, unsorted_idx = sort_sequences(
inputs, lengths)
inputs = torch.nn.utils.rnn.pack_padded_sequence(
inputs, sorted_lengths, batch_first=True)
outputs, (ht, ct) = self.lstm(inputs, hidden)
if lengths is not None:
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True)
outputs = outputs.index_select(0, unsorted_idx)
ht = ht.index_select(1, unsorted_idx)
ct = ct.index_select(1, unsorted_idx)
The resulting tensors could still do backpropagation as usual, so the unsorted tensors could be returned directly.