├── README.md └── pad_packed_demo.py /README.md: -------------------------------------------------------------------------------- 1 | ## Minimal tutorial on packing and unpacking sequences in pytorch. 2 | 3 | > This is a fork from [@Tushar-N 's gist](https://gist.github.com/Tushar-N/dfca335e370a2bc3bc79876e6270099e). I have added comments and extra diagrams that should (hopefully) make it easier to understand. Forked Gist is [here](https://gist.github.com/HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec), but I prefer markdowns for tutorials ! :) 4 | 5 | We want to run LSTM on a batch of 3 character sequences `['long_str', 'tiny', 'medium']`. Here are the steps. You would be interested in *ed steps only. 6 | - **Step 1**: Construct Vocabulary 7 | - **Step 2**: Load indexed data (list of instances, where each instance is list of character indices) 8 | - **Step 3**: Make Model 9 | - **Step 4**: **\*** Pad instances with 0s till max length sequence 10 | - **Step 5**: **\*** Sort instances by sequence length in descending order 11 | - **Step 6**: **\*** Embed the instances 12 | - **Step 7**: **\*** Call pack_padded_sequence with embeded instances and sequence lengths 13 | - **Step 8**: **\*** Forward with LSTM 14 | - **Step 9**: **\*** Call unpack_padded_sequences if required / or just pick last hidden vector 15 | - \* Summary of Shape Transformations 16 | 17 | ### We want to run LSTM on a batch following 3 character sequences 18 | 19 | ```python 20 | seqs = ['long_str', # len = 8 21 | 'tiny', # len = 4 22 | 'medium'] # len = 6 23 | 24 | ``` 25 | 26 | ### Step 1: Construct Vocabulary 27 | ```python 28 | # make sure idx is 0 29 | vocab = [''] + sorted(set([char for seq in seqs for char in seq])) 30 | # => ['', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y'] 31 | ``` 32 | 33 | ### Step 2: Load indexed data (list of instances, where each instance is list of character indices) 34 | ```python 35 | vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs] 36 | # vectorized_seqs => [[6, 9, 8, 4, 1, 11, 12, 10], 37 | # [12, 5, 8, 14], 38 | # [7, 3, 2, 5, 13, 7]] 39 | ``` 40 | 41 | ### Step 3: Make Model 42 | ```python 43 | embed = Embedding(len(vocab), 4) # embedding_dim = 4 44 | lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5 45 | ``` 46 | 47 | 48 | ### Step 4: Pad instances with 0s till max length sequence 49 | ```python 50 | # get the length of each seq in your batch 51 | seq_lengths = LongTensor(list(map(len, vectorized_seqs))) 52 | # seq_lengths => [ 8, 4, 6] 53 | # batch_sum_seq_len: 8 + 4 + 6 = 18 54 | # max_seq_len: 8 55 | 56 | seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long() 57 | # seq_tensor => [[0 0 0 0 0 0 0 0] 58 | # [0 0 0 0 0 0 0 0] 59 | # [0 0 0 0 0 0 0 0]] 60 | 61 | for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)): 62 | seq_tensor[idx, :seqlen] = LongTensor(seq) 63 | # seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str 64 | # [12 5 8 14 0 0 0 0] # tiny 65 | # [ 7 3 2 5 13 7 0 0]] # medium 66 | # seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8) 67 | ``` 68 | 69 | ### Step 5: Sort instances by sequence length in descending order 70 | ```python 71 | seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) 72 | seq_tensor = seq_tensor[perm_idx] 73 | # seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str 74 | # [ 7 3 2 5 13 7 0 0] # medium 75 | # [12 5 8 14 0 0 0 0]] # tiny 76 | # seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8) 77 | ``` 78 | 79 | ### Step 6: Embed the instances 80 | ```python 81 | embedded_seq_tensor = embed(seq_tensor) 82 | # embedded_seq_tensor => 83 | # [[[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l 84 | # [-0.23622951 2.0361056 0.15435742 -0.04513785] o 85 | # [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n 86 | # [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] g 87 | # [-1.6334635 -0.6100042 1.7509955 -1.931793 ] _ 88 | # [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] s 89 | # [ 0.64004815 0.45813003 0.3476034 -0.03451729] t 90 | # [-0.22739866 -0.45782727 -0.6643252 0.25129375]] r 91 | 92 | # [[ 0.16031227 -0.08209462 -0.16297023 0.48121014] m 93 | # [-0.7303265 -0.857339 0.58913064 -1.1068314 ] e 94 | # [ 0.48159844 -1.4886451 0.92639893 0.76906884] d 95 | # [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i 96 | # [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] u 97 | # [ 0.16031227 -0.08209462 -0.16297023 0.48121014] m 98 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] 99 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]] 100 | 101 | # [[ 0.64004815 0.45813003 0.3476034 -0.03451729] t 102 | # [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i 103 | # [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n 104 | # [-1.284392 0.68294704 1.4064184 -0.42879772] y 105 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] 106 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] 107 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] 108 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]]] 109 | # embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4) 110 | 111 | ``` 112 | 113 | ### Step 7: Call pack_padded_sequence with embeded instances and sequence lengths 114 | ```python 115 | packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True) 116 | # packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes 117 | # 118 | # packed_input.data => 119 | # [[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l 120 | # [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] m 121 | # [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] t 122 | # [ 0.16031227 -0.08209462 -0.16297023 0.48121014] o 123 | # [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] e 124 | # [-1.284392 0.68294704 1.4064184 -0.42879772] i 125 | # [ 0.64004815 0.45813003 0.3476034 -0.03451729] n 126 | # [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] d 127 | # [ 0.64004815 0.45813003 0.3476034 -0.03451729] n 128 | # [-0.23622951 2.0361056 0.15435742 -0.04513785] g 129 | # [ 0.16031227 -0.08209462 -0.16297023 0.48121014] i 130 | # [-0.22739866 -0.45782727 -0.6643252 0.25129375]] y 131 | # [-0.7303265 -0.857339 0.58913064 -1.1068314 ] _ 132 | # [-1.6334635 -0.6100042 1.7509955 -1.931793 ] u 133 | # [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] s 134 | # [-0.6000342 1.1732816 0.19938554 -1.5976517 ] m 135 | # [-0.6000342 1.1732816 0.19938554 -1.5976517 ] t 136 | # [ 0.48159844 -1.4886451 0.92639893 0.76906884] r 137 | # packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4) 138 | # 139 | # packed_input.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] 140 | # visualization : 141 | # l o n g _ s t r #(long_str) 142 | # m e d i u m #(medium) 143 | # t i n y #(tiny) 144 | # 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len]) 145 | ``` 146 | 147 | ### Step 8: Forward with LSTM 148 | ```python 149 | packed_output, (ht, ct) = lstm(packed_input) 150 | # packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes 151 | # 152 | # packed_output.data : 153 | # [[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l 154 | # [ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m 155 | # [ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t 156 | # [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o 157 | # [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e 158 | # [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i 159 | # [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n 160 | # [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d 161 | # [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n 162 | # [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g 163 | # [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i 164 | # [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y 165 | # [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _ 166 | # [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u 167 | # [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s 168 | # [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m 169 | # [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t 170 | # [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r 171 | # packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5) 172 | 173 | # packed_output.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] (same as packed_input.batch_sizes) 174 | # visualization : 175 | # l o n g _ s t r #(long_str) 176 | # m e d i u m #(medium) 177 | # t i n y #(tiny) 178 | # 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len]) 179 | ``` 180 | 181 | ### Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector 182 | 183 | ```python 184 | # unpack your output if required 185 | output, input_sizes = pad_packed_sequence(packed_output, batch_first=True) 186 | # output: 187 | # output => 188 | # [[[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l 189 | # [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o 190 | # [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n 191 | # [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g 192 | # [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _ 193 | # [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s 194 | # [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t 195 | # [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r 196 | 197 | # [[ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m 198 | # [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e 199 | # [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d 200 | # [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i 201 | # [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u 202 | # [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m 203 | # [ 0. 0. 0. 0. 0. ] 204 | # [ 0. 0. 0. 0. 0. ]] 205 | 206 | # [[ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t 207 | # [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i 208 | # [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n 209 | # [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y 210 | # [ 0. 0. 0. 0. 0. ] 211 | # [ 0. 0. 0. 0. 0. ] 212 | # [ 0. 0. 0. 0. 0. ] 213 | # [ 0. 0. 0. 0. 0. ]]] 214 | # output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5) 215 | 216 | # Or if you just want the final hidden state? 217 | print(ht[-1]) 218 | ``` 219 | 220 | 221 | ### Summary of Shape Transformations 222 | 223 | ```python 224 | # (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim) 225 | # (batch_size X max_seq_len X embedding_dim) ---> Pack ---> (batch_sum_seq_len X embedding_dim) 226 | # (batch_sum_seq_len X embedding_dim) ---> LSTM ---> (batch_sum_seq_len X hidden_dim) 227 | # (batch_sum_seq_len X hidden_dim) ---> UnPack ---> (batch_size X max_seq_len X hidden_dim) 228 | ``` 229 | -------------------------------------------------------------------------------- /pad_packed_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import LongTensor 3 | from torch.nn import Embedding, LSTM 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 6 | 7 | ## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium'] 8 | # 9 | # Step 1: Construct Vocabulary 10 | # Step 2: Load indexed data (list of instances, where each instance is list of character indices) 11 | # Step 3: Make Model 12 | # * Step 4: Pad instances with 0s till max length sequence 13 | # * Step 5: Sort instances by sequence length in descending order 14 | # * Step 6: Embed the instances 15 | # * Step 7: Call pack_padded_sequence with embeded instances and sequence lengths 16 | # * Step 8: Forward with LSTM 17 | # * Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector 18 | # * Summary of Shape Transformations 19 | 20 | # We want to run LSTM on a batch following 3 character sequences 21 | seqs = ['long_str', # len = 8 22 | 'tiny', # len = 4 23 | 'medium'] # len = 6 24 | 25 | 26 | ## Step 1: Construct Vocabulary ## 27 | ##------------------------------## 28 | # make sure idx is 0 29 | vocab = [''] + sorted(set([char for seq in seqs for char in seq])) 30 | # => ['', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y'] 31 | 32 | 33 | ## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ## 34 | ##-------------------------------------------------------------------------------------------------## 35 | vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs] 36 | # vectorized_seqs => [[6, 9, 8, 4, 1, 11, 12, 10], 37 | # [12, 5, 8, 14], 38 | # [7, 3, 2, 5, 13, 7]] 39 | 40 | 41 | ## Step 3: Make Model ## 42 | ##--------------------## 43 | embed = Embedding(len(vocab), 4) # embedding_dim = 4 44 | lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5 45 | 46 | 47 | ## Step 4: Pad instances with 0s till max length sequence ## 48 | ##--------------------------------------------------------## 49 | 50 | # get the length of each seq in your batch 51 | seq_lengths = LongTensor(list(map(len, vectorized_seqs))) 52 | # seq_lengths => [ 8, 4, 6] 53 | # batch_sum_seq_len: 8 + 4 + 6 = 18 54 | # max_seq_len: 8 55 | 56 | seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long() 57 | # seq_tensor => [[0 0 0 0 0 0 0 0] 58 | # [0 0 0 0 0 0 0 0] 59 | # [0 0 0 0 0 0 0 0]] 60 | 61 | for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)): 62 | seq_tensor[idx, :seqlen] = LongTensor(seq) 63 | # seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str 64 | # [12 5 8 14 0 0 0 0] # tiny 65 | # [ 7 3 2 5 13 7 0 0]] # medium 66 | # seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8) 67 | 68 | 69 | ## Step 5: Sort instances by sequence length in descending order ## 70 | ##---------------------------------------------------------------## 71 | 72 | seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) 73 | seq_tensor = seq_tensor[perm_idx] 74 | # seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str 75 | # [ 7 3 2 5 13 7 0 0] # medium 76 | # [12 5 8 14 0 0 0 0]] # tiny 77 | # seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8) 78 | 79 | 80 | ## Step 6: Embed the instances ## 81 | ##-----------------------------## 82 | 83 | embedded_seq_tensor = embed(seq_tensor) 84 | # embedded_seq_tensor => 85 | # [[[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l 86 | # [-0.23622951 2.0361056 0.15435742 -0.04513785] o 87 | # [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n 88 | # [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] g 89 | # [-1.6334635 -0.6100042 1.7509955 -1.931793 ] _ 90 | # [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] s 91 | # [ 0.64004815 0.45813003 0.3476034 -0.03451729] t 92 | # [-0.22739866 -0.45782727 -0.6643252 0.25129375]] r 93 | 94 | # [[ 0.16031227 -0.08209462 -0.16297023 0.48121014] m 95 | # [-0.7303265 -0.857339 0.58913064 -1.1068314 ] e 96 | # [ 0.48159844 -1.4886451 0.92639893 0.76906884] d 97 | # [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i 98 | # [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] u 99 | # [ 0.16031227 -0.08209462 -0.16297023 0.48121014] m 100 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] 101 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]] 102 | 103 | # [[ 0.64004815 0.45813003 0.3476034 -0.03451729] t 104 | # [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i 105 | # [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n 106 | # [-1.284392 0.68294704 1.4064184 -0.42879772] y 107 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] 108 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] 109 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] 110 | # [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]]] 111 | # embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4) 112 | 113 | 114 | ## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ## 115 | ##-------------------------------------------------------------------------------## 116 | 117 | packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True) 118 | # packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes 119 | # 120 | # packed_input.data => 121 | # [[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l 122 | # [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] m 123 | # [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] t 124 | # [ 0.16031227 -0.08209462 -0.16297023 0.48121014] o 125 | # [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] e 126 | # [-1.284392 0.68294704 1.4064184 -0.42879772] i 127 | # [ 0.64004815 0.45813003 0.3476034 -0.03451729] n 128 | # [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] d 129 | # [ 0.64004815 0.45813003 0.3476034 -0.03451729] n 130 | # [-0.23622951 2.0361056 0.15435742 -0.04513785] g 131 | # [ 0.16031227 -0.08209462 -0.16297023 0.48121014] i 132 | # [-0.22739866 -0.45782727 -0.6643252 0.25129375]] y 133 | # [-0.7303265 -0.857339 0.58913064 -1.1068314 ] _ 134 | # [-1.6334635 -0.6100042 1.7509955 -1.931793 ] u 135 | # [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] s 136 | # [-0.6000342 1.1732816 0.19938554 -1.5976517 ] m 137 | # [-0.6000342 1.1732816 0.19938554 -1.5976517 ] t 138 | # [ 0.48159844 -1.4886451 0.92639893 0.76906884] r 139 | # packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4) 140 | # 141 | # packed_input.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] 142 | # visualization : 143 | # l o n g _ s t r #(long_str) 144 | # m e d i u m #(medium) 145 | # t i n y #(tiny) 146 | # 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len]) 147 | 148 | 149 | ## Step 8: Forward with LSTM ## 150 | ##---------------------------## 151 | 152 | packed_output, (ht, ct) = lstm(packed_input) 153 | # packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes 154 | # 155 | # packed_output.data : 156 | # [[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l 157 | # [ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] o 158 | # [ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] n 159 | # [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] g 160 | # [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] _ 161 | # [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] s 162 | # [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] t 163 | # [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] r 164 | # [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] m 165 | # [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] e 166 | # [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] d 167 | # [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] i 168 | # [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] u 169 | # [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] m 170 | # [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] t 171 | # [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] i 172 | # [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] n 173 | # [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] y 174 | # packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5) 175 | 176 | # packed_output.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] (same as packed_input.batch_sizes) 177 | # visualization : 178 | # l o n g _ s t r #(long_str) 179 | # m e d i u m #(medium) 180 | # t i n y #(tiny) 181 | # 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len]) 182 | 183 | 184 | ## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ## 185 | ##------------------------------------------------------------------------------------## 186 | 187 | # unpack your output if required 188 | output, input_sizes = pad_packed_sequence(packed_output, batch_first=True) 189 | # output: 190 | # output => 191 | # [[[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l 192 | # [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o 193 | # [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n 194 | # [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g 195 | # [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _ 196 | # [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s 197 | # [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t 198 | # [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r 199 | 200 | # [[ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m 201 | # [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e 202 | # [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d 203 | # [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i 204 | # [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u 205 | # [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m 206 | # [ 0. 0. 0. 0. 0. ] 207 | # [ 0. 0. 0. 0. 0. ]] 208 | 209 | # [[ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t 210 | # [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i 211 | # [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n 212 | # [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y 213 | # [ 0. 0. 0. 0. 0. ] 214 | # [ 0. 0. 0. 0. 0. ] 215 | # [ 0. 0. 0. 0. 0. ] 216 | # [ 0. 0. 0. 0. 0. ]]] 217 | # output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5) 218 | 219 | # Or if you just want the final hidden state? 220 | print(ht[-1]) 221 | 222 | ## Summary of Shape Transformations ## 223 | ##----------------------------------## 224 | 225 | # (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim) 226 | # (batch_size X max_seq_len X embedding_dim) ---> Pack ---> (batch_sum_seq_len X embedding_dim) 227 | # (batch_sum_seq_len X embedding_dim) ---> LSTM ---> (batch_sum_seq_len X hidden_dim) 228 | # (batch_sum_seq_len X hidden_dim) ---> UnPack ---> (batch_size X max_seq_len X hidden_dim) 229 | --------------------------------------------------------------------------------