Although using PackedSequence in PyTorch allows faster processing for sequential data, there is something inconvenient: the sequences must be sorted according to lengths in a batch. If we are doing Seq2Seq such as machine translation, there exist input and output sequences, and their lengths might not match. We could simply sort the sequences according to lengths of input and only use PackedSequence in encoder while not using it in decoder. But there is another approach: sort sequences again for decoder and unsort the output later. This note is about the second approach.
P.s., the newer version of PyTorch is going to have built-in support for this, once it’s released, it’s no longer needed to implement it by ourselves.
Sort the Sequences
Firstly, we define a function
sort_sequences to sort
inputs according to lengths. Three outputs are returned: the sorted data,
inputs, sorted lengths,
lengths_sorted, and a
unsorted_idx, which could be used to unsort data later.
def sort_sequences(inputs, lengths): """sort_sequences Sort sequences according to lengths descendingly. :param inputs (Tensor): input sequences, size [B, T, D] :param lengths (Tensor): length of each sequence, size [B] """ lengths_sorted, sorted_idx = lengths.sort(descending=True) _, unsorted_idx = sorted_idx.sort() return inputs[sorted_idx], lengths_sorted, unsorted_idx
sort_sequences in RNNs
Before feeding the data into RNNs, use
sort_sequences to sort the data, and then use
pack_padded_sequence to convert the data into
Afterwards, we could use
unsorted_idx to unsort both output and output hidden vectors.
class RNN(nn.Module): def forward(self, inputs, lengths=None, hidden=None): if lengths is not None: inputs, sorted_lengths, unsorted_idx = sort_sequences( inputs, lengths) inputs = torch.nn.utils.rnn.pack_padded_sequence( inputs, sorted_lengths, batch_first=True) outputs, (ht, ct) = self.lstm(inputs, hidden) if lengths is not None: outputs, _ = torch.nn.utils.rnn.pad_packed_sequence( outputs, batch_first=True) outputs = outputs.index_select(0, unsorted_idx) ht = ht.index_select(1, unsorted_idx) ct = ct.index_select(1, unsorted_idx)
The resulting tensors could still do backpropagation as usual, so the unsorted tensors could be returned directly.