├── README.md
└── transducer_tutorial_example.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # transducer-tutorial
2 | This notebook implements the Transducer sequence-to-sequence model from scratch in PyTorch, including the loss function, the greedy search algorithm, and a complete example of training the model on a sequence transduction task.
3 |
4 | See the corresponding blog post [here](https://lorenlugosch.github.io/posts/2020/11/transducer/).
5 |
6 | _If you found this tutorial helpful and would like to cite it, you can use the following BibTeX entry:_
7 |
8 | ```
9 | @misc{
10 | lugosch_2020,
11 | title={Sequence-to-sequence learning with Transducers},
12 | url={https://lorenlugosch.github.io/posts/2020/11/transducer/},
13 | author={Lugosch, Loren},
14 | year={2020},
15 | month={Nov}
16 | }
17 | ```
18 |
--------------------------------------------------------------------------------
/transducer_tutorial_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "id": "BbwgR5UdNkkm"
17 | },
18 | "source": [
19 | "# Transducer implementation in PyTorch\n",
20 | "\n",
21 | "*by Loren Lugosch*\n"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {
27 | "id": "yBlJNKsjTtaZ"
28 | },
29 | "source": [
30 | "\n",
31 | "In this notebook, we will implement a Transducer sequence-to-sequence model for inserting missing vowels into a sentence (\"Hll, Wrld\" --> \"Hello, World\")."
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 1,
37 | "metadata": {
38 | "colab": {
39 | "base_uri": "https://localhost:8080/"
40 | },
41 | "id": "Q-iHU02C7fAj",
42 | "outputId": "6ce0dae6-5036-4fb4-fe33-45aceec263a6"
43 | },
44 | "outputs": [
45 | {
46 | "name": "stdout",
47 | "output_type": "stream",
48 | "text": [
49 | "Requirement already satisfied: unidecode in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (1.2.0)\n",
50 | "--2024-02-10 15:18:35-- https://raw.githubusercontent.com/lorenlugosch/infer_missing_vowels/master/data/train/war_and_peace.txt\n",
51 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...\n",
52 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
53 | "HTTP request sent, awaiting response... 200 OK\n",
54 | "Length: 3196229 (3.0M) [text/plain]\n",
55 | "Saving to: ‘war_and_peace.txt.5’\n",
56 | "\n",
57 | "war_and_peace.txt.5 100%[===================>] 3.05M --.-KB/s in 0.07s \n",
58 | "\n",
59 | "2024-02-10 15:18:35 (45.4 MB/s) - ‘war_and_peace.txt.5’ saved [3196229/3196229]\n",
60 | "\n",
61 | "/home/ubuntu/transducer_fix\n"
62 | ]
63 | }
64 | ],
65 | "source": [
66 | "import torch\n",
67 | "import string\n",
68 | "import numpy as np\n",
69 | "import itertools\n",
70 | "from collections import Counter\n",
71 | "from tqdm import tqdm\n",
72 | "!pip install unidecode\n",
73 | "import unidecode\n",
74 | "\n",
75 | "# Some training data.\n",
76 | "# Poor Tolstoy, once again reduced to grist for the neural network mill!\n",
77 | "!wget https://raw.githubusercontent.com/lorenlugosch/infer_missing_vowels/master/data/train/war_and_peace.txt\n",
78 | "!pwd\n"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {
84 | "id": "CTfRgwxmjv1B"
85 | },
86 | "source": [
87 | "# Building blocks\n",
88 | "\n",
89 | "First, we will define the encoder, predictor, and joiner using standard neural nets.\n",
90 | "\n",
91 | "
"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 2,
97 | "metadata": {
98 | "id": "B7mLFyUG7kJH"
99 | },
100 | "outputs": [],
101 | "source": [
102 | "NULL_INDEX = 0\n",
103 | "\n",
104 | "encoder_dim = 1024\n",
105 | "predictor_dim = 1024\n",
106 | "joiner_dim = 1024"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {
112 | "id": "MABMTjrGY4vz"
113 | },
114 | "source": [
115 | "The encoder is any network that can take as input a variable-length sequence: so, RNNs, CNNs, and self-attention/Transformer encoders will all work.\n"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 3,
121 | "metadata": {
122 | "id": "KE7j2T5EY33-"
123 | },
124 | "outputs": [],
125 | "source": [
126 | "class Encoder(torch.nn.Module):\n",
127 | " def __init__(self, num_inputs):\n",
128 | " super(Encoder, self).__init__()\n",
129 | " self.embed = torch.nn.Embedding(num_inputs, encoder_dim)\n",
130 | " self.rnn = torch.nn.GRU(input_size=encoder_dim, hidden_size=encoder_dim, num_layers=3, batch_first=True, bidirectional=True, dropout=0.1)\n",
131 | " self.linear = torch.nn.Linear(encoder_dim*2, joiner_dim)\n",
132 | "\n",
133 | " def forward(self, x):\n",
134 | " out = x\n",
135 | " out = self.embed(out)\n",
136 | " out = self.rnn(out)[0]\n",
137 | " out = self.linear(out)\n",
138 | " return out"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "metadata": {
144 | "id": "BRknN6QRY9-g"
145 | },
146 | "source": [
147 | "The predictor is any _causal_ network (= can't look at the future): in other words, unidirectional RNNs, causal convolutions, or masked self-attention. "
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 4,
153 | "metadata": {
154 | "id": "hPARF5LmY7-r"
155 | },
156 | "outputs": [],
157 | "source": [
158 | "class Predictor(torch.nn.Module):\n",
159 | " def __init__(self, num_outputs):\n",
160 | " super(Predictor, self).__init__()\n",
161 | " self.embed = torch.nn.Embedding(num_outputs, predictor_dim)\n",
162 | " self.rnn = torch.nn.GRUCell(input_size=predictor_dim, hidden_size=predictor_dim)\n",
163 | " self.linear = torch.nn.Linear(predictor_dim, joiner_dim)\n",
164 | " \n",
165 | " self.initial_state = torch.nn.Parameter(torch.randn(predictor_dim))\n",
166 | " self.start_symbol = NULL_INDEX # In the original paper, a vector of 0s is used; just using the null index instead is easier when using an Embedding layer.\n",
167 | "\n",
168 | " def forward_one_step(self, input, previous_state):\n",
169 | " embedding = self.embed(input)\n",
170 | " state = self.rnn.forward(embedding, previous_state)\n",
171 | " out = self.linear(state)\n",
172 | " return out, state\n",
173 | "\n",
174 | " def forward(self, y):\n",
175 | " batch_size = y.shape[0]\n",
176 | " U = y.shape[1]\n",
177 | " outs = []\n",
178 | " state = torch.stack([self.initial_state] * batch_size).to(y.device)\n",
179 | " for u in range(U+1): # need U+1 to get null output for final timestep \n",
180 | " if u == 0:\n",
181 | " decoder_input = torch.tensor([self.start_symbol] * batch_size, device=y.device)\n",
182 | " else:\n",
183 | " decoder_input = y[:,u-1]\n",
184 | " out, state = self.forward_one_step(decoder_input, state)\n",
185 | " outs.append(out)\n",
186 | " out = torch.stack(outs, dim=1)\n",
187 | " return out"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "metadata": {
193 | "id": "ZHPZ3PATZEAW"
194 | },
195 | "source": [
196 | "The joiner is a feedforward network/MLP with one hidden layer applied independently to each $(t,u)$ index.\n",
197 | "\n",
198 | "(The linear part of the hidden layer is contained in the encoder and predictor, so we just do the nonlinearity here and then the output layer.)"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 5,
204 | "metadata": {
205 | "id": "Vlzca1orZDLa"
206 | },
207 | "outputs": [],
208 | "source": [
209 | "class Joiner(torch.nn.Module):\n",
210 | " def __init__(self, num_outputs):\n",
211 | " super(Joiner, self).__init__()\n",
212 | " self.linear = torch.nn.Linear(joiner_dim, num_outputs)\n",
213 | "\n",
214 | " def forward(self, encoder_out, predictor_out):\n",
215 | " out = encoder_out + predictor_out\n",
216 | " out = torch.nn.functional.relu(out)\n",
217 | " out = self.linear(out)\n",
218 | " return out"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "metadata": {
224 | "id": "a_-INbhSTApv"
225 | },
226 | "source": [
227 | "# Transducer model + loss function\n",
228 | "\n",
229 | "Using the encoder, predictor, and joiner, we will implement the Transducer model and its loss function.\n",
230 | "\n",
231 | "
"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {
237 | "id": "bdcKwA_lkzxJ"
238 | },
239 | "source": [
240 | "We can use a simple PyTorch implementation of the loss function, relying on automatic differentiation to give us gradients."
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 6,
246 | "metadata": {
247 | "id": "sYSagKi-gHM4"
248 | },
249 | "outputs": [],
250 | "source": [
251 | "class Transducer(torch.nn.Module):\n",
252 | " def __init__(self, num_inputs, num_outputs):\n",
253 | " super(Transducer, self).__init__()\n",
254 | " self.encoder = Encoder(num_inputs)\n",
255 | " self.predictor = Predictor(num_outputs)\n",
256 | " self.joiner = Joiner(num_outputs)\n",
257 | "\n",
258 | " if torch.cuda.is_available(): self.device = \"cuda:0\"\n",
259 | " else: self.device = \"cpu\"\n",
260 | " self.to(self.device)\n",
261 | "\n",
262 | " def compute_forward_prob(self, joiner_out, T, U, y):\n",
263 | " \"\"\"\n",
264 | " joiner_out: tensor of shape (B, T_max, U_max+1, #labels)\n",
265 | " T: list of input lengths\n",
266 | " U: list of output lengths \n",
267 | " y: label tensor (B, U_max+1)\n",
268 | " \"\"\"\n",
269 | " B = joiner_out.shape[0]\n",
270 | " T_max = joiner_out.shape[1]\n",
271 | " U_max = joiner_out.shape[2] - 1\n",
272 | " log_alpha = torch.zeros(B, T_max, U_max+1, device=model.device)\n",
273 | " for t in range(T_max):\n",
274 | " for u in range(U_max+1):\n",
275 | " if u == 0:\n",
276 | " if t == 0:\n",
277 | " log_alpha[:, t, u] = 0.\n",
278 | "\n",
279 | " else: #t > 0\n",
280 | " log_alpha[:, t, u] = log_alpha[:, t-1, u] + joiner_out[:, t-1, 0, NULL_INDEX] \n",
281 | " \n",
282 | " else: #u > 0\n",
283 | " if t == 0:\n",
284 | " log_alpha[:, t, u] = log_alpha[:, t,u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)\n",
285 | " \n",
286 | " else: #t > 0\n",
287 | " log_alpha[:, t, u] = torch.logsumexp(torch.stack([\n",
288 | " log_alpha[:, t-1, u] + joiner_out[:, t-1, u, NULL_INDEX],\n",
289 | " log_alpha[:, t, u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)\n",
290 | " ]), dim=0)\n",
291 | " \n",
292 | " log_probs = []\n",
293 | " for b in range(B):\n",
294 | " log_prob = log_alpha[b, T[b]-1, U[b]] + joiner_out[b, T[b]-1, U[b], NULL_INDEX]\n",
295 | " log_probs.append(log_prob)\n",
296 | " log_probs = torch.stack(log_probs) \n",
297 | " return log_probs\n",
298 | "\n",
299 | " def compute_loss(self, x, y, T, U):\n",
300 | " encoder_out = self.encoder.forward(x)\n",
301 | " predictor_out = self.predictor.forward(y)\n",
302 | " joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)\n",
303 | " loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()\n",
304 | " return loss"
305 | ]
306 | },
307 | {
308 | "cell_type": "markdown",
309 | "metadata": {
310 | "id": "IK0c2S2xaARd"
311 | },
312 | "source": [
313 | "Let's first verify that the forward algorithm actually correctly computes the sum (in log space, the [logsumexp](https://lorenlugosch.github.io/posts/2020/06/logsumexp/)) of all possible alignments, using a short input/output pair for which computing all possible alignments is feasible.\n",
314 | "\n",
315 | "
"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 7,
321 | "metadata": {
322 | "id": "RWtkoXH6U8Pm"
323 | },
324 | "outputs": [],
325 | "source": [
326 | "def compute_single_alignment_prob(self, encoder_out, predictor_out, T, U, z, y):\n",
327 | " \"\"\"\n",
328 | " Computes the probability of one alignment, z.\n",
329 | " \"\"\"\n",
330 | " t = 0; u = 0\n",
331 | " t_u_indices = []\n",
332 | " y_expanded = []\n",
333 | " for step in z:\n",
334 | " t_u_indices.append((t,u))\n",
335 | " if step == 0: # right (null)\n",
336 | " y_expanded.append(NULL_INDEX)\n",
337 | " t += 1\n",
338 | " if step == 1: # down (label)\n",
339 | " y_expanded.append(y[u])\n",
340 | " u += 1\n",
341 | " t_u_indices.append((T-1,U))\n",
342 | " y_expanded.append(NULL_INDEX)\n",
343 | "\n",
344 | " t_indices = [t for (t,u) in t_u_indices]\n",
345 | " u_indices = [u for (t,u) in t_u_indices]\n",
346 | " encoder_out_expanded = encoder_out[t_indices]\n",
347 | " predictor_out_expanded = predictor_out[u_indices]\n",
348 | " joiner_out = self.joiner.forward(encoder_out_expanded, predictor_out_expanded).log_softmax(1)\n",
349 | " logprob = -torch.nn.functional.nll_loss(input=joiner_out, target=torch.tensor(y_expanded).long().to(self.device), reduction=\"sum\")\n",
350 | " return logprob\n",
351 | "\n",
352 | "Transducer.compute_single_alignment_prob = compute_single_alignment_prob"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": 8,
358 | "metadata": {
359 | "colab": {
360 | "base_uri": "https://localhost:8080/"
361 | },
362 | "id": "e8xzM0dZfea9",
363 | "outputId": "241648b8-5484-4220-d6a5-469bb05f5253"
364 | },
365 | "outputs": [
366 | {
367 | "name": "stdout",
368 | "output_type": "stream",
369 | "text": [
370 | "Loss computed by enumerating all possible alignments: tensor(21.8215)\n",
371 | "Loss computed using the forward algorithm: tensor([21.8215], device='cuda:0', grad_fn=)\n"
372 | ]
373 | }
374 | ],
375 | "source": [
376 | "# Generate example inputs/outputs\n",
377 | "num_outputs = len(string.ascii_uppercase) + 1 # [null, A, B, ... Z]\n",
378 | "model = Transducer(1, num_outputs)\n",
379 | "y_letters = \"CAT\"\n",
380 | "y = torch.tensor([string.ascii_uppercase.index(l) + 1 for l in y_letters]).unsqueeze(0).to(model.device)\n",
381 | "T = torch.tensor([4]); U = torch.tensor([len(y_letters)]); B = 1\n",
382 | "\n",
383 | "encoder_out = torch.randn(B, T, joiner_dim).to(model.device)\n",
384 | "predictor_out = torch.randn(B, U+1, joiner_dim).to(model.device)\n",
385 | "joiner_out = model.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)\n",
386 | "\n",
387 | "#######################################################\n",
388 | "# Compute loss by enumerating all possible alignments #\n",
389 | "#######################################################\n",
390 | "all_permutations = list(itertools.permutations([0]*(T-1) + [1]*U))\n",
391 | "all_distinct_permutations = list(Counter(all_permutations).keys())\n",
392 | "alignment_probs = []\n",
393 | "for z in all_distinct_permutations:\n",
394 | " alignment_prob = model.compute_single_alignment_prob(encoder_out[0], predictor_out[0], T.item(), U.item(), z, y[0])\n",
395 | " alignment_probs.append(alignment_prob)\n",
396 | "loss_enumerate = -torch.tensor(alignment_probs).logsumexp(0)\n",
397 | "\n",
398 | "#######################################################\n",
399 | "# Compute loss using the forward algorithm #\n",
400 | "#######################################################\n",
401 | "loss_forward = -model.compute_forward_prob(joiner_out, T, U, y)\n",
402 | "\n",
403 | "print(\"Loss computed by enumerating all possible alignments: \", loss_enumerate)\n",
404 | "print(\"Loss computed using the forward algorithm: \", loss_forward)"
405 | ]
406 | },
407 | {
408 | "cell_type": "markdown",
409 | "metadata": {
410 | "id": "WSBAwQONf3z9"
411 | },
412 | "source": [
413 | "Now let's add the greedy search algorithm for predicting an output sequence.\n",
414 | "\n",
415 | "(Note that I've assumed we're using RNNs for the predictor here. You would have to modify this code a bit if you want to use convolutions/self-attention instead.) \n",
416 | "
\n",
417 | "
"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": 9,
423 | "metadata": {
424 | "id": "V0xeyb7Jf18_"
425 | },
426 | "outputs": [],
427 | "source": [
428 | "def greedy_search(self, x, T):\n",
429 | " y_batch = []\n",
430 | " B = len(x)\n",
431 | " encoder_out = self.encoder.forward(x)\n",
432 | " U_max = 200\n",
433 | " for b in range(B):\n",
434 | " t = 0; u = 0; y = [self.predictor.start_symbol]; predictor_state = self.predictor.initial_state.unsqueeze(0)\n",
435 | " while t < T[b] and u < U_max:\n",
436 | " predictor_input = torch.tensor([ y[-1] ], device = x.device)\n",
437 | " g_u, predictor_state = self.predictor.forward_one_step(predictor_input, predictor_state)\n",
438 | " f_t = encoder_out[b, t]\n",
439 | " h_t_u = self.joiner.forward(f_t, g_u)\n",
440 | " argmax = h_t_u.max(-1)[1].item()\n",
441 | " if argmax == NULL_INDEX:\n",
442 | " t += 1\n",
443 | " else: # argmax == a label\n",
444 | " u += 1\n",
445 | " y.append(argmax)\n",
446 | " y_batch.append(y[1:]) # remove start symbol\n",
447 | " return y_batch\n",
448 | "\n",
449 | "Transducer.greedy_search = greedy_search"
450 | ]
451 | },
452 | {
453 | "cell_type": "markdown",
454 | "metadata": {
455 | "id": "82XU9-gr3goI"
456 | },
457 | "source": [
458 | "The code above will work, but training will be very slow because the Transducer loss is written in pure Python. You can use the fast implementation from SpeechBrain instead by running the block below."
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": 10,
464 | "metadata": {
465 | "colab": {
466 | "base_uri": "https://localhost:8080/"
467 | },
468 | "id": "qhUQMJ-23f2y",
469 | "outputId": "85b406cd-8cfd-4c8c-817b-0cb851531f56"
470 | },
471 | "outputs": [
472 | {
473 | "name": "stdout",
474 | "output_type": "stream",
475 | "text": [
476 | "Requirement already satisfied: speechbrain in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (0.5.16)\n",
477 | "Requirement already satisfied: hyperpyyaml in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from speechbrain) (1.2.2)\n",
478 | "Requirement already satisfied: joblib in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from speechbrain) (1.3.2)\n",
479 | "Requirement already satisfied: numpy in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from speechbrain) (1.23.1)\n",
480 | "Requirement already satisfied: packaging in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from speechbrain) (23.1)\n",
481 | "Requirement already satisfied: scipy in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from speechbrain) (1.12.0)\n",
482 | "Requirement already satisfied: sentencepiece in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from speechbrain) (0.1.99)\n",
483 | "Requirement already satisfied: torch>=1.9 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from speechbrain) (2.2.0)\n",
484 | "Requirement already satisfied: torchaudio in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from speechbrain) (2.2.0)\n",
485 | "Requirement already satisfied: tqdm in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from speechbrain) (4.66.1)\n",
486 | "Requirement already satisfied: huggingface-hub in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from speechbrain) (0.20.3)\n",
487 | "Requirement already satisfied: filelock in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from torch>=1.9->speechbrain) (3.13.1)\n",
488 | "Requirement already satisfied: typing-extensions>=4.8.0 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from torch>=1.9->speechbrain) (4.9.0)\n",
489 | "Requirement already satisfied: sympy in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from torch>=1.9->speechbrain) (1.12)\n",
490 | "Requirement already satisfied: networkx in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from torch>=1.9->speechbrain) (3.1)\n",
491 | "Requirement already satisfied: jinja2 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from torch>=1.9->speechbrain) (3.1.3)\n",
492 | "Requirement already satisfied: fsspec in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from torch>=1.9->speechbrain) (2024.2.0)\n",
493 | "Requirement already satisfied: requests in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from huggingface-hub->speechbrain) (2.31.0)\n",
494 | "Requirement already satisfied: pyyaml>=5.1 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from huggingface-hub->speechbrain) (6.0.1)\n",
495 | "Requirement already satisfied: ruamel.yaml>=0.17.28 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from hyperpyyaml->speechbrain) (0.18.6)\n",
496 | "Requirement already satisfied: ruamel.yaml.clib>=0.2.7 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from ruamel.yaml>=0.17.28->hyperpyyaml->speechbrain) (0.2.8)\n",
497 | "Requirement already satisfied: MarkupSafe>=2.0 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from jinja2->torch>=1.9->speechbrain) (2.1.3)\n",
498 | "Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from requests->huggingface-hub->speechbrain) (2.0.4)\n",
499 | "Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from requests->huggingface-hub->speechbrain) (3.4)\n",
500 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from requests->huggingface-hub->speechbrain) (2.1.0)\n",
501 | "Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from requests->huggingface-hub->speechbrain) (2024.2.2)\n",
502 | "Requirement already satisfied: mpmath>=0.19 in /home/ubuntu/miniconda3/envs/transducer_fix/lib/python3.9/site-packages (from sympy->torch>=1.9->speechbrain) (1.3.0)\n"
503 | ]
504 | }
505 | ],
506 | "source": [
507 | "!pip install speechbrain\n",
508 | "# you may also need numba if it is not installed on your system, you can install it with conda:\n",
509 | "#!conda install -y numba cudatoolkit=9.0\n",
510 | "# you may also have to downgrade numpy in order to account for numba requirements:\n",
511 | "#!conda install -y numpy=1.21.1\n",
512 | "\n",
513 | "from speechbrain.nnet.loss.transducer_loss import TransducerLoss\n",
514 | "transducer_loss = TransducerLoss(0)\n",
515 | "\n",
516 | "def compute_loss(self, x, y, T, U):\n",
517 | " encoder_out = self.encoder.forward(x)\n",
518 | " predictor_out = self.predictor.forward(y)\n",
519 | " joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)\n",
520 | " #loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()\n",
521 | " T = T.to(joiner_out.device)\n",
522 | " U = U.to(joiner_out.device)\n",
523 | " loss = transducer_loss(joiner_out, y, T, U) #, blank_index=NULL_INDEX, reduction=\"mean\")\n",
524 | " return loss\n",
525 | "\n",
526 | "Transducer.compute_loss = compute_loss"
527 | ]
528 | },
529 | {
530 | "cell_type": "markdown",
531 | "metadata": {
532 | "id": "Ff9raB0jVGzN"
533 | },
534 | "source": [
535 | "# Some utilities\n",
536 | "\n",
537 | "Here we will add a bit of boilerplate code for training and loading data."
538 | ]
539 | },
540 | {
541 | "cell_type": "code",
542 | "execution_count": 11,
543 | "metadata": {
544 | "colab": {
545 | "base_uri": "https://localhost:8080/"
546 | },
547 | "id": "5b17OQm4WdVy",
548 | "outputId": "25498f94-9543-40f8-dd1e-fbfd7857debf"
549 | },
550 | "outputs": [
551 | {
552 | "data": {
553 | "text/plain": [
554 | "('\"Wll, Prnc, s Gn nd Lcc r nw jst fmly stts f th',\n",
555 | " '\"Well, Prince, so Genoa and Lucca are now just family estates of the')"
556 | ]
557 | },
558 | "execution_count": 11,
559 | "metadata": {},
560 | "output_type": "execute_result"
561 | }
562 | ],
563 | "source": [
564 | "class TextDataset(torch.utils.data.Dataset):\n",
565 | " def __init__(self, lines, batch_size):\n",
566 | " lines = list(filter((\"\\n\").__ne__, lines))\n",
567 | "\n",
568 | " self.lines = lines # list of strings\n",
569 | " collate = Collate()\n",
570 | " self.loader = torch.utils.data.DataLoader(self, batch_size=batch_size, num_workers=0, shuffle=True, collate_fn=collate)\n",
571 | "\n",
572 | " def __len__(self):\n",
573 | " return len(self.lines)\n",
574 | "\n",
575 | " def __getitem__(self, idx):\n",
576 | " line = self.lines[idx].replace(\"\\n\", \"\")\n",
577 | " line = unidecode.unidecode(line) # remove special characters\n",
578 | " x = \"\".join(c for c in line if c not in \"AEIOUaeiou\") # remove vowels from input\n",
579 | " y = line\n",
580 | " return (x,y)\n",
581 | "\n",
582 | "def encode_string(s):\n",
583 | " for c in s:\n",
584 | " if c not in string.printable:\n",
585 | " print(s)\n",
586 | " return [string.printable.index(c) + 1 for c in s]\n",
587 | "\n",
588 | "def decode_labels(l):\n",
589 | " return \"\".join([string.printable[c - 1] for c in l])\n",
590 | "\n",
591 | "class Collate:\n",
592 | " def __call__(self, batch):\n",
593 | " \"\"\"\n",
594 | " batch: list of tuples (input string, output string)\n",
595 | " Returns a minibatch of strings, encoded as labels and padded to have the same length.\n",
596 | " \"\"\"\n",
597 | " x = []; y = []\n",
598 | " batch_size = len(batch)\n",
599 | " for index in range(batch_size):\n",
600 | " x_,y_ = batch[index]\n",
601 | " x.append(encode_string(x_))\n",
602 | " y.append(encode_string(y_))\n",
603 | "\n",
604 | " # pad all sequences to have same length\n",
605 | " T = [len(x_) for x_ in x]\n",
606 | " U = [len(y_) for y_ in y]\n",
607 | " T_max = max(T)\n",
608 | " U_max = max(U)\n",
609 | " for index in range(batch_size):\n",
610 | " x[index] += [NULL_INDEX] * (T_max - len(x[index]))\n",
611 | " x[index] = torch.tensor(x[index])\n",
612 | " y[index] += [NULL_INDEX] * (U_max - len(y[index]))\n",
613 | " y[index] = torch.tensor(y[index])\n",
614 | "\n",
615 | " # stack into single tensor\n",
616 | " x = torch.stack(x)\n",
617 | " y = torch.stack(y)\n",
618 | " T = torch.tensor(T)\n",
619 | " U = torch.tensor(U)\n",
620 | "\n",
621 | " return (x,y,T,U)\n",
622 | "\n",
623 | "with open(\"war_and_peace.txt\", \"r\") as f:\n",
624 | " lines = f.readlines()\n",
625 | "\n",
626 | "end = round(0.9 * len(lines))\n",
627 | "train_lines = lines[:end]\n",
628 | "test_lines = lines[end:]\n",
629 | "train_set = TextDataset(train_lines, batch_size=64) #8)\n",
630 | "test_set = TextDataset(test_lines, batch_size=64) #8)\n",
631 | "train_set.__getitem__(0)"
632 | ]
633 | },
634 | {
635 | "cell_type": "code",
636 | "execution_count": 12,
637 | "metadata": {
638 | "id": "gaZEQYzfFEQ0"
639 | },
640 | "outputs": [],
641 | "source": [
642 | "class Trainer:\n",
643 | " def __init__(self, model, lr):\n",
644 | " self.model = model\n",
645 | " self.lr = lr\n",
646 | " self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)\n",
647 | " \n",
648 | " def train(self, dataset, print_interval = 20):\n",
649 | " train_loss = 0\n",
650 | " num_samples = 0\n",
651 | " self.model.train()\n",
652 | " pbar = tqdm(dataset.loader)\n",
653 | " for idx, batch in enumerate(pbar):\n",
654 | " x,y,T,U = batch\n",
655 | " x = x.to(self.model.device); y = y.to(self.model.device)\n",
656 | " batch_size = len(x)\n",
657 | " num_samples += batch_size\n",
658 | " loss = self.model.compute_loss(x,y,T,U)\n",
659 | " self.optimizer.zero_grad()\n",
660 | " pbar.set_description(\"%.2f\" % loss.item())\n",
661 | " loss.backward()\n",
662 | " self.optimizer.step()\n",
663 | " train_loss += loss.item() * batch_size\n",
664 | " if idx % print_interval == 0:\n",
665 | " self.model.eval()\n",
666 | " guesses = self.model.greedy_search(x,T)\n",
667 | " self.model.train()\n",
668 | " print(\"\\n\")\n",
669 | " for b in range(2):\n",
670 | " print(\"input:\", decode_labels(x[b,:T[b]]))\n",
671 | " print(\"guess:\", decode_labels(guesses[b]))\n",
672 | " print(\"truth:\", decode_labels(y[b,:U[b]]))\n",
673 | " print(\"\")\n",
674 | " train_loss /= num_samples\n",
675 | " return train_loss\n",
676 | "\n",
677 | " def test(self, dataset, print_interval=1):\n",
678 | " test_loss = 0\n",
679 | " num_samples = 0\n",
680 | " self.model.eval()\n",
681 | " pbar = tqdm(dataset.loader)\n",
682 | " with torch.no_grad():\n",
683 | " for idx, batch in enumerate(pbar):\n",
684 | " x,y,T,U = batch\n",
685 | " x = x.to(self.model.device); y = y.to(self.model.device)\n",
686 | " batch_size = len(x)\n",
687 | " num_samples += batch_size\n",
688 | " loss = self.model.compute_loss(x,y,T,U)\n",
689 | " pbar.set_description(\"%.2f\" % loss.item())\n",
690 | " test_loss += loss.item() * batch_size\n",
691 | " if idx % print_interval == 0:\n",
692 | " print(\"\\n\")\n",
693 | " print(\"input:\", decode_labels(x[0,:T[0]]))\n",
694 | " print(\"guess:\", decode_labels(self.model.greedy_search(x,T)[0]))\n",
695 | " print(\"truth:\", decode_labels(y[0,:U[0]]))\n",
696 | " print(\"\")\n",
697 | " test_loss /= num_samples\n",
698 | " return test_loss\n",
699 | " "
700 | ]
701 | },
702 | {
703 | "cell_type": "markdown",
704 | "metadata": {
705 | "id": "q4PupgBKWe6p"
706 | },
707 | "source": [
708 | "# Training the model\n",
709 | "\n",
710 | "Now we will train a model. This will generate some output sequences every 20 batches."
711 | ]
712 | },
713 | {
714 | "cell_type": "code",
715 | "execution_count": 13,
716 | "metadata": {
717 | "colab": {
718 | "base_uri": "https://localhost:8080/"
719 | },
720 | "id": "1TSrbH9xGPEC",
721 | "outputId": "102f780d-145c-481f-d667-1cd2b76d6111"
722 | },
723 | "outputs": [
724 | {
725 | "name": "stderr",
726 | "output_type": "stream",
727 | "text": [
728 | "9.62: 0%| | 1/709 [00:02<26:40, 2.26s/it]"
729 | ]
730 | },
731 | {
732 | "name": "stdout",
733 | "output_type": "stream",
734 | "text": [
735 | "\n",
736 | "\n",
737 | "input: g wy.\n",
738 | "guess: \n",
739 | "truth: go away.\n",
740 | "\n",
741 | "input: \" knw frm rlbl srcs tht th Dwgr mprss s tkng kn\n",
742 | "guess: \n",
743 | "truth: \"I know from reliable sources that the Dowager Empress is taking a keen\n",
744 | "\n"
745 | ]
746 | },
747 | {
748 | "name": "stderr",
749 | "output_type": "stream",
750 | "text": [
751 | "3.46: 3%|█ | 21/709 [00:22<22:19, 1.95s/it]"
752 | ]
753 | },
754 | {
755 | "name": "stdout",
756 | "output_type": "stream",
757 | "text": [
758 | "\n",
759 | "\n",
760 | "input: t f ths pstn; tht hr ws th Tln tht wld lft hm frm\n",
761 | "guess: \n",
762 | "truth: out of this position; that here was the Toulon that would lift him from\n",
763 | "\n",
764 | "input: rthr rnd shldrs nd glncd t Nvsltsv wh ws nr hm, s\n",
765 | "guess: \n",
766 | "truth: rather round shoulders and glanced at Novosiltsev who was near him, as\n",
767 | "\n"
768 | ]
769 | },
770 | {
771 | "name": "stderr",
772 | "output_type": "stream",
773 | "text": [
774 | "1.81: 6%|██ | 41/709 [00:47<34:44, 3.12s/it]"
775 | ]
776 | },
777 | {
778 | "name": "stdout",
779 | "output_type": "stream",
780 | "text": [
781 | "\n",
782 | "\n",
783 | "input: thm tht sh mght pry fr thm. Bt nthr cld sh dbt th\n",
784 | "guess: the th sh mh pprey fof the. to nother cold sh dad th\n",
785 | "truth: them that she might pray for them. But neither could she doubt the\n",
786 | "\n",
787 | "input: fr dng. H mgnd ll srts f pssbl cntngncs, jst lk\n",
788 | "guess: fofre dad momand lol sert fof ps conting, ast lol\n",
789 | "truth: for doing. He imagined all sorts of possible contingencies, just like\n",
790 | "\n"
791 | ]
792 | },
793 | {
794 | "name": "stderr",
795 | "output_type": "stream",
796 | "text": [
797 | "1.18: 9%|███ | 61/709 [01:12<33:48, 3.13s/it]"
798 | ]
799 | },
800 | {
801 | "name": "stdout",
802 | "output_type": "stream",
803 | "text": [
804 | "\n",
805 | "\n",
806 | "input: mprr, wld stbbrnly nsst n th crrctnss r flsty f sm\n",
807 | "guess: mere, wald steberinly nost no th corctins ree folsty of some\n",
808 | "truth: Emperor, would stubbornly insist on the correctness or falsity of some\n",
809 | "\n",
810 | "input: Prr hd mt n Ptrsbrg scty. n th Prsdnt's chr st \n",
811 | "guess: Pere had mat no Pitersbrig sicy. no th Pirsedint's cher st \n",
812 | "truth: Pierre had met in Petersburg society. In the President's chair sat a\n",
813 | "\n"
814 | ]
815 | },
816 | {
817 | "name": "stderr",
818 | "output_type": "stream",
819 | "text": [
820 | "0.83: 11%|████ | 81/709 [01:36<32:28, 3.10s/it]"
821 | ]
822 | },
823 | {
824 | "name": "stdout",
825 | "output_type": "stream",
826 | "text": [
827 | "\n",
828 | "\n",
829 | "input: shld mk t gd trgt fr th Frnch, bsds m frd \n",
830 | "guess: sholed me to god a terg fer the Frinch, buseds a me fared \n",
831 | "truth: should make too good a target for the French, besides I am afraid I\n",
832 | "\n",
833 | "input: lnchng wth hm nd Brs.\n",
834 | "guess: lonching with him and Bors.\n",
835 | "truth: lunching with him and Boris.\n",
836 | "\n"
837 | ]
838 | },
839 | {
840 | "name": "stderr",
841 | "output_type": "stream",
842 | "text": [
843 | "0.72: 14%|████▉ | 101/709 [02:00<30:27, 3.01s/it]"
844 | ]
845 | },
846 | {
847 | "name": "stdout",
848 | "output_type": "stream",
849 | "text": [
850 | "\n",
851 | "\n",
852 | "input: hr trnk hd bn tkn dwn frm ts crt, nd ll th lds wr\n",
853 | "guess: her trenk had been taken don for ites cort, and lil th lod were\n",
854 | "truth: her trunk had been taken down from its cart, and all the loads were\n",
855 | "\n",
856 | "input: spcl grtty fr hs fthfl srvcs, wll knwng tht t tht\n",
857 | "guess: spicl garty for his fothifl serices, wil knwing that to that\n",
858 | "truth: special gratuity for his faithful services, well knowing that at that\n",
859 | "\n"
860 | ]
861 | },
862 | {
863 | "name": "stderr",
864 | "output_type": "stream",
865 | "text": [
866 | "0.66: 17%|█████▉ | 121/709 [02:18<15:06, 1.54s/it]"
867 | ]
868 | },
869 | {
870 | "name": "stdout",
871 | "output_type": "stream",
872 | "text": [
873 | "\n",
874 | "\n",
875 | "input: s th nly thng w nd fr, r h s chngng hs pstn.\" (H\n",
876 | "guess: as the noly thing w and for, rea he as chinging his postion.\" ((H\n",
877 | "truth: is the only thing we need fear, or he is changing his position.\" (He\n",
878 | "\n",
879 | "input: stt f xctmnt s whn sh hd slmmd th dr f hr rm.\n",
880 | "guess: st of exctemant as when sh had seled the dor of her rom.\n",
881 | "truth: state of excitement as when she had slammed the door of her room.\n",
882 | "\n"
883 | ]
884 | },
885 | {
886 | "name": "stderr",
887 | "output_type": "stream",
888 | "text": [
889 | "0.51: 20%|██████▉ | 141/709 [02:33<14:35, 1.54s/it]"
890 | ]
891 | },
892 | {
893 | "name": "stdout",
894 | "output_type": "stream",
895 | "text": [
896 | "\n",
897 | "\n",
898 | "input: s, n lks dctn nd wll-brd ppl.' t s fr yr sk \n",
899 | "guess: see, in lokes deacten ane well-bered peple. to is for your ske a\n",
900 | "truth: see, one likes education and well-bred people.' It is for your sake I\n",
901 | "\n",
902 | "input: Nchls dd nt thnk t th prpr thng t vst hr gn; bt ll\n",
903 | "guess: Nichal did not think to th preper thing to vist her gon; but al\n",
904 | "truth: Nicholas did not think it the proper thing to visit her again; but all\n",
905 | "\n"
906 | ]
907 | },
908 | {
909 | "name": "stderr",
910 | "output_type": "stream",
911 | "text": [
912 | "0.51: 23%|███████▉ | 161/709 [02:48<13:35, 1.49s/it]"
913 | ]
914 | },
915 | {
916 | "name": "stdout",
917 | "output_type": "stream",
918 | "text": [
919 | "\n",
920 | "\n",
921 | "input: th prvncs t dstnc frm Mscw, lds, nd gntlmn n\n",
922 | "guess: th pevances it a distance for Mosow, lods, ane gentleman in\n",
923 | "truth: the provinces at a distance from Moscow, ladies, and gentlemen in\n",
924 | "\n",
925 | "input: mnd tht ws lrdy trd f t ll, nd tht w mst ll d. nc\n",
926 | "guess: moned that a was alead tered of it al, ane that w most al doo. noce\n",
927 | "truth: mind that I was already tired of it all, and that we must all die. Once\n",
928 | "\n"
929 | ]
930 | },
931 | {
932 | "name": "stderr",
933 | "output_type": "stream",
934 | "text": [
935 | "0.51: 26%|████████▉ | 181/709 [03:03<12:59, 1.48s/it]"
936 | ]
937 | },
938 | {
939 | "name": "stdout",
940 | "output_type": "stream",
941 | "text": [
942 | "\n",
943 | "\n",
944 | "input: tht h ws t th frnt nd n th pwr f mn twrd whm h nw\n",
945 | "guess: that he was it th forint ane in th por of a ma towared whm he now\n",
946 | "truth: that he was at the front and in the power of a man toward whom he now\n",
947 | "\n",
948 | "input: \"Bt hw s t th dctr frm Mscw s nt hr yt?\" sd th\n",
949 | "guess: \"But how is it th decter from Mosw is not her yet?\" said th\n",
950 | "truth: \"But how is it the doctor from Moscow is not here yet?\" said the\n",
951 | "\n"
952 | ]
953 | },
954 | {
955 | "name": "stderr",
956 | "output_type": "stream",
957 | "text": [
958 | "0.47: 28%|█████████▉ | 201/709 [03:17<12:20, 1.46s/it]"
959 | ]
960 | },
961 | {
962 | "name": "stdout",
963 | "output_type": "stream",
964 | "text": [
965 | "\n",
966 | "\n",
967 | "input: ntrm Prnc ndrw hrd ld vcs nd rngng stcct lgh--\n",
968 | "guess: inter Price und hered lod voces ane a renging stect lagh--\n",
969 | "truth: anteroom Prince Andrew heard loud voices and a ringing staccato laugh--a\n",
970 | "\n",
971 | "input: nnr lbrtry f hs mnd n prtbl frm s f ntntnlly, s\n",
972 | "guess: nener liberty of his mined in a peatable from as of intentenlly, as\n",
973 | "truth: inner laboratory of his mind in a portable form as if intentionally, so\n",
974 | "\n"
975 | ]
976 | },
977 | {
978 | "name": "stderr",
979 | "output_type": "stream",
980 | "text": [
981 | "0.41: 31%|██████████▉ | 221/709 [03:32<12:25, 1.53s/it]"
982 | ]
983 | },
984 | {
985 | "name": "stdout",
986 | "output_type": "stream",
987 | "text": [
988 | "\n",
989 | "\n",
990 | "input: prvt rm nd, cllng hs djtnt, skd fr sm pprs rltng\n",
991 | "guess: perite room ane, coling his adant, aske for sam apeas relating\n",
992 | "truth: private room and, calling his adjutant, asked for some papers relating\n",
993 | "\n",
994 | "input: th rd t th plc t whch h hd bn snt, h hd gllpd t th\n",
995 | "guess: th rod to th place it which he had been sent, he had gloped to the\n",
996 | "truth: the road to the place to which he had been sent, he had galloped to the\n",
997 | "\n"
998 | ]
999 | },
1000 | {
1001 | "name": "stderr",
1002 | "output_type": "stream",
1003 | "text": [
1004 | "0.40: 34%|███████████▉ | 241/709 [03:47<11:36, 1.49s/it]"
1005 | ]
1006 | },
1007 | {
1008 | "name": "stdout",
1009 | "output_type": "stream",
1010 | "text": [
1011 | "\n",
1012 | "\n",
1013 | "input: Drng th hrs f sltd, sffrng, nd prtl dlrm h spnt\n",
1014 | "guess: Deang th here of siled, sufering, ane peatel delorom he spent\n",
1015 | "truth: During the hours of solitude, suffering, and partial delirium he spent\n",
1016 | "\n",
1017 | "input: fxd n chsn spt, vgrsly hldng ts tl rct, shnng nd\n",
1018 | "guess: feed in a chesion spet, vigersly holing ites tel recot, shing ane\n",
1019 | "truth: fixed in a chosen spot, vigorously holding its tail erect, shining and\n",
1020 | "\n"
1021 | ]
1022 | },
1023 | {
1024 | "name": "stderr",
1025 | "output_type": "stream",
1026 | "text": [
1027 | "0.40: 37%|████████████▉ | 261/709 [04:02<10:02, 1.34s/it]"
1028 | ]
1029 | },
1030 | {
1031 | "name": "stdout",
1032 | "output_type": "stream",
1033 | "text": [
1034 | "\n",
1035 | "\n",
1036 | "input: t nc shwd tslf by cntlss sgns. Ths sgns wr: Lrstn's\n",
1037 | "guess: to noce showed itel by countles sige. This sig wer: Lesion's\n",
1038 | "truth: at once showed itself by countless signs. These signs were: Lauriston's\n",
1039 | "\n",
1040 | "input: prtns ccrdng t hm wr cnnctd wth sch cmplctd\n",
1041 | "guess: peatens acoriding to him were conced with suh copliced\n",
1042 | "truth: operations according to him were connected with such complicated\n",
1043 | "\n"
1044 | ]
1045 | },
1046 | {
1047 | "name": "stderr",
1048 | "output_type": "stream",
1049 | "text": [
1050 | "0.35: 40%|█████████████▊ | 281/709 [04:15<08:16, 1.16s/it]"
1051 | ]
1052 | },
1053 | {
1054 | "name": "stdout",
1055 | "output_type": "stream",
1056 | "text": [
1057 | "\n",
1058 | "\n",
1059 | "input: \"nd ths s grttd--ths s rcgntn fr ths wh hv\n",
1060 | "guess: \"Ane thes is greted--thes is recogantion for thes who have\n",
1061 | "truth: \"And this is gratitude--this is recognition for those who have\n",
1062 | "\n",
1063 | "input: t th plc, hwvr, n ffcl cm rnnng t t mt hm, nd\n",
1064 | "guess: at the pale, hower, in ofical came rining at to meet him, ane\n",
1065 | "truth: to the palace, however, an official came running out to meet him, and\n",
1066 | "\n"
1067 | ]
1068 | },
1069 | {
1070 | "name": "stderr",
1071 | "output_type": "stream",
1072 | "text": [
1073 | "0.30: 42%|██████████████▊ | 301/709 [04:28<08:03, 1.19s/it]"
1074 | ]
1075 | },
1076 | {
1077 | "name": "stdout",
1078 | "output_type": "stream",
1079 | "text": [
1080 | "\n",
1081 | "\n",
1082 | "input: bng n nfrm mprtd smthng strng nd fntstc t ths\n",
1083 | "guess: being no unifor impored something stering ane fainstic to thes\n",
1084 | "truth: being in uniform imparted something strange and fantastic to these\n",
1085 | "\n",
1086 | "input: srprs, s h prtndd nt t s d Bsst nd clld Fbvr t\n",
1087 | "guess: sers, so he petened not it is do Bes ane calle Fabover to\n",
1088 | "truth: a surprise, so he pretended not to see de Beausset and called Fabvier to\n",
1089 | "\n"
1090 | ]
1091 | },
1092 | {
1093 | "name": "stderr",
1094 | "output_type": "stream",
1095 | "text": [
1096 | "0.31: 45%|███████████████▊ | 321/709 [04:41<07:19, 1.13s/it]"
1097 | ]
1098 | },
1099 | {
1100 | "name": "stdout",
1101 | "output_type": "stream",
1102 | "text": [
1103 | "\n",
1104 | "\n",
1105 | "input: Th prsnrs wr plcd n crtn rdr, ccrdng t th lst\n",
1106 | "guess: The pesoners were pled in a caraten orie, acriding to the last\n",
1107 | "truth: The prisoners were placed in a certain order, according to the list\n",
1108 | "\n",
1109 | "input: t, bt ch tm sh lkd t th prcssn hr ys sght tht\n",
1110 | "guess: it, but ch tem she loked to the perision her eyes siht that\n",
1111 | "truth: it, but each time she looked at the procession her eyes sought that\n",
1112 | "\n"
1113 | ]
1114 | },
1115 | {
1116 | "name": "stderr",
1117 | "output_type": "stream",
1118 | "text": [
1119 | "0.28: 48%|████████████████▊ | 341/709 [04:54<06:53, 1.12s/it]"
1120 | ]
1121 | },
1122 | {
1123 | "name": "stdout",
1124 | "output_type": "stream",
1125 | "text": [
1126 | "\n",
1127 | "\n",
1128 | "input: xprssn, vdntly hbtl t hm whn cnvrsng wth wmn.\n",
1129 | "guess: exrission, vedently hable to him when conersing with woman.\n",
1130 | "truth: expression, evidently habitual to him when conversing with women.\n",
1131 | "\n",
1132 | "input: pt ths wh dd nt tk n ctv prt n th ffrs f th ldgs\n",
1133 | "guess: put thes wh did not took in acte peat in th afers of th lods\n",
1134 | "truth: put those who did not take an active part in the affairs of the lodges\n",
1135 | "\n"
1136 | ]
1137 | },
1138 | {
1139 | "name": "stderr",
1140 | "output_type": "stream",
1141 | "text": [
1142 | "0.30: 51%|█████████████████▊ | 361/709 [05:07<06:33, 1.13s/it]"
1143 | ]
1144 | },
1145 | {
1146 | "name": "stdout",
1147 | "output_type": "stream",
1148 | "text": [
1149 | "\n",
1150 | "\n",
1151 | "input: n th hghr cmmnd thr wr tw shrply dfnd prts: Ktzv's\n",
1152 | "guess: In th hige comaned ther were tw shery defened peates: Kutov's\n",
1153 | "truth: In the higher command there were two sharply defined parties: Kutuzov's\n",
1154 | "\n",
1155 | "input: m, Prr. shll nt frgt yr ntrsts.\"\n",
1156 | "guess: me, Piere. I shel not foret your inters.\"\n",
1157 | "truth: me, Pierre. I shall not forget your interests.\"\n",
1158 | "\n"
1159 | ]
1160 | },
1161 | {
1162 | "name": "stderr",
1163 | "output_type": "stream",
1164 | "text": [
1165 | "0.29: 54%|██████████████████▊ | 381/709 [05:20<06:16, 1.15s/it]"
1166 | ]
1167 | },
1168 | {
1169 | "name": "stdout",
1170 | "output_type": "stream",
1171 | "text": [
1172 | "\n",
1173 | "\n",
1174 | "input: whch cn cnvnc m, dr frnd--lf nd dth r wht cnvnc.\n",
1175 | "guess: which can conce me, dear fried-life ane death are what conce.\n",
1176 | "truth: which can convince me, dear friend--life and death are what convince.\n",
1177 | "\n",
1178 | "input: N, 'll tll y nw. Y knw Sny's my drst frnd. Sch frnd\n",
1179 | "guess: No, I'l tel you now. You kn Sonys my dear fried. Suc a friened\n",
1180 | "truth: No, I'll tell you now. You know Sonya's my dearest friend. Such a friend\n",
1181 | "\n"
1182 | ]
1183 | },
1184 | {
1185 | "name": "stderr",
1186 | "output_type": "stream",
1187 | "text": [
1188 | "0.33: 57%|███████████████████▊ | 401/709 [05:33<05:45, 1.12s/it]"
1189 | ]
1190 | },
1191 | {
1192 | "name": "stdout",
1193 | "output_type": "stream",
1194 | "text": [
1195 | "\n",
1196 | "\n",
1197 | "input: thrshd hm bcs h hd t gt n s qckly s pssbl. nd ,\"\n",
1198 | "guess: thrished him becase he had to got in is quicly is posible. ane I,\n",
1199 | "truth: thrashed him because he had to get on as quickly as possible. And I,\"\n",
1200 | "\n",
1201 | "input: sm t hv sm gnrlly sy. Thy sd tht ndbtdly wr,\n",
1202 | "guess: sam to have sam generally sa. They said that unedly were,\n",
1203 | "truth: seem to have some generally say. They said that undoubtedly war,\n",
1204 | "\n"
1205 | ]
1206 | },
1207 | {
1208 | "name": "stderr",
1209 | "output_type": "stream",
1210 | "text": [
1211 | "0.26: 59%|████████████████████▊ | 421/709 [05:46<05:44, 1.20s/it]"
1212 | ]
1213 | },
1214 | {
1215 | "name": "stdout",
1216 | "output_type": "stream",
1217 | "text": [
1218 | "\n",
1219 | "\n",
1220 | "input: hwvr, fr h wll lv s gn t tk prt n ths nhppy wr\n",
1221 | "guess: hower, fea he wil love is gan to take peat in thes unhapy were\n",
1222 | "truth: however, for he will leave us again to take part in this unhappy war\n",
1223 | "\n",
1224 | "input: hs nd bck gn, shtng cnfsd nstrctns t th hrryng\n",
1225 | "guess: his ane baco again, shating consed intericons to the hurying\n",
1226 | "truth: house and back again, shouting confused instructions to the hurrying\n",
1227 | "\n"
1228 | ]
1229 | },
1230 | {
1231 | "name": "stderr",
1232 | "output_type": "stream",
1233 | "text": [
1234 | "0.31: 62%|█████████████████████▊ | 441/709 [06:07<15:04, 3.38s/it]"
1235 | ]
1236 | },
1237 | {
1238 | "name": "stdout",
1239 | "output_type": "stream",
1240 | "text": [
1241 | "\n",
1242 | "\n",
1243 | "input: drwng rm s wt.\"\n",
1244 | "guess: dring rom is wit.\"\n",
1245 | "truth: drawing room so wet.\"\n",
1246 | "\n",
1247 | "input: gv n frthr rdrs nd slntly cntnd t wlk n n frnt f th\n",
1248 | "guess: geve in father ories ane silently contened it wal no no front of the\n",
1249 | "truth: gave no further orders and silently continued to walk on in front of the\n",
1250 | "\n"
1251 | ]
1252 | },
1253 | {
1254 | "name": "stderr",
1255 | "output_type": "stream",
1256 | "text": [
1257 | "0.23: 65%|██████████████████████▊ | 461/709 [06:33<14:07, 3.42s/it]"
1258 | ]
1259 | },
1260 | {
1261 | "name": "stdout",
1262 | "output_type": "stream",
1263 | "text": [
1264 | "\n",
1265 | "\n",
1266 | "input: srvng, nd msndrstndng rs. Th tw cmmndrs wr mch\n",
1267 | "guess: seng, ane a menedersteng rose. The tw comaneders were muh\n",
1268 | "truth: serving, and a misunderstanding arose. The two commanders were much\n",
1269 | "\n",
1270 | "input: hm. Bndrchk's hrs swrvd nd gllpd pst.\n",
1271 | "guess: him. Bondarchk's hore swered ane gloped pas.\n",
1272 | "truth: him. Bondarchuk's horse swerved and galloped past.\n",
1273 | "\n"
1274 | ]
1275 | },
1276 | {
1277 | "name": "stderr",
1278 | "output_type": "stream",
1279 | "text": [
1280 | "0.22: 68%|███████████████████████▋ | 481/709 [06:58<12:06, 3.19s/it]"
1281 | ]
1282 | },
1283 | {
1284 | "name": "stdout",
1285 | "output_type": "stream",
1286 | "text": [
1287 | "\n",
1288 | "\n",
1289 | "input: \"Ys, f t nly wr s!\" sd Prnc ndrw. \"Hwvr, t s tm t\n",
1290 | "guess: \"Yes, if it ony were is! sa Pric nder. \"Hower, it is tem to\n",
1291 | "truth: \"Yes, if it only were so!\" said Prince Andrew. \"However, it is time to\n",
1292 | "\n",
1293 | "input: \" m sd t smthng swt. Rsns, fn ns... tk thm ll!\" h\n",
1294 | "guess: \"I am side to something swee. Rusian, fain nose. took them al!\" he\n",
1295 | "truth: \"I am used to something sweet. Raisins, fine ones... take them all!\" he\n",
1296 | "\n"
1297 | ]
1298 | },
1299 | {
1300 | "name": "stderr",
1301 | "output_type": "stream",
1302 | "text": [
1303 | "0.20: 71%|████████████████████████▋ | 501/709 [07:24<12:13, 3.53s/it]"
1304 | ]
1305 | },
1306 | {
1307 | "name": "stdout",
1308 | "output_type": "stream",
1309 | "text": [
1310 | "\n",
1311 | "\n",
1312 | "input: \"G, dr,\" sd Prncss Mry.\n",
1313 | "guess: \"Go, dear, sa Prins Mary.\n",
1314 | "truth: \"Go, dear,\" said Princess Mary.\n",
1315 | "\n",
1316 | "input: f w nd vn srvlty. dfrnc sch s h hd nvr bfr\n",
1317 | "guess: of w ane ven serily. I deferec suh is he had never befer\n",
1318 | "truth: of awe and even servility. A deference such as he had never before\n",
1319 | "\n"
1320 | ]
1321 | },
1322 | {
1323 | "name": "stderr",
1324 | "output_type": "stream",
1325 | "text": [
1326 | "0.17: 73%|█████████████████████████▋ | 521/709 [07:51<11:34, 3.70s/it]"
1327 | ]
1328 | },
1329 | {
1330 | "name": "stdout",
1331 | "output_type": "stream",
1332 | "text": [
1333 | "\n",
1334 | "\n",
1335 | "input: Th prncss rstd hr br rnd rm n lttl tbl nd cnsdrd\n",
1336 | "guess: The penes resed her ber rouned rom in a litle table ane coned\n",
1337 | "truth: The princess rested her bare round arm on a little table and considered\n",
1338 | "\n",
1339 | "input: dmnstrtns f lv f Brnn.\n",
1340 | "guess: demensteritens of love of Borin.\n",
1341 | "truth: demonstrations of love of Bourienne.\n",
1342 | "\n"
1343 | ]
1344 | },
1345 | {
1346 | "name": "stderr",
1347 | "output_type": "stream",
1348 | "text": [
1349 | "0.30: 76%|██████████████████████████▋ | 541/709 [08:10<05:52, 2.10s/it]"
1350 | ]
1351 | },
1352 | {
1353 | "name": "stdout",
1354 | "output_type": "stream",
1355 | "text": [
1356 | "\n",
1357 | "\n",
1358 | "input: hr.\n",
1359 | "guess: her.\n",
1360 | "truth: her.\n",
1361 | "\n",
1362 | "input: s fr s Mscw nd t sty wth hm thr. Mtng cmrd t th\n",
1363 | "guess: so fer so Mosow ane to sta with him ther. Meeting a comered to the\n",
1364 | "truth: as far as Moscow and to stay with him there. Meeting a comrade at the\n",
1365 | "\n"
1366 | ]
1367 | },
1368 | {
1369 | "name": "stderr",
1370 | "output_type": "stream",
1371 | "text": [
1372 | "0.22: 79%|███████████████████████████▋ | 561/709 [08:36<07:57, 3.23s/it]"
1373 | ]
1374 | },
1375 | {
1376 | "name": "stdout",
1377 | "output_type": "stream",
1378 | "text": [
1379 | "\n",
1380 | "\n",
1381 | "input: slf-pssssn mngld t tht mmnt n Rstv's sl.\n",
1382 | "guess: sef-posion minged to that ment in Rostov's sol.\n",
1383 | "truth: self-possession mingled at that moment in Rostov's soul.\n",
1384 | "\n",
1385 | "input: vdntly dd nt xpct.\n",
1386 | "guess: evidently did not expecoat.\n",
1387 | "truth: evidently did not expect.\n",
1388 | "\n"
1389 | ]
1390 | },
1391 | {
1392 | "name": "stderr",
1393 | "output_type": "stream",
1394 | "text": [
1395 | "0.24: 82%|████████████████████████████▋ | 581/709 [09:02<07:24, 3.47s/it]"
1396 | ]
1397 | },
1398 | {
1399 | "name": "stdout",
1400 | "output_type": "stream",
1401 | "text": [
1402 | "\n",
1403 | "\n",
1404 | "input: n nn Pvlvn's drwng rm. Prr rmvd hs ft frm th sf.\n",
1405 | "guess: in nin Palovna's drawing rom. Piere remod his feet frem th sof.\n",
1406 | "truth: in Anna Pavlovna's drawing room. Pierre removed his feet from the sofa.\n",
1407 | "\n",
1408 | "input: srvnts. Th ftmn wh hd gn t nnnc thm ws stppd by\n",
1409 | "guess: sents. The fotemen who had gon to anone them was steped by\n",
1410 | "truth: servants. The footman who had gone to announce them was stopped by\n",
1411 | "\n"
1412 | ]
1413 | },
1414 | {
1415 | "name": "stderr",
1416 | "output_type": "stream",
1417 | "text": [
1418 | "0.21: 85%|█████████████████████████████▋ | 601/709 [09:29<06:37, 3.68s/it]"
1419 | ]
1420 | },
1421 | {
1422 | "name": "stdout",
1423 | "output_type": "stream",
1424 | "text": [
1425 | "\n",
1426 | "\n",
1427 | "input: Sddnly th snd f frng f cnnn ws hrd frm th mbnkmnt,\n",
1428 | "guess: Sedenly the sone if a fearin of can was heared frem the mankement,\n",
1429 | "truth: Suddenly the sound of a firing of cannon was heard from the embankment,\n",
1430 | "\n",
1431 | "input: th gtwy f th crtyrd. n th dns wvrng smk sm f th\n",
1432 | "guess: the gety of the courity. in the denes wering sme som of the\n",
1433 | "truth: the gateway of the courtyard. In the dense wavering smoke some of the\n",
1434 | "\n"
1435 | ]
1436 | },
1437 | {
1438 | "name": "stderr",
1439 | "output_type": "stream",
1440 | "text": [
1441 | "0.22: 88%|██████████████████████████████▋ | 621/709 [09:55<05:13, 3.57s/it]"
1442 | ]
1443 | },
1444 | {
1445 | "name": "stdout",
1446 | "output_type": "stream",
1447 | "text": [
1448 | "\n",
1449 | "\n",
1450 | "input: Wllrsk ws slnt thrght th drv. T Prr's nqrs s t\n",
1451 | "guess: Wlerski was silet thrig th drove. To Piere's inirs is to\n",
1452 | "truth: Willarski was silent throughout the drive. To Pierre's inquiries as to\n",
1453 | "\n",
1454 | "input: lstnd t th dr, nd t smd t hr tht hs mttrngs wr\n",
1455 | "guess: lised to the dear, need it seeme to her that his mutterings were\n",
1456 | "truth: listened at the door, and it seemed to her that his mutterings were\n",
1457 | "\n"
1458 | ]
1459 | },
1460 | {
1461 | "name": "stderr",
1462 | "output_type": "stream",
1463 | "text": [
1464 | "0.17: 90%|███████████████████████████████▋ | 641/709 [10:16<02:00, 1.77s/it]"
1465 | ]
1466 | },
1467 | {
1468 | "name": "stdout",
1469 | "output_type": "stream",
1470 | "text": [
1471 | "\n",
1472 | "\n",
1473 | "input: nd f th ml. H wnkd t th btlr, whsprd drctns t th\n",
1474 | "guess: ane if th mile. He wined to th boler, whisered deritenes to the\n",
1475 | "truth: end of the meal. He winked at the butler, whispered directions to the\n",
1476 | "\n",
1477 | "input: nstrctns t hnd t t th Gvrnr nd t cm bck s qckly s\n",
1478 | "guess: intericatens to hane it to th Governer need to come bo so quicly so\n",
1479 | "truth: instructions to hand it to the Governor and to come back as quickly as\n",
1480 | "\n"
1481 | ]
1482 | },
1483 | {
1484 | "name": "stderr",
1485 | "output_type": "stream",
1486 | "text": [
1487 | "0.18: 93%|████████████████████████████████▋ | 661/709 [10:42<02:54, 3.63s/it]"
1488 | ]
1489 | },
1490 | {
1491 | "name": "stdout",
1492 | "output_type": "stream",
1493 | "text": [
1494 | "\n",
1495 | "\n",
1496 | "input: tlk wth Brs. H nd nt cm s ftn....\"\n",
1497 | "guess: tak with Boris. He need not came so foten..\"\n",
1498 | "truth: talk with Boris. He need not come so often....\"\n",
1499 | "\n",
1500 | "input: prncss' rdr.\n",
1501 | "guess: pines' orie.\n",
1502 | "truth: princess' order.\n",
1503 | "\n"
1504 | ]
1505 | },
1506 | {
1507 | "name": "stderr",
1508 | "output_type": "stream",
1509 | "text": [
1510 | "0.20: 96%|█████████████████████████████████▌ | 681/709 [11:09<01:42, 3.67s/it]"
1511 | ]
1512 | },
1513 | {
1514 | "name": "stdout",
1515 | "output_type": "stream",
1516 | "text": [
1517 | "\n",
1518 | "\n",
1519 | "input: Cptn vn Tll chncd t rd t th sm spt, nd sng th\n",
1520 | "guess: Capaten van Tol chaned to rod to the sam spet, aned seeing the\n",
1521 | "truth: Captain von Toll chanced to ride to the same spot, and seeing the\n",
1522 | "\n",
1523 | "input: ftr rrngng hs clths, h tk th pstl nd ws bt t g t.\n",
1524 | "guess: Afer aranging his coles, he took th pistal aned was but to go it.\n",
1525 | "truth: After arranging his clothes, he took the pistol and was about to go out.\n",
1526 | "\n"
1527 | ]
1528 | },
1529 | {
1530 | "name": "stderr",
1531 | "output_type": "stream",
1532 | "text": [
1533 | "0.17: 99%|██████████████████████████████████▌| 701/709 [11:36<00:30, 3.76s/it]"
1534 | ]
1535 | },
1536 | {
1537 | "name": "stdout",
1538 | "output_type": "stream",
1539 | "text": [
1540 | "\n",
1541 | "\n",
1542 | "input: rtrtng strght bck frm th nvdrs, dvtd frm tht drct\n",
1543 | "guess: retreting striht bak frem the ineders, devoted frem that derit\n",
1544 | "truth: retreating straight back from the invaders, deviated from that direct\n",
1545 | "\n",
1546 | "input: sldm xprncd, cm pn hr. Sh frd t lk rnd, t smd\n",
1547 | "guess: selem expenced, came upen her. She fared to lok roned, it seemed\n",
1548 | "truth: seldom experienced, came upon her. She feared to look round, it seemed\n",
1549 | "\n"
1550 | ]
1551 | },
1552 | {
1553 | "name": "stderr",
1554 | "output_type": "stream",
1555 | "text": [
1556 | "0.17: 100%|███████████████████████████████████| 709/709 [11:43<00:00, 1.01it/s]\n",
1557 | "0.17: 0%| | 0/81 [00:00, ?it/s]"
1558 | ]
1559 | },
1560 | {
1561 | "name": "stdout",
1562 | "output_type": "stream",
1563 | "text": [
1564 | "\n",
1565 | "\n",
1566 | "input: thr dghtrs bsds sn fr whm sh hd lngd nd whm sh ws\n"
1567 | ]
1568 | },
1569 | {
1570 | "name": "stderr",
1571 | "output_type": "stream",
1572 | "text": [
1573 | "0.17: 1%|▍ | 1/81 [00:09<12:27, 9.34s/it]"
1574 | ]
1575 | },
1576 | {
1577 | "name": "stdout",
1578 | "output_type": "stream",
1579 | "text": [
1580 | "guess: oter daughters besid a son fer whom she had longed aned whom she was\n",
1581 | "truth: three daughters besides a son for whom she had longed and whom she was\n",
1582 | "\n"
1583 | ]
1584 | },
1585 | {
1586 | "name": "stderr",
1587 | "output_type": "stream",
1588 | "text": [
1589 | "0.20: 1%|▍ | 1/81 [00:09<12:27, 9.34s/it]"
1590 | ]
1591 | },
1592 | {
1593 | "name": "stdout",
1594 | "output_type": "stream",
1595 | "text": [
1596 | "\n",
1597 | "\n",
1598 | "input: ccptns s hndrncs t lf, nd cnsdrd tht thy wr ll\n"
1599 | ]
1600 | },
1601 | {
1602 | "name": "stderr",
1603 | "output_type": "stream",
1604 | "text": [
1605 | "0.20: 2%|▉ | 2/81 [00:18<12:07, 9.21s/it]"
1606 | ]
1607 | },
1608 | {
1609 | "name": "stdout",
1610 | "output_type": "stream",
1611 | "text": [
1612 | "guess: acupations so hinces to life, aned condered t the were al\n",
1613 | "truth: occupations as hindrances to life, and considered that they were all\n",
1614 | "\n"
1615 | ]
1616 | },
1617 | {
1618 | "name": "stderr",
1619 | "output_type": "stream",
1620 | "text": [
1621 | "0.20: 2%|▉ | 2/81 [00:18<12:07, 9.21s/it]"
1622 | ]
1623 | },
1624 | {
1625 | "name": "stdout",
1626 | "output_type": "stream",
1627 | "text": [
1628 | "\n",
1629 | "\n",
1630 | "input: mt by rcgnzng dvnty whch sbjctd th ntns t th wll f\n"
1631 | ]
1632 | },
1633 | {
1634 | "name": "stderr",
1635 | "output_type": "stream",
1636 | "text": [
1637 | "0.20: 4%|█▍ | 3/81 [00:25<10:56, 8.41s/it]"
1638 | ]
1639 | },
1640 | {
1641 | "name": "stdout",
1642 | "output_type": "stream",
1643 | "text": [
1644 | "guess: met by reconizing a dvinty whic suje th natene to the will if\n",
1645 | "truth: met by recognizing a divinity which subjected the nations to the will of\n",
1646 | "\n"
1647 | ]
1648 | },
1649 | {
1650 | "name": "stderr",
1651 | "output_type": "stream",
1652 | "text": [
1653 | "0.20: 4%|█▍ | 3/81 [00:26<10:56, 8.41s/it]"
1654 | ]
1655 | },
1656 | {
1657 | "name": "stdout",
1658 | "output_type": "stream",
1659 | "text": [
1660 | "\n",
1661 | "\n",
1662 | "input: ndls, nd whch, whn thy wr rdy, sh lwys trmphntly drw,\n"
1663 | ]
1664 | },
1665 | {
1666 | "name": "stderr",
1667 | "output_type": "stream",
1668 | "text": [
1669 | "0.20: 5%|█▉ | 4/81 [00:29<08:19, 6.48s/it]"
1670 | ]
1671 | },
1672 | {
1673 | "name": "stdout",
1674 | "output_type": "stream",
1675 | "text": [
1676 | "guess: aneds, aned whic, when the were read, she alys trupently drow,\n",
1677 | "truth: needles, and which, when they were ready, she always triumphantly drew,\n",
1678 | "\n"
1679 | ]
1680 | },
1681 | {
1682 | "name": "stderr",
1683 | "output_type": "stream",
1684 | "text": [
1685 | "0.25: 5%|█▉ | 4/81 [00:29<08:19, 6.48s/it]"
1686 | ]
1687 | },
1688 | {
1689 | "name": "stdout",
1690 | "output_type": "stream",
1691 | "text": [
1692 | "\n",
1693 | "\n",
1694 | "input: nd Prss. nd thr t h klld grt mny. n Rss thr ws\n"
1695 | ]
1696 | },
1697 | {
1698 | "name": "stderr",
1699 | "output_type": "stream",
1700 | "text": [
1701 | "0.25: 6%|██▎ | 5/81 [00:35<07:50, 6.19s/it]"
1702 | ]
1703 | },
1704 | {
1705 | "name": "stdout",
1706 | "output_type": "stream",
1707 | "text": [
1708 | "guess: ane Pris. aned there to he kled a gerit man. in Rusia there was\n",
1709 | "truth: and Prussia. And there too he killed a great many. In Russia there was\n",
1710 | "\n"
1711 | ]
1712 | },
1713 | {
1714 | "name": "stderr",
1715 | "output_type": "stream",
1716 | "text": [
1717 | "0.17: 6%|██▎ | 5/81 [00:35<07:50, 6.19s/it]"
1718 | ]
1719 | },
1720 | {
1721 | "name": "stdout",
1722 | "output_type": "stream",
1723 | "text": [
1724 | "\n",
1725 | "\n",
1726 | "input: wlfr nd cvlztn f hmnty n gnrl, by whch s slly\n"
1727 | ]
1728 | },
1729 | {
1730 | "name": "stderr",
1731 | "output_type": "stream",
1732 | "text": [
1733 | "0.17: 7%|██▊ | 6/81 [00:44<09:17, 7.43s/it]"
1734 | ]
1735 | },
1736 | {
1737 | "name": "stdout",
1738 | "output_type": "stream",
1739 | "text": [
1740 | "guess: weler aned covation of huminty in genera, by whic is soly\n",
1741 | "truth: welfare and civilization of humanity in general, by which is usually\n",
1742 | "\n"
1743 | ]
1744 | },
1745 | {
1746 | "name": "stderr",
1747 | "output_type": "stream",
1748 | "text": [
1749 | "0.18: 7%|██▊ | 6/81 [00:45<09:17, 7.43s/it]"
1750 | ]
1751 | },
1752 | {
1753 | "name": "stdout",
1754 | "output_type": "stream",
1755 | "text": [
1756 | "\n",
1757 | "\n",
1758 | "input: hs vc sd crssly:\n"
1759 | ]
1760 | },
1761 | {
1762 | "name": "stderr",
1763 | "output_type": "stream",
1764 | "text": [
1765 | "0.18: 9%|███▎ | 7/81 [00:53<09:45, 7.91s/it]"
1766 | ]
1767 | },
1768 | {
1769 | "name": "stdout",
1770 | "output_type": "stream",
1771 | "text": [
1772 | "guess: his voce sa corsly:\n",
1773 | "truth: his voice said crossly:\n",
1774 | "\n"
1775 | ]
1776 | },
1777 | {
1778 | "name": "stderr",
1779 | "output_type": "stream",
1780 | "text": [
1781 | "0.19: 9%|███▎ | 7/81 [00:54<09:45, 7.91s/it]"
1782 | ]
1783 | },
1784 | {
1785 | "name": "stdout",
1786 | "output_type": "stream",
1787 | "text": [
1788 | "\n",
1789 | "\n",
1790 | "input: wh hd sccdd t th pst whn Drn dd nd wh ws ccsd f\n"
1791 | ]
1792 | },
1793 | {
1794 | "name": "stderr",
1795 | "output_type": "stream",
1796 | "text": [
1797 | "0.19: 10%|███▊ | 8/81 [01:02<09:48, 8.05s/it]"
1798 | ]
1799 | },
1800 | {
1801 | "name": "stdout",
1802 | "output_type": "stream",
1803 | "text": [
1804 | "guess: who had suced to the pas when Dron did aned who was acesed of\n",
1805 | "truth: who had succeeded to the post when Dron died and who was accused of\n",
1806 | "\n"
1807 | ]
1808 | },
1809 | {
1810 | "name": "stderr",
1811 | "output_type": "stream",
1812 | "text": [
1813 | "0.17: 10%|███▊ | 8/81 [01:02<09:48, 8.05s/it]"
1814 | ]
1815 | },
1816 | {
1817 | "name": "stdout",
1818 | "output_type": "stream",
1819 | "text": [
1820 | "\n",
1821 | "\n",
1822 | "input: h hd bn lk tht hmslf bt shrt tm bfr.\n"
1823 | ]
1824 | },
1825 | {
1826 | "name": "stderr",
1827 | "output_type": "stream",
1828 | "text": [
1829 | "0.17: 11%|████▏ | 9/81 [01:10<09:50, 8.20s/it]"
1830 | ]
1831 | },
1832 | {
1833 | "name": "stdout",
1834 | "output_type": "stream",
1835 | "text": [
1836 | "guess: he had been lok t himef but a shire tem befer.\n",
1837 | "truth: he had been like that himself but a short time before.\n",
1838 | "\n"
1839 | ]
1840 | },
1841 | {
1842 | "name": "stderr",
1843 | "output_type": "stream",
1844 | "text": [
1845 | "0.21: 11%|████▏ | 9/81 [01:11<09:50, 8.20s/it]"
1846 | ]
1847 | },
1848 | {
1849 | "name": "stdout",
1850 | "output_type": "stream",
1851 | "text": [
1852 | "\n",
1853 | "\n",
1854 | "input: prpr hm t tk n hmslf th whl rspnsblty fr wht s\n"
1855 | ]
1856 | },
1857 | {
1858 | "name": "stderr",
1859 | "output_type": "stream",
1860 | "text": [
1861 | "0.21: 12%|████▌ | 10/81 [01:19<09:50, 8.32s/it]"
1862 | ]
1863 | },
1864 | {
1865 | "name": "stdout",
1866 | "output_type": "stream",
1867 | "text": [
1868 | "guess: preper him to took in himef the whel respensibility fer what is\n",
1869 | "truth: prepare him to take on himself the whole responsibility for what is\n",
1870 | "\n"
1871 | ]
1872 | },
1873 | {
1874 | "name": "stderr",
1875 | "output_type": "stream",
1876 | "text": [
1877 | "0.28: 12%|████▌ | 10/81 [01:19<09:50, 8.32s/it]"
1878 | ]
1879 | },
1880 | {
1881 | "name": "stdout",
1882 | "output_type": "stream",
1883 | "text": [
1884 | "\n",
1885 | "\n",
1886 | "input: thghts h xprssd n cnvrstn sppsd hs wshs t b. nd sh\n"
1887 | ]
1888 | },
1889 | {
1890 | "name": "stderr",
1891 | "output_type": "stream",
1892 | "text": [
1893 | "0.28: 14%|█████ | 11/81 [01:28<09:54, 8.49s/it]"
1894 | ]
1895 | },
1896 | {
1897 | "name": "stdout",
1898 | "output_type": "stream",
1899 | "text": [
1900 | "guess: thoughts he exressed in conerasten supsed his wishes to be. aned she\n",
1901 | "truth: thoughts he expressed in conversation supposed his wishes to be. And she\n",
1902 | "\n"
1903 | ]
1904 | },
1905 | {
1906 | "name": "stderr",
1907 | "output_type": "stream",
1908 | "text": [
1909 | "0.26: 14%|█████ | 11/81 [01:28<09:54, 8.49s/it]"
1910 | ]
1911 | },
1912 | {
1913 | "name": "stdout",
1914 | "output_type": "stream",
1915 | "text": [
1916 | "\n",
1917 | "\n",
1918 | "input: Sddnly h smd t rmmbr; scrcly prcptbl sml flshd\n"
1919 | ]
1920 | },
1921 | {
1922 | "name": "stderr",
1923 | "output_type": "stream",
1924 | "text": [
1925 | "0.26: 15%|█████▍ | 12/81 [01:37<09:52, 8.58s/it]"
1926 | ]
1927 | },
1928 | {
1929 | "name": "stdout",
1930 | "output_type": "stream",
1931 | "text": [
1932 | "guess: Sdenly he seemed to remeber; a scery periple sme fashed\n",
1933 | "truth: Suddenly he seemed to remember; a scarcely perceptible smile flashed\n",
1934 | "\n"
1935 | ]
1936 | },
1937 | {
1938 | "name": "stderr",
1939 | "output_type": "stream",
1940 | "text": [
1941 | "0.23: 15%|█████▍ | 12/81 [01:37<09:52, 8.58s/it]"
1942 | ]
1943 | },
1944 | {
1945 | "name": "stdout",
1946 | "output_type": "stream",
1947 | "text": [
1948 | "\n",
1949 | "\n",
1950 | "input: shws s tht nthr Ls X nr Mttrnch, wh rld vr\n"
1951 | ]
1952 | },
1953 | {
1954 | "name": "stderr",
1955 | "output_type": "stream",
1956 | "text": [
1957 | "0.23: 16%|█████▉ | 13/81 [01:45<09:49, 8.67s/it]"
1958 | ]
1959 | },
1960 | {
1961 | "name": "stdout",
1962 | "output_type": "stream",
1963 | "text": [
1964 | "guess: shows so that another I Lise I ner a Materich, who reled over\n",
1965 | "truth: shows us that neither a Louis XI nor a Metternich, who ruled over\n",
1966 | "\n"
1967 | ]
1968 | },
1969 | {
1970 | "name": "stderr",
1971 | "output_type": "stream",
1972 | "text": [
1973 | "0.21: 16%|█████▉ | 13/81 [01:46<09:49, 8.67s/it]"
1974 | ]
1975 | },
1976 | {
1977 | "name": "stdout",
1978 | "output_type": "stream",
1979 | "text": [
1980 | "\n",
1981 | "\n",
1982 | "input: hm nd drw frm hm n msd nd gntl sml.\n"
1983 | ]
1984 | },
1985 | {
1986 | "name": "stderr",
1987 | "output_type": "stream",
1988 | "text": [
1989 | "0.21: 17%|██████▍ | 14/81 [01:54<09:43, 8.72s/it]"
1990 | ]
1991 | },
1992 | {
1993 | "name": "stdout",
1994 | "output_type": "stream",
1995 | "text": [
1996 | "guess: him aned drow frem him in amude aned gentle sme.\n",
1997 | "truth: him and drew from him an amused and gentle smile.\n",
1998 | "\n"
1999 | ]
2000 | },
2001 | {
2002 | "name": "stderr",
2003 | "output_type": "stream",
2004 | "text": [
2005 | "0.27: 17%|██████▍ | 14/81 [01:54<09:43, 8.72s/it]"
2006 | ]
2007 | },
2008 | {
2009 | "name": "stdout",
2010 | "output_type": "stream",
2011 | "text": [
2012 | "\n",
2013 | "\n",
2014 | "input: t lng-stndng mprssn rltd t hs nnrmst flngs, hd ts\n"
2015 | ]
2016 | },
2017 | {
2018 | "name": "stderr",
2019 | "output_type": "stream",
2020 | "text": [
2021 | "0.27: 19%|██████▊ | 15/81 [02:01<08:48, 8.00s/it]"
2022 | ]
2023 | },
2024 | {
2025 | "name": "stdout",
2026 | "output_type": "stream",
2027 | "text": [
2028 | "guess: It a long-steng imersion related to his nanemest feelings, had ites\n",
2029 | "truth: to a long-standing impression related to his innermost feelings, had its\n",
2030 | "\n"
2031 | ]
2032 | },
2033 | {
2034 | "name": "stderr",
2035 | "output_type": "stream",
2036 | "text": [
2037 | "0.20: 19%|██████▊ | 15/81 [02:01<08:48, 8.00s/it]"
2038 | ]
2039 | },
2040 | {
2041 | "name": "stdout",
2042 | "output_type": "stream",
2043 | "text": [
2044 | "\n",
2045 | "\n",
2046 | "input: ctvty s cnsdrd by th hstrns f cltr t b th cs r\n"
2047 | ]
2048 | },
2049 | {
2050 | "name": "stderr",
2051 | "output_type": "stream",
2052 | "text": [
2053 | "0.20: 20%|███████▎ | 16/81 [02:04<07:14, 6.68s/it]"
2054 | ]
2055 | },
2056 | {
2057 | "name": "stdout",
2058 | "output_type": "stream",
2059 | "text": [
2060 | "guess: acovity so condered by th hiserans of cule to be the case our\n",
2061 | "truth: activity is considered by the historians of culture to be the cause or\n",
2062 | "\n"
2063 | ]
2064 | },
2065 | {
2066 | "name": "stderr",
2067 | "output_type": "stream",
2068 | "text": [
2069 | "0.19: 20%|███████▎ | 16/81 [02:04<07:14, 6.68s/it]"
2070 | ]
2071 | },
2072 | {
2073 | "name": "stdout",
2074 | "output_type": "stream",
2075 | "text": [
2076 | "\n",
2077 | "\n",
2078 | "input: rtnd; bt wtht rftng thm t wld sm mpssbl t cntn\n"
2079 | ]
2080 | },
2081 | {
2082 | "name": "stderr",
2083 | "output_type": "stream",
2084 | "text": [
2085 | "0.19: 21%|███████▊ | 17/81 [02:12<07:21, 6.89s/it]"
2086 | ]
2087 | },
2088 | {
2089 | "name": "stdout",
2090 | "output_type": "stream",
2091 | "text": [
2092 | "guess: retened but wit refaing them it woled some imopsible to conten\n",
2093 | "truth: retained; but without refuting them it would seem impossible to continue\n",
2094 | "\n"
2095 | ]
2096 | },
2097 | {
2098 | "name": "stderr",
2099 | "output_type": "stream",
2100 | "text": [
2101 | "0.18: 21%|███████▊ | 17/81 [02:12<07:21, 6.89s/it]"
2102 | ]
2103 | },
2104 | {
2105 | "name": "stdout",
2106 | "output_type": "stream",
2107 | "text": [
2108 | "\n",
2109 | "\n",
2110 | "input: hm by Gd. Bt s sn s w d nt dmt tht, t bcms ssntl t\n"
2111 | ]
2112 | },
2113 | {
2114 | "name": "stderr",
2115 | "output_type": "stream",
2116 | "text": [
2117 | "0.18: 22%|████████▏ | 18/81 [02:20<07:51, 7.48s/it]"
2118 | ]
2119 | },
2120 | {
2121 | "name": "stdout",
2122 | "output_type": "stream",
2123 | "text": [
2124 | "guess: him by God. But so son so we do not adim that, to become sentel to\n",
2125 | "truth: him by God. But as soon as we do not admit that, it becomes essential to\n",
2126 | "\n"
2127 | ]
2128 | },
2129 | {
2130 | "name": "stderr",
2131 | "output_type": "stream",
2132 | "text": [
2133 | "0.20: 22%|████████▏ | 18/81 [02:21<07:51, 7.48s/it]"
2134 | ]
2135 | },
2136 | {
2137 | "name": "stdout",
2138 | "output_type": "stream",
2139 | "text": [
2140 | "\n",
2141 | "\n",
2142 | "input: tht sh nd my b mn nd wf,\" h tld hmslf.\n"
2143 | ]
2144 | },
2145 | {
2146 | "name": "stderr",
2147 | "output_type": "stream",
2148 | "text": [
2149 | "0.20: 23%|████████▋ | 19/81 [02:29<08:04, 7.82s/it]"
2150 | ]
2151 | },
2152 | {
2153 | "name": "stdout",
2154 | "output_type": "stream",
2155 | "text": [
2156 | "guess: t she need I may be man aned wife,\" he tole himef.\n",
2157 | "truth: that she and I may be man and wife,\" he told himself.\n",
2158 | "\n"
2159 | ]
2160 | },
2161 | {
2162 | "name": "stderr",
2163 | "output_type": "stream",
2164 | "text": [
2165 | "0.22: 23%|████████▋ | 19/81 [02:29<08:04, 7.82s/it]"
2166 | ]
2167 | },
2168 | {
2169 | "name": "stdout",
2170 | "output_type": "stream",
2171 | "text": [
2172 | "\n",
2173 | "\n",
2174 | "input: \"Wll, shll lv. h, hw splndd!\"\n"
2175 | ]
2176 | },
2177 | {
2178 | "name": "stderr",
2179 | "output_type": "stream",
2180 | "text": [
2181 | "0.22: 25%|█████████▏ | 20/81 [02:37<08:06, 7.98s/it]"
2182 | ]
2183 | },
2184 | {
2185 | "name": "stdout",
2186 | "output_type": "stream",
2187 | "text": [
2188 | "guess: \"Well, I shel love. Ah, how selend!\"\n",
2189 | "truth: \"Well, I shall live. Ah, how splendid!\"\n",
2190 | "\n"
2191 | ]
2192 | },
2193 | {
2194 | "name": "stderr",
2195 | "output_type": "stream",
2196 | "text": [
2197 | "0.20: 25%|█████████▏ | 20/81 [02:38<08:06, 7.98s/it]"
2198 | ]
2199 | },
2200 | {
2201 | "name": "stdout",
2202 | "output_type": "stream",
2203 | "text": [
2204 | "\n",
2205 | "\n",
2206 | "input: dscrb th hrrrs nd sffrngs h hd wtnssd h ws\n"
2207 | ]
2208 | },
2209 | {
2210 | "name": "stderr",
2211 | "output_type": "stream",
2212 | "text": [
2213 | "0.20: 26%|█████████▌ | 21/81 [02:47<08:21, 8.35s/it]"
2214 | ]
2215 | },
2216 | {
2217 | "name": "stdout",
2218 | "output_type": "stream",
2219 | "text": [
2220 | "guess: drib th horors aned suferings he had witensed he was\n",
2221 | "truth: describe the horrors and sufferings he had witnessed he was\n",
2222 | "\n"
2223 | ]
2224 | },
2225 | {
2226 | "name": "stderr",
2227 | "output_type": "stream",
2228 | "text": [
2229 | "0.22: 26%|█████████▌ | 21/81 [02:47<08:21, 8.35s/it]"
2230 | ]
2231 | },
2232 | {
2233 | "name": "stdout",
2234 | "output_type": "stream",
2235 | "text": [
2236 | "\n",
2237 | "\n",
2238 | "input: th rsltnt f vrs frcs, nd thn hs pwr s tslf frc\n"
2239 | ]
2240 | },
2241 | {
2242 | "name": "stderr",
2243 | "output_type": "stream",
2244 | "text": [
2245 | "0.22: 27%|██████████ | 22/81 [02:55<08:10, 8.31s/it]"
2246 | ]
2247 | },
2248 | {
2249 | "name": "stdout",
2250 | "output_type": "stream",
2251 | "text": [
2252 | "guess: th resulent of various forices, aned than his per so itel a fori\n",
2253 | "truth: the resultant of various forces, and then his power is itself a force\n",
2254 | "\n"
2255 | ]
2256 | },
2257 | {
2258 | "name": "stderr",
2259 | "output_type": "stream",
2260 | "text": [
2261 | "0.19: 27%|██████████ | 22/81 [02:55<08:10, 8.31s/it]"
2262 | ]
2263 | },
2264 | {
2265 | "name": "stdout",
2266 | "output_type": "stream",
2267 | "text": [
2268 | "\n",
2269 | "\n",
2270 | "input: Ntsh, wh ws sttng ppst t hm wth hr ldst dghtr n hr\n"
2271 | ]
2272 | },
2273 | {
2274 | "name": "stderr",
2275 | "output_type": "stream",
2276 | "text": [
2277 | "0.19: 28%|██████████▌ | 23/81 [03:05<08:27, 8.74s/it]"
2278 | ]
2279 | },
2280 | {
2281 | "name": "stdout",
2282 | "output_type": "stream",
2283 | "text": [
2284 | "guess: Natasha, who was siting oposite to him with her lodes daughter in her\n",
2285 | "truth: Natasha, who was sitting opposite to him with her eldest daughter on her\n",
2286 | "\n"
2287 | ]
2288 | },
2289 | {
2290 | "name": "stderr",
2291 | "output_type": "stream",
2292 | "text": [
2293 | "0.20: 28%|██████████▌ | 23/81 [03:05<08:27, 8.74s/it]"
2294 | ]
2295 | },
2296 | {
2297 | "name": "stdout",
2298 | "output_type": "stream",
2299 | "text": [
2300 | "\n",
2301 | "\n",
2302 | "input: lttr:\n"
2303 | ]
2304 | },
2305 | {
2306 | "name": "stderr",
2307 | "output_type": "stream",
2308 | "text": [
2309 | "0.20: 30%|██████████▉ | 24/81 [03:14<08:29, 8.93s/it]"
2310 | ]
2311 | },
2312 | {
2313 | "name": "stdout",
2314 | "output_type": "stream",
2315 | "text": [
2316 | "guess: later:\n",
2317 | "truth: letter:\n",
2318 | "\n"
2319 | ]
2320 | },
2321 | {
2322 | "name": "stderr",
2323 | "output_type": "stream",
2324 | "text": [
2325 | "0.21: 30%|██████████▉ | 24/81 [03:14<08:29, 8.93s/it]"
2326 | ]
2327 | },
2328 | {
2329 | "name": "stdout",
2330 | "output_type": "stream",
2331 | "text": [
2332 | "\n",
2333 | "\n",
2334 | "input: n t f th thr, n th chldrn's prsnc.\n"
2335 | ]
2336 | },
2337 | {
2338 | "name": "stderr",
2339 | "output_type": "stream",
2340 | "text": [
2341 | "0.21: 31%|███████████▍ | 25/81 [03:23<08:18, 8.90s/it]"
2342 | ]
2343 | },
2344 | {
2345 | "name": "stdout",
2346 | "output_type": "stream",
2347 | "text": [
2348 | "guess: on it if th other, in the chilerin's perince.\n",
2349 | "truth: one out of the other, in the children's presence.\n",
2350 | "\n"
2351 | ]
2352 | },
2353 | {
2354 | "name": "stderr",
2355 | "output_type": "stream",
2356 | "text": [
2357 | "0.18: 31%|███████████▍ | 25/81 [03:23<08:18, 8.90s/it]"
2358 | ]
2359 | },
2360 | {
2361 | "name": "stdout",
2362 | "output_type": "stream",
2363 | "text": [
2364 | "\n",
2365 | "\n",
2366 | "input: \"Y lwys hv sch strng fncs! ddn't vn thnk f bng\n"
2367 | ]
2368 | },
2369 | {
2370 | "name": "stderr",
2371 | "output_type": "stream",
2372 | "text": [
2373 | "0.18: 32%|███████████▉ | 26/81 [03:32<08:15, 9.01s/it]"
2374 | ]
2375 | },
2376 | {
2377 | "name": "stdout",
2378 | "output_type": "stream",
2379 | "text": [
2380 | "guess: \"You alys have suh stering fains! I dn' ven thin if being\n",
2381 | "truth: \"You always have such strange fancies! I didn't even think of being\n",
2382 | "\n"
2383 | ]
2384 | },
2385 | {
2386 | "name": "stderr",
2387 | "output_type": "stream",
2388 | "text": [
2389 | "0.24: 32%|███████████▉ | 26/81 [03:32<08:15, 9.01s/it]"
2390 | ]
2391 | },
2392 | {
2393 | "name": "stdout",
2394 | "output_type": "stream",
2395 | "text": [
2396 | "\n",
2397 | "\n",
2398 | "input: slsh t nd drwn n nthr. Thy klld th kng nd mny thr\n"
2399 | ]
2400 | },
2401 | {
2402 | "name": "stderr",
2403 | "output_type": "stream",
2404 | "text": [
2405 | "0.24: 33%|████████████▎ | 27/81 [03:36<06:51, 7.63s/it]"
2406 | ]
2407 | },
2408 | {
2409 | "name": "stdout",
2410 | "output_type": "stream",
2411 | "text": [
2412 | "guess: sel it aned drown no nother. They kled th akin aned man ther\n",
2413 | "truth: slash at and drown one another. They killed the king and many other\n",
2414 | "\n"
2415 | ]
2416 | },
2417 | {
2418 | "name": "stderr",
2419 | "output_type": "stream",
2420 | "text": [
2421 | "0.23: 33%|████████████▎ | 27/81 [03:37<06:51, 7.63s/it]"
2422 | ]
2423 | },
2424 | {
2425 | "name": "stdout",
2426 | "output_type": "stream",
2427 | "text": [
2428 | "\n",
2429 | "\n",
2430 | "input: nvtblty tslf, tht s, mr frm wtht cntnt.\n"
2431 | ]
2432 | },
2433 | {
2434 | "name": "stderr",
2435 | "output_type": "stream",
2436 | "text": [
2437 | "0.23: 35%|████████████▊ | 28/81 [03:41<05:57, 6.75s/it]"
2438 | ]
2439 | },
2440 | {
2441 | "name": "stdout",
2442 | "output_type": "stream",
2443 | "text": [
2444 | "guess: inevitility itel, that so, I mere frem wit content.\n",
2445 | "truth: inevitability itself, that is, a mere form without content.\n",
2446 | "\n"
2447 | ]
2448 | },
2449 | {
2450 | "name": "stderr",
2451 | "output_type": "stream",
2452 | "text": [
2453 | "0.20: 35%|████████████▊ | 28/81 [03:41<05:57, 6.75s/it]"
2454 | ]
2455 | },
2456 | {
2457 | "name": "stdout",
2458 | "output_type": "stream",
2459 | "text": [
2460 | "\n",
2461 | "\n",
2462 | "input: dscvry f th lw f Cprncs th Ptlmc wrlds wr stll\n"
2463 | ]
2464 | },
2465 | {
2466 | "name": "stderr",
2467 | "output_type": "stream",
2468 | "text": [
2469 | "0.20: 36%|█████████████▏ | 29/81 [03:50<06:28, 7.47s/it]"
2470 | ]
2471 | },
2472 | {
2473 | "name": "stdout",
2474 | "output_type": "stream",
2475 | "text": [
2476 | "guess: dry if the low if Cepeances th Ptamac worids were stle\n",
2477 | "truth: discovery of the law of Copernicus the Ptolemaic worlds were still\n",
2478 | "\n"
2479 | ]
2480 | },
2481 | {
2482 | "name": "stderr",
2483 | "output_type": "stream",
2484 | "text": [
2485 | "0.20: 36%|█████████████▏ | 29/81 [03:51<06:28, 7.47s/it]"
2486 | ]
2487 | },
2488 | {
2489 | "name": "stdout",
2490 | "output_type": "stream",
2491 | "text": [
2492 | "\n",
2493 | "\n",
2494 | "input: hm. Sh wtd n th ld cntss, pttd nd spld th chldrn,\n"
2495 | ]
2496 | },
2497 | {
2498 | "name": "stderr",
2499 | "output_type": "stream",
2500 | "text": [
2501 | "0.20: 37%|█████████████▋ | 30/81 [03:59<06:40, 7.86s/it]"
2502 | ]
2503 | },
2504 | {
2505 | "name": "stdout",
2506 | "output_type": "stream",
2507 | "text": [
2508 | "guess: him. She waid in the oled counes, pated aned sped the chilerin,\n",
2509 | "truth: home. She waited on the old countess, petted and spoiled the children,\n",
2510 | "\n"
2511 | ]
2512 | },
2513 | {
2514 | "name": "stderr",
2515 | "output_type": "stream",
2516 | "text": [
2517 | "0.24: 37%|█████████████▋ | 30/81 [03:59<06:40, 7.86s/it]"
2518 | ]
2519 | },
2520 | {
2521 | "name": "stdout",
2522 | "output_type": "stream",
2523 | "text": [
2524 | "\n",
2525 | "\n",
2526 | "input: \" Lrd, Lrd! Hw strry t s! Trmnds! Tht mns hrd\n"
2527 | ]
2528 | },
2529 | {
2530 | "name": "stderr",
2531 | "output_type": "stream",
2532 | "text": [
2533 | "0.24: 38%|██████████████▏ | 31/81 [04:09<07:05, 8.51s/it]"
2534 | ]
2535 | },
2536 | {
2537 | "name": "stdout",
2538 | "output_type": "stream",
2539 | "text": [
2540 | "guess: \"I Lored, I Lored How stery it see! Terineds! That mene I heared\n",
2541 | "truth: \"O Lord, O Lord! How starry it is! Tremendous! That means a hard\n",
2542 | "\n"
2543 | ]
2544 | },
2545 | {
2546 | "name": "stderr",
2547 | "output_type": "stream",
2548 | "text": [
2549 | "0.19: 38%|██████████████▏ | 31/81 [04:09<07:05, 8.51s/it]"
2550 | ]
2551 | },
2552 | {
2553 | "name": "stdout",
2554 | "output_type": "stream",
2555 | "text": [
2556 | "\n",
2557 | "\n",
2558 | "input: vdntly nbl t rprss hs vxtn, ftr th prncss' crrg\n"
2559 | ]
2560 | },
2561 | {
2562 | "name": "stderr",
2563 | "output_type": "stream",
2564 | "text": [
2565 | "0.19: 40%|██████████████▌ | 32/81 [04:19<07:13, 8.84s/it]"
2566 | ]
2567 | },
2568 | {
2569 | "name": "stdout",
2570 | "output_type": "stream",
2571 | "text": [
2572 | "guess: evintly noble to repress his vexan, afare th pins' cariage\n",
2573 | "truth: evidently unable to repress his vexation, after the princess' carriage\n",
2574 | "\n"
2575 | ]
2576 | },
2577 | {
2578 | "name": "stderr",
2579 | "output_type": "stream",
2580 | "text": [
2581 | "0.20: 40%|██████████████▌ | 32/81 [04:19<07:13, 8.84s/it]"
2582 | ]
2583 | },
2584 | {
2585 | "name": "stdout",
2586 | "output_type": "stream",
2587 | "text": [
2588 | "\n",
2589 | "\n",
2590 | "input: mch bttr.\n"
2591 | ]
2592 | },
2593 | {
2594 | "name": "stderr",
2595 | "output_type": "stream",
2596 | "text": [
2597 | "0.20: 41%|███████████████ | 33/81 [04:27<06:57, 8.70s/it]"
2598 | ]
2599 | },
2600 | {
2601 | "name": "stdout",
2602 | "output_type": "stream",
2603 | "text": [
2604 | "guess: muh beter.\n",
2605 | "truth: much better.\n",
2606 | "\n"
2607 | ]
2608 | },
2609 | {
2610 | "name": "stderr",
2611 | "output_type": "stream",
2612 | "text": [
2613 | "0.22: 41%|███████████████ | 33/81 [04:27<06:57, 8.70s/it]"
2614 | ]
2615 | },
2616 | {
2617 | "name": "stdout",
2618 | "output_type": "stream",
2619 | "text": [
2620 | "\n",
2621 | "\n",
2622 | "input: sch nd sch dcrs nd rdrs t th rmy, th flt, th\n"
2623 | ]
2624 | },
2625 | {
2626 | "name": "stderr",
2627 | "output_type": "stream",
2628 | "text": [
2629 | "0.22: 42%|███████████████▌ | 34/81 [04:36<06:46, 8.65s/it]"
2630 | ]
2631 | },
2632 | {
2633 | "name": "stdout",
2634 | "output_type": "stream",
2635 | "text": [
2636 | "guess: suh aned suh deceres aned ories to the army, the fel, the\n",
2637 | "truth: such and such decrees and orders to the army, the fleet, the\n",
2638 | "\n"
2639 | ]
2640 | },
2641 | {
2642 | "name": "stderr",
2643 | "output_type": "stream",
2644 | "text": [
2645 | "0.21: 42%|███████████████▌ | 34/81 [04:36<06:46, 8.65s/it]"
2646 | ]
2647 | },
2648 | {
2649 | "name": "stdout",
2650 | "output_type": "stream",
2651 | "text": [
2652 | "\n",
2653 | "\n",
2654 | "input: rfs hr, ngrly skng hr nt t ntrfr n wht ws nt hr\n"
2655 | ]
2656 | },
2657 | {
2658 | "name": "stderr",
2659 | "output_type": "stream",
2660 | "text": [
2661 | "0.21: 43%|███████████████▉ | 35/81 [04:44<06:40, 8.70s/it]"
2662 | ]
2663 | },
2664 | {
2665 | "name": "stdout",
2666 | "output_type": "stream",
2667 | "text": [
2668 | "guess: refes here, angery aski her not to intere in what was not her\n",
2669 | "truth: refuse her, angrily asking her not to interfere in what was not her\n",
2670 | "\n"
2671 | ]
2672 | },
2673 | {
2674 | "name": "stderr",
2675 | "output_type": "stream",
2676 | "text": [
2677 | "0.22: 43%|███████████████▉ | 35/81 [04:45<06:40, 8.70s/it]"
2678 | ]
2679 | },
2680 | {
2681 | "name": "stdout",
2682 | "output_type": "stream",
2683 | "text": [
2684 | "\n",
2685 | "\n",
2686 | "input: dssmntn f ds, th prntng prss cld hv ccmplshd tht\n"
2687 | ]
2688 | },
2689 | {
2690 | "name": "stderr",
2691 | "output_type": "stream",
2692 | "text": [
2693 | "0.22: 44%|████████████████▍ | 36/81 [04:53<06:33, 8.74s/it]"
2694 | ]
2695 | },
2696 | {
2697 | "name": "stdout",
2698 | "output_type": "stream",
2699 | "text": [
2700 | "guess: desemanten of ideas, the periting pres coule have acomilid t\n",
2701 | "truth: dissemination of ideas, the printing press could have accomplished that\n",
2702 | "\n"
2703 | ]
2704 | },
2705 | {
2706 | "name": "stderr",
2707 | "output_type": "stream",
2708 | "text": [
2709 | "0.23: 44%|████████████████▍ | 36/81 [04:53<06:33, 8.74s/it]"
2710 | ]
2711 | },
2712 | {
2713 | "name": "stdout",
2714 | "output_type": "stream",
2715 | "text": [
2716 | "\n",
2717 | "\n",
2718 | "input: t ths mmnt wr skd: 'Wld y rthr b wht y wr bfr\n"
2719 | ]
2720 | },
2721 | {
2722 | "name": "stderr",
2723 | "output_type": "stream",
2724 | "text": [
2725 | "0.23: 46%|████████████████▉ | 37/81 [05:02<06:23, 8.71s/it]"
2726 | ]
2727 | },
2728 | {
2729 | "name": "stdout",
2730 | "output_type": "stream",
2731 | "text": [
2732 | "guess: to this moment I were aske: 'ouled you rather be what you were befer\n",
2733 | "truth: at this moment I were asked: 'Would you rather be what you were before\n",
2734 | "\n"
2735 | ]
2736 | },
2737 | {
2738 | "name": "stderr",
2739 | "output_type": "stream",
2740 | "text": [
2741 | "0.19: 46%|████████████████▉ | 37/81 [05:02<06:23, 8.71s/it]"
2742 | ]
2743 | },
2744 | {
2745 | "name": "stdout",
2746 | "output_type": "stream",
2747 | "text": [
2748 | "\n",
2749 | "\n",
2750 | "input: Nxt dy Prr cm t sy gd-by. Ntsh ws lss nmtd thn sh\n"
2751 | ]
2752 | },
2753 | {
2754 | "name": "stderr",
2755 | "output_type": "stream",
2756 | "text": [
2757 | "0.19: 47%|█████████████████▎ | 38/81 [05:09<05:49, 8.12s/it]"
2758 | ]
2759 | },
2760 | {
2761 | "name": "stdout",
2762 | "output_type": "stream",
2763 | "text": [
2764 | "guess: Nex da Piere came to sa god-boy. Natasha was los animated than she\n",
2765 | "truth: Next day Pierre came to say good-by. Natasha was less animated than she\n",
2766 | "\n"
2767 | ]
2768 | },
2769 | {
2770 | "name": "stderr",
2771 | "output_type": "stream",
2772 | "text": [
2773 | "0.21: 47%|█████████████████▎ | 38/81 [05:09<05:49, 8.12s/it]"
2774 | ]
2775 | },
2776 | {
2777 | "name": "stdout",
2778 | "output_type": "stream",
2779 | "text": [
2780 | "\n",
2781 | "\n",
2782 | "input: strng chnc n n prcvs ths. Hs prt s nt yt ndd. Th mn\n"
2783 | ]
2784 | },
2785 | {
2786 | "name": "stderr",
2787 | "output_type": "stream",
2788 | "text": [
2789 | "0.21: 48%|█████████████████▊ | 39/81 [05:12<04:45, 6.79s/it]"
2790 | ]
2791 | },
2792 | {
2793 | "name": "stdout",
2794 | "output_type": "stream",
2795 | "text": [
2796 | "guess: stering chance in in perives this. His parit is not yet nod. The man\n",
2797 | "truth: strange chance no one perceives this. His part is not yet ended. The man\n",
2798 | "\n"
2799 | ]
2800 | },
2801 | {
2802 | "name": "stderr",
2803 | "output_type": "stream",
2804 | "text": [
2805 | "0.22: 48%|█████████████████▊ | 39/81 [05:12<04:45, 6.79s/it]"
2806 | ]
2807 | },
2808 | {
2809 | "name": "stdout",
2810 | "output_type": "stream",
2811 | "text": [
2812 | "\n",
2813 | "\n",
2814 | "input: \"Nw, Prr nrss thm splnddly,\" sd Ntsh. \"H sys hs hnd s\n"
2815 | ]
2816 | },
2817 | {
2818 | "name": "stderr",
2819 | "output_type": "stream",
2820 | "text": [
2821 | "0.22: 49%|██████████████████▎ | 40/81 [05:17<04:17, 6.27s/it]"
2822 | ]
2823 | },
2824 | {
2825 | "name": "stdout",
2826 | "output_type": "stream",
2827 | "text": [
2828 | "guess: \"Now, Piere neress them spendly,\" sa Natasha. \"He sas his haned s\n",
2829 | "truth: \"Now, Pierre nurses them splendidly,\" said Natasha. \"He says his hand is\n",
2830 | "\n"
2831 | ]
2832 | },
2833 | {
2834 | "name": "stderr",
2835 | "output_type": "stream",
2836 | "text": [
2837 | "0.23: 49%|██████████████████▎ | 40/81 [05:18<04:17, 6.27s/it]"
2838 | ]
2839 | },
2840 | {
2841 | "name": "stdout",
2842 | "output_type": "stream",
2843 | "text": [
2844 | "\n",
2845 | "\n",
2846 | "input: crcl nd, s h sppsd, f smlr ntrsts.\n"
2847 | ]
2848 | },
2849 | {
2850 | "name": "stderr",
2851 | "output_type": "stream",
2852 | "text": [
2853 | "0.23: 51%|██████████████████▋ | 41/81 [05:26<04:39, 6.98s/it]"
2854 | ]
2855 | },
2856 | {
2857 | "name": "stdout",
2858 | "output_type": "stream",
2859 | "text": [
2860 | "guess: cirical aned, so he suped, if siler interits.\n",
2861 | "truth: circle and, as he supposed, of similar interests.\n",
2862 | "\n"
2863 | ]
2864 | },
2865 | {
2866 | "name": "stderr",
2867 | "output_type": "stream",
2868 | "text": [
2869 | "0.20: 51%|██████████████████▋ | 41/81 [05:26<04:39, 6.98s/it]"
2870 | ]
2871 | },
2872 | {
2873 | "name": "stdout",
2874 | "output_type": "stream",
2875 | "text": [
2876 | "\n",
2877 | "\n",
2878 | "input: crtn dfnt dpndnc xsts btwn th tw.\n"
2879 | ]
2880 | },
2881 | {
2882 | "name": "stderr",
2883 | "output_type": "stream",
2884 | "text": [
2885 | "0.20: 52%|███████████████████▏ | 42/81 [05:35<05:01, 7.73s/it]"
2886 | ]
2887 | },
2888 | {
2889 | "name": "stdout",
2890 | "output_type": "stream",
2891 | "text": [
2892 | "guess: ceriten defint depenedenc exits beteen the tw.\n",
2893 | "truth: certain definite dependence exists between the two.\n",
2894 | "\n"
2895 | ]
2896 | },
2897 | {
2898 | "name": "stderr",
2899 | "output_type": "stream",
2900 | "text": [
2901 | "0.20: 52%|███████████████████▏ | 42/81 [05:36<05:01, 7.73s/it]"
2902 | ]
2903 | },
2904 | {
2905 | "name": "stdout",
2906 | "output_type": "stream",
2907 | "text": [
2908 | "\n",
2909 | "\n",
2910 | "input: wth hs wf st dwn t th lng tbl ld fr twnty prsns, t\n"
2911 | ]
2912 | },
2913 | {
2914 | "name": "stderr",
2915 | "output_type": "stream",
2916 | "text": [
2917 | "0.20: 53%|███████████████████▋ | 43/81 [05:45<05:10, 8.17s/it]"
2918 | ]
2919 | },
2920 | {
2921 | "name": "stdout",
2922 | "output_type": "stream",
2923 | "text": [
2924 | "guess: with his wife sa down to the long tal lod fer twinty perins, to\n",
2925 | "truth: with his wife sat down at the long table laid for twenty persons, at\n",
2926 | "\n"
2927 | ]
2928 | },
2929 | {
2930 | "name": "stderr",
2931 | "output_type": "stream",
2932 | "text": [
2933 | "0.21: 53%|███████████████████▋ | 43/81 [05:45<05:10, 8.17s/it]"
2934 | ]
2935 | },
2936 | {
2937 | "name": "stdout",
2938 | "output_type": "stream",
2939 | "text": [
2940 | "\n",
2941 | "\n",
2942 | "input: mn's mrl rspnsblts frm thr pth.\n"
2943 | ]
2944 | },
2945 | {
2946 | "name": "stderr",
2947 | "output_type": "stream",
2948 | "text": [
2949 | "0.21: 54%|████████████████████ | 44/81 [05:53<05:09, 8.36s/it]"
2950 | ]
2951 | },
2952 | {
2953 | "name": "stdout",
2954 | "output_type": "stream",
2955 | "text": [
2956 | "guess: man's mora respensits frem ther path.\n",
2957 | "truth: men's moral responsibilities from their path.\n",
2958 | "\n"
2959 | ]
2960 | },
2961 | {
2962 | "name": "stderr",
2963 | "output_type": "stream",
2964 | "text": [
2965 | "0.20: 54%|████████████████████ | 44/81 [05:54<05:09, 8.36s/it]"
2966 | ]
2967 | },
2968 | {
2969 | "name": "stdout",
2970 | "output_type": "stream",
2971 | "text": [
2972 | "\n",
2973 | "\n",
2974 | "input: th hstrn wh jdgs lxndr wll ls ftr th lps f sm\n"
2975 | ]
2976 | },
2977 | {
2978 | "name": "stderr",
2979 | "output_type": "stream",
2980 | "text": [
2981 | "0.20: 56%|████████████████████▌ | 45/81 [06:03<05:10, 8.63s/it]"
2982 | ]
2983 | },
2984 | {
2985 | "name": "stdout",
2986 | "output_type": "stream",
2987 | "text": [
2988 | "guess: th hiseran who juds Aleneder will lose afare th lips if some\n",
2989 | "truth: the historian who judges Alexander will also after the lapse of some\n",
2990 | "\n"
2991 | ]
2992 | },
2993 | {
2994 | "name": "stderr",
2995 | "output_type": "stream",
2996 | "text": [
2997 | "0.22: 56%|████████████████████▌ | 45/81 [06:03<05:10, 8.63s/it]"
2998 | ]
2999 | },
3000 | {
3001 | "name": "stdout",
3002 | "output_type": "stream",
3003 | "text": [
3004 | "\n",
3005 | "\n",
3006 | "input: f n thsnd yrs vn n mn n mlln cld ct frly, tht\n"
3007 | ]
3008 | },
3009 | {
3010 | "name": "stderr",
3011 | "output_type": "stream",
3012 | "text": [
3013 | "0.22: 57%|█████████████████████ | 46/81 [06:12<05:10, 8.86s/it]"
3014 | ]
3015 | },
3016 | {
3017 | "name": "stdout",
3018 | "output_type": "stream",
3019 | "text": [
3020 | "guess: of in a thousaned youres ven no man in a milen coule cot fery, that\n",
3021 | "truth: If in a thousand years even one man in a million could act freely, that\n",
3022 | "\n"
3023 | ]
3024 | },
3025 | {
3026 | "name": "stderr",
3027 | "output_type": "stream",
3028 | "text": [
3029 | "0.20: 57%|█████████████████████ | 46/81 [06:12<05:10, 8.86s/it]"
3030 | ]
3031 | },
3032 | {
3033 | "name": "stdout",
3034 | "output_type": "stream",
3035 | "text": [
3036 | "\n",
3037 | "\n",
3038 | "input: cntss cld nt cncv f lf wtht th lxrs cndtns sh\n"
3039 | ]
3040 | },
3041 | {
3042 | "name": "stderr",
3043 | "output_type": "stream",
3044 | "text": [
3045 | "0.20: 58%|█████████████████████▍ | 47/81 [06:21<05:03, 8.91s/it]"
3046 | ]
3047 | },
3048 | {
3049 | "name": "stdout",
3050 | "output_type": "stream",
3051 | "text": [
3052 | "guess: counes coul not cone of life wit the lexeres conedions she\n",
3053 | "truth: countess could not conceive of life without the luxurious conditions she\n",
3054 | "\n"
3055 | ]
3056 | },
3057 | {
3058 | "name": "stderr",
3059 | "output_type": "stream",
3060 | "text": [
3061 | "0.21: 58%|█████████████████████▍ | 47/81 [06:21<05:03, 8.91s/it]"
3062 | ]
3063 | },
3064 | {
3065 | "name": "stdout",
3066 | "output_type": "stream",
3067 | "text": [
3068 | "\n",
3069 | "\n",
3070 | "input: h cmmts, whch n th mths f ths rnd hm s nt t nc\n"
3071 | ]
3072 | },
3073 | {
3074 | "name": "stderr",
3075 | "output_type": "stream",
3076 | "text": [
3077 | "0.21: 59%|█████████████████████▉ | 48/81 [06:30<04:50, 8.80s/it]"
3078 | ]
3079 | },
3080 | {
3081 | "name": "stdout",
3082 | "output_type": "stream",
3083 | "text": [
3084 | "guess: he comets, which in the moths if thes rouned him so not to noce\n",
3085 | "truth: he commits, which in the mouths of those around him is not at once\n",
3086 | "\n"
3087 | ]
3088 | },
3089 | {
3090 | "name": "stderr",
3091 | "output_type": "stream",
3092 | "text": [
3093 | "0.22: 59%|█████████████████████▉ | 48/81 [06:30<04:50, 8.80s/it]"
3094 | ]
3095 | },
3096 | {
3097 | "name": "stdout",
3098 | "output_type": "stream",
3099 | "text": [
3100 | "\n",
3101 | "\n",
3102 | "input: cnvrstn frm chngng ts rdnry chrctr f gssp bt th\n"
3103 | ]
3104 | },
3105 | {
3106 | "name": "stderr",
3107 | "output_type": "stream",
3108 | "text": [
3109 | "0.22: 60%|██████████████████████▍ | 49/81 [06:39<04:46, 8.96s/it]"
3110 | ]
3111 | },
3112 | {
3113 | "name": "stdout",
3114 | "output_type": "stream",
3115 | "text": [
3116 | "guess: conerasten frem changing ites orinery charicer of gosip but the\n",
3117 | "truth: conversation from changing its ordinary character of gossip about the\n",
3118 | "\n"
3119 | ]
3120 | },
3121 | {
3122 | "name": "stderr",
3123 | "output_type": "stream",
3124 | "text": [
3125 | "0.22: 60%|██████████████████████▍ | 49/81 [06:39<04:46, 8.96s/it]"
3126 | ]
3127 | },
3128 | {
3129 | "name": "stdout",
3130 | "output_type": "stream",
3131 | "text": [
3132 | "\n",
3133 | "\n",
3134 | "input: \"nd bcs,\" Prr cntnd, \"nly n wh blvs tht thr s \n"
3135 | ]
3136 | },
3137 | {
3138 | "name": "stderr",
3139 | "output_type": "stream",
3140 | "text": [
3141 | "0.22: 62%|██████████████████████▊ | 50/81 [06:45<04:07, 7.99s/it]"
3142 | ]
3143 | },
3144 | {
3145 | "name": "stdout",
3146 | "output_type": "stream",
3147 | "text": [
3148 | "guess: \"Aned becase,\" Piere contened, \"ony in who beles that there is a\n",
3149 | "truth: \"And because,\" Pierre continued, \"only one who believes that there is a\n",
3150 | "\n"
3151 | ]
3152 | },
3153 | {
3154 | "name": "stderr",
3155 | "output_type": "stream",
3156 | "text": [
3157 | "0.33: 62%|██████████████████████▊ | 50/81 [06:45<04:07, 7.99s/it]"
3158 | ]
3159 | },
3160 | {
3161 | "name": "stdout",
3162 | "output_type": "stream",
3163 | "text": [
3164 | "\n",
3165 | "\n",
3166 | "input: pnd th crkng dr, wnt p t th sf wth nrgtc stps f\n"
3167 | ]
3168 | },
3169 | {
3170 | "name": "stderr",
3171 | "output_type": "stream",
3172 | "text": [
3173 | "0.20: 63%|███████████████████████▎ | 51/81 [06:48<03:15, 6.52s/it]"
3174 | ]
3175 | },
3176 | {
3177 | "name": "stdout",
3178 | "output_type": "stream",
3179 | "text": [
3180 | "guess: opene the coring dear, went up to the sof with nerig stes of\n",
3181 | "truth: opened the creaking door, went up to the sofa with energetic steps of\n",
3182 | "\n",
3183 | "\n",
3184 | "\n",
3185 | "input: lwys hd ldy cmpnns, bt wh thy wr nd wht thy wr lk h\n"
3186 | ]
3187 | },
3188 | {
3189 | "name": "stderr",
3190 | "output_type": "stream",
3191 | "text": [
3192 | "0.20: 64%|███████████████████████▊ | 52/81 [06:57<03:30, 7.26s/it]"
3193 | ]
3194 | },
3195 | {
3196 | "name": "stdout",
3197 | "output_type": "stream",
3198 | "text": [
3199 | "guess: alys had lay copins, but who the were need what the were lok he\n",
3200 | "truth: always had lady companions, but who they were and what they were like he\n",
3201 | "\n"
3202 | ]
3203 | },
3204 | {
3205 | "name": "stderr",
3206 | "output_type": "stream",
3207 | "text": [
3208 | "0.22: 64%|███████████████████████▊ | 52/81 [06:57<03:30, 7.26s/it]"
3209 | ]
3210 | },
3211 | {
3212 | "name": "stdout",
3213 | "output_type": "stream",
3214 | "text": [
3215 | "\n",
3216 | "\n",
3217 | "input: ncntstbly rsn nd xprmnt my prv t hm tht t s\n"
3218 | ]
3219 | },
3220 | {
3221 | "name": "stderr",
3222 | "output_type": "stream",
3223 | "text": [
3224 | "0.22: 65%|████████████████████████▏ | 53/81 [07:05<03:33, 7.63s/it]"
3225 | ]
3226 | },
3227 | {
3228 | "name": "stdout",
3229 | "output_type": "stream",
3230 | "text": [
3231 | "guess: ancestibly reasen aned exriment my prove to him t it so\n",
3232 | "truth: incontestably reason and experiment may prove to him that it is\n",
3233 | "\n"
3234 | ]
3235 | },
3236 | {
3237 | "name": "stderr",
3238 | "output_type": "stream",
3239 | "text": [
3240 | "0.21: 65%|████████████████████████▏ | 53/81 [07:06<03:33, 7.63s/it]"
3241 | ]
3242 | },
3243 | {
3244 | "name": "stdout",
3245 | "output_type": "stream",
3246 | "text": [
3247 | "\n",
3248 | "\n",
3249 | "input: mght hp fr hlp frm hs fllws nd th dfnt plc h hld\n"
3250 | ]
3251 | },
3252 | {
3253 | "name": "stderr",
3254 | "output_type": "stream",
3255 | "text": [
3256 | "0.21: 67%|████████████████████████▋ | 54/81 [07:14<03:35, 7.99s/it]"
3257 | ]
3258 | },
3259 | {
3260 | "name": "stdout",
3261 | "output_type": "stream",
3262 | "text": [
3263 | "guess: might hope fer hel frem his fellows aned the defint pale he hel\n",
3264 | "truth: might hope for help from his fellows and the definite place he held\n",
3265 | "\n"
3266 | ]
3267 | },
3268 | {
3269 | "name": "stderr",
3270 | "output_type": "stream",
3271 | "text": [
3272 | "0.24: 67%|████████████████████████▋ | 54/81 [07:14<03:35, 7.99s/it]"
3273 | ]
3274 | },
3275 | {
3276 | "name": "stdout",
3277 | "output_type": "stream",
3278 | "text": [
3279 | "\n",
3280 | "\n",
3281 | "input: Blv dmrd th prsnts nd ws dlghtd wth hr drss mtrl.\n"
3282 | ]
3283 | },
3284 | {
3285 | "name": "stderr",
3286 | "output_type": "stream",
3287 | "text": [
3288 | "0.24: 68%|█████████████████████████ | 55/81 [07:23<03:35, 8.28s/it]"
3289 | ]
3290 | },
3291 | {
3292 | "name": "stdout",
3293 | "output_type": "stream",
3294 | "text": [
3295 | "guess: Bele dired th perits aned was delited with her dres matera.\n",
3296 | "truth: Belova admired the presents and was delighted with her dress material.\n",
3297 | "\n"
3298 | ]
3299 | },
3300 | {
3301 | "name": "stderr",
3302 | "output_type": "stream",
3303 | "text": [
3304 | "0.18: 68%|█████████████████████████ | 55/81 [07:23<03:35, 8.28s/it]"
3305 | ]
3306 | },
3307 | {
3308 | "name": "stdout",
3309 | "output_type": "stream",
3310 | "text": [
3311 | "\n",
3312 | "\n",
3313 | "input: nd ths smpl wrds, hr lk, nd th xprssn n hr fc whch\n"
3314 | ]
3315 | },
3316 | {
3317 | "name": "stderr",
3318 | "output_type": "stream",
3319 | "text": [
3320 | "0.18: 69%|█████████████████████████▌ | 56/81 [07:33<03:35, 8.62s/it]"
3321 | ]
3322 | },
3323 | {
3324 | "name": "stdout",
3325 | "output_type": "stream",
3326 | "text": [
3327 | "guess: ane this siple wored, her lok, aned the exression in her face which\n",
3328 | "truth: And these simple words, her look, and the expression on her face which\n",
3329 | "\n"
3330 | ]
3331 | },
3332 | {
3333 | "name": "stderr",
3334 | "output_type": "stream",
3335 | "text": [
3336 | "0.22: 69%|█████████████████████████▌ | 56/81 [07:33<03:35, 8.62s/it]"
3337 | ]
3338 | },
3339 | {
3340 | "name": "stdout",
3341 | "output_type": "stream",
3342 | "text": [
3343 | "\n",
3344 | "\n",
3345 | "input: n nthr thr yrs, by 1820, h hd s mngd hs ffrs tht h\n"
3346 | ]
3347 | },
3348 | {
3349 | "name": "stderr",
3350 | "output_type": "stream",
3351 | "text": [
3352 | "0.22: 70%|██████████████████████████ | 57/81 [07:42<03:35, 8.96s/it]"
3353 | ]
3354 | },
3355 | {
3356 | "name": "stdout",
3357 | "output_type": "stream",
3358 | "text": [
3359 | "guess: one nother ther yeres, by 18 he had so managed his afaires that he\n",
3360 | "truth: In another three years, by 1820, he had so managed his affairs that he\n",
3361 | "\n"
3362 | ]
3363 | },
3364 | {
3365 | "name": "stderr",
3366 | "output_type": "stream",
3367 | "text": [
3368 | "0.21: 70%|██████████████████████████ | 57/81 [07:43<03:35, 8.96s/it]"
3369 | ]
3370 | },
3371 | {
3372 | "name": "stdout",
3373 | "output_type": "stream",
3374 | "text": [
3375 | "\n",
3376 | "\n",
3377 | "input: mst sm t b gns. nd t mst ppr n stnshng cnjnctn\n"
3378 | ]
3379 | },
3380 | {
3381 | "name": "stderr",
3382 | "output_type": "stream",
3383 | "text": [
3384 | "0.21: 72%|██████████████████████████▍ | 58/81 [07:51<03:24, 8.88s/it]"
3385 | ]
3386 | },
3387 | {
3388 | "name": "stdout",
3389 | "output_type": "stream",
3390 | "text": [
3391 | "guess: mes som to be a genes. need it mis paper in sonishing concation\n",
3392 | "truth: must seem to be a genius. And it must appear an astonishing conjunction\n",
3393 | "\n"
3394 | ]
3395 | },
3396 | {
3397 | "name": "stderr",
3398 | "output_type": "stream",
3399 | "text": [
3400 | "0.19: 72%|██████████████████████████▍ | 58/81 [07:51<03:24, 8.88s/it]"
3401 | ]
3402 | },
3403 | {
3404 | "name": "stdout",
3405 | "output_type": "stream",
3406 | "text": [
3407 | "\n",
3408 | "\n",
3409 | "input: Wllrsk ws gng t Mscw nd thy grd t trvl tgthr.\n"
3410 | ]
3411 | },
3412 | {
3413 | "name": "stderr",
3414 | "output_type": "stream",
3415 | "text": [
3416 | "0.19: 73%|██████████████████████████▉ | 59/81 [07:59<03:10, 8.68s/it]"
3417 | ]
3418 | },
3419 | {
3420 | "name": "stdout",
3421 | "output_type": "stream",
3422 | "text": [
3423 | "guess: Wlarski was gong to Moscow aned the geared to triva toger.\n",
3424 | "truth: Willarski was going to Moscow and they agreed to travel together.\n",
3425 | "\n"
3426 | ]
3427 | },
3428 | {
3429 | "name": "stderr",
3430 | "output_type": "stream",
3431 | "text": [
3432 | "0.19: 73%|██████████████████████████▉ | 59/81 [07:59<03:10, 8.68s/it]"
3433 | ]
3434 | },
3435 | {
3436 | "name": "stdout",
3437 | "output_type": "stream",
3438 | "text": [
3439 | "\n",
3440 | "\n",
3441 | "input: hm. \"Why? Tll m. Y mst tll m!\"\n"
3442 | ]
3443 | },
3444 | {
3445 | "name": "stderr",
3446 | "output_type": "stream",
3447 | "text": [
3448 | "0.19: 74%|███████████████████████████▍ | 60/81 [08:08<03:01, 8.63s/it]"
3449 | ]
3450 | },
3451 | {
3452 | "name": "stdout",
3453 | "output_type": "stream",
3454 | "text": [
3455 | "guess: him. \"Why? Tel me. You mis tel me!\"\n",
3456 | "truth: him. \"Why? Tell me. You must tell me!\"\n",
3457 | "\n"
3458 | ]
3459 | },
3460 | {
3461 | "name": "stderr",
3462 | "output_type": "stream",
3463 | "text": [
3464 | "0.23: 74%|███████████████████████████▍ | 60/81 [08:08<03:01, 8.63s/it]"
3465 | ]
3466 | },
3467 | {
3468 | "name": "stdout",
3469 | "output_type": "stream",
3470 | "text": [
3471 | "\n",
3472 | "\n",
3473 | "input: t rmmbr wht thy hr t nrch thr mnds nd whn pprtnty\n"
3474 | ]
3475 | },
3476 | {
3477 | "name": "stderr",
3478 | "output_type": "stream",
3479 | "text": [
3480 | "0.21: 75%|███████████████████████████▊ | 61/81 [08:16<02:49, 8.47s/it]"
3481 | ]
3482 | },
3483 | {
3484 | "name": "stdout",
3485 | "output_type": "stream",
3486 | "text": [
3487 | "guess: to remeber what the here to norich ther mineds aned when oportenty\n",
3488 | "truth: to remember what they hear to enrich their minds and when opportunity\n",
3489 | "\n",
3490 | "\n",
3491 | "\n",
3492 | "input: stsfd....\"\n"
3493 | ]
3494 | },
3495 | {
3496 | "name": "stderr",
3497 | "output_type": "stream",
3498 | "text": [
3499 | "0.25: 77%|████████████████████████████▎ | 62/81 [08:19<02:11, 6.94s/it]"
3500 | ]
3501 | },
3502 | {
3503 | "name": "stdout",
3504 | "output_type": "stream",
3505 | "text": [
3506 | "guess: sasified..\"\n",
3507 | "truth: satisfied....\"\n",
3508 | "\n",
3509 | "\n",
3510 | "\n",
3511 | "input: \"Why ths,\" bgn Prr, nt sttng dwn bt pcng th rm,\n"
3512 | ]
3513 | },
3514 | {
3515 | "name": "stderr",
3516 | "output_type": "stream",
3517 | "text": [
3518 | "0.25: 78%|████████████████████████████▊ | 63/81 [08:24<01:54, 6.37s/it]"
3519 | ]
3520 | },
3521 | {
3522 | "name": "stdout",
3523 | "output_type": "stream",
3524 | "text": [
3525 | "guess: \"W this,\" began Piere, not siting down but pang the rom,\n",
3526 | "truth: \"Why this,\" began Pierre, not sitting down but pacing the room,\n",
3527 | "\n"
3528 | ]
3529 | },
3530 | {
3531 | "name": "stderr",
3532 | "output_type": "stream",
3533 | "text": [
3534 | "0.27: 78%|████████████████████████████▊ | 63/81 [08:25<01:54, 6.37s/it]"
3535 | ]
3536 | },
3537 | {
3538 | "name": "stdout",
3539 | "output_type": "stream",
3540 | "text": [
3541 | "\n",
3542 | "\n",
3543 | "input: tht...\" nd rmmbrng hs frmr tndrnss, nd lkng nw t hs\n"
3544 | ]
3545 | },
3546 | {
3547 | "name": "stderr",
3548 | "output_type": "stream",
3549 | "text": [
3550 | "0.27: 79%|█████████████████████████████▏ | 64/81 [08:33<02:01, 7.13s/it]"
3551 | ]
3552 | },
3553 | {
3554 | "name": "stdout",
3555 | "output_type": "stream",
3556 | "text": [
3557 | "guess: that..\" aned remebering his fore tenedernes, aned loking now to his\n",
3558 | "truth: that...\" And remembering his former tenderness, and looking now at his\n",
3559 | "\n"
3560 | ]
3561 | },
3562 | {
3563 | "name": "stderr",
3564 | "output_type": "stream",
3565 | "text": [
3566 | "0.24: 79%|█████████████████████████████▏ | 64/81 [08:33<02:01, 7.13s/it]"
3567 | ]
3568 | },
3569 | {
3570 | "name": "stdout",
3571 | "output_type": "stream",
3572 | "text": [
3573 | "\n",
3574 | "\n",
3575 | "input: cntrmvmnt s thn ccmplshd frm st t wst wth \n"
3576 | ]
3577 | },
3578 | {
3579 | "name": "stderr",
3580 | "output_type": "stream",
3581 | "text": [
3582 | "0.24: 80%|█████████████████████████████▋ | 65/81 [08:41<01:58, 7.44s/it]"
3583 | ]
3584 | },
3585 | {
3586 | "name": "stdout",
3587 | "output_type": "stream",
3588 | "text": [
3589 | "guess: a conterovement so than acomilied frem sa it wis with a\n",
3590 | "truth: A countermovement is then accomplished from east to west with a\n",
3591 | "\n"
3592 | ]
3593 | },
3594 | {
3595 | "name": "stderr",
3596 | "output_type": "stream",
3597 | "text": [
3598 | "0.21: 80%|█████████████████████████████▋ | 65/81 [08:42<01:58, 7.44s/it]"
3599 | ]
3600 | },
3601 | {
3602 | "name": "stdout",
3603 | "output_type": "stream",
3604 | "text": [
3605 | "\n",
3606 | "\n",
3607 | "input: rmmbrng; bt snc tht... t's nly bn trmntng flk.\"\n"
3608 | ]
3609 | },
3610 | {
3611 | "name": "stderr",
3612 | "output_type": "stream",
3613 | "text": [
3614 | "0.21: 81%|██████████████████████████████▏ | 66/81 [08:50<01:57, 7.86s/it]"
3615 | ]
3616 | },
3617 | {
3618 | "name": "stdout",
3619 | "output_type": "stream",
3620 | "text": [
3621 | "guess: rembering; but sine that.. ites ony been terinting fok.\"\n",
3622 | "truth: remembering; but since that... it's only been tormenting folk.\"\n",
3623 | "\n"
3624 | ]
3625 | },
3626 | {
3627 | "name": "stderr",
3628 | "output_type": "stream",
3629 | "text": [
3630 | "0.22: 81%|██████████████████████████████▏ | 66/81 [08:50<01:57, 7.86s/it]"
3631 | ]
3632 | },
3633 | {
3634 | "name": "stdout",
3635 | "output_type": "stream",
3636 | "text": [
3637 | "\n",
3638 | "\n",
3639 | "input: nt t b trfld wth thr--n wrd, h ws rl mstr!\"\n"
3640 | ]
3641 | },
3642 | {
3643 | "name": "stderr",
3644 | "output_type": "stream",
3645 | "text": [
3646 | "0.22: 83%|██████████████████████████████▌ | 67/81 [08:59<01:54, 8.17s/it]"
3647 | ]
3648 | },
3649 | {
3650 | "name": "stdout",
3651 | "output_type": "stream",
3652 | "text": [
3653 | "guess: not to be trefoled with ther-in a wored, he was a rel miser!\"\n",
3654 | "truth: not to be trifled with either--in a word, he was a real master!\"\n",
3655 | "\n"
3656 | ]
3657 | },
3658 | {
3659 | "name": "stderr",
3660 | "output_type": "stream",
3661 | "text": [
3662 | "0.18: 83%|██████████████████████████████▌ | 67/81 [08:59<01:54, 8.17s/it]"
3663 | ]
3664 | },
3665 | {
3666 | "name": "stdout",
3667 | "output_type": "stream",
3668 | "text": [
3669 | "\n",
3670 | "\n",
3671 | "input: plyng ptnc, nd s--thgh by frc f hbt sh grtd hm wth\n"
3672 | ]
3673 | },
3674 | {
3675 | "name": "stderr",
3676 | "output_type": "stream",
3677 | "text": [
3678 | "0.18: 84%|███████████████████████████████ | 68/81 [09:08<01:49, 8.39s/it]"
3679 | ]
3680 | },
3681 | {
3682 | "name": "stdout",
3683 | "output_type": "stream",
3684 | "text": [
3685 | "guess: plying patence, aned so-thouh by foric of habout she gered him with\n",
3686 | "truth: playing patience, and so--though by force of habit she greeted him with\n",
3687 | "\n"
3688 | ]
3689 | },
3690 | {
3691 | "name": "stderr",
3692 | "output_type": "stream",
3693 | "text": [
3694 | "0.19: 84%|███████████████████████████████ | 68/81 [09:08<01:49, 8.39s/it]"
3695 | ]
3696 | },
3697 | {
3698 | "name": "stdout",
3699 | "output_type": "stream",
3700 | "text": [
3701 | "\n",
3702 | "\n",
3703 | "input: brghtly. \" nly wntd t tll y bt Pty: tdy nrs ws cmng\n"
3704 | ]
3705 | },
3706 | {
3707 | "name": "stderr",
3708 | "output_type": "stream",
3709 | "text": [
3710 | "0.19: 85%|███████████████████████████████▌ | 69/81 [09:16<01:40, 8.39s/it]"
3711 | ]
3712 | },
3713 | {
3714 | "name": "stdout",
3715 | "output_type": "stream",
3716 | "text": [
3717 | "guess: brihtly. \"I ony waned to tel you but Pety: toy nores was coming\n",
3718 | "truth: brightly. \"I only wanted to tell you about Petya: today nurse was coming\n",
3719 | "\n"
3720 | ]
3721 | },
3722 | {
3723 | "name": "stderr",
3724 | "output_type": "stream",
3725 | "text": [
3726 | "0.18: 85%|███████████████████████████████▌ | 69/81 [09:17<01:40, 8.39s/it]"
3727 | ]
3728 | },
3729 | {
3730 | "name": "stdout",
3731 | "output_type": "stream",
3732 | "text": [
3733 | "\n",
3734 | "\n",
3735 | "input: sn hs scrfcd hmslf fr hs mthr,\" s ppl wr syng.\n"
3736 | ]
3737 | },
3738 | {
3739 | "name": "stderr",
3740 | "output_type": "stream",
3741 | "text": [
3742 | "0.18: 86%|███████████████████████████████▉ | 70/81 [09:26<01:35, 8.64s/it]"
3743 | ]
3744 | },
3745 | {
3746 | "name": "stdout",
3747 | "output_type": "stream",
3748 | "text": [
3749 | "guess: seen his sarfed himef fer his mother,\" so pepe were sang.\n",
3750 | "truth: son has sacrificed himself for his mother,\" as people were saying.\n",
3751 | "\n"
3752 | ]
3753 | },
3754 | {
3755 | "name": "stderr",
3756 | "output_type": "stream",
3757 | "text": [
3758 | "0.18: 86%|███████████████████████████████▉ | 70/81 [09:26<01:35, 8.64s/it]"
3759 | ]
3760 | },
3761 | {
3762 | "name": "stdout",
3763 | "output_type": "stream",
3764 | "text": [
3765 | "\n",
3766 | "\n",
3767 | "input: H ws prd f hr ntllgnc nd gdnss, rcgnzd hs wn\n"
3768 | ]
3769 | },
3770 | {
3771 | "name": "stderr",
3772 | "output_type": "stream",
3773 | "text": [
3774 | "0.18: 88%|████████████████████████████████▍ | 71/81 [09:34<01:25, 8.59s/it]"
3775 | ]
3776 | },
3777 | {
3778 | "name": "stdout",
3779 | "output_type": "stream",
3780 | "text": [
3781 | "guess: He was peared of her intlence aned godnes, reconized his owin\n",
3782 | "truth: He was proud of her intelligence and goodness, recognized his own\n",
3783 | "\n"
3784 | ]
3785 | },
3786 | {
3787 | "name": "stderr",
3788 | "output_type": "stream",
3789 | "text": [
3790 | "0.25: 88%|████████████████████████████████▍ | 71/81 [09:34<01:25, 8.59s/it]"
3791 | ]
3792 | },
3793 | {
3794 | "name": "stdout",
3795 | "output_type": "stream",
3796 | "text": [
3797 | "\n",
3798 | "\n",
3799 | "input: cntnlly sght t fnd--th m f lf--n lngr xstd fr hm\n"
3800 | ]
3801 | },
3802 | {
3803 | "name": "stderr",
3804 | "output_type": "stream",
3805 | "text": [
3806 | "0.25: 89%|████████████████████████████████▉ | 72/81 [09:43<01:17, 8.59s/it]"
3807 | ]
3808 | },
3809 | {
3810 | "name": "stdout",
3811 | "output_type": "stream",
3812 | "text": [
3813 | "guess: conally siht to faine-th am if lif-no longer exised fer him\n",
3814 | "truth: continually sought to find--the aim of life--no longer existed for him\n",
3815 | "\n"
3816 | ]
3817 | },
3818 | {
3819 | "name": "stderr",
3820 | "output_type": "stream",
3821 | "text": [
3822 | "0.19: 89%|████████████████████████████████▉ | 72/81 [09:43<01:17, 8.59s/it]"
3823 | ]
3824 | },
3825 | {
3826 | "name": "stdout",
3827 | "output_type": "stream",
3828 | "text": [
3829 | "\n",
3830 | "\n",
3831 | "input: tm.\n"
3832 | ]
3833 | },
3834 | {
3835 | "name": "stderr",
3836 | "output_type": "stream",
3837 | "text": [
3838 | "0.21: 90%|█████████████████████████████████▎ | 73/81 [09:51<01:07, 8.46s/it]"
3839 | ]
3840 | },
3841 | {
3842 | "name": "stdout",
3843 | "output_type": "stream",
3844 | "text": [
3845 | "guess: tem.\n",
3846 | "truth: time.\n",
3847 | "\n",
3848 | "\n",
3849 | "\n",
3850 | "input: jstl, vrtk n nthr, nd fght, nd t wld b qlly\n"
3851 | ]
3852 | },
3853 | {
3854 | "name": "stderr",
3855 | "output_type": "stream",
3856 | "text": [
3857 | "0.21: 91%|█████████████████████████████████▊ | 74/81 [09:54<00:48, 6.95s/it]"
3858 | ]
3859 | },
3860 | {
3861 | "name": "stdout",
3862 | "output_type": "stream",
3863 | "text": [
3864 | "guess: jusa, overite in another, aned fit, need it woled be quily\n",
3865 | "truth: jostle, overtake one another, and fight, and it would be equally\n",
3866 | "\n"
3867 | ]
3868 | },
3869 | {
3870 | "name": "stderr",
3871 | "output_type": "stream",
3872 | "text": [
3873 | "0.20: 91%|█████████████████████████████████▊ | 74/81 [09:54<00:48, 6.95s/it]"
3874 | ]
3875 | },
3876 | {
3877 | "name": "stdout",
3878 | "output_type": "stream",
3879 | "text": [
3880 | "\n",
3881 | "\n",
3882 | "input: nthr tm.\"\n"
3883 | ]
3884 | },
3885 | {
3886 | "name": "stderr",
3887 | "output_type": "stream",
3888 | "text": [
3889 | "0.20: 93%|██████████████████████████████████▎ | 75/81 [09:59<00:37, 6.25s/it]"
3890 | ]
3891 | },
3892 | {
3893 | "name": "stdout",
3894 | "output_type": "stream",
3895 | "text": [
3896 | "guess: another tem.\"\n",
3897 | "truth: another time.\"\n",
3898 | "\n"
3899 | ]
3900 | },
3901 | {
3902 | "name": "stderr",
3903 | "output_type": "stream",
3904 | "text": [
3905 | "0.17: 93%|██████████████████████████████████▎ | 75/81 [09:59<00:37, 6.25s/it]"
3906 | ]
3907 | },
3908 | {
3909 | "name": "stdout",
3910 | "output_type": "stream",
3911 | "text": [
3912 | "\n",
3913 | "\n",
3914 | "input: th nstttns f stt nd chrch r rctd.\n"
3915 | ]
3916 | },
3917 | {
3918 | "name": "stderr",
3919 | "output_type": "stream",
3920 | "text": [
3921 | "0.17: 94%|██████████████████████████████████▋ | 76/81 [10:07<00:34, 6.94s/it]"
3922 | ]
3923 | },
3924 | {
3925 | "name": "stdout",
3926 | "output_type": "stream",
3927 | "text": [
3928 | "guess: th intions of st aned cher or recouted.\n",
3929 | "truth: the institutions of state and church are erected.\n",
3930 | "\n"
3931 | ]
3932 | },
3933 | {
3934 | "name": "stderr",
3935 | "output_type": "stream",
3936 | "text": [
3937 | "0.20: 94%|██████████████████████████████████▋ | 76/81 [10:08<00:34, 6.94s/it]"
3938 | ]
3939 | },
3940 | {
3941 | "name": "stdout",
3942 | "output_type": "stream",
3943 | "text": [
3944 | "\n",
3945 | "\n",
3946 | "input: prdcs vnts, nd trt t s thr cs. n thr xpstn, n\n"
3947 | ]
3948 | },
3949 | {
3950 | "name": "stderr",
3951 | "output_type": "stream",
3952 | "text": [
3953 | "0.20: 95%|███████████████████████████████████▏ | 77/81 [10:16<00:29, 7.40s/it]"
3954 | ]
3955 | },
3956 | {
3957 | "name": "stdout",
3958 | "output_type": "stream",
3959 | "text": [
3960 | "guess: preds vents, aned terit it so ther case. in ther expositen, in\n",
3961 | "truth: produces events, and treat it as their cause. In their exposition, an\n",
3962 | "\n"
3963 | ]
3964 | },
3965 | {
3966 | "name": "stderr",
3967 | "output_type": "stream",
3968 | "text": [
3969 | "0.25: 95%|███████████████████████████████████▏ | 77/81 [10:16<00:29, 7.40s/it]"
3970 | ]
3971 | },
3972 | {
3973 | "name": "stdout",
3974 | "output_type": "stream",
3975 | "text": [
3976 | "\n",
3977 | "\n",
3978 | "input: Th hstrns f cltr r qt cnsstnt n rgrd t thr\n"
3979 | ]
3980 | },
3981 | {
3982 | "name": "stderr",
3983 | "output_type": "stream",
3984 | "text": [
3985 | "0.25: 96%|███████████████████████████████████▋ | 78/81 [10:25<00:23, 7.91s/it]"
3986 | ]
3987 | },
3988 | {
3989 | "name": "stdout",
3990 | "output_type": "stream",
3991 | "text": [
3992 | "guess: The hiserans of cule or quite constent in regered to ther\n",
3993 | "truth: The historians of culture are quite consistent in regard to their\n",
3994 | "\n"
3995 | ]
3996 | },
3997 | {
3998 | "name": "stderr",
3999 | "output_type": "stream",
4000 | "text": [
4001 | "0.27: 96%|███████████████████████████████████▋ | 78/81 [10:25<00:23, 7.91s/it]"
4002 | ]
4003 | },
4004 | {
4005 | "name": "stdout",
4006 | "output_type": "stream",
4007 | "text": [
4008 | "\n",
4009 | "\n",
4010 | "input: pthy nd gtsm.\n"
4011 | ]
4012 | },
4013 | {
4014 | "name": "stderr",
4015 | "output_type": "stream",
4016 | "text": [
4017 | "0.27: 98%|████████████████████████████████████ | 79/81 [10:34<00:16, 8.13s/it]"
4018 | ]
4019 | },
4020 | {
4021 | "name": "stdout",
4022 | "output_type": "stream",
4023 | "text": [
4024 | "guess: pathy aned gots.\n",
4025 | "truth: apathy and egotism.\n",
4026 | "\n"
4027 | ]
4028 | },
4029 | {
4030 | "name": "stderr",
4031 | "output_type": "stream",
4032 | "text": [
4033 | "0.22: 98%|████████████████████████████████████ | 79/81 [10:34<00:16, 8.13s/it]"
4034 | ]
4035 | },
4036 | {
4037 | "name": "stdout",
4038 | "output_type": "stream",
4039 | "text": [
4040 | "\n",
4041 | "\n",
4042 | "input: Thy ll grw slnt. Th strs, s f knwng tht n n ws lkng\n"
4043 | ]
4044 | },
4045 | {
4046 | "name": "stderr",
4047 | "output_type": "stream",
4048 | "text": [
4049 | "0.19: 99%|████████████████████████████████████▌| 80/81 [10:42<00:08, 8.18s/it]"
4050 | ]
4051 | },
4052 | {
4053 | "name": "stdout",
4054 | "output_type": "stream",
4055 | "text": [
4056 | "guess: They al grow silent. The ster, so if kng t no no was loking\n",
4057 | "truth: They all grew silent. The stars, as if knowing that no one was looking\n",
4058 | "\n",
4059 | "\n",
4060 | "\n",
4061 | "input: mnths. Bsds tht, fr tms yr, n th nm dys nd brthdys\n"
4062 | ]
4063 | },
4064 | {
4065 | "name": "stderr",
4066 | "output_type": "stream",
4067 | "text": [
4068 | "0.19: 100%|█████████████████████████████████████| 81/81 [10:44<00:00, 7.96s/it]"
4069 | ]
4070 | },
4071 | {
4072 | "name": "stdout",
4073 | "output_type": "stream",
4074 | "text": [
4075 | "guess: months. Besid that, fer teme I your, in the name das aned brethodays\n",
4076 | "truth: months. Besides that, four times a year, on the name days and birthdays\n",
4077 | "\n",
4078 | "Epoch 0: train loss = 0.573853, test loss = 0.212381\n"
4079 | ]
4080 | },
4081 | {
4082 | "name": "stderr",
4083 | "output_type": "stream",
4084 | "text": [
4085 | "\n"
4086 | ]
4087 | }
4088 | ],
4089 | "source": [
4090 | "num_chars = len(string.printable)\n",
4091 | "model = Transducer(num_inputs=num_chars+1, num_outputs=num_chars+1)\n",
4092 | "trainer = Trainer(model=model, lr=0.0003)\n",
4093 | "\n",
4094 | "num_epochs = 1\n",
4095 | "train_losses=[]\n",
4096 | "test_losses=[]\n",
4097 | "\n",
4098 | "for epoch in range(num_epochs):\n",
4099 | " train_loss = trainer.train(train_set)\n",
4100 | " test_loss = trainer.test(test_set)\n",
4101 | " train_losses.append(train_loss)\n",
4102 | " test_losses.append(test_loss)\n",
4103 | " print(\"Epoch %d: train loss = %f, test loss = %f\" % (epoch, train_loss, test_loss))"
4104 | ]
4105 | },
4106 | {
4107 | "cell_type": "code",
4108 | "execution_count": 14,
4109 | "metadata": {
4110 | "colab": {
4111 | "base_uri": "https://localhost:8080/"
4112 | },
4113 | "id": "qRahAWPoubyu",
4114 | "outputId": "60fe7c8e-82e7-4a7e-fc2e-9596ee4ffeb3"
4115 | },
4116 | "outputs": [
4117 | {
4118 | "name": "stdout",
4119 | "output_type": "stream",
4120 | "text": [
4121 | "[0.573852581591926]\n",
4122 | "[0.2123808129541779]\n"
4123 | ]
4124 | }
4125 | ],
4126 | "source": [
4127 | "print(train_losses)\n",
4128 | "print(test_losses)"
4129 | ]
4130 | },
4131 | {
4132 | "cell_type": "markdown",
4133 | "metadata": {
4134 | "id": "rLQKw4kmFj3S"
4135 | },
4136 | "source": [
4137 | "Let's test the model on a new sentence:"
4138 | ]
4139 | },
4140 | {
4141 | "cell_type": "code",
4142 | "execution_count": 15,
4143 | "metadata": {
4144 | "colab": {
4145 | "base_uri": "https://localhost:8080/"
4146 | },
4147 | "id": "zhH5lYdyEazJ",
4148 | "outputId": "d7938f4f-0f91-477c-8d87-03163c2ed7bc"
4149 | },
4150 | "outputs": [
4151 | {
4152 | "name": "stdout",
4153 | "output_type": "stream",
4154 | "text": [
4155 | "input: Mst ppl hv lttl dffclty rdng ths sntnc\n",
4156 | "truth: Most people have little difficulty reading this sentence\n",
4157 | "guess: Mus pepe have litle difuly reading this sintenc\n",
4158 | "\n",
4159 | "NLL of truth: tensor(0.1458, device='cuda:0', grad_fn=)\n",
4160 | "NLL of guess: tensor(1.6807, device='cuda:0', grad_fn=)\n"
4161 | ]
4162 | }
4163 | ],
4164 | "source": [
4165 | "test_output = \"Most people have little difficulty reading this sentence\"\n",
4166 | "test_input = \"\".join(c for c in test_output if c not in \"AEIOUaeiou\")\n",
4167 | "print(\"input: \" + test_input)\n",
4168 | "x = torch.tensor(encode_string(test_input)).unsqueeze(0).to(model.device)\n",
4169 | "y = torch.tensor(encode_string(test_output)).unsqueeze(0).to(model.device)\n",
4170 | "T = torch.tensor([x.shape[1]]).to(model.device)\n",
4171 | "U = torch.tensor([y.shape[1]]).to(model.device)\n",
4172 | "guess = model.greedy_search(x,T)[0]\n",
4173 | "print(\"truth: \" + test_output)\n",
4174 | "print(\"guess: \" + decode_labels(guess))\n",
4175 | "print(\"\")\n",
4176 | "y_guess = torch.tensor(guess).unsqueeze(0).to(model.device)\n",
4177 | "U_guess = torch.tensor(len(guess)).unsqueeze(0).to(model.device)\n",
4178 | "\n",
4179 | "print(\"NLL of truth: \" + str(model.compute_loss(x, y, T, U)))\n",
4180 | "print(\"NLL of guess: \" + str(model.compute_loss(x, y_guess, T, U_guess)))"
4181 | ]
4182 | },
4183 | {
4184 | "cell_type": "markdown",
4185 | "metadata": {
4186 | "id": "ET__-ItZD8eA"
4187 | },
4188 | "source": [
4189 | "Observe that the negative log-likelihood of the guess is actually worse than that of the true label sequence (AKA, a \"[search error](https://www.aclweb.org/anthology/D19-1331.pdf)\"). This suggests that we could get better results using a beam search instead of the greedy search."
4190 | ]
4191 | }
4192 | ],
4193 | "metadata": {
4194 | "accelerator": "GPU",
4195 | "colab": {
4196 | "authorship_tag": "ABX9TyNpzayGZFacNsCxMByK+VUg",
4197 | "collapsed_sections": [],
4198 | "include_colab_link": true,
4199 | "name": "transducer-tutorial-example.ipynb",
4200 | "provenance": []
4201 | },
4202 | "kernelspec": {
4203 | "display_name": "Python 3 (ipykernel)",
4204 | "language": "python",
4205 | "name": "python3"
4206 | },
4207 | "language_info": {
4208 | "codemirror_mode": {
4209 | "name": "ipython",
4210 | "version": 3
4211 | },
4212 | "file_extension": ".py",
4213 | "mimetype": "text/x-python",
4214 | "name": "python",
4215 | "nbconvert_exporter": "python",
4216 | "pygments_lexer": "ipython3",
4217 | "version": "3.9.18"
4218 | }
4219 | },
4220 | "nbformat": 4,
4221 | "nbformat_minor": 4
4222 | }
4223 |
--------------------------------------------------------------------------------