├── InstructPix2Pix_using_diffusers.ipynb ├── README.md ├── image_2_image_using_diffusers.ipynb ├── in_painting_with_stable_diffusion_using_diffusers.ipynb ├── longformer_qa_training.ipynb └── onnx_t5.ipynb /README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /onnx_t5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.7.8" 21 | }, 22 | "colab": { 23 | "name": "onnx_t5.ipynb", 24 | "provenance": [] 25 | } 26 | }, 27 | "cells": [ 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "QMprDxVmbIGp" 32 | }, 33 | "source": [ 34 | "!pip install -qqq --upgrade torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html\n", 35 | "!pip install -qqq -U transformers onnxruntime>=1.4.0 onnxruntime-tools>=1.4.2 psutil" 36 | ], 37 | "execution_count": null, 38 | "outputs": [] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "utCSDezubIGt" 44 | }, 45 | "source": [ 46 | "import inspect\n", 47 | "import logging\n", 48 | "import os\n", 49 | "from pathlib import Path\n", 50 | "\n", 51 | "import torch\n", 52 | "from psutil import cpu_count\n", 53 | "from transformers import T5Config, T5ForConditionalGeneration, T5Tokenizer\n", 54 | "from transformers.generation_utils import GenerationMixin\n", 55 | "from transformers.modeling_outputs import BaseModelOutputWithPast, Seq2SeqLMOutput\n", 56 | "\n", 57 | "# Constants from the performance optimization available in onnxruntime\n", 58 | "# It needs to be done before importing onnxruntime\n", 59 | "os.environ[\"OMP_NUM_THREADS\"] = str(cpu_count(logical=True))\n", 60 | "os.environ[\"OMP_WAIT_POLICY\"] = \"ACTIVE\"\n", 61 | "\n", 62 | "from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers" 63 | ], 64 | "execution_count": 1, 65 | "outputs": [] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "metadata": { 70 | "id": "oW667kSxbIGx" 71 | }, 72 | "source": [ 73 | "logger = logging.getLogger(__name__)" 74 | ], 75 | "execution_count": 2, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "metadata": { 81 | "id": "nki_oYQHbIG2" 82 | }, 83 | "source": [ 84 | "class T5Encoder(torch.nn.Module):\n", 85 | " def __init__(self, encoder):\n", 86 | " super().__init__()\n", 87 | " self.encoder = encoder\n", 88 | "\n", 89 | " def forward(self, input_ids, attention_mask):\n", 90 | " return self.encoder(input_ids=input_ids, attention_mask=attention_mask)[0]\n", 91 | "\n", 92 | "\n", 93 | "class T5Decoder(torch.nn.Module):\n", 94 | " def __init__(self, decoder, config):\n", 95 | " super().__init__()\n", 96 | " self.decoder = decoder\n", 97 | " self.config = config\n", 98 | "\n", 99 | " def forward(self, input_ids, encoder_hidden_states, attention_mask=None, past_key_values=None):\n", 100 | " past_arg_key = (\n", 101 | " \"past_key_value_states\"\n", 102 | " if \"past_key_value_states\" in inspect.getfullargspec(self.decoder.forward).args\n", 103 | " else \"past_key_values\"\n", 104 | " )\n", 105 | " past_arg = {past_arg_key: past_key_values}\n", 106 | " decoder_output = self.decoder(\n", 107 | " input_ids=input_ids,\n", 108 | " encoder_attention_mask=attention_mask,\n", 109 | " encoder_hidden_states=encoder_hidden_states,\n", 110 | " use_cache=True,\n", 111 | " return_dict=True,\n", 112 | " **past_arg,\n", 113 | " )\n", 114 | " past_key_values = decoder_output.past_key_values\n", 115 | " sequence_output = decoder_output.last_hidden_state\n", 116 | " sequence_output = sequence_output * (self.config.d_model ** -0.5)\n", 117 | " return sequence_output, past_key_values\n", 118 | "\n", 119 | "\n", 120 | "class T5LMHead(torch.nn.Module):\n", 121 | " def __init__(self, lm_head):\n", 122 | " super().__init__()\n", 123 | " self.lm_head = lm_head\n", 124 | "\n", 125 | " def forward(self, decoder_output):\n", 126 | " return self.lm_head(decoder_output)\n" 127 | ], 128 | "execution_count": 3, 129 | "outputs": [] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "metadata": { 134 | "id": "IdKdzHn5bIG5" 135 | }, 136 | "source": [ 137 | "def create_t5_encoder_decoder(model=\"t5-base\"):\n", 138 | " \"\"\"Generates an encoder and a decoder model with a language model head from a pretrained huggingface model\n", 139 | " Args:\n", 140 | " model (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5\n", 141 | " Returns:\n", 142 | " t5_encoder: pytorch t5 encoder with a wrapper to output only the hidden states\n", 143 | " t5_decoder: pytorch t5 decoder with a language modeling head\n", 144 | " \"\"\"\n", 145 | "\n", 146 | " # T5 is an encoder / decoder model with a language modeling head on top.\n", 147 | " # We need to separate those out for efficient language generation\n", 148 | " if isinstance(model, str):\n", 149 | " model = T5ForConditionalGeneration.from_pretrained(model)\n", 150 | "\n", 151 | " encoder = model.encoder\n", 152 | " decoder = model.decoder\n", 153 | " lm_head = model.lm_head\n", 154 | "\n", 155 | " t5_encoder = T5Encoder(encoder).eval()\n", 156 | " t5_decoder = T5Decoder(decoder, model.config).eval()\n", 157 | " t5_lm_head = T5LMHead(lm_head).eval()\n", 158 | " return t5_encoder, t5_decoder, t5_lm_head" 159 | ], 160 | "execution_count": 4, 161 | "outputs": [] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "jfnVmeG5bIG8" 167 | }, 168 | "source": [ 169 | "def generate_onnx_representation(model, encoder_path, lm_path):\n", 170 | " \"\"\"Exports a given huggingface pretrained model, or a given model and tokenizer, to onnx\n", 171 | " Args:\n", 172 | " pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5\n", 173 | " output_prefix (str): Path to the onnx file\n", 174 | " \"\"\"\n", 175 | "\n", 176 | " simplified_encoder, decoder, lm_head = create_t5_encoder_decoder(model)\n", 177 | "\n", 178 | " # Example sequence\n", 179 | " tok = T5Tokenizer.from_pretrained(model)\n", 180 | " enc = tok(\"42 is the answer\", return_tensors=\"pt\")\n", 181 | " input_ids = enc[\"input_ids\"]\n", 182 | " attention_mask = enc[\"attention_mask\"]\n", 183 | "\n", 184 | " # Exports to ONNX\n", 185 | " _ = torch.onnx._export(\n", 186 | " simplified_encoder,\n", 187 | " (input_ids, attention_mask),\n", 188 | " encoder_path,\n", 189 | " export_params=True,\n", 190 | " opset_version=12,\n", 191 | " input_names=[\"input_ids\", \"attention_mask\"],\n", 192 | " output_names=[\"encoder_hidden_states\"],\n", 193 | " dynamic_axes={\n", 194 | " \"input_ids\": {0: \"batch\", 1: \"sequence\"},\n", 195 | " \"attention_mask\": {0: \"batch\", 1: \"sequence\"},\n", 196 | " \"encoder_hidden_states\": {0: \"batch\", 1: \"sequence\"},\n", 197 | " },\n", 198 | " )\n", 199 | "\n", 200 | " encoder_out = simplified_encoder(input_ids, attention_mask)\n", 201 | " decoder_out, _ = decoder(input_ids, encoder_out)\n", 202 | " _ = torch.onnx.export(\n", 203 | " lm_head,\n", 204 | " decoder_out,\n", 205 | " lm_path,\n", 206 | " export_params=True,\n", 207 | " opset_version=12,\n", 208 | " input_names=[\"decoder_output\"],\n", 209 | " output_names=[\"lm_logits\"],\n", 210 | " dynamic_axes={\n", 211 | " \"decoder_output\": {0: \"batch\", 1: \"sequence\"},\n", 212 | " \"lm_logits\": {0: \"batch\", 1: \"sequence\"},\n", 213 | " },\n", 214 | " )" 215 | ], 216 | "execution_count": 5, 217 | "outputs": [] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "metadata": { 222 | "id": "GAILXi1dbIHA" 223 | }, 224 | "source": [ 225 | "def create_model_for_provider(model_path: str, provider: str) -> InferenceSession:\n", 226 | "\n", 227 | " assert provider in get_all_providers(), f\"provider {provider} not found, {get_all_providers()}\"\n", 228 | "\n", 229 | " # Few properties that might have an impact on performances (provided by MS)\n", 230 | " options = SessionOptions()\n", 231 | " options.intra_op_num_threads = 1\n", 232 | " options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL\n", 233 | "\n", 234 | " # Load the model as a graph and prepare the CPU backend\n", 235 | " session = InferenceSession(model_path, options, providers=[provider])\n", 236 | " session.disable_fallback()\n", 237 | "\n", 238 | " return session" 239 | ], 240 | "execution_count": 6, 241 | "outputs": [] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "metadata": { 246 | "id": "8P0hx04UbIHD" 247 | }, 248 | "source": [ 249 | "class OnnxT5(GenerationMixin):\n", 250 | " def __init__(self, model_name_or_path, onnx_path):\n", 251 | "\n", 252 | " self.model_name_or_path = Path(model_name_or_path)\n", 253 | " self.onnx_path = Path(onnx_path)\n", 254 | " self.mode_base_name = self.model_name_or_path.stem\n", 255 | " self.encoder_path = self.onnx_path.joinpath(f\"{self.mode_base_name}_encoder.onnx\")\n", 256 | " self.lm_head_path = self.onnx_path.joinpath(f\"{self.mode_base_name}_lm_head.onnx\")\n", 257 | "\n", 258 | " if not (self.encoder_path.exists() and self.lm_head_path.exists()):\n", 259 | " self._export_onnx_graph()\n", 260 | "\n", 261 | " self.encoder_sess = create_model_for_provider(self.encoder_path.as_posix(), \"CPUExecutionProvider\")\n", 262 | " self.lm_sess = create_model_for_provider(self.lm_head_path.as_posix(), \"CPUExecutionProvider\")\n", 263 | "\n", 264 | " self.config = T5Config.from_pretrained(model_name_or_path)\n", 265 | " decoder = T5ForConditionalGeneration.from_pretrained(model_name_or_path).decoder\n", 266 | " self.decoder = T5Decoder(decoder, self.config).eval()\n", 267 | "\n", 268 | " self._warmup_onnx_graph()\n", 269 | "\n", 270 | " @torch.no_grad()\n", 271 | " def __call__(\n", 272 | " self,\n", 273 | " input_ids=None,\n", 274 | " attention_mask=None,\n", 275 | " decoder_input_ids=None,\n", 276 | " encoder_outputs=None,\n", 277 | " past_key_values=None,\n", 278 | " **kwargs,\n", 279 | " ):\n", 280 | " if input_ids is not None:\n", 281 | " return self._encoder_forward(input_ids=input_ids, attention_mask=attention_mask)\n", 282 | "\n", 283 | " decoder_output, past = self.decoder(decoder_input_ids, encoder_outputs, attention_mask, past_key_values)\n", 284 | "\n", 285 | " inputs = {\"decoder_output\": decoder_output.cpu().detach().numpy()}\n", 286 | " lm_logits = self.lm_sess.run(None, inputs)[0]\n", 287 | " lm_logits = torch.from_numpy(lm_logits)\n", 288 | " return Seq2SeqLMOutput(logits=lm_logits, past_key_values=past)\n", 289 | "\n", 290 | " def _encoder_forward(self, input_ids=None, attention_mask=None):\n", 291 | " inputs = {\n", 292 | " \"input_ids\": input_ids.cpu().detach().numpy(),\n", 293 | " \"attention_mask\": attention_mask.cpu().detach().numpy(),\n", 294 | " }\n", 295 | " last_hidden_state = self.encoder_sess.run(None, inputs)[0]\n", 296 | " last_hidden_state = torch.from_numpy(last_hidden_state)\n", 297 | " return BaseModelOutputWithPast(last_hidden_state=last_hidden_state)\n", 298 | "\n", 299 | " def get_encoder(self):\n", 300 | " return self\n", 301 | "\n", 302 | " def get_output_embeddings(self):\n", 303 | " return self\n", 304 | "\n", 305 | " def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs):\n", 306 | " if past is not None:\n", 307 | " input_ids = input_ids[:, -1:]\n", 308 | " return {\n", 309 | " \"decoder_input_ids\": input_ids,\n", 310 | " \"past_key_values\": past,\n", 311 | " \"encoder_outputs\": encoder_outputs.last_hidden_state,\n", 312 | " \"attention_mask\": attention_mask,\n", 313 | " \"use_cache\": True,\n", 314 | " }\n", 315 | "\n", 316 | " def parameters(self):\n", 317 | " return iter(torch.tensor([42, 42]))\n", 318 | "\n", 319 | " def _export_onnx_graph(self):\n", 320 | " self.onnx_path.mkdir(parents=True, exist_ok=True)\n", 321 | " generate_onnx_representation(\n", 322 | " self.model_name_or_path.as_posix(), self.encoder_path.as_posix(), self.lm_head_path.as_posix()\n", 323 | " )\n", 324 | "\n", 325 | " def _reorder_cache(self, past, beam_idx):\n", 326 | " # if decoder past is not included in output\n", 327 | " # speedy decoding is disabled and no need to reorder\n", 328 | " if past is None:\n", 329 | " logger.warning(\"You might want to consider setting `use_cache=True` to speed up decoding\")\n", 330 | " return past\n", 331 | "\n", 332 | " reordered_decoder_past = ()\n", 333 | " for layer_past_states in past:\n", 334 | " # get the correct batch idx from layer past batch dim\n", 335 | " # batch dim of `past` is at 2nd position\n", 336 | " reordered_layer_past_states = ()\n", 337 | " for layer_past_state in layer_past_states:\n", 338 | " # need to set correct `past` for each of the four key / value states\n", 339 | " reordered_layer_past_states = reordered_layer_past_states + (\n", 340 | " layer_past_state.index_select(0, beam_idx),\n", 341 | " )\n", 342 | "\n", 343 | " assert reordered_layer_past_states[0].shape == layer_past_states[0].shape\n", 344 | " assert len(reordered_layer_past_states) == len(layer_past_states)\n", 345 | "\n", 346 | " reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)\n", 347 | " return reordered_decoder_past\n", 348 | "\n", 349 | " def _warmup_onnx_graph(self):\n", 350 | " input_ids = torch.ones(1, 512, dtype=torch.long)\n", 351 | " attention_mask = torch.ones(1, 512, dtype=torch.long)\n", 352 | " for _ in range(10):\n", 353 | " encoder_outputs = self._encoder_forward(\n", 354 | " input_ids=input_ids, attention_mask=attention_mask\n", 355 | " ).last_hidden_state\n", 356 | "\n", 357 | " decoder_output, _ = self.decoder(input_ids, encoder_outputs, attention_mask)\n", 358 | " inputs = {\"decoder_output\": decoder_output.cpu().detach().numpy()}\n", 359 | " for _ in range(10):\n", 360 | " self.lm_sess.run(None, inputs)" 361 | ], 362 | "execution_count": 7, 363 | "outputs": [] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "metadata": { 368 | "id": "JazbVdT3bIHG" 369 | }, 370 | "source": [ 371 | "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM" 372 | ], 373 | "execution_count": 8, 374 | "outputs": [] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "metadata": { 379 | "id": "QAUi_LC-bIHJ" 380 | }, 381 | "source": [ 382 | "tokenizer = AutoTokenizer.from_pretrained(\"t5-small\")\n", 383 | "enc = tokenizer(\"translate English to French: 42 is the answer to life the universe and everything.\", return_tensors=\"pt\")" 384 | ], 385 | "execution_count": 9, 386 | "outputs": [] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "metadata": { 391 | "id": "sdgmJgsUbIHM" 392 | }, 393 | "source": [ 394 | "onnx_model = OnnxT5(model_name_or_path=\"t5-small\", onnx_path=\"onnx_models\")" 395 | ], 396 | "execution_count": 10, 397 | "outputs": [] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "metadata": { 402 | "id": "_Qz3W1W9bIHP", 403 | "outputId": "585e5ab0-3a10-47bb-fc90-e5985ae717a1", 404 | "colab": { 405 | "base_uri": "https://localhost:8080/" 406 | } 407 | }, 408 | "source": [ 409 | "tokens = onnx_model.generate(**enc, num_beams=2, use_cache=True) # same HF's generate method\n", 410 | "tokenizer.batch_decode(tokens)" 411 | ], 412 | "execution_count": 11, 413 | "outputs": [ 414 | { 415 | "output_type": "execute_result", 416 | "data": { 417 | "text/plain": [ 418 | "[\"42 est la réponse à la vie l'univers et tout.\"]" 419 | ] 420 | }, 421 | "metadata": { 422 | "tags": [] 423 | }, 424 | "execution_count": 11 425 | } 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "metadata": { 431 | "id": "2Fg5IQKybIHS" 432 | }, 433 | "source": [ 434 | "### Benchmark" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "metadata": { 440 | "id": "oGI2fDlMbIHT" 441 | }, 442 | "source": [ 443 | "%matplotlib inline\n", 444 | "\n", 445 | "import matplotlib\n", 446 | "import matplotlib.pyplot as plt\n", 447 | "import seaborn as sns\n", 448 | "sns.set()\n", 449 | "import numpy as np\n", 450 | "import os\n", 451 | "\n", 452 | "\n", 453 | "def plot_benchmark(results):\n", 454 | " # Compute average inference time + std\n", 455 | " time_results = {k: np.mean(v.model_inference_time) * 1e3 for k, v in results.items()}\n", 456 | " time_results_std = np.std([v.model_inference_time for v in results.values()]) * 1000\n", 457 | "\n", 458 | " plt.rcdefaults()\n", 459 | " fig, ax = plt.subplots(figsize=(10, 8))\n", 460 | " ax.set_ylabel(\"Avg Inference time (ms)\")\n", 461 | " ax.set_title(\"Average inference time (ms) for each provider\")\n", 462 | " ax.bar(time_results.keys(), time_results.values(), yerr=time_results_std)\n", 463 | " plt.show()\n", 464 | "\n", 465 | "from contextlib import contextmanager\n", 466 | "from dataclasses import dataclass\n", 467 | "from time import time\n", 468 | "from tqdm import trange\n", 469 | "\n", 470 | "@contextmanager\n", 471 | "def track_infer_time(buffer: [int]):\n", 472 | " start = time()\n", 473 | " yield\n", 474 | " end = time()\n", 475 | "\n", 476 | " buffer.append(end - start)\n", 477 | "\n", 478 | "@dataclass\n", 479 | "class OnnxInferenceResult:\n", 480 | " model_inference_time: [int] \n", 481 | " optimized_model_path: str" 482 | ], 483 | "execution_count": null, 484 | "outputs": [] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "metadata": { 489 | "id": "Vlm6I33CbIHW" 490 | }, 491 | "source": [ 492 | "torch_model = AutoModelForSeq2SeqLM.from_pretrained(\"t5-small\")" 493 | ], 494 | "execution_count": null, 495 | "outputs": [] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "metadata": { 500 | "id": "GlVF-M-nbIHZ" 501 | }, 502 | "source": [ 503 | "ARTICLE_SUBWAY = 'summarize: New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared \"I do\" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her \"first and only\" marriage. Barrientos, now 39, is facing two criminal counts of \"offering a false instrument for filing in the first degree,\" referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\\'s Investigation Division. Seven of the men are from so-called \"red-flagged\" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'\n", 504 | "enc = tokenizer(ARTICLE_SUBWAY, truncation=True, return_tensors=\"pt\")" 505 | ], 506 | "execution_count": null, 507 | "outputs": [] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "metadata": { 512 | "id": "l-o34YgWbIHc", 513 | "outputId": "eb9f3d82-6395-46bf-c148-78abb5415c8e" 514 | }, 515 | "source": [ 516 | "models = [(\"torch\", torch_model), (\"onnx\", onnx_model)]\n", 517 | "results = {}\n", 518 | "for label, model in models:\n", 519 | " # Compute \n", 520 | " time_buffer = []\n", 521 | " for _ in trange(100, desc=f\"Tracking inference time for {label}\"):\n", 522 | " with track_infer_time(time_buffer):\n", 523 | " model.generate(\n", 524 | " **enc,\n", 525 | " num_beams=4,\n", 526 | " length_penalty=2.0,\n", 527 | " max_length=142,\n", 528 | " min_length=56,\n", 529 | " no_repeat_ngram_size=3,\n", 530 | " do_sample=False,\n", 531 | " early_stopping=True\n", 532 | " )\n", 533 | "\n", 534 | " # Store the result\n", 535 | " results[label] = OnnxInferenceResult(\n", 536 | " time_buffer, \n", 537 | " None\n", 538 | " )" 539 | ], 540 | "execution_count": null, 541 | "outputs": [ 542 | { 543 | "output_type": "stream", 544 | "text": [ 545 | "Tracking inference time for torch: 100%|██████████| 100/100 [07:56<00:00, 4.76s/it]\n", 546 | "Tracking inference time for onnx: 0%| | 0/100 [00:00" 570 | ] 571 | }, 572 | "metadata": { 573 | "tags": [] 574 | } 575 | } 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "metadata": { 581 | "id": "4UoA7UYdbIHh" 582 | }, 583 | "source": [ 584 | "en_text = 'translate English to French: This image section from an infrared recording by the Spitzer telescope shows a \"family portrait\" of countless generations of stars: the oldest stars are seen as blue dots. '\n", 585 | "enc = tokenizer(en_text, truncation=True, return_tensors=\"pt\")" 586 | ], 587 | "execution_count": null, 588 | "outputs": [] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "metadata": { 593 | "id": "UwDhA-jEbIHm", 594 | "outputId": "a90a4e2e-6d83-4c55-b7af-87af14cf60d6" 595 | }, 596 | "source": [ 597 | "results = {}\n", 598 | "for label, model in models:\n", 599 | " # Compute \n", 600 | " time_buffer = []\n", 601 | " for _ in trange(100, desc=f\"Tracking inference time for {label}\"):\n", 602 | " with track_infer_time(time_buffer):\n", 603 | " model.generate(\n", 604 | " **enc,\n", 605 | " num_beams=4,\n", 606 | " length_penalty=2.0,\n", 607 | " max_length=100,\n", 608 | " no_repeat_ngram_size=3,\n", 609 | " do_sample=False,\n", 610 | " early_stopping=True,\n", 611 | " )\n", 612 | "\n", 613 | " # Store the result\n", 614 | " results[label] = OnnxInferenceResult(\n", 615 | " time_buffer, \n", 616 | " None\n", 617 | " )" 618 | ], 619 | "execution_count": null, 620 | "outputs": [ 621 | { 622 | "output_type": "stream", 623 | "text": [ 624 | "Tracking inference time for torch: 100%|██████████| 100/100 [04:40<00:00, 2.81s/it]\n", 625 | "Tracking inference time for onnx: 0%| | 0/100 [00:00" 649 | ] 650 | }, 651 | "metadata": { 652 | "tags": [] 653 | } 654 | } 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "metadata": { 660 | "id": "gwmFWO25bIHt" 661 | }, 662 | "source": [ 663 | "" 664 | ], 665 | "execution_count": null, 666 | "outputs": [] 667 | } 668 | ] 669 | } --------------------------------------------------------------------------------