├── README.md ├── batching_experiment.py ├── dataset_utils.py ├── environment.yml ├── figures ├── ablation_speed_up_over_batch_size_both_models.pdf ├── alldataset_comparison.pdf ├── batch_size_figure.pdf ├── dataset_level_ablation.pdf ├── prepacking_gif_final.gif └── speedup_gain16llama1b5000.pdf ├── model.py ├── prepack_generation_demo.ipynb ├── processor.py ├── profiling_dataset_level_prepacking.py ├── profiling_time_and_memory.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Prepacking: A Simple Method for Fast Prefilling and Increased Throughput in Large Language Models 2 | 3 |

4 | Prepacking Demo 5 |

6 | 7 | 8 | This repository contains the source code of the following paper: 9 | 10 | > **"Prepacking: A Simple Method for Fast Prefilling and Increased Throughput in Large Language Models"**
11 | > Siyan Zhao, Daniel Israel, Guy Van den Broeck, Aditya Grover
12 | > 13 | > **Abstract:** *During inference for transformer-based large language models (LLM), prefilling is the computation of the key-value (KV) cache for input tokens in the prompt prior to autoregressive generation. For longer input prompt lengths, prefilling will incur a significant overhead on decoding time. In this work, we highlight the following pitfall of prefilling: for batches containing high-varying prompt lengths, significant computation is wasted by the standard practice of padding sequences to the maximum length. As LLMs increasingly support longer context lengths, potentially up to 10 million tokens, variations in prompt lengths within a batch become more pronounced. To address this, we propose Prepacking, a simple yet effective method to optimize prefilling computation. To avoid redundant computation on pad tokens, prepacking combines prompts of varying lengths into a sequence and packs multiple sequences into a compact batch using a bin-packing algorithm. It then modifies the attention mask and positional encoding to compute multiple prefilled KV-caches for multiple prompts within a single sequence. On standard curated dataset containing prompts with varying lengths, we obtain a significant speed and memory efficiency improvements as compared to the default padding-based prefilling computation within Huggingface across a range of base model configurations and inference serving scenarios.* 14 | 15 | 16 | [[Paper]](https://arxiv.org/abs/2404.09529) 17 | 18 | ## Setup Environment 19 | ### Clone the repository 20 | ```bash 21 | git clone https://github.com/siyan-zhao/prepacking.git 22 | cd prepacking 23 | ``` 24 | 25 | ### Conda Setup 26 | ``` 27 | conda env create -f environment.yml 28 | conda activate prepack 29 | ``` 30 | 31 | ## Profile Speed and Memory 32 | 33 | 34 | ### Profile Prefill or Time to First Token (TTFT) Time and Compare Peak GPU Memory and Utilization 35 | 36 | ``` 37 | CUDA_VISIBLE_DEVICES=0 python profiling_time_and_memory.py --metric=prefill --dataset=mmlu --batch_size=64 --model_name=llama1b --num_runs=5 38 | ``` 39 | 40 | Example output when profiled on a single 48GB NVIDIA A6000 GPU: 41 | 42 | | Method | Avg prefill Time /batch (s) | Max GPU Utilization (%) | Max GPU Memory (MB) | Mean GPU Utilization (%) | Std Dev Time (s) | Std Dev Max GPU Util (%) | Std Dev Mean GPU Util (%) | 43 | |----------------|-----------------------------|-------------------------|---------------------|--------------------------|------------------|--------------------------|---------------------------| 44 | | prepacking | 0.441 | 100.000 | 4578.328 | 91.156 | 0.347 | 0.000 | 7.966 | 45 | | full-batching | 2.299 | 100.000 | 34599.695 | 99.719 | 1.741 | 0.000 | 0.223 | 46 | | length-ordered | 0.658 | 100.000 | 22950.019 | 97.865 | 0.815 | 0.000 | 3.236 | 47 | 48 | ### Compare Per Prompt Inference Prefill Time Including Dataset Prepacking 49 | 50 | ``` 51 | CUDA_VISIBLE_DEVICES=0 python profiling_dataset_level_prepacking.py --metric=prefill --model_name=llama1b --batch_size=32 --loadbit=8 --dataset=mmlu 52 | ``` 53 | 54 | ## Play with Prepacking Generation 55 | 56 | A Colab example of using prepacking for generation. Compare it against default generation yourself.

57 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/siyan-zhao/prepacking/blob/main/prepack_generation_demo.ipynb) 58 | 59 | ## Reference 60 | If you find our work useful, please consider citing our [paper](https://arxiv.org/abs/2404.09529): 61 | 62 | ``` 63 | @misc{zhao2024prepacking, 64 | title={Prepacking: A Simple Method for Fast Prefilling and Increased Throughput in Large Language Models}, 65 | author={Siyan Zhao and Daniel Israel and Guy Van den Broeck and Aditya Grover}, 66 | year={2024}, 67 | eprint={2404.09529}, 68 | archivePrefix={arXiv}, 69 | primaryClass={cs.LG} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /batching_experiment.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaForCausalLM 2 | import torch 3 | import time 4 | import numpy as np 5 | import pickle 6 | import matplotlib.pyplot as plt 7 | from tqdm import tqdm 8 | 9 | 10 | torch.set_num_threads(1) 11 | 12 | 13 | def run_batching_experiment(model, batch_size, length): 14 | model.eval() 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | scenario_times = [] 17 | for _ in range(100): 18 | data = torch.randint(0, 32000, (batch_size, length)).to(device) 19 | with torch.no_grad(): 20 | start_time = time.time() 21 | _ = model(input_ids=data) 22 | elapsed = time.time() - start_time 23 | scenario_times.append(elapsed) 24 | avg_scenario_time = np.mean(scenario_times) 25 | std_dev = np.std(scenario_times) 26 | return avg_scenario_time, std_dev 27 | 28 | 29 | if __name__ == "__main__": 30 | model_path = "princeton-nlp/Sheared-LLaMA-1.3B" 31 | model = LlamaForCausalLM.from_pretrained( 32 | model_path, 33 | load_in_4bit=True, 34 | bnb_4bit_compute_dtype=torch.float16, 35 | low_cpu_mem_usage=True, 36 | device_map="auto", 37 | ) 38 | 39 | LENGTHS = [50, 100, 200, 400] 40 | BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64] 41 | data = {} 42 | for length in tqdm(LENGTHS): 43 | data[length] = [run_batching_experiment(model, bs, length) for bs in BATCH_SIZES] 44 | print(f"Length: {length}", data[length]) 45 | 46 | for length in LENGTHS: 47 | d = data[length] 48 | plt.plot(BATCH_SIZES, [d[0] for d in data[length]], label=f"D={length}") 49 | plt.fill_between( 50 | BATCH_SIZES, [d[0] - d[1] for d in data[length]], [d[0] + d[1] for d in data[length]], alpha=0.2 51 | ) 52 | plt.legend() 53 | plt.xlabel("Batch size") 54 | plt.xscale("log", base=2) 55 | plt.ylabel("TTFT (s)") 56 | plt.savefig("batch_size_figure.png") 57 | -------------------------------------------------------------------------------- /dataset_utils.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from torch.utils.data import DataLoader, Dataset 3 | import random 4 | import torch 5 | from utils import left_pad_sequence 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from transformers.trainer_utils import set_seed 9 | 10 | SEED = 41 11 | set_seed(SEED) 12 | 13 | 14 | class HFDataset(Dataset): 15 | def __init__(self, texts): 16 | self.texts = texts 17 | 18 | def __len__(self): 19 | return len(self.texts) 20 | 21 | def __getitem__(self, idx): 22 | return self.texts[idx] 23 | 24 | 25 | def load_and_prepare_data( 26 | dataset_name: str, 27 | config_name: str, 28 | split: str, 29 | tokenizer, 30 | sample_size: int = 1000, 31 | max_length: int = None, 32 | ): 33 | dataset = load_dataset(dataset_name, config_name, split=split) 34 | if dataset_name == "wikitext": 35 | texts = dataset["text"] 36 | elif "mmlu" in dataset_name: 37 | texts = dataset["question"] 38 | elif "rlhf" in dataset_name: 39 | texts = dataset["chosen"] 40 | elif "alpaca" in dataset_name: 41 | texts = dataset["text"] 42 | elif "samsum" in dataset_name: 43 | texts = dataset["dialogue"] 44 | print(len(texts), "texts loaded", sample_size) 45 | texts = [text for text in texts if len(text.split()) > 1] 46 | texts = random.sample(texts, sample_size) 47 | 48 | if max_length: 49 | tokenized_texts = [ 50 | tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt") for text in texts 51 | ] 52 | texts = [ 53 | tokenizer.decode(tokens.input_ids[0], skip_special_tokens=True) for tokens in tokenized_texts 54 | ] 55 | return texts 56 | 57 | 58 | def load_and_evaluate_dataset(dataset: str, tokenizer): 59 | dataset_config = { 60 | "mmlu": ("cais/mmlu", "all", None, 1000, "test"), 61 | "wikitext512": ("wikitext", "wikitext-2-raw-v1", 512, 1000, "test"), 62 | "wikitext256": ("wikitext", "wikitext-2-raw-v1", 256, 1000, "test"), 63 | "rlhf": ("Anthropic/hh-rlhf", "default", None, 1000, "train"), 64 | "alpaca": ("tatsu-lab/alpaca", "default", None, 1000, "train"), 65 | "samsum": ("samsum", "samsum", None, 1000, "train"), 66 | } 67 | 68 | if dataset not in dataset_config: 69 | raise ValueError(f"Unsupported dataset: {dataset}") 70 | 71 | dataset_name, config_name, max_length, sample_size, split = dataset_config[dataset] 72 | 73 | # Load dataset 74 | texts = load_and_prepare_data( 75 | dataset_name=dataset_name, 76 | config_name=config_name, 77 | split=split, 78 | tokenizer=tokenizer, 79 | sample_size=sample_size, 80 | max_length=max_length, 81 | ) 82 | 83 | # Evaluate sentences from the dataset 84 | print(f"Evaluating {len(texts)} sentences from the dataset {dataset_name}") 85 | lengths = [len(tokenizer(text).input_ids) for text in texts] 86 | print("Mean length:", np.mean(lengths)) 87 | print("Max length:", np.max(lengths)) 88 | print("Min length:", np.min(lengths)) 89 | 90 | # Plotting the length distribution 91 | plt.figure(figsize=(10, 6)) 92 | plt.hist(lengths, bins=30, color="skyblue", edgecolor="black") 93 | plt.title(f"Length Distribution of Texts in {dataset_name}") 94 | plt.xlabel("Length of Texts in Tokens") 95 | plt.ylabel("Frequency") 96 | plt.savefig(f"length_distribution_{dataset}.png") 97 | plt.close() 98 | 99 | return texts 100 | 101 | 102 | def sample_batches(texts, batch_size): 103 | random.shuffle(texts) 104 | for i in range(0, len(texts), batch_size): 105 | yield texts[i : i + batch_size] 106 | 107 | 108 | def sample_batches_deterministic(texts, batch_size): 109 | for i in range(0, len(texts), batch_size): 110 | yield texts[i : i + batch_size] 111 | 112 | 113 | def sample_batches_by_length(texts, batch_size): 114 | # Sort texts by their length 115 | texts_sorted = sorted(texts, key=lambda x: len(x)) 116 | # Yield batches of similar length 117 | for i in range(0, len(texts_sorted), batch_size): 118 | yield texts_sorted[i : i + batch_size] 119 | 120 | 121 | def sample_packed_dataset(dataset): 122 | # Assumption is dataset is already shuffled and batched 123 | for i in range(len(dataset)): 124 | yield dataset[i] 125 | 126 | 127 | def unpack_kv(packed_outputs, restart_dict, original_ids, device): 128 | 129 | batch_size = sum(map(len, restart_dict)) - len(restart_dict) 130 | 131 | dim1, dim2 = len(packed_outputs), len(packed_outputs[0]) 132 | 133 | save_cache = [[None for _ in range(dim2)] for _ in range(dim1)] 134 | batch_length = [len(ids) - 1 for ids in original_ids] 135 | compute = True 136 | 137 | attention_masks = torch.zeros((batch_size, max(batch_length) + 1), dtype=torch.int, device=device) 138 | final_tokens = torch.empty(batch_size, dtype=torch.int, device=device) 139 | 140 | for j in range(dim1): # layer 141 | for k in range(dim2): # k, v 142 | batch_cache = np.empty(batch_size, dtype=object) 143 | 144 | for b, batch in enumerate(restart_dict): 145 | batch_indices = list(batch.keys()) 146 | for i in range(len(batch) - 1): 147 | c = packed_outputs[j][k][b, :, batch_indices[i] : batch_indices[i + 1], :].permute( 148 | 1, 0, 2 149 | ) 150 | original_index = restart_dict[b][batch_indices[i + 1]] 151 | batch_cache[original_index] = c 152 | if compute: 153 | prompt = original_ids[batch[batch_indices[i + 1]]] 154 | final_tokens[original_index] = prompt[-1] 155 | attention_masks[original_index, -(batch_length[original_index]) - 1 :] = 1 156 | compute = False 157 | padded = left_pad_sequence(batch_cache, batch_first=True, padding_value=0).permute(0, 2, 1, 3) 158 | save_cache[j][k] = padded 159 | 160 | return save_cache, final_tokens.unsqueeze(dim=-1), attention_masks 161 | 162 | 163 | class PackedDataset(Dataset): 164 | def __init__(self, new_tokens, new_positions, new_mask, restart_indices, original_ids, batch_size): 165 | self.tensor_tuple = (new_tokens, new_positions, new_mask) 166 | self.restart_indices = restart_indices 167 | self.original_ids = original_ids 168 | self.initial_processing(batch_size) 169 | 170 | def __len__(self): 171 | return len(self.batch_indices) 172 | 173 | def __getitem__(self, idx): 174 | batch_idx = self.batch_indices[idx] 175 | tensors = tuple(tensor[batch_idx] for tensor in self.tensor_tuple) 176 | restart_idx = self.restart_indices[idx] 177 | original_id = self.original_ids[idx] 178 | return tensors + (restart_idx,) + (original_id,) 179 | 180 | def initial_processing(self, batch_size): 181 | 182 | num_samples = len(self.tensor_tuple[0]) 183 | indices = list(range(num_samples)) 184 | random.shuffle(indices) 185 | batch_indices = [indices[i : i + batch_size] for i in range(0, num_samples, batch_size)] 186 | 187 | # Purpose of below code is that with batching, restart indices and original ids must be reformatted 188 | # for consistency with existing unpack_kv function 189 | batched_restart_indices = [] 190 | for idx in batch_indices: 191 | batch_restart_indices = list(map(self.restart_indices.__getitem__, idx)) 192 | batched_restart_indices.append(batch_restart_indices) 193 | 194 | new_original_ids = [] 195 | for idx in range(len(batch_indices)): 196 | original_id = [] 197 | i = 0 198 | restart_idx = batched_restart_indices[idx] 199 | for d in restart_idx: 200 | for key in d: 201 | value = d[key] 202 | if value != -1: 203 | original_id.append(self.original_ids[value]) 204 | d[key] = i 205 | i += 1 206 | new_original_ids.append(original_id) 207 | 208 | self.original_ids = new_original_ids 209 | self.restart_indices = batched_restart_indices 210 | self.batch_indices = batch_indices 211 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: prepack 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2024.3.11=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.3=he6710b0_2 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=1.1.1w=h7f8727e_0 15 | - pip=23.3.1=py39h06a4308_0 16 | - python=3.9.7=h12debd9_1 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.2.2=py39h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.41.2=py39h06a4308_0 22 | - xz=5.4.6=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - absl-py==2.1.0 26 | - accelerate==0.29.2 27 | - aiohttp==3.9.4 28 | - aiosignal==1.3.1 29 | - async-timeout==4.0.3 30 | - attrs==23.2.0 31 | - binpacking==1.5.2 32 | - bitsandbytes==0.41.1 33 | - certifi==2024.2.2 34 | - charset-normalizer==3.3.2 35 | - cmake==3.29.2 36 | - contourpy==1.2.1 37 | - cycler==0.12.1 38 | - datasets==2.18.0 39 | - dill==0.3.8 40 | - filelock==3.13.4 41 | - fire==0.6.0 42 | - fonttools==4.51.0 43 | - frozenlist==1.4.1 44 | - fsspec==2024.2.0 45 | - future==1.0.0 46 | - gputil==1.4.0 47 | - huggingface-hub==0.22.2 48 | - idna==3.7 49 | - immutabledict==4.2.0 50 | - importlib-resources==6.4.0 51 | - jinja2==3.1.3 52 | - kiwisolver==1.4.5 53 | - lit==18.1.3 54 | - markupsafe==2.1.5 55 | - matplotlib==3.8.3 56 | - mpmath==1.3.0 57 | - multidict==6.0.5 58 | - multiprocess==0.70.16 59 | - networkx==3.2.1 60 | - numpy==1.25.2 61 | - nvidia-cublas-cu11==11.10.3.66 62 | - nvidia-cuda-cupti-cu11==11.7.101 63 | - nvidia-cuda-nvrtc-cu11==11.7.99 64 | - nvidia-cuda-runtime-cu11==11.7.99 65 | - nvidia-cudnn-cu11==8.5.0.96 66 | - nvidia-cufft-cu11==10.9.0.58 67 | - nvidia-curand-cu11==10.2.10.91 68 | - nvidia-cusolver-cu11==11.4.0.1 69 | - nvidia-cusparse-cu11==11.7.4.91 70 | - nvidia-nccl-cu11==2.14.3 71 | - nvidia-nvtx-cu11==11.7.91 72 | - ortools==9.9.3963 73 | - packaging==24.0 74 | - pandas==2.2.2 75 | - pillow==10.3.0 76 | - prettytable==3.9.0 77 | - protobuf==5.26.1 78 | - psutil==5.9.8 79 | - pyarrow==15.0.2 80 | - pyarrow-hotfix==0.6 81 | - pyparsing==3.1.2 82 | - python-dateutil==2.9.0.post0 83 | - pytz==2024.1 84 | - pyyaml==6.0.1 85 | - regex==2023.12.25 86 | - requests==2.31.0 87 | - safetensors==0.4.2 88 | - scipy==1.11.2 89 | - six==1.16.0 90 | - sympy==1.12 91 | - termcolor==2.4.0 92 | - tokenizers==0.15.2 93 | - torch==2.0.1 94 | - torchaudio==2.0.2 95 | - torchvision==0.15.2 96 | - tqdm==4.66.2 97 | - transformers==4.38.2 98 | - triton==2.0.0 99 | - typing-extensions==4.11.0 100 | - tzdata==2024.1 101 | - urllib3==2.2.1 102 | - wcwidth==0.2.13 103 | - xxhash==3.4.1 104 | - yarl==1.9.4 105 | - zipp==3.18.1 106 | prefix: /home/siyanz/miniconda3/envs/prepack1 107 | -------------------------------------------------------------------------------- /figures/ablation_speed_up_over_batch_size_both_models.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/ablation_speed_up_over_batch_size_both_models.pdf -------------------------------------------------------------------------------- /figures/alldataset_comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/alldataset_comparison.pdf -------------------------------------------------------------------------------- /figures/batch_size_figure.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/batch_size_figure.pdf -------------------------------------------------------------------------------- /figures/dataset_level_ablation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/dataset_level_ablation.pdf -------------------------------------------------------------------------------- /figures/prepacking_gif_final.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/prepacking_gif_final.gif -------------------------------------------------------------------------------- /figures/speedup_gain16llama1b5000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/speedup_gain16llama1b5000.pdf -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import transformers 4 | from transformers.modeling_outputs import BaseModelOutputWithPast 5 | 6 | from typing import List, Optional, Tuple, Union 7 | from transformers.generation.utils import GenerationMixin 8 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa 9 | from transformers.cache_utils import Cache, DynamicCache 10 | from transformers.utils import logging 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | # Custom model to allow passing of 2D attention masks 15 | class CustomLlamaModel(transformers.LlamaModel): 16 | def __init__(self, config): 17 | super().__init__(config) 18 | 19 | def _update_causal_mask(self, attention_mask, input_tensor): 20 | if self.config._attn_implementation == "flash_attention_2": 21 | if attention_mask is not None and 0.0 in attention_mask: 22 | return attention_mask 23 | return None 24 | 25 | batch_size, seq_length = input_tensor.shape[:2] 26 | dtype = input_tensor.dtype 27 | device = input_tensor.device 28 | 29 | # support going beyond cached `max_position_embedding` 30 | if seq_length > self.causal_mask.shape[-1]: 31 | causal_mask = torch.full( 32 | (2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), 33 | fill_value=1, 34 | ) 35 | self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) 36 | 37 | # We use the current dtype to avoid any overflows 38 | min_dtype = torch.finfo(dtype).min 39 | causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype 40 | 41 | causal_mask = causal_mask.to(dtype=dtype, device=device) 42 | if attention_mask is not None and attention_mask.dim() == 2: 43 | mask_length = attention_mask.shape[-1] 44 | padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) 45 | causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( 46 | padding_mask, min_dtype 47 | ) 48 | 49 | if attention_mask is not None and attention_mask.dim() == 3: 50 | attention_mask = attention_mask.to(dtype=dtype, device=device) 51 | mask_length = attention_mask.shape[-1] 52 | padding_mask = causal_mask[..., :mask_length, :mask_length].eq(0.0) * attention_mask[ 53 | :, None, :, : 54 | ].eq(0.0) 55 | causal_mask[..., :mask_length, :mask_length] = causal_mask[ 56 | ..., :mask_length, :mask_length 57 | ].masked_fill(padding_mask, min_dtype) 58 | 59 | if self.config._attn_implementation == "sdpa" and attention_mask is not None: 60 | # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). 61 | is_tracing = ( 62 | torch.jit.is_tracing() 63 | or isinstance(input_tensor, torch.fx.Proxy) 64 | or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) 65 | ) 66 | if not is_tracing and torch.any(attention_mask != 1): 67 | # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when 68 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 69 | # Details: https://github.com/pytorch/pytorch/issues/110213 70 | causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to( 71 | dtype 72 | ) 73 | 74 | return causal_mask 75 | 76 | 77 | class CustomCausalLlamaModel(transformers.LlamaForCausalLM, GenerationMixin): 78 | def __init__(self, config): 79 | super().__init__(config) 80 | self.model = CustomLlamaModel(config) 81 | 82 | 83 | class CustomMistralModel(transformers.MistralModel): 84 | def __init__(self, config): 85 | super().__init__(config) 86 | 87 | # @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) 88 | def forward( 89 | self, 90 | input_ids: torch.LongTensor = None, 91 | attention_mask: Optional[torch.Tensor] = None, 92 | position_ids: Optional[torch.LongTensor] = None, 93 | past_key_values: Optional[List[torch.FloatTensor]] = None, 94 | inputs_embeds: Optional[torch.FloatTensor] = None, 95 | use_cache: Optional[bool] = None, 96 | output_attentions: Optional[bool] = None, 97 | output_hidden_states: Optional[bool] = None, 98 | return_dict: Optional[bool] = None, 99 | ) -> Union[Tuple, BaseModelOutputWithPast]: 100 | output_attentions = ( 101 | output_attentions if output_attentions is not None else self.config.output_attentions 102 | ) 103 | output_hidden_states = ( 104 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 105 | ) 106 | use_cache = use_cache if use_cache is not None else self.config.use_cache 107 | 108 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 109 | 110 | # retrieve input_ids and inputs_embeds 111 | if input_ids is not None and inputs_embeds is not None: 112 | raise ValueError( 113 | "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" 114 | ) 115 | elif input_ids is not None: 116 | batch_size, seq_length = input_ids.shape 117 | elif inputs_embeds is not None: 118 | batch_size, seq_length, _ = inputs_embeds.shape 119 | else: 120 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 121 | 122 | if self.gradient_checkpointing and self.training: 123 | if use_cache: 124 | logger.warning_once( 125 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 126 | ) 127 | use_cache = False 128 | 129 | past_key_values_length = 0 130 | 131 | if use_cache: 132 | use_legacy_cache = not isinstance(past_key_values, Cache) 133 | if use_legacy_cache: 134 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 135 | past_key_values_length = past_key_values.get_usable_length(seq_length) 136 | 137 | if position_ids is None: 138 | device = input_ids.device if input_ids is not None else inputs_embeds.device 139 | position_ids = torch.arange( 140 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 141 | ) 142 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 143 | else: 144 | position_ids = position_ids.view(-1, seq_length).long() 145 | 146 | if inputs_embeds is None: 147 | inputs_embeds = self.embed_tokens(input_ids) 148 | 149 | if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: 150 | is_padding_right = attention_mask[:, -1].sum().item() != batch_size 151 | if is_padding_right: 152 | raise ValueError( 153 | "You are attempting to perform batched generation with padding_side='right'" 154 | " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " 155 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. " 156 | ) 157 | 158 | if self._attn_implementation == "flash_attention_2": 159 | # 2d mask is passed through the layers 160 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 161 | elif self._attn_implementation == "sdpa" and not output_attentions: 162 | # output_attentions=True can not be supported when using SDPA, and we fall back on 163 | # the manual implementation that requires a 4D causal mask in all cases. 164 | attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( 165 | attention_mask, 166 | (batch_size, seq_length), 167 | inputs_embeds, 168 | past_key_values_length, 169 | ) 170 | else: 171 | if attention_mask.shape[1] == 1: # for use cache 172 | attention_mask = attention_mask.view(batch_size, -1) 173 | elif attention_mask.dim() == 3: 174 | attention_mask = attention_mask.view(batch_size, 1, seq_length, seq_length) 175 | # 4d mask is passed through the layers 176 | attention_mask = _prepare_4d_causal_attention_mask( 177 | attention_mask, 178 | (batch_size, seq_length), 179 | inputs_embeds, 180 | past_key_values_length, 181 | sliding_window=self.config.sliding_window, 182 | ) 183 | hidden_states = inputs_embeds 184 | 185 | # decoder layers 186 | all_hidden_states = () if output_hidden_states else None 187 | all_self_attns = () if output_attentions else None 188 | next_decoder_cache = None 189 | 190 | for decoder_layer in self.layers: 191 | if output_hidden_states: 192 | all_hidden_states += (hidden_states,) 193 | 194 | if self.gradient_checkpointing and self.training: 195 | layer_outputs = self._gradient_checkpointing_func( 196 | decoder_layer.__call__, 197 | hidden_states, 198 | attention_mask, 199 | position_ids, 200 | past_key_values, 201 | output_attentions, 202 | use_cache, 203 | ) 204 | else: 205 | layer_outputs = decoder_layer( 206 | hidden_states, 207 | attention_mask=attention_mask, 208 | position_ids=position_ids, 209 | past_key_value=past_key_values, 210 | output_attentions=output_attentions, 211 | use_cache=use_cache, 212 | ) 213 | 214 | hidden_states = layer_outputs[0] 215 | 216 | if use_cache: 217 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 218 | 219 | if output_attentions: 220 | all_self_attns += (layer_outputs[1],) 221 | 222 | hidden_states = self.norm(hidden_states) 223 | 224 | # add hidden states from the last decoder layer 225 | if output_hidden_states: 226 | all_hidden_states += (hidden_states,) 227 | 228 | next_cache = None 229 | if use_cache: 230 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache 231 | 232 | if not return_dict: 233 | return tuple( 234 | v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None 235 | ) 236 | return BaseModelOutputWithPast( 237 | last_hidden_state=hidden_states, 238 | past_key_values=next_cache, 239 | hidden_states=all_hidden_states, 240 | attentions=all_self_attns, 241 | ) 242 | 243 | class CustomCausalMistralModel(transformers.MistralForCausalLM, GenerationMixin): 244 | def __init__(self, config): 245 | super().__init__(config) 246 | self.model = CustomMistralModel(config) 247 | -------------------------------------------------------------------------------- /prepack_generation_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Imports**" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "!git clone https://github.com/siyan-zhao/prepacking.git\n", 17 | "%cd prepacking/\n", 18 | "!pip install ortools==9.9.3963\n", 19 | "!pip install binpacking==1.5.2\n", 20 | "!pip install datasets==2.18.0\n", 21 | "!pip install -i https://pypi.org/simple/ bitsandbytes" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": { 28 | "id": "UYWdWPXytfLG" 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "import torch\n", 33 | "from processor import PrePackProcessor\n", 34 | "from model import CustomCausalLlamaModel\n", 35 | "from transformers import AutoTokenizer\n", 36 | "from transformers.trainer_utils import set_seed\n", 37 | "from dataset_utils import unpack_kv\n", 38 | "from transformers import BitsAndBytesConfig" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "YW6yOlJxzCVw" 45 | }, 46 | "source": [ 47 | "**Load Model**" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "colab": { 55 | "base_uri": "https://localhost:8080/", 56 | "height": 833, 57 | "referenced_widgets": [ 58 | "27732c5e88314d12a403bfbeea05b71a", 59 | "2bdbfdcdc9a54b8fad0068fda0fecd55", 60 | "48ee78ad568843ddba1619a3e50404ce", 61 | "68025228c36c419e88ce26efcac76456", 62 | "1ff88c3b38de421da1bcff730cccbf52", 63 | "4149a1ba8c174a83abcadc4bcb1ef98a", 64 | "0b9b2e91fe864f25ae09c37bf9145cae", 65 | "782c8fe1763b4977a678d9fb4fd5dadc", 66 | "d3b47705f4e14680b13b21be52d2e604", 67 | "2b5de781f41e4d6dbdc6c7939b405af0", 68 | "1b9ee7656cae4986ba45656fef43d154", 69 | "e022c45f119c4eb281d5679b80d6387d", 70 | "c6ab5db12b0647f5989e233d914edd87", 71 | "6e238cbf7ab84e5b9d6e6d756ccd9487", 72 | "7f160c6489514b35945179d204c06bc7", 73 | "24ffd51165764526962c2de4c004ae51", 74 | "8322ac7b0d8644d78804e6d85ff05baa", 75 | "1fecf11618b244909505323d2968e5fc", 76 | "0f57f44193ef462490f4e1526637c874", 77 | "94f49a2b0fca4fa5a22ea8a07f628b99", 78 | "3f21f1f034744945a9b0657e313a6350", 79 | "e98ec60115014dc0a5ec44fdec3a0073", 80 | "aadb942209e84408a2bf11e55412b574", 81 | "fa9cafdcab2f439b9196f6f05166d898", 82 | "240fed609ad44aafad099550d3c265df", 83 | "4afabea1be0249068ad394e7265abf24", 84 | "91bf2b38bd224cbfabc77aa85d75f48b", 85 | "1a0d37386914443f983a89cf56d7a900", 86 | "42075dab0a5443b0b65e1ff8b2cb5e89", 87 | "b73c8d9ccf494d6193ff8262777a98da", 88 | "01c044c6e5364772b2032197bdca6a37", 89 | "29ee1203a8444b579caeec2f9fe201e6", 90 | "5cbd1a0fe71c434ca0bd655d6347947f", 91 | "a001503024fe44c8acc0d5e1a6d561b7", 92 | "0c0d9cda00e045c2a16eab19347b49d2", 93 | "aa4160a1837f427bb8ae673d31a0f24a", 94 | "7562626ebef743cfb5c234a8babd884d", 95 | "ed4321893f5f4630bf2dbf1f4a38bd39", 96 | "e4685f450cb543f482d41881e2555b34", 97 | "2cdfa2a8e85548ce9259285a04310822", 98 | "bb6f8aa46e3a459aacc6c99f9f76c72c", 99 | "1f057e7c483d41a8b63f8b59480962f1", 100 | "8119e15ecae64227857fb5a927c95877", 101 | "abfc8e73b9f94e4aa684de84a6194e20", 102 | "975d4035cf144139b455fcb06dde003e", 103 | "4c44190d75cb4f21826fee5b25052886", 104 | "48e7aedb9eb8467ba5418860c5d99a07", 105 | "1a65ae611f96427e85b5e5ec21bca22d", 106 | "49536e13224f481bbb4a02a1391ab347", 107 | "9b5632fd42b7452295e7873e7c7a230f", 108 | "b5ca1b962b874a50bd3657a1d74ea998", 109 | "c8219473c99d480081104fe4551036de", 110 | "58524a4c38cf4289801ee2569c07948d", 111 | "b3ffc98ce8fa41ddad546ceed1f3e0fb", 112 | "6d38a5c62d1243bc9c8d68fc0a20393b", 113 | "8c2d4bf32f5e46539cfe78ba705e379a", 114 | "d160b88d7f00493d81c424be93d7030f", 115 | "256bd87c03b2449a9cf25303a84b96d6", 116 | "e05657d7cdab4c4082ebc87b6d903969", 117 | "8246ecc55fbd476ba7fca6a183e4d741", 118 | "95584181b8fd41eebb322c951c72c1f6", 119 | "88795b72502645b4bdc95beadb4471d5", 120 | "a44eefdea1bd4b83a210f226f074e1d0", 121 | "4e3c9dbe0057456683e12c1357925ec5", 122 | "55690f32de47405da42599803e654fc6", 123 | "6ea49b6d6530403fb9acf7a15493b6e9", 124 | "12e2bb57b55c43b2a46ccdde99cb981b", 125 | "dd92cd36ad544a17a22374c0896947da", 126 | "4c7aa0c521cc4e0e8f9bb8789e6ad8f9", 127 | "454099c862aa46149f79467e54b7f7ac", 128 | "10a555507a6246ee9c2ff36a5405d63e", 129 | "b568300f99bd4bd48e8f37029282eed6", 130 | "13a9fb93d75e41d180e7673acad43dab", 131 | "cfee62ce88bf4a409ce34337556539cf", 132 | "a16259ceb15b464fa87355a14aaac331", 133 | "405acde9b0b74f18aaca033b36c28fbd", 134 | "9c5163fdeb61471bbafdfe08393a6d80" 135 | ] 136 | }, 137 | "id": "whQESxVbzB4d", 138 | "outputId": "8f3035cc-17cf-483c-9559-90e8055bad01" 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "SEED = 42\n", 143 | "set_seed(SEED)\n", 144 | "model_path = \"princeton-nlp/Sheared-LLaMA-1.3B\"\n", 145 | "tokenizer = AutoTokenizer.from_pretrained(model_path)\n", 146 | "tokenizer.pad_token = \"[PAD]\"\n", 147 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 148 | "custom_model = CustomCausalLlamaModel.from_pretrained(model_path)\n", 149 | "custom_model.to(device)\n", 150 | "custom_model.eval()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": { 156 | "id": "Dy590zUqzHeV" 157 | }, 158 | "source": [ 159 | "**Prepacking Generation**" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 7, 165 | "metadata": { 166 | "id": "ZiRs_hLetkCB" 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "\n", 171 | "processor = PrePackProcessor(tokenizer)\n", 172 | "\n", 173 | "# Change to any prompts\n", 174 | "sentences = [\n", 175 | " \"Rescuers are searching for multiple people in the water after Baltimore bridge collapse, report says\",\n", 176 | " \"Major bridge in Maryland collapses after being hit by a ship\",\n", 177 | " \"The capital of Germany is\",\n", 178 | " \"The capital of Spain is\",\n", 179 | " \"The capital of Greece is\",\n", 180 | " \"Today I'm going to the\",\n", 181 | " \"Baltimore Police Department told NBC\",\n", 182 | " \"My\",\n", 183 | " \"It\",\n", 184 | "]\n", 185 | "\n", 186 | "packed_tokens, restart_positions, independent_mask, restart_dict, original_ids = processor.batch_process(sentences)\n", 187 | "\n", 188 | "\n", 189 | "with torch.no_grad():\n", 190 | " packed_outputs = custom_model(\n", 191 | " input_ids=packed_tokens.to(device),\n", 192 | " attention_mask=independent_mask.to(device),\n", 193 | " position_ids=restart_positions.to(device),\n", 194 | " return_dict=True,\n", 195 | " output_hidden_states=True,\n", 196 | " )\n", 197 | "\n", 198 | "cache, final_tokens, attention_mask = unpack_kv(\n", 199 | " packed_outputs[\"past_key_values\"], restart_dict, original_ids, device\n", 200 | ")\n", 201 | "\n", 202 | "prepack_generated_output = custom_model.generate(\n", 203 | " input_ids=final_tokens.to(device),\n", 204 | " attention_mask=attention_mask.to(device),\n", 205 | " max_new_tokens=20,\n", 206 | " use_cache=True,\n", 207 | " do_sample=False,\n", 208 | " past_key_values=cache,\n", 209 | " num_return_sequences=1,\n", 210 | " output_scores=True,\n", 211 | " return_dict_in_generate=True,\n", 212 | ")\n", 213 | "\n" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": { 219 | "id": "T2Q5miklzaDH" 220 | }, 221 | "source": [ 222 | "**Default Generation**" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 8, 228 | "metadata": { 229 | "id": "bdgfwCkbxCos" 230 | }, 231 | "outputs": [], 232 | "source": [ 233 | "\n", 234 | "with torch.no_grad():\n", 235 | " normal_tokens_id = tokenizer(sentences, return_tensors=\"pt\", padding=True, truncation=False).to(\n", 236 | " device\n", 237 | " )\n", 238 | " normal_outputs = custom_model(**normal_tokens_id, return_dict=True, output_hidden_states=True)\n", 239 | "\n", 240 | "default_generated_output = custom_model.generate(\n", 241 | " **normal_tokens_id,\n", 242 | " max_new_tokens=20,\n", 243 | " use_cache=True,\n", 244 | " do_sample=False,\n", 245 | " num_return_sequences=1,\n", 246 | " output_scores=True,\n", 247 | " return_dict_in_generate=True\n", 248 | ")\n", 249 | "\n", 250 | "attention_mask = normal_tokens_id[\"attention_mask\"]\n", 251 | "\n" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "**Compare Generations**" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 9, 264 | "metadata": { 265 | "colab": { 266 | "base_uri": "https://localhost:8080/" 267 | }, 268 | "id": "4_unN3rS1pNQ", 269 | "outputId": "ecaae29d-5b29-4160-921c-46a009313e93" 270 | }, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "Asserting Same Tokens\n", 277 | "--------------- comparing ---------------\n", 278 | "Prepacked 0 : \n", 279 | "The Baltimore Sun reports that the collapse of a bridge in Baltimore on Monday morning has left at least\n", 280 | "Default 0 : \n", 281 | "The Baltimore Sun reports that the collapse of a bridge in Baltimore on Monday morning has left at least\n", 282 | "--------------- comparing ---------------\n", 283 | "Prepacked 1 : \n", 284 | "The bridge was built in 1968 and was the first of its kind in the\n", 285 | "Default 1 : \n", 286 | "The bridge was built in 1968 and was the first of its kind in the\n", 287 | "--------------- comparing ---------------\n", 288 | "Prepacked 2 : Berlin. Berlin is a city of contrasts. Berlin is a city of contrasts. It is\n", 289 | "Default 2 : Berlin. Berlin is a city of contrasts. Berlin is a city of contrasts. It is\n", 290 | "--------------- comparing ---------------\n", 291 | "Prepacked 3 : Madrid, and it is the largest city in the country. The city is located in the central part\n", 292 | "Default 3 : Madrid, and it is the largest city in the country. The city is located in the central part\n", 293 | "--------------- comparing ---------------\n", 294 | "Prepacked 4 : Athens, the largest city in Greece and the second largest in Europe. The city is located on\n", 295 | "Default 4 : Athens, the largest city in Greece and the second largest in Europe. The city is located on\n", 296 | "--------------- comparing ---------------\n", 297 | "Prepacked 5 : gym. I'm going to do a 30 minute workout. I'm\n", 298 | "Default 5 : gym. I'm going to do a 30 minute workout. I'm\n", 299 | "--------------- comparing ---------------\n", 300 | "Prepacked 6 : 10 that the suspect was arrested on a warrant for a probation violation.\n", 301 | "The\n", 302 | "Default 6 : 10 that the suspect was arrested on a warrant for a probation violation.\n", 303 | "The\n", 304 | "--------------- comparing ---------------\n", 305 | "Prepacked 7 : name is Katie and I am a 20 year old student from the UK. I am\n", 306 | "Default 7 : name is Katie and I am a 20 year old student from the UK. I am\n", 307 | "--------------- comparing ---------------\n", 308 | "Prepacked 8 : ’s been a while since I’ve posted anything on this blog. I’ve been busy\n", 309 | "Default 8 : ’s been a while since I’ve posted anything on this blog. I’ve been busy\n" 310 | ] 311 | } 312 | ], 313 | "source": [ 314 | "print(\"Asserting Same Tokens\")\n", 315 | "\n", 316 | "# Check tokens\n", 317 | "# Note that it is possible to have different generations due to numerical instability\n", 318 | "for i, (prepack_token, default_token) in enumerate(\n", 319 | " zip(prepack_generated_output.sequences, default_generated_output.sequences)\n", 320 | "):\n", 321 | "\n", 322 | " prepack = tokenizer.decode(prepack_token[1:])\n", 323 | " default = tokenizer.decode(default_token[attention_mask.shape[-1] :])\n", 324 | " print(\"-\" * 15, \"comparing\", \"-\" * 15)\n", 325 | " print(\"Prepacked\", i, \":\", prepack)\n", 326 | " print(\"Default\", i, \":\", default)\n", 327 | "\n", 328 | " assert prepack == default" 329 | ] 330 | } 331 | ], 332 | "metadata": { 333 | "accelerator": "GPU", 334 | "colab": { 335 | "gpuType": "T4", 336 | "provenance": [] 337 | }, 338 | "kernelspec": { 339 | "display_name": "Python 3", 340 | "name": "python3" 341 | }, 342 | "language_info": { 343 | "name": "python" 344 | }, 345 | "widgets": { 346 | "application/vnd.jupyter.widget-state+json": { 347 | "01c044c6e5364772b2032197bdca6a37": { 348 | "model_module": "@jupyter-widgets/controls", 349 | "model_module_version": "1.5.0", 350 | "model_name": "ProgressStyleModel", 351 | "state": { 352 | "_model_module": "@jupyter-widgets/controls", 353 | "_model_module_version": "1.5.0", 354 | "_model_name": "ProgressStyleModel", 355 | "_view_count": null, 356 | "_view_module": "@jupyter-widgets/base", 357 | "_view_module_version": "1.2.0", 358 | "_view_name": "StyleView", 359 | "bar_color": null, 360 | "description_width": "" 361 | } 362 | }, 363 | "0b9b2e91fe864f25ae09c37bf9145cae": { 364 | "model_module": "@jupyter-widgets/controls", 365 | "model_module_version": "1.5.0", 366 | "model_name": "DescriptionStyleModel", 367 | "state": { 368 | "_model_module": "@jupyter-widgets/controls", 369 | "_model_module_version": "1.5.0", 370 | "_model_name": "DescriptionStyleModel", 371 | "_view_count": null, 372 | "_view_module": "@jupyter-widgets/base", 373 | "_view_module_version": "1.2.0", 374 | "_view_name": "StyleView", 375 | "description_width": "" 376 | } 377 | }, 378 | "0c0d9cda00e045c2a16eab19347b49d2": { 379 | "model_module": "@jupyter-widgets/controls", 380 | "model_module_version": "1.5.0", 381 | "model_name": "HTMLModel", 382 | "state": { 383 | "_dom_classes": [], 384 | "_model_module": "@jupyter-widgets/controls", 385 | "_model_module_version": "1.5.0", 386 | "_model_name": "HTMLModel", 387 | "_view_count": null, 388 | "_view_module": "@jupyter-widgets/controls", 389 | "_view_module_version": "1.5.0", 390 | "_view_name": "HTMLView", 391 | "description": "", 392 | "description_tooltip": null, 393 | "layout": "IPY_MODEL_e4685f450cb543f482d41881e2555b34", 394 | "placeholder": "​", 395 | "style": "IPY_MODEL_2cdfa2a8e85548ce9259285a04310822", 396 | "value": "special_tokens_map.json: 100%" 397 | } 398 | }, 399 | "0f57f44193ef462490f4e1526637c874": { 400 | "model_module": "@jupyter-widgets/base", 401 | "model_module_version": "1.2.0", 402 | "model_name": "LayoutModel", 403 | "state": { 404 | "_model_module": "@jupyter-widgets/base", 405 | "_model_module_version": "1.2.0", 406 | "_model_name": "LayoutModel", 407 | "_view_count": null, 408 | "_view_module": "@jupyter-widgets/base", 409 | "_view_module_version": "1.2.0", 410 | "_view_name": "LayoutView", 411 | "align_content": null, 412 | "align_items": null, 413 | "align_self": null, 414 | "border": null, 415 | "bottom": null, 416 | "display": null, 417 | "flex": null, 418 | "flex_flow": null, 419 | "grid_area": null, 420 | "grid_auto_columns": null, 421 | "grid_auto_flow": null, 422 | "grid_auto_rows": null, 423 | "grid_column": null, 424 | "grid_gap": null, 425 | "grid_row": null, 426 | "grid_template_areas": null, 427 | "grid_template_columns": null, 428 | "grid_template_rows": null, 429 | "height": null, 430 | "justify_content": null, 431 | "justify_items": null, 432 | "left": null, 433 | "margin": null, 434 | "max_height": null, 435 | "max_width": null, 436 | "min_height": null, 437 | "min_width": null, 438 | "object_fit": null, 439 | "object_position": null, 440 | "order": null, 441 | "overflow": null, 442 | "overflow_x": null, 443 | "overflow_y": null, 444 | "padding": null, 445 | "right": null, 446 | "top": null, 447 | "visibility": null, 448 | "width": null 449 | } 450 | }, 451 | "10a555507a6246ee9c2ff36a5405d63e": { 452 | "model_module": "@jupyter-widgets/base", 453 | "model_module_version": "1.2.0", 454 | "model_name": "LayoutModel", 455 | "state": { 456 | "_model_module": "@jupyter-widgets/base", 457 | "_model_module_version": "1.2.0", 458 | "_model_name": "LayoutModel", 459 | "_view_count": null, 460 | "_view_module": "@jupyter-widgets/base", 461 | "_view_module_version": "1.2.0", 462 | "_view_name": "LayoutView", 463 | "align_content": null, 464 | "align_items": null, 465 | "align_self": null, 466 | "border": null, 467 | "bottom": null, 468 | "display": null, 469 | "flex": null, 470 | "flex_flow": null, 471 | "grid_area": null, 472 | "grid_auto_columns": null, 473 | "grid_auto_flow": null, 474 | "grid_auto_rows": null, 475 | "grid_column": null, 476 | "grid_gap": null, 477 | "grid_row": null, 478 | "grid_template_areas": null, 479 | "grid_template_columns": null, 480 | "grid_template_rows": null, 481 | "height": null, 482 | "justify_content": null, 483 | "justify_items": null, 484 | "left": null, 485 | "margin": null, 486 | "max_height": null, 487 | "max_width": null, 488 | "min_height": null, 489 | "min_width": null, 490 | "object_fit": null, 491 | "object_position": null, 492 | "order": null, 493 | "overflow": null, 494 | "overflow_x": null, 495 | "overflow_y": null, 496 | "padding": null, 497 | "right": null, 498 | "top": null, 499 | "visibility": null, 500 | "width": null 501 | } 502 | }, 503 | "12e2bb57b55c43b2a46ccdde99cb981b": { 504 | "model_module": "@jupyter-widgets/controls", 505 | "model_module_version": "1.5.0", 506 | "model_name": "HBoxModel", 507 | "state": { 508 | "_dom_classes": [], 509 | "_model_module": "@jupyter-widgets/controls", 510 | "_model_module_version": "1.5.0", 511 | "_model_name": "HBoxModel", 512 | "_view_count": null, 513 | "_view_module": "@jupyter-widgets/controls", 514 | "_view_module_version": "1.5.0", 515 | "_view_name": "HBoxView", 516 | "box_style": "", 517 | "children": [ 518 | "IPY_MODEL_dd92cd36ad544a17a22374c0896947da", 519 | "IPY_MODEL_4c7aa0c521cc4e0e8f9bb8789e6ad8f9", 520 | "IPY_MODEL_454099c862aa46149f79467e54b7f7ac" 521 | ], 522 | "layout": "IPY_MODEL_10a555507a6246ee9c2ff36a5405d63e" 523 | } 524 | }, 525 | "13a9fb93d75e41d180e7673acad43dab": { 526 | "model_module": "@jupyter-widgets/controls", 527 | "model_module_version": "1.5.0", 528 | "model_name": "DescriptionStyleModel", 529 | "state": { 530 | "_model_module": "@jupyter-widgets/controls", 531 | "_model_module_version": "1.5.0", 532 | "_model_name": "DescriptionStyleModel", 533 | "_view_count": null, 534 | "_view_module": "@jupyter-widgets/base", 535 | "_view_module_version": "1.2.0", 536 | "_view_name": "StyleView", 537 | "description_width": "" 538 | } 539 | }, 540 | "1a0d37386914443f983a89cf56d7a900": { 541 | "model_module": "@jupyter-widgets/base", 542 | "model_module_version": "1.2.0", 543 | "model_name": "LayoutModel", 544 | "state": { 545 | "_model_module": "@jupyter-widgets/base", 546 | "_model_module_version": "1.2.0", 547 | "_model_name": "LayoutModel", 548 | "_view_count": null, 549 | "_view_module": "@jupyter-widgets/base", 550 | "_view_module_version": "1.2.0", 551 | "_view_name": "LayoutView", 552 | "align_content": null, 553 | "align_items": null, 554 | "align_self": null, 555 | "border": null, 556 | "bottom": null, 557 | "display": null, 558 | "flex": null, 559 | "flex_flow": null, 560 | "grid_area": null, 561 | "grid_auto_columns": null, 562 | "grid_auto_flow": null, 563 | "grid_auto_rows": null, 564 | "grid_column": null, 565 | "grid_gap": null, 566 | "grid_row": null, 567 | "grid_template_areas": null, 568 | "grid_template_columns": null, 569 | "grid_template_rows": null, 570 | "height": null, 571 | "justify_content": null, 572 | "justify_items": null, 573 | "left": null, 574 | "margin": null, 575 | "max_height": null, 576 | "max_width": null, 577 | "min_height": null, 578 | "min_width": null, 579 | "object_fit": null, 580 | "object_position": null, 581 | "order": null, 582 | "overflow": null, 583 | "overflow_x": null, 584 | "overflow_y": null, 585 | "padding": null, 586 | "right": null, 587 | "top": null, 588 | "visibility": null, 589 | "width": null 590 | } 591 | }, 592 | "1a65ae611f96427e85b5e5ec21bca22d": { 593 | "model_module": "@jupyter-widgets/controls", 594 | "model_module_version": "1.5.0", 595 | "model_name": "HTMLModel", 596 | "state": { 597 | "_dom_classes": [], 598 | "_model_module": "@jupyter-widgets/controls", 599 | "_model_module_version": "1.5.0", 600 | "_model_name": "HTMLModel", 601 | "_view_count": null, 602 | "_view_module": "@jupyter-widgets/controls", 603 | "_view_module_version": "1.5.0", 604 | "_view_name": "HTMLView", 605 | "description": "", 606 | "description_tooltip": null, 607 | "layout": "IPY_MODEL_b3ffc98ce8fa41ddad546ceed1f3e0fb", 608 | "placeholder": "​", 609 | "style": "IPY_MODEL_6d38a5c62d1243bc9c8d68fc0a20393b", 610 | "value": " 632/632 [00:00<00:00, 23.6kB/s]" 611 | } 612 | }, 613 | "1b9ee7656cae4986ba45656fef43d154": { 614 | "model_module": "@jupyter-widgets/controls", 615 | "model_module_version": "1.5.0", 616 | "model_name": "DescriptionStyleModel", 617 | "state": { 618 | "_model_module": "@jupyter-widgets/controls", 619 | "_model_module_version": "1.5.0", 620 | "_model_name": "DescriptionStyleModel", 621 | "_view_count": null, 622 | "_view_module": "@jupyter-widgets/base", 623 | "_view_module_version": "1.2.0", 624 | "_view_name": "StyleView", 625 | "description_width": "" 626 | } 627 | }, 628 | "1f057e7c483d41a8b63f8b59480962f1": { 629 | "model_module": "@jupyter-widgets/controls", 630 | "model_module_version": "1.5.0", 631 | "model_name": "ProgressStyleModel", 632 | "state": { 633 | "_model_module": "@jupyter-widgets/controls", 634 | "_model_module_version": "1.5.0", 635 | "_model_name": "ProgressStyleModel", 636 | "_view_count": null, 637 | "_view_module": "@jupyter-widgets/base", 638 | "_view_module_version": "1.2.0", 639 | "_view_name": "StyleView", 640 | "bar_color": null, 641 | "description_width": "" 642 | } 643 | }, 644 | "1fecf11618b244909505323d2968e5fc": { 645 | "model_module": "@jupyter-widgets/controls", 646 | "model_module_version": "1.5.0", 647 | "model_name": "DescriptionStyleModel", 648 | "state": { 649 | "_model_module": "@jupyter-widgets/controls", 650 | "_model_module_version": "1.5.0", 651 | "_model_name": "DescriptionStyleModel", 652 | "_view_count": null, 653 | "_view_module": "@jupyter-widgets/base", 654 | "_view_module_version": "1.2.0", 655 | "_view_name": "StyleView", 656 | "description_width": "" 657 | } 658 | }, 659 | "1ff88c3b38de421da1bcff730cccbf52": { 660 | "model_module": "@jupyter-widgets/base", 661 | "model_module_version": "1.2.0", 662 | "model_name": "LayoutModel", 663 | "state": { 664 | "_model_module": "@jupyter-widgets/base", 665 | "_model_module_version": "1.2.0", 666 | "_model_name": "LayoutModel", 667 | "_view_count": null, 668 | "_view_module": "@jupyter-widgets/base", 669 | "_view_module_version": "1.2.0", 670 | "_view_name": "LayoutView", 671 | "align_content": null, 672 | "align_items": null, 673 | "align_self": null, 674 | "border": null, 675 | "bottom": null, 676 | "display": null, 677 | "flex": null, 678 | "flex_flow": null, 679 | "grid_area": null, 680 | "grid_auto_columns": null, 681 | "grid_auto_flow": null, 682 | "grid_auto_rows": null, 683 | "grid_column": null, 684 | "grid_gap": null, 685 | "grid_row": null, 686 | "grid_template_areas": null, 687 | "grid_template_columns": null, 688 | "grid_template_rows": null, 689 | "height": null, 690 | "justify_content": null, 691 | "justify_items": null, 692 | "left": null, 693 | "margin": null, 694 | "max_height": null, 695 | "max_width": null, 696 | "min_height": null, 697 | "min_width": null, 698 | "object_fit": null, 699 | "object_position": null, 700 | "order": null, 701 | "overflow": null, 702 | "overflow_x": null, 703 | "overflow_y": null, 704 | "padding": null, 705 | "right": null, 706 | "top": null, 707 | "visibility": null, 708 | "width": null 709 | } 710 | }, 711 | "240fed609ad44aafad099550d3c265df": { 712 | "model_module": "@jupyter-widgets/controls", 713 | "model_module_version": "1.5.0", 714 | "model_name": "FloatProgressModel", 715 | "state": { 716 | "_dom_classes": [], 717 | "_model_module": "@jupyter-widgets/controls", 718 | "_model_module_version": "1.5.0", 719 | "_model_name": "FloatProgressModel", 720 | "_view_count": null, 721 | "_view_module": "@jupyter-widgets/controls", 722 | "_view_module_version": "1.5.0", 723 | "_view_name": "ProgressView", 724 | "bar_style": "success", 725 | "description": "", 726 | "description_tooltip": null, 727 | "layout": "IPY_MODEL_b73c8d9ccf494d6193ff8262777a98da", 728 | "max": 1842665, 729 | "min": 0, 730 | "orientation": "horizontal", 731 | "style": "IPY_MODEL_01c044c6e5364772b2032197bdca6a37", 732 | "value": 1842665 733 | } 734 | }, 735 | "24ffd51165764526962c2de4c004ae51": { 736 | "model_module": "@jupyter-widgets/base", 737 | "model_module_version": "1.2.0", 738 | "model_name": "LayoutModel", 739 | "state": { 740 | "_model_module": "@jupyter-widgets/base", 741 | "_model_module_version": "1.2.0", 742 | "_model_name": "LayoutModel", 743 | "_view_count": null, 744 | "_view_module": "@jupyter-widgets/base", 745 | "_view_module_version": "1.2.0", 746 | "_view_name": "LayoutView", 747 | "align_content": null, 748 | "align_items": null, 749 | "align_self": null, 750 | "border": null, 751 | "bottom": null, 752 | "display": null, 753 | "flex": null, 754 | "flex_flow": null, 755 | "grid_area": null, 756 | "grid_auto_columns": null, 757 | "grid_auto_flow": null, 758 | "grid_auto_rows": null, 759 | "grid_column": null, 760 | "grid_gap": null, 761 | "grid_row": null, 762 | "grid_template_areas": null, 763 | "grid_template_columns": null, 764 | "grid_template_rows": null, 765 | "height": null, 766 | "justify_content": null, 767 | "justify_items": null, 768 | "left": null, 769 | "margin": null, 770 | "max_height": null, 771 | "max_width": null, 772 | "min_height": null, 773 | "min_width": null, 774 | "object_fit": null, 775 | "object_position": null, 776 | "order": null, 777 | "overflow": null, 778 | "overflow_x": null, 779 | "overflow_y": null, 780 | "padding": null, 781 | "right": null, 782 | "top": null, 783 | "visibility": null, 784 | "width": null 785 | } 786 | }, 787 | "256bd87c03b2449a9cf25303a84b96d6": { 788 | "model_module": "@jupyter-widgets/controls", 789 | "model_module_version": "1.5.0", 790 | "model_name": "FloatProgressModel", 791 | "state": { 792 | "_dom_classes": [], 793 | "_model_module": "@jupyter-widgets/controls", 794 | "_model_module_version": "1.5.0", 795 | "_model_name": "FloatProgressModel", 796 | "_view_count": null, 797 | "_view_module": "@jupyter-widgets/controls", 798 | "_view_module_version": "1.5.0", 799 | "_view_name": "ProgressView", 800 | "bar_style": "success", 801 | "description": "", 802 | "description_tooltip": null, 803 | "layout": "IPY_MODEL_a44eefdea1bd4b83a210f226f074e1d0", 804 | "max": 5381778489, 805 | "min": 0, 806 | "orientation": "horizontal", 807 | "style": "IPY_MODEL_4e3c9dbe0057456683e12c1357925ec5", 808 | "value": 5381778489 809 | } 810 | }, 811 | "27732c5e88314d12a403bfbeea05b71a": { 812 | "model_module": "@jupyter-widgets/controls", 813 | "model_module_version": "1.5.0", 814 | "model_name": "HBoxModel", 815 | "state": { 816 | "_dom_classes": [], 817 | "_model_module": "@jupyter-widgets/controls", 818 | "_model_module_version": "1.5.0", 819 | "_model_name": "HBoxModel", 820 | "_view_count": null, 821 | "_view_module": "@jupyter-widgets/controls", 822 | "_view_module_version": "1.5.0", 823 | "_view_name": "HBoxView", 824 | "box_style": "", 825 | "children": [ 826 | "IPY_MODEL_2bdbfdcdc9a54b8fad0068fda0fecd55", 827 | "IPY_MODEL_48ee78ad568843ddba1619a3e50404ce", 828 | "IPY_MODEL_68025228c36c419e88ce26efcac76456" 829 | ], 830 | "layout": "IPY_MODEL_1ff88c3b38de421da1bcff730cccbf52" 831 | } 832 | }, 833 | "29ee1203a8444b579caeec2f9fe201e6": { 834 | "model_module": "@jupyter-widgets/base", 835 | "model_module_version": "1.2.0", 836 | "model_name": "LayoutModel", 837 | "state": { 838 | "_model_module": "@jupyter-widgets/base", 839 | "_model_module_version": "1.2.0", 840 | "_model_name": "LayoutModel", 841 | "_view_count": null, 842 | "_view_module": "@jupyter-widgets/base", 843 | "_view_module_version": "1.2.0", 844 | "_view_name": "LayoutView", 845 | "align_content": null, 846 | "align_items": null, 847 | "align_self": null, 848 | "border": null, 849 | "bottom": null, 850 | "display": null, 851 | "flex": null, 852 | "flex_flow": null, 853 | "grid_area": null, 854 | "grid_auto_columns": null, 855 | "grid_auto_flow": null, 856 | "grid_auto_rows": null, 857 | "grid_column": null, 858 | "grid_gap": null, 859 | "grid_row": null, 860 | "grid_template_areas": null, 861 | "grid_template_columns": null, 862 | "grid_template_rows": null, 863 | "height": null, 864 | "justify_content": null, 865 | "justify_items": null, 866 | "left": null, 867 | "margin": null, 868 | "max_height": null, 869 | "max_width": null, 870 | "min_height": null, 871 | "min_width": null, 872 | "object_fit": null, 873 | "object_position": null, 874 | "order": null, 875 | "overflow": null, 876 | "overflow_x": null, 877 | "overflow_y": null, 878 | "padding": null, 879 | "right": null, 880 | "top": null, 881 | "visibility": null, 882 | "width": null 883 | } 884 | }, 885 | "2b5de781f41e4d6dbdc6c7939b405af0": { 886 | "model_module": "@jupyter-widgets/base", 887 | "model_module_version": "1.2.0", 888 | "model_name": "LayoutModel", 889 | "state": { 890 | "_model_module": "@jupyter-widgets/base", 891 | "_model_module_version": "1.2.0", 892 | "_model_name": "LayoutModel", 893 | "_view_count": null, 894 | "_view_module": "@jupyter-widgets/base", 895 | "_view_module_version": "1.2.0", 896 | "_view_name": "LayoutView", 897 | "align_content": null, 898 | "align_items": null, 899 | "align_self": null, 900 | "border": null, 901 | "bottom": null, 902 | "display": null, 903 | "flex": null, 904 | "flex_flow": null, 905 | "grid_area": null, 906 | "grid_auto_columns": null, 907 | "grid_auto_flow": null, 908 | "grid_auto_rows": null, 909 | "grid_column": null, 910 | "grid_gap": null, 911 | "grid_row": null, 912 | "grid_template_areas": null, 913 | "grid_template_columns": null, 914 | "grid_template_rows": null, 915 | "height": null, 916 | "justify_content": null, 917 | "justify_items": null, 918 | "left": null, 919 | "margin": null, 920 | "max_height": null, 921 | "max_width": null, 922 | "min_height": null, 923 | "min_width": null, 924 | "object_fit": null, 925 | "object_position": null, 926 | "order": null, 927 | "overflow": null, 928 | "overflow_x": null, 929 | "overflow_y": null, 930 | "padding": null, 931 | "right": null, 932 | "top": null, 933 | "visibility": null, 934 | "width": null 935 | } 936 | }, 937 | "2bdbfdcdc9a54b8fad0068fda0fecd55": { 938 | "model_module": "@jupyter-widgets/controls", 939 | "model_module_version": "1.5.0", 940 | "model_name": "HTMLModel", 941 | "state": { 942 | "_dom_classes": [], 943 | "_model_module": "@jupyter-widgets/controls", 944 | "_model_module_version": "1.5.0", 945 | "_model_name": "HTMLModel", 946 | "_view_count": null, 947 | "_view_module": "@jupyter-widgets/controls", 948 | "_view_module_version": "1.5.0", 949 | "_view_name": "HTMLView", 950 | "description": "", 951 | "description_tooltip": null, 952 | "layout": "IPY_MODEL_4149a1ba8c174a83abcadc4bcb1ef98a", 953 | "placeholder": "​", 954 | "style": "IPY_MODEL_0b9b2e91fe864f25ae09c37bf9145cae", 955 | "value": "tokenizer_config.json: 100%" 956 | } 957 | }, 958 | "2cdfa2a8e85548ce9259285a04310822": { 959 | "model_module": "@jupyter-widgets/controls", 960 | "model_module_version": "1.5.0", 961 | "model_name": "DescriptionStyleModel", 962 | "state": { 963 | "_model_module": "@jupyter-widgets/controls", 964 | "_model_module_version": "1.5.0", 965 | "_model_name": "DescriptionStyleModel", 966 | "_view_count": null, 967 | "_view_module": "@jupyter-widgets/base", 968 | "_view_module_version": "1.2.0", 969 | "_view_name": "StyleView", 970 | "description_width": "" 971 | } 972 | }, 973 | "3f21f1f034744945a9b0657e313a6350": { 974 | "model_module": "@jupyter-widgets/base", 975 | "model_module_version": "1.2.0", 976 | "model_name": "LayoutModel", 977 | "state": { 978 | "_model_module": "@jupyter-widgets/base", 979 | "_model_module_version": "1.2.0", 980 | "_model_name": "LayoutModel", 981 | "_view_count": null, 982 | "_view_module": "@jupyter-widgets/base", 983 | "_view_module_version": "1.2.0", 984 | "_view_name": "LayoutView", 985 | "align_content": null, 986 | "align_items": null, 987 | "align_self": null, 988 | "border": null, 989 | "bottom": null, 990 | "display": null, 991 | "flex": null, 992 | "flex_flow": null, 993 | "grid_area": null, 994 | "grid_auto_columns": null, 995 | "grid_auto_flow": null, 996 | "grid_auto_rows": null, 997 | "grid_column": null, 998 | "grid_gap": null, 999 | "grid_row": null, 1000 | "grid_template_areas": null, 1001 | "grid_template_columns": null, 1002 | "grid_template_rows": null, 1003 | "height": null, 1004 | "justify_content": null, 1005 | "justify_items": null, 1006 | "left": null, 1007 | "margin": null, 1008 | "max_height": null, 1009 | "max_width": null, 1010 | "min_height": null, 1011 | "min_width": null, 1012 | "object_fit": null, 1013 | "object_position": null, 1014 | "order": null, 1015 | "overflow": null, 1016 | "overflow_x": null, 1017 | "overflow_y": null, 1018 | "padding": null, 1019 | "right": null, 1020 | "top": null, 1021 | "visibility": null, 1022 | "width": null 1023 | } 1024 | }, 1025 | "405acde9b0b74f18aaca033b36c28fbd": { 1026 | "model_module": "@jupyter-widgets/base", 1027 | "model_module_version": "1.2.0", 1028 | "model_name": "LayoutModel", 1029 | "state": { 1030 | "_model_module": "@jupyter-widgets/base", 1031 | "_model_module_version": "1.2.0", 1032 | "_model_name": "LayoutModel", 1033 | "_view_count": null, 1034 | "_view_module": "@jupyter-widgets/base", 1035 | "_view_module_version": "1.2.0", 1036 | "_view_name": "LayoutView", 1037 | "align_content": null, 1038 | "align_items": null, 1039 | "align_self": null, 1040 | "border": null, 1041 | "bottom": null, 1042 | "display": null, 1043 | "flex": null, 1044 | "flex_flow": null, 1045 | "grid_area": null, 1046 | "grid_auto_columns": null, 1047 | "grid_auto_flow": null, 1048 | "grid_auto_rows": null, 1049 | "grid_column": null, 1050 | "grid_gap": null, 1051 | "grid_row": null, 1052 | "grid_template_areas": null, 1053 | "grid_template_columns": null, 1054 | "grid_template_rows": null, 1055 | "height": null, 1056 | "justify_content": null, 1057 | "justify_items": null, 1058 | "left": null, 1059 | "margin": null, 1060 | "max_height": null, 1061 | "max_width": null, 1062 | "min_height": null, 1063 | "min_width": null, 1064 | "object_fit": null, 1065 | "object_position": null, 1066 | "order": null, 1067 | "overflow": null, 1068 | "overflow_x": null, 1069 | "overflow_y": null, 1070 | "padding": null, 1071 | "right": null, 1072 | "top": null, 1073 | "visibility": null, 1074 | "width": null 1075 | } 1076 | }, 1077 | "4149a1ba8c174a83abcadc4bcb1ef98a": { 1078 | "model_module": "@jupyter-widgets/base", 1079 | "model_module_version": "1.2.0", 1080 | "model_name": "LayoutModel", 1081 | "state": { 1082 | "_model_module": "@jupyter-widgets/base", 1083 | "_model_module_version": "1.2.0", 1084 | "_model_name": "LayoutModel", 1085 | "_view_count": null, 1086 | "_view_module": "@jupyter-widgets/base", 1087 | "_view_module_version": "1.2.0", 1088 | "_view_name": "LayoutView", 1089 | "align_content": null, 1090 | "align_items": null, 1091 | "align_self": null, 1092 | "border": null, 1093 | "bottom": null, 1094 | "display": null, 1095 | "flex": null, 1096 | "flex_flow": null, 1097 | "grid_area": null, 1098 | "grid_auto_columns": null, 1099 | "grid_auto_flow": null, 1100 | "grid_auto_rows": null, 1101 | "grid_column": null, 1102 | "grid_gap": null, 1103 | "grid_row": null, 1104 | "grid_template_areas": null, 1105 | "grid_template_columns": null, 1106 | "grid_template_rows": null, 1107 | "height": null, 1108 | "justify_content": null, 1109 | "justify_items": null, 1110 | "left": null, 1111 | "margin": null, 1112 | "max_height": null, 1113 | "max_width": null, 1114 | "min_height": null, 1115 | "min_width": null, 1116 | "object_fit": null, 1117 | "object_position": null, 1118 | "order": null, 1119 | "overflow": null, 1120 | "overflow_x": null, 1121 | "overflow_y": null, 1122 | "padding": null, 1123 | "right": null, 1124 | "top": null, 1125 | "visibility": null, 1126 | "width": null 1127 | } 1128 | }, 1129 | "42075dab0a5443b0b65e1ff8b2cb5e89": { 1130 | "model_module": "@jupyter-widgets/controls", 1131 | "model_module_version": "1.5.0", 1132 | "model_name": "DescriptionStyleModel", 1133 | "state": { 1134 | "_model_module": "@jupyter-widgets/controls", 1135 | "_model_module_version": "1.5.0", 1136 | "_model_name": "DescriptionStyleModel", 1137 | "_view_count": null, 1138 | "_view_module": "@jupyter-widgets/base", 1139 | "_view_module_version": "1.2.0", 1140 | "_view_name": "StyleView", 1141 | "description_width": "" 1142 | } 1143 | }, 1144 | "454099c862aa46149f79467e54b7f7ac": { 1145 | "model_module": "@jupyter-widgets/controls", 1146 | "model_module_version": "1.5.0", 1147 | "model_name": "HTMLModel", 1148 | "state": { 1149 | "_dom_classes": [], 1150 | "_model_module": "@jupyter-widgets/controls", 1151 | "_model_module_version": "1.5.0", 1152 | "_model_name": "HTMLModel", 1153 | "_view_count": null, 1154 | "_view_module": "@jupyter-widgets/controls", 1155 | "_view_module_version": "1.5.0", 1156 | "_view_name": "HTMLView", 1157 | "description": "", 1158 | "description_tooltip": null, 1159 | "layout": "IPY_MODEL_405acde9b0b74f18aaca033b36c28fbd", 1160 | "placeholder": "​", 1161 | "style": "IPY_MODEL_9c5163fdeb61471bbafdfe08393a6d80", 1162 | "value": " 132/132 [00:00<00:00, 8.87kB/s]" 1163 | } 1164 | }, 1165 | "48e7aedb9eb8467ba5418860c5d99a07": { 1166 | "model_module": "@jupyter-widgets/controls", 1167 | "model_module_version": "1.5.0", 1168 | "model_name": "FloatProgressModel", 1169 | "state": { 1170 | "_dom_classes": [], 1171 | "_model_module": "@jupyter-widgets/controls", 1172 | "_model_module_version": "1.5.0", 1173 | "_model_name": "FloatProgressModel", 1174 | "_view_count": null, 1175 | "_view_module": "@jupyter-widgets/controls", 1176 | "_view_module_version": "1.5.0", 1177 | "_view_name": "ProgressView", 1178 | "bar_style": "success", 1179 | "description": "", 1180 | "description_tooltip": null, 1181 | "layout": "IPY_MODEL_c8219473c99d480081104fe4551036de", 1182 | "max": 632, 1183 | "min": 0, 1184 | "orientation": "horizontal", 1185 | "style": "IPY_MODEL_58524a4c38cf4289801ee2569c07948d", 1186 | "value": 632 1187 | } 1188 | }, 1189 | "48ee78ad568843ddba1619a3e50404ce": { 1190 | "model_module": "@jupyter-widgets/controls", 1191 | "model_module_version": "1.5.0", 1192 | "model_name": "FloatProgressModel", 1193 | "state": { 1194 | "_dom_classes": [], 1195 | "_model_module": "@jupyter-widgets/controls", 1196 | "_model_module_version": "1.5.0", 1197 | "_model_name": "FloatProgressModel", 1198 | "_view_count": null, 1199 | "_view_module": "@jupyter-widgets/controls", 1200 | "_view_module_version": "1.5.0", 1201 | "_view_name": "ProgressView", 1202 | "bar_style": "success", 1203 | "description": "", 1204 | "description_tooltip": null, 1205 | "layout": "IPY_MODEL_782c8fe1763b4977a678d9fb4fd5dadc", 1206 | "max": 727, 1207 | "min": 0, 1208 | "orientation": "horizontal", 1209 | "style": "IPY_MODEL_d3b47705f4e14680b13b21be52d2e604", 1210 | "value": 727 1211 | } 1212 | }, 1213 | "49536e13224f481bbb4a02a1391ab347": { 1214 | "model_module": "@jupyter-widgets/base", 1215 | "model_module_version": "1.2.0", 1216 | "model_name": "LayoutModel", 1217 | "state": { 1218 | "_model_module": "@jupyter-widgets/base", 1219 | "_model_module_version": "1.2.0", 1220 | "_model_name": "LayoutModel", 1221 | "_view_count": null, 1222 | "_view_module": "@jupyter-widgets/base", 1223 | "_view_module_version": "1.2.0", 1224 | "_view_name": "LayoutView", 1225 | "align_content": null, 1226 | "align_items": null, 1227 | "align_self": null, 1228 | "border": null, 1229 | "bottom": null, 1230 | "display": null, 1231 | "flex": null, 1232 | "flex_flow": null, 1233 | "grid_area": null, 1234 | "grid_auto_columns": null, 1235 | "grid_auto_flow": null, 1236 | "grid_auto_rows": null, 1237 | "grid_column": null, 1238 | "grid_gap": null, 1239 | "grid_row": null, 1240 | "grid_template_areas": null, 1241 | "grid_template_columns": null, 1242 | "grid_template_rows": null, 1243 | "height": null, 1244 | "justify_content": null, 1245 | "justify_items": null, 1246 | "left": null, 1247 | "margin": null, 1248 | "max_height": null, 1249 | "max_width": null, 1250 | "min_height": null, 1251 | "min_width": null, 1252 | "object_fit": null, 1253 | "object_position": null, 1254 | "order": null, 1255 | "overflow": null, 1256 | "overflow_x": null, 1257 | "overflow_y": null, 1258 | "padding": null, 1259 | "right": null, 1260 | "top": null, 1261 | "visibility": null, 1262 | "width": null 1263 | } 1264 | }, 1265 | "4afabea1be0249068ad394e7265abf24": { 1266 | "model_module": "@jupyter-widgets/controls", 1267 | "model_module_version": "1.5.0", 1268 | "model_name": "HTMLModel", 1269 | "state": { 1270 | "_dom_classes": [], 1271 | "_model_module": "@jupyter-widgets/controls", 1272 | "_model_module_version": "1.5.0", 1273 | "_model_name": "HTMLModel", 1274 | "_view_count": null, 1275 | "_view_module": "@jupyter-widgets/controls", 1276 | "_view_module_version": "1.5.0", 1277 | "_view_name": "HTMLView", 1278 | "description": "", 1279 | "description_tooltip": null, 1280 | "layout": "IPY_MODEL_29ee1203a8444b579caeec2f9fe201e6", 1281 | "placeholder": "​", 1282 | "style": "IPY_MODEL_5cbd1a0fe71c434ca0bd655d6347947f", 1283 | "value": " 1.84M/1.84M [00:00<00:00, 7.76MB/s]" 1284 | } 1285 | }, 1286 | "4c44190d75cb4f21826fee5b25052886": { 1287 | "model_module": "@jupyter-widgets/controls", 1288 | "model_module_version": "1.5.0", 1289 | "model_name": "HTMLModel", 1290 | "state": { 1291 | "_dom_classes": [], 1292 | "_model_module": "@jupyter-widgets/controls", 1293 | "_model_module_version": "1.5.0", 1294 | "_model_name": "HTMLModel", 1295 | "_view_count": null, 1296 | "_view_module": "@jupyter-widgets/controls", 1297 | "_view_module_version": "1.5.0", 1298 | "_view_name": "HTMLView", 1299 | "description": "", 1300 | "description_tooltip": null, 1301 | "layout": "IPY_MODEL_9b5632fd42b7452295e7873e7c7a230f", 1302 | "placeholder": "​", 1303 | "style": "IPY_MODEL_b5ca1b962b874a50bd3657a1d74ea998", 1304 | "value": "config.json: 100%" 1305 | } 1306 | }, 1307 | "4c7aa0c521cc4e0e8f9bb8789e6ad8f9": { 1308 | "model_module": "@jupyter-widgets/controls", 1309 | "model_module_version": "1.5.0", 1310 | "model_name": "FloatProgressModel", 1311 | "state": { 1312 | "_dom_classes": [], 1313 | "_model_module": "@jupyter-widgets/controls", 1314 | "_model_module_version": "1.5.0", 1315 | "_model_name": "FloatProgressModel", 1316 | "_view_count": null, 1317 | "_view_module": "@jupyter-widgets/controls", 1318 | "_view_module_version": "1.5.0", 1319 | "_view_name": "ProgressView", 1320 | "bar_style": "success", 1321 | "description": "", 1322 | "description_tooltip": null, 1323 | "layout": "IPY_MODEL_cfee62ce88bf4a409ce34337556539cf", 1324 | "max": 132, 1325 | "min": 0, 1326 | "orientation": "horizontal", 1327 | "style": "IPY_MODEL_a16259ceb15b464fa87355a14aaac331", 1328 | "value": 132 1329 | } 1330 | }, 1331 | "4e3c9dbe0057456683e12c1357925ec5": { 1332 | "model_module": "@jupyter-widgets/controls", 1333 | "model_module_version": "1.5.0", 1334 | "model_name": "ProgressStyleModel", 1335 | "state": { 1336 | "_model_module": "@jupyter-widgets/controls", 1337 | "_model_module_version": "1.5.0", 1338 | "_model_name": "ProgressStyleModel", 1339 | "_view_count": null, 1340 | "_view_module": "@jupyter-widgets/base", 1341 | "_view_module_version": "1.2.0", 1342 | "_view_name": "StyleView", 1343 | "bar_color": null, 1344 | "description_width": "" 1345 | } 1346 | }, 1347 | "55690f32de47405da42599803e654fc6": { 1348 | "model_module": "@jupyter-widgets/base", 1349 | "model_module_version": "1.2.0", 1350 | "model_name": "LayoutModel", 1351 | "state": { 1352 | "_model_module": "@jupyter-widgets/base", 1353 | "_model_module_version": "1.2.0", 1354 | "_model_name": "LayoutModel", 1355 | "_view_count": null, 1356 | "_view_module": "@jupyter-widgets/base", 1357 | "_view_module_version": "1.2.0", 1358 | "_view_name": "LayoutView", 1359 | "align_content": null, 1360 | "align_items": null, 1361 | "align_self": null, 1362 | "border": null, 1363 | "bottom": null, 1364 | "display": null, 1365 | "flex": null, 1366 | "flex_flow": null, 1367 | "grid_area": null, 1368 | "grid_auto_columns": null, 1369 | "grid_auto_flow": null, 1370 | "grid_auto_rows": null, 1371 | "grid_column": null, 1372 | "grid_gap": null, 1373 | "grid_row": null, 1374 | "grid_template_areas": null, 1375 | "grid_template_columns": null, 1376 | "grid_template_rows": null, 1377 | "height": null, 1378 | "justify_content": null, 1379 | "justify_items": null, 1380 | "left": null, 1381 | "margin": null, 1382 | "max_height": null, 1383 | "max_width": null, 1384 | "min_height": null, 1385 | "min_width": null, 1386 | "object_fit": null, 1387 | "object_position": null, 1388 | "order": null, 1389 | "overflow": null, 1390 | "overflow_x": null, 1391 | "overflow_y": null, 1392 | "padding": null, 1393 | "right": null, 1394 | "top": null, 1395 | "visibility": null, 1396 | "width": null 1397 | } 1398 | }, 1399 | "58524a4c38cf4289801ee2569c07948d": { 1400 | "model_module": "@jupyter-widgets/controls", 1401 | "model_module_version": "1.5.0", 1402 | "model_name": "ProgressStyleModel", 1403 | "state": { 1404 | "_model_module": "@jupyter-widgets/controls", 1405 | "_model_module_version": "1.5.0", 1406 | "_model_name": "ProgressStyleModel", 1407 | "_view_count": null, 1408 | "_view_module": "@jupyter-widgets/base", 1409 | "_view_module_version": "1.2.0", 1410 | "_view_name": "StyleView", 1411 | "bar_color": null, 1412 | "description_width": "" 1413 | } 1414 | }, 1415 | "5cbd1a0fe71c434ca0bd655d6347947f": { 1416 | "model_module": "@jupyter-widgets/controls", 1417 | "model_module_version": "1.5.0", 1418 | "model_name": "DescriptionStyleModel", 1419 | "state": { 1420 | "_model_module": "@jupyter-widgets/controls", 1421 | "_model_module_version": "1.5.0", 1422 | "_model_name": "DescriptionStyleModel", 1423 | "_view_count": null, 1424 | "_view_module": "@jupyter-widgets/base", 1425 | "_view_module_version": "1.2.0", 1426 | "_view_name": "StyleView", 1427 | "description_width": "" 1428 | } 1429 | }, 1430 | "68025228c36c419e88ce26efcac76456": { 1431 | "model_module": "@jupyter-widgets/controls", 1432 | "model_module_version": "1.5.0", 1433 | "model_name": "HTMLModel", 1434 | "state": { 1435 | "_dom_classes": [], 1436 | "_model_module": "@jupyter-widgets/controls", 1437 | "_model_module_version": "1.5.0", 1438 | "_model_name": "HTMLModel", 1439 | "_view_count": null, 1440 | "_view_module": "@jupyter-widgets/controls", 1441 | "_view_module_version": "1.5.0", 1442 | "_view_name": "HTMLView", 1443 | "description": "", 1444 | "description_tooltip": null, 1445 | "layout": "IPY_MODEL_2b5de781f41e4d6dbdc6c7939b405af0", 1446 | "placeholder": "​", 1447 | "style": "IPY_MODEL_1b9ee7656cae4986ba45656fef43d154", 1448 | "value": " 727/727 [00:00<00:00, 48.7kB/s]" 1449 | } 1450 | }, 1451 | "6d38a5c62d1243bc9c8d68fc0a20393b": { 1452 | "model_module": "@jupyter-widgets/controls", 1453 | "model_module_version": "1.5.0", 1454 | "model_name": "DescriptionStyleModel", 1455 | "state": { 1456 | "_model_module": "@jupyter-widgets/controls", 1457 | "_model_module_version": "1.5.0", 1458 | "_model_name": "DescriptionStyleModel", 1459 | "_view_count": null, 1460 | "_view_module": "@jupyter-widgets/base", 1461 | "_view_module_version": "1.2.0", 1462 | "_view_name": "StyleView", 1463 | "description_width": "" 1464 | } 1465 | }, 1466 | "6e238cbf7ab84e5b9d6e6d756ccd9487": { 1467 | "model_module": "@jupyter-widgets/controls", 1468 | "model_module_version": "1.5.0", 1469 | "model_name": "FloatProgressModel", 1470 | "state": { 1471 | "_dom_classes": [], 1472 | "_model_module": "@jupyter-widgets/controls", 1473 | "_model_module_version": "1.5.0", 1474 | "_model_name": "FloatProgressModel", 1475 | "_view_count": null, 1476 | "_view_module": "@jupyter-widgets/controls", 1477 | "_view_module_version": "1.5.0", 1478 | "_view_name": "ProgressView", 1479 | "bar_style": "success", 1480 | "description": "", 1481 | "description_tooltip": null, 1482 | "layout": "IPY_MODEL_0f57f44193ef462490f4e1526637c874", 1483 | "max": 499723, 1484 | "min": 0, 1485 | "orientation": "horizontal", 1486 | "style": "IPY_MODEL_94f49a2b0fca4fa5a22ea8a07f628b99", 1487 | "value": 499723 1488 | } 1489 | }, 1490 | "6ea49b6d6530403fb9acf7a15493b6e9": { 1491 | "model_module": "@jupyter-widgets/controls", 1492 | "model_module_version": "1.5.0", 1493 | "model_name": "DescriptionStyleModel", 1494 | "state": { 1495 | "_model_module": "@jupyter-widgets/controls", 1496 | "_model_module_version": "1.5.0", 1497 | "_model_name": "DescriptionStyleModel", 1498 | "_view_count": null, 1499 | "_view_module": "@jupyter-widgets/base", 1500 | "_view_module_version": "1.2.0", 1501 | "_view_name": "StyleView", 1502 | "description_width": "" 1503 | } 1504 | }, 1505 | "7562626ebef743cfb5c234a8babd884d": { 1506 | "model_module": "@jupyter-widgets/controls", 1507 | "model_module_version": "1.5.0", 1508 | "model_name": "HTMLModel", 1509 | "state": { 1510 | "_dom_classes": [], 1511 | "_model_module": "@jupyter-widgets/controls", 1512 | "_model_module_version": "1.5.0", 1513 | "_model_name": "HTMLModel", 1514 | "_view_count": null, 1515 | "_view_module": "@jupyter-widgets/controls", 1516 | "_view_module_version": "1.5.0", 1517 | "_view_name": "HTMLView", 1518 | "description": "", 1519 | "description_tooltip": null, 1520 | "layout": "IPY_MODEL_8119e15ecae64227857fb5a927c95877", 1521 | "placeholder": "​", 1522 | "style": "IPY_MODEL_abfc8e73b9f94e4aa684de84a6194e20", 1523 | "value": " 411/411 [00:00<00:00, 12.4kB/s]" 1524 | } 1525 | }, 1526 | "782c8fe1763b4977a678d9fb4fd5dadc": { 1527 | "model_module": "@jupyter-widgets/base", 1528 | "model_module_version": "1.2.0", 1529 | "model_name": "LayoutModel", 1530 | "state": { 1531 | "_model_module": "@jupyter-widgets/base", 1532 | "_model_module_version": "1.2.0", 1533 | "_model_name": "LayoutModel", 1534 | "_view_count": null, 1535 | "_view_module": "@jupyter-widgets/base", 1536 | "_view_module_version": "1.2.0", 1537 | "_view_name": "LayoutView", 1538 | "align_content": null, 1539 | "align_items": null, 1540 | "align_self": null, 1541 | "border": null, 1542 | "bottom": null, 1543 | "display": null, 1544 | "flex": null, 1545 | "flex_flow": null, 1546 | "grid_area": null, 1547 | "grid_auto_columns": null, 1548 | "grid_auto_flow": null, 1549 | "grid_auto_rows": null, 1550 | "grid_column": null, 1551 | "grid_gap": null, 1552 | "grid_row": null, 1553 | "grid_template_areas": null, 1554 | "grid_template_columns": null, 1555 | "grid_template_rows": null, 1556 | "height": null, 1557 | "justify_content": null, 1558 | "justify_items": null, 1559 | "left": null, 1560 | "margin": null, 1561 | "max_height": null, 1562 | "max_width": null, 1563 | "min_height": null, 1564 | "min_width": null, 1565 | "object_fit": null, 1566 | "object_position": null, 1567 | "order": null, 1568 | "overflow": null, 1569 | "overflow_x": null, 1570 | "overflow_y": null, 1571 | "padding": null, 1572 | "right": null, 1573 | "top": null, 1574 | "visibility": null, 1575 | "width": null 1576 | } 1577 | }, 1578 | "7f160c6489514b35945179d204c06bc7": { 1579 | "model_module": "@jupyter-widgets/controls", 1580 | "model_module_version": "1.5.0", 1581 | "model_name": "HTMLModel", 1582 | "state": { 1583 | "_dom_classes": [], 1584 | "_model_module": "@jupyter-widgets/controls", 1585 | "_model_module_version": "1.5.0", 1586 | "_model_name": "HTMLModel", 1587 | "_view_count": null, 1588 | "_view_module": "@jupyter-widgets/controls", 1589 | "_view_module_version": "1.5.0", 1590 | "_view_name": "HTMLView", 1591 | "description": "", 1592 | "description_tooltip": null, 1593 | "layout": "IPY_MODEL_3f21f1f034744945a9b0657e313a6350", 1594 | "placeholder": "​", 1595 | "style": "IPY_MODEL_e98ec60115014dc0a5ec44fdec3a0073", 1596 | "value": " 500k/500k [00:00<00:00, 9.94MB/s]" 1597 | } 1598 | }, 1599 | "8119e15ecae64227857fb5a927c95877": { 1600 | "model_module": "@jupyter-widgets/base", 1601 | "model_module_version": "1.2.0", 1602 | "model_name": "LayoutModel", 1603 | "state": { 1604 | "_model_module": "@jupyter-widgets/base", 1605 | "_model_module_version": "1.2.0", 1606 | "_model_name": "LayoutModel", 1607 | "_view_count": null, 1608 | "_view_module": "@jupyter-widgets/base", 1609 | "_view_module_version": "1.2.0", 1610 | "_view_name": "LayoutView", 1611 | "align_content": null, 1612 | "align_items": null, 1613 | "align_self": null, 1614 | "border": null, 1615 | "bottom": null, 1616 | "display": null, 1617 | "flex": null, 1618 | "flex_flow": null, 1619 | "grid_area": null, 1620 | "grid_auto_columns": null, 1621 | "grid_auto_flow": null, 1622 | "grid_auto_rows": null, 1623 | "grid_column": null, 1624 | "grid_gap": null, 1625 | "grid_row": null, 1626 | "grid_template_areas": null, 1627 | "grid_template_columns": null, 1628 | "grid_template_rows": null, 1629 | "height": null, 1630 | "justify_content": null, 1631 | "justify_items": null, 1632 | "left": null, 1633 | "margin": null, 1634 | "max_height": null, 1635 | "max_width": null, 1636 | "min_height": null, 1637 | "min_width": null, 1638 | "object_fit": null, 1639 | "object_position": null, 1640 | "order": null, 1641 | "overflow": null, 1642 | "overflow_x": null, 1643 | "overflow_y": null, 1644 | "padding": null, 1645 | "right": null, 1646 | "top": null, 1647 | "visibility": null, 1648 | "width": null 1649 | } 1650 | }, 1651 | "8246ecc55fbd476ba7fca6a183e4d741": { 1652 | "model_module": "@jupyter-widgets/base", 1653 | "model_module_version": "1.2.0", 1654 | "model_name": "LayoutModel", 1655 | "state": { 1656 | "_model_module": "@jupyter-widgets/base", 1657 | "_model_module_version": "1.2.0", 1658 | "_model_name": "LayoutModel", 1659 | "_view_count": null, 1660 | "_view_module": "@jupyter-widgets/base", 1661 | "_view_module_version": "1.2.0", 1662 | "_view_name": "LayoutView", 1663 | "align_content": null, 1664 | "align_items": null, 1665 | "align_self": null, 1666 | "border": null, 1667 | "bottom": null, 1668 | "display": null, 1669 | "flex": null, 1670 | "flex_flow": null, 1671 | "grid_area": null, 1672 | "grid_auto_columns": null, 1673 | "grid_auto_flow": null, 1674 | "grid_auto_rows": null, 1675 | "grid_column": null, 1676 | "grid_gap": null, 1677 | "grid_row": null, 1678 | "grid_template_areas": null, 1679 | "grid_template_columns": null, 1680 | "grid_template_rows": null, 1681 | "height": null, 1682 | "justify_content": null, 1683 | "justify_items": null, 1684 | "left": null, 1685 | "margin": null, 1686 | "max_height": null, 1687 | "max_width": null, 1688 | "min_height": null, 1689 | "min_width": null, 1690 | "object_fit": null, 1691 | "object_position": null, 1692 | "order": null, 1693 | "overflow": null, 1694 | "overflow_x": null, 1695 | "overflow_y": null, 1696 | "padding": null, 1697 | "right": null, 1698 | "top": null, 1699 | "visibility": null, 1700 | "width": null 1701 | } 1702 | }, 1703 | "8322ac7b0d8644d78804e6d85ff05baa": { 1704 | "model_module": "@jupyter-widgets/base", 1705 | "model_module_version": "1.2.0", 1706 | "model_name": "LayoutModel", 1707 | "state": { 1708 | "_model_module": "@jupyter-widgets/base", 1709 | "_model_module_version": "1.2.0", 1710 | "_model_name": "LayoutModel", 1711 | "_view_count": null, 1712 | "_view_module": "@jupyter-widgets/base", 1713 | "_view_module_version": "1.2.0", 1714 | "_view_name": "LayoutView", 1715 | "align_content": null, 1716 | "align_items": null, 1717 | "align_self": null, 1718 | "border": null, 1719 | "bottom": null, 1720 | "display": null, 1721 | "flex": null, 1722 | "flex_flow": null, 1723 | "grid_area": null, 1724 | "grid_auto_columns": null, 1725 | "grid_auto_flow": null, 1726 | "grid_auto_rows": null, 1727 | "grid_column": null, 1728 | "grid_gap": null, 1729 | "grid_row": null, 1730 | "grid_template_areas": null, 1731 | "grid_template_columns": null, 1732 | "grid_template_rows": null, 1733 | "height": null, 1734 | "justify_content": null, 1735 | "justify_items": null, 1736 | "left": null, 1737 | "margin": null, 1738 | "max_height": null, 1739 | "max_width": null, 1740 | "min_height": null, 1741 | "min_width": null, 1742 | "object_fit": null, 1743 | "object_position": null, 1744 | "order": null, 1745 | "overflow": null, 1746 | "overflow_x": null, 1747 | "overflow_y": null, 1748 | "padding": null, 1749 | "right": null, 1750 | "top": null, 1751 | "visibility": null, 1752 | "width": null 1753 | } 1754 | }, 1755 | "88795b72502645b4bdc95beadb4471d5": { 1756 | "model_module": "@jupyter-widgets/controls", 1757 | "model_module_version": "1.5.0", 1758 | "model_name": "DescriptionStyleModel", 1759 | "state": { 1760 | "_model_module": "@jupyter-widgets/controls", 1761 | "_model_module_version": "1.5.0", 1762 | "_model_name": "DescriptionStyleModel", 1763 | "_view_count": null, 1764 | "_view_module": "@jupyter-widgets/base", 1765 | "_view_module_version": "1.2.0", 1766 | "_view_name": "StyleView", 1767 | "description_width": "" 1768 | } 1769 | }, 1770 | "8c2d4bf32f5e46539cfe78ba705e379a": { 1771 | "model_module": "@jupyter-widgets/controls", 1772 | "model_module_version": "1.5.0", 1773 | "model_name": "HBoxModel", 1774 | "state": { 1775 | "_dom_classes": [], 1776 | "_model_module": "@jupyter-widgets/controls", 1777 | "_model_module_version": "1.5.0", 1778 | "_model_name": "HBoxModel", 1779 | "_view_count": null, 1780 | "_view_module": "@jupyter-widgets/controls", 1781 | "_view_module_version": "1.5.0", 1782 | "_view_name": "HBoxView", 1783 | "box_style": "", 1784 | "children": [ 1785 | "IPY_MODEL_d160b88d7f00493d81c424be93d7030f", 1786 | "IPY_MODEL_256bd87c03b2449a9cf25303a84b96d6", 1787 | "IPY_MODEL_e05657d7cdab4c4082ebc87b6d903969" 1788 | ], 1789 | "layout": "IPY_MODEL_8246ecc55fbd476ba7fca6a183e4d741" 1790 | } 1791 | }, 1792 | "91bf2b38bd224cbfabc77aa85d75f48b": { 1793 | "model_module": "@jupyter-widgets/base", 1794 | "model_module_version": "1.2.0", 1795 | "model_name": "LayoutModel", 1796 | "state": { 1797 | "_model_module": "@jupyter-widgets/base", 1798 | "_model_module_version": "1.2.0", 1799 | "_model_name": "LayoutModel", 1800 | "_view_count": null, 1801 | "_view_module": "@jupyter-widgets/base", 1802 | "_view_module_version": "1.2.0", 1803 | "_view_name": "LayoutView", 1804 | "align_content": null, 1805 | "align_items": null, 1806 | "align_self": null, 1807 | "border": null, 1808 | "bottom": null, 1809 | "display": null, 1810 | "flex": null, 1811 | "flex_flow": null, 1812 | "grid_area": null, 1813 | "grid_auto_columns": null, 1814 | "grid_auto_flow": null, 1815 | "grid_auto_rows": null, 1816 | "grid_column": null, 1817 | "grid_gap": null, 1818 | "grid_row": null, 1819 | "grid_template_areas": null, 1820 | "grid_template_columns": null, 1821 | "grid_template_rows": null, 1822 | "height": null, 1823 | "justify_content": null, 1824 | "justify_items": null, 1825 | "left": null, 1826 | "margin": null, 1827 | "max_height": null, 1828 | "max_width": null, 1829 | "min_height": null, 1830 | "min_width": null, 1831 | "object_fit": null, 1832 | "object_position": null, 1833 | "order": null, 1834 | "overflow": null, 1835 | "overflow_x": null, 1836 | "overflow_y": null, 1837 | "padding": null, 1838 | "right": null, 1839 | "top": null, 1840 | "visibility": null, 1841 | "width": null 1842 | } 1843 | }, 1844 | "94f49a2b0fca4fa5a22ea8a07f628b99": { 1845 | "model_module": "@jupyter-widgets/controls", 1846 | "model_module_version": "1.5.0", 1847 | "model_name": "ProgressStyleModel", 1848 | "state": { 1849 | "_model_module": "@jupyter-widgets/controls", 1850 | "_model_module_version": "1.5.0", 1851 | "_model_name": "ProgressStyleModel", 1852 | "_view_count": null, 1853 | "_view_module": "@jupyter-widgets/base", 1854 | "_view_module_version": "1.2.0", 1855 | "_view_name": "StyleView", 1856 | "bar_color": null, 1857 | "description_width": "" 1858 | } 1859 | }, 1860 | "95584181b8fd41eebb322c951c72c1f6": { 1861 | "model_module": "@jupyter-widgets/base", 1862 | "model_module_version": "1.2.0", 1863 | "model_name": "LayoutModel", 1864 | "state": { 1865 | "_model_module": "@jupyter-widgets/base", 1866 | "_model_module_version": "1.2.0", 1867 | "_model_name": "LayoutModel", 1868 | "_view_count": null, 1869 | "_view_module": "@jupyter-widgets/base", 1870 | "_view_module_version": "1.2.0", 1871 | "_view_name": "LayoutView", 1872 | "align_content": null, 1873 | "align_items": null, 1874 | "align_self": null, 1875 | "border": null, 1876 | "bottom": null, 1877 | "display": null, 1878 | "flex": null, 1879 | "flex_flow": null, 1880 | "grid_area": null, 1881 | "grid_auto_columns": null, 1882 | "grid_auto_flow": null, 1883 | "grid_auto_rows": null, 1884 | "grid_column": null, 1885 | "grid_gap": null, 1886 | "grid_row": null, 1887 | "grid_template_areas": null, 1888 | "grid_template_columns": null, 1889 | "grid_template_rows": null, 1890 | "height": null, 1891 | "justify_content": null, 1892 | "justify_items": null, 1893 | "left": null, 1894 | "margin": null, 1895 | "max_height": null, 1896 | "max_width": null, 1897 | "min_height": null, 1898 | "min_width": null, 1899 | "object_fit": null, 1900 | "object_position": null, 1901 | "order": null, 1902 | "overflow": null, 1903 | "overflow_x": null, 1904 | "overflow_y": null, 1905 | "padding": null, 1906 | "right": null, 1907 | "top": null, 1908 | "visibility": null, 1909 | "width": null 1910 | } 1911 | }, 1912 | "975d4035cf144139b455fcb06dde003e": { 1913 | "model_module": "@jupyter-widgets/controls", 1914 | "model_module_version": "1.5.0", 1915 | "model_name": "HBoxModel", 1916 | "state": { 1917 | "_dom_classes": [], 1918 | "_model_module": "@jupyter-widgets/controls", 1919 | "_model_module_version": "1.5.0", 1920 | "_model_name": "HBoxModel", 1921 | "_view_count": null, 1922 | "_view_module": "@jupyter-widgets/controls", 1923 | "_view_module_version": "1.5.0", 1924 | "_view_name": "HBoxView", 1925 | "box_style": "", 1926 | "children": [ 1927 | "IPY_MODEL_4c44190d75cb4f21826fee5b25052886", 1928 | "IPY_MODEL_48e7aedb9eb8467ba5418860c5d99a07", 1929 | "IPY_MODEL_1a65ae611f96427e85b5e5ec21bca22d" 1930 | ], 1931 | "layout": "IPY_MODEL_49536e13224f481bbb4a02a1391ab347" 1932 | } 1933 | }, 1934 | "9b5632fd42b7452295e7873e7c7a230f": { 1935 | "model_module": "@jupyter-widgets/base", 1936 | "model_module_version": "1.2.0", 1937 | "model_name": "LayoutModel", 1938 | "state": { 1939 | "_model_module": "@jupyter-widgets/base", 1940 | "_model_module_version": "1.2.0", 1941 | "_model_name": "LayoutModel", 1942 | "_view_count": null, 1943 | "_view_module": "@jupyter-widgets/base", 1944 | "_view_module_version": "1.2.0", 1945 | "_view_name": "LayoutView", 1946 | "align_content": null, 1947 | "align_items": null, 1948 | "align_self": null, 1949 | "border": null, 1950 | "bottom": null, 1951 | "display": null, 1952 | "flex": null, 1953 | "flex_flow": null, 1954 | "grid_area": null, 1955 | "grid_auto_columns": null, 1956 | "grid_auto_flow": null, 1957 | "grid_auto_rows": null, 1958 | "grid_column": null, 1959 | "grid_gap": null, 1960 | "grid_row": null, 1961 | "grid_template_areas": null, 1962 | "grid_template_columns": null, 1963 | "grid_template_rows": null, 1964 | "height": null, 1965 | "justify_content": null, 1966 | "justify_items": null, 1967 | "left": null, 1968 | "margin": null, 1969 | "max_height": null, 1970 | "max_width": null, 1971 | "min_height": null, 1972 | "min_width": null, 1973 | "object_fit": null, 1974 | "object_position": null, 1975 | "order": null, 1976 | "overflow": null, 1977 | "overflow_x": null, 1978 | "overflow_y": null, 1979 | "padding": null, 1980 | "right": null, 1981 | "top": null, 1982 | "visibility": null, 1983 | "width": null 1984 | } 1985 | }, 1986 | "9c5163fdeb61471bbafdfe08393a6d80": { 1987 | "model_module": "@jupyter-widgets/controls", 1988 | "model_module_version": "1.5.0", 1989 | "model_name": "DescriptionStyleModel", 1990 | "state": { 1991 | "_model_module": "@jupyter-widgets/controls", 1992 | "_model_module_version": "1.5.0", 1993 | "_model_name": "DescriptionStyleModel", 1994 | "_view_count": null, 1995 | "_view_module": "@jupyter-widgets/base", 1996 | "_view_module_version": "1.2.0", 1997 | "_view_name": "StyleView", 1998 | "description_width": "" 1999 | } 2000 | }, 2001 | "a001503024fe44c8acc0d5e1a6d561b7": { 2002 | "model_module": "@jupyter-widgets/controls", 2003 | "model_module_version": "1.5.0", 2004 | "model_name": "HBoxModel", 2005 | "state": { 2006 | "_dom_classes": [], 2007 | "_model_module": "@jupyter-widgets/controls", 2008 | "_model_module_version": "1.5.0", 2009 | "_model_name": "HBoxModel", 2010 | "_view_count": null, 2011 | "_view_module": "@jupyter-widgets/controls", 2012 | "_view_module_version": "1.5.0", 2013 | "_view_name": "HBoxView", 2014 | "box_style": "", 2015 | "children": [ 2016 | "IPY_MODEL_0c0d9cda00e045c2a16eab19347b49d2", 2017 | "IPY_MODEL_aa4160a1837f427bb8ae673d31a0f24a", 2018 | "IPY_MODEL_7562626ebef743cfb5c234a8babd884d" 2019 | ], 2020 | "layout": "IPY_MODEL_ed4321893f5f4630bf2dbf1f4a38bd39" 2021 | } 2022 | }, 2023 | "a16259ceb15b464fa87355a14aaac331": { 2024 | "model_module": "@jupyter-widgets/controls", 2025 | "model_module_version": "1.5.0", 2026 | "model_name": "ProgressStyleModel", 2027 | "state": { 2028 | "_model_module": "@jupyter-widgets/controls", 2029 | "_model_module_version": "1.5.0", 2030 | "_model_name": "ProgressStyleModel", 2031 | "_view_count": null, 2032 | "_view_module": "@jupyter-widgets/base", 2033 | "_view_module_version": "1.2.0", 2034 | "_view_name": "StyleView", 2035 | "bar_color": null, 2036 | "description_width": "" 2037 | } 2038 | }, 2039 | "a44eefdea1bd4b83a210f226f074e1d0": { 2040 | "model_module": "@jupyter-widgets/base", 2041 | "model_module_version": "1.2.0", 2042 | "model_name": "LayoutModel", 2043 | "state": { 2044 | "_model_module": "@jupyter-widgets/base", 2045 | "_model_module_version": "1.2.0", 2046 | "_model_name": "LayoutModel", 2047 | "_view_count": null, 2048 | "_view_module": "@jupyter-widgets/base", 2049 | "_view_module_version": "1.2.0", 2050 | "_view_name": "LayoutView", 2051 | "align_content": null, 2052 | "align_items": null, 2053 | "align_self": null, 2054 | "border": null, 2055 | "bottom": null, 2056 | "display": null, 2057 | "flex": null, 2058 | "flex_flow": null, 2059 | "grid_area": null, 2060 | "grid_auto_columns": null, 2061 | "grid_auto_flow": null, 2062 | "grid_auto_rows": null, 2063 | "grid_column": null, 2064 | "grid_gap": null, 2065 | "grid_row": null, 2066 | "grid_template_areas": null, 2067 | "grid_template_columns": null, 2068 | "grid_template_rows": null, 2069 | "height": null, 2070 | "justify_content": null, 2071 | "justify_items": null, 2072 | "left": null, 2073 | "margin": null, 2074 | "max_height": null, 2075 | "max_width": null, 2076 | "min_height": null, 2077 | "min_width": null, 2078 | "object_fit": null, 2079 | "object_position": null, 2080 | "order": null, 2081 | "overflow": null, 2082 | "overflow_x": null, 2083 | "overflow_y": null, 2084 | "padding": null, 2085 | "right": null, 2086 | "top": null, 2087 | "visibility": null, 2088 | "width": null 2089 | } 2090 | }, 2091 | "aa4160a1837f427bb8ae673d31a0f24a": { 2092 | "model_module": "@jupyter-widgets/controls", 2093 | "model_module_version": "1.5.0", 2094 | "model_name": "FloatProgressModel", 2095 | "state": { 2096 | "_dom_classes": [], 2097 | "_model_module": "@jupyter-widgets/controls", 2098 | "_model_module_version": "1.5.0", 2099 | "_model_name": "FloatProgressModel", 2100 | "_view_count": null, 2101 | "_view_module": "@jupyter-widgets/controls", 2102 | "_view_module_version": "1.5.0", 2103 | "_view_name": "ProgressView", 2104 | "bar_style": "success", 2105 | "description": "", 2106 | "description_tooltip": null, 2107 | "layout": "IPY_MODEL_bb6f8aa46e3a459aacc6c99f9f76c72c", 2108 | "max": 411, 2109 | "min": 0, 2110 | "orientation": "horizontal", 2111 | "style": "IPY_MODEL_1f057e7c483d41a8b63f8b59480962f1", 2112 | "value": 411 2113 | } 2114 | }, 2115 | "aadb942209e84408a2bf11e55412b574": { 2116 | "model_module": "@jupyter-widgets/controls", 2117 | "model_module_version": "1.5.0", 2118 | "model_name": "HBoxModel", 2119 | "state": { 2120 | "_dom_classes": [], 2121 | "_model_module": "@jupyter-widgets/controls", 2122 | "_model_module_version": "1.5.0", 2123 | "_model_name": "HBoxModel", 2124 | "_view_count": null, 2125 | "_view_module": "@jupyter-widgets/controls", 2126 | "_view_module_version": "1.5.0", 2127 | "_view_name": "HBoxView", 2128 | "box_style": "", 2129 | "children": [ 2130 | "IPY_MODEL_fa9cafdcab2f439b9196f6f05166d898", 2131 | "IPY_MODEL_240fed609ad44aafad099550d3c265df", 2132 | "IPY_MODEL_4afabea1be0249068ad394e7265abf24" 2133 | ], 2134 | "layout": "IPY_MODEL_91bf2b38bd224cbfabc77aa85d75f48b" 2135 | } 2136 | }, 2137 | "abfc8e73b9f94e4aa684de84a6194e20": { 2138 | "model_module": "@jupyter-widgets/controls", 2139 | "model_module_version": "1.5.0", 2140 | "model_name": "DescriptionStyleModel", 2141 | "state": { 2142 | "_model_module": "@jupyter-widgets/controls", 2143 | "_model_module_version": "1.5.0", 2144 | "_model_name": "DescriptionStyleModel", 2145 | "_view_count": null, 2146 | "_view_module": "@jupyter-widgets/base", 2147 | "_view_module_version": "1.2.0", 2148 | "_view_name": "StyleView", 2149 | "description_width": "" 2150 | } 2151 | }, 2152 | "b3ffc98ce8fa41ddad546ceed1f3e0fb": { 2153 | "model_module": "@jupyter-widgets/base", 2154 | "model_module_version": "1.2.0", 2155 | "model_name": "LayoutModel", 2156 | "state": { 2157 | "_model_module": "@jupyter-widgets/base", 2158 | "_model_module_version": "1.2.0", 2159 | "_model_name": "LayoutModel", 2160 | "_view_count": null, 2161 | "_view_module": "@jupyter-widgets/base", 2162 | "_view_module_version": "1.2.0", 2163 | "_view_name": "LayoutView", 2164 | "align_content": null, 2165 | "align_items": null, 2166 | "align_self": null, 2167 | "border": null, 2168 | "bottom": null, 2169 | "display": null, 2170 | "flex": null, 2171 | "flex_flow": null, 2172 | "grid_area": null, 2173 | "grid_auto_columns": null, 2174 | "grid_auto_flow": null, 2175 | "grid_auto_rows": null, 2176 | "grid_column": null, 2177 | "grid_gap": null, 2178 | "grid_row": null, 2179 | "grid_template_areas": null, 2180 | "grid_template_columns": null, 2181 | "grid_template_rows": null, 2182 | "height": null, 2183 | "justify_content": null, 2184 | "justify_items": null, 2185 | "left": null, 2186 | "margin": null, 2187 | "max_height": null, 2188 | "max_width": null, 2189 | "min_height": null, 2190 | "min_width": null, 2191 | "object_fit": null, 2192 | "object_position": null, 2193 | "order": null, 2194 | "overflow": null, 2195 | "overflow_x": null, 2196 | "overflow_y": null, 2197 | "padding": null, 2198 | "right": null, 2199 | "top": null, 2200 | "visibility": null, 2201 | "width": null 2202 | } 2203 | }, 2204 | "b568300f99bd4bd48e8f37029282eed6": { 2205 | "model_module": "@jupyter-widgets/base", 2206 | "model_module_version": "1.2.0", 2207 | "model_name": "LayoutModel", 2208 | "state": { 2209 | "_model_module": "@jupyter-widgets/base", 2210 | "_model_module_version": "1.2.0", 2211 | "_model_name": "LayoutModel", 2212 | "_view_count": null, 2213 | "_view_module": "@jupyter-widgets/base", 2214 | "_view_module_version": "1.2.0", 2215 | "_view_name": "LayoutView", 2216 | "align_content": null, 2217 | "align_items": null, 2218 | "align_self": null, 2219 | "border": null, 2220 | "bottom": null, 2221 | "display": null, 2222 | "flex": null, 2223 | "flex_flow": null, 2224 | "grid_area": null, 2225 | "grid_auto_columns": null, 2226 | "grid_auto_flow": null, 2227 | "grid_auto_rows": null, 2228 | "grid_column": null, 2229 | "grid_gap": null, 2230 | "grid_row": null, 2231 | "grid_template_areas": null, 2232 | "grid_template_columns": null, 2233 | "grid_template_rows": null, 2234 | "height": null, 2235 | "justify_content": null, 2236 | "justify_items": null, 2237 | "left": null, 2238 | "margin": null, 2239 | "max_height": null, 2240 | "max_width": null, 2241 | "min_height": null, 2242 | "min_width": null, 2243 | "object_fit": null, 2244 | "object_position": null, 2245 | "order": null, 2246 | "overflow": null, 2247 | "overflow_x": null, 2248 | "overflow_y": null, 2249 | "padding": null, 2250 | "right": null, 2251 | "top": null, 2252 | "visibility": null, 2253 | "width": null 2254 | } 2255 | }, 2256 | "b5ca1b962b874a50bd3657a1d74ea998": { 2257 | "model_module": "@jupyter-widgets/controls", 2258 | "model_module_version": "1.5.0", 2259 | "model_name": "DescriptionStyleModel", 2260 | "state": { 2261 | "_model_module": "@jupyter-widgets/controls", 2262 | "_model_module_version": "1.5.0", 2263 | "_model_name": "DescriptionStyleModel", 2264 | "_view_count": null, 2265 | "_view_module": "@jupyter-widgets/base", 2266 | "_view_module_version": "1.2.0", 2267 | "_view_name": "StyleView", 2268 | "description_width": "" 2269 | } 2270 | }, 2271 | "b73c8d9ccf494d6193ff8262777a98da": { 2272 | "model_module": "@jupyter-widgets/base", 2273 | "model_module_version": "1.2.0", 2274 | "model_name": "LayoutModel", 2275 | "state": { 2276 | "_model_module": "@jupyter-widgets/base", 2277 | "_model_module_version": "1.2.0", 2278 | "_model_name": "LayoutModel", 2279 | "_view_count": null, 2280 | "_view_module": "@jupyter-widgets/base", 2281 | "_view_module_version": "1.2.0", 2282 | "_view_name": "LayoutView", 2283 | "align_content": null, 2284 | "align_items": null, 2285 | "align_self": null, 2286 | "border": null, 2287 | "bottom": null, 2288 | "display": null, 2289 | "flex": null, 2290 | "flex_flow": null, 2291 | "grid_area": null, 2292 | "grid_auto_columns": null, 2293 | "grid_auto_flow": null, 2294 | "grid_auto_rows": null, 2295 | "grid_column": null, 2296 | "grid_gap": null, 2297 | "grid_row": null, 2298 | "grid_template_areas": null, 2299 | "grid_template_columns": null, 2300 | "grid_template_rows": null, 2301 | "height": null, 2302 | "justify_content": null, 2303 | "justify_items": null, 2304 | "left": null, 2305 | "margin": null, 2306 | "max_height": null, 2307 | "max_width": null, 2308 | "min_height": null, 2309 | "min_width": null, 2310 | "object_fit": null, 2311 | "object_position": null, 2312 | "order": null, 2313 | "overflow": null, 2314 | "overflow_x": null, 2315 | "overflow_y": null, 2316 | "padding": null, 2317 | "right": null, 2318 | "top": null, 2319 | "visibility": null, 2320 | "width": null 2321 | } 2322 | }, 2323 | "bb6f8aa46e3a459aacc6c99f9f76c72c": { 2324 | "model_module": "@jupyter-widgets/base", 2325 | "model_module_version": "1.2.0", 2326 | "model_name": "LayoutModel", 2327 | "state": { 2328 | "_model_module": "@jupyter-widgets/base", 2329 | "_model_module_version": "1.2.0", 2330 | "_model_name": "LayoutModel", 2331 | "_view_count": null, 2332 | "_view_module": "@jupyter-widgets/base", 2333 | "_view_module_version": "1.2.0", 2334 | "_view_name": "LayoutView", 2335 | "align_content": null, 2336 | "align_items": null, 2337 | "align_self": null, 2338 | "border": null, 2339 | "bottom": null, 2340 | "display": null, 2341 | "flex": null, 2342 | "flex_flow": null, 2343 | "grid_area": null, 2344 | "grid_auto_columns": null, 2345 | "grid_auto_flow": null, 2346 | "grid_auto_rows": null, 2347 | "grid_column": null, 2348 | "grid_gap": null, 2349 | "grid_row": null, 2350 | "grid_template_areas": null, 2351 | "grid_template_columns": null, 2352 | "grid_template_rows": null, 2353 | "height": null, 2354 | "justify_content": null, 2355 | "justify_items": null, 2356 | "left": null, 2357 | "margin": null, 2358 | "max_height": null, 2359 | "max_width": null, 2360 | "min_height": null, 2361 | "min_width": null, 2362 | "object_fit": null, 2363 | "object_position": null, 2364 | "order": null, 2365 | "overflow": null, 2366 | "overflow_x": null, 2367 | "overflow_y": null, 2368 | "padding": null, 2369 | "right": null, 2370 | "top": null, 2371 | "visibility": null, 2372 | "width": null 2373 | } 2374 | }, 2375 | "c6ab5db12b0647f5989e233d914edd87": { 2376 | "model_module": "@jupyter-widgets/controls", 2377 | "model_module_version": "1.5.0", 2378 | "model_name": "HTMLModel", 2379 | "state": { 2380 | "_dom_classes": [], 2381 | "_model_module": "@jupyter-widgets/controls", 2382 | "_model_module_version": "1.5.0", 2383 | "_model_name": "HTMLModel", 2384 | "_view_count": null, 2385 | "_view_module": "@jupyter-widgets/controls", 2386 | "_view_module_version": "1.5.0", 2387 | "_view_name": "HTMLView", 2388 | "description": "", 2389 | "description_tooltip": null, 2390 | "layout": "IPY_MODEL_8322ac7b0d8644d78804e6d85ff05baa", 2391 | "placeholder": "​", 2392 | "style": "IPY_MODEL_1fecf11618b244909505323d2968e5fc", 2393 | "value": "tokenizer.model: 100%" 2394 | } 2395 | }, 2396 | "c8219473c99d480081104fe4551036de": { 2397 | "model_module": "@jupyter-widgets/base", 2398 | "model_module_version": "1.2.0", 2399 | "model_name": "LayoutModel", 2400 | "state": { 2401 | "_model_module": "@jupyter-widgets/base", 2402 | "_model_module_version": "1.2.0", 2403 | "_model_name": "LayoutModel", 2404 | "_view_count": null, 2405 | "_view_module": "@jupyter-widgets/base", 2406 | "_view_module_version": "1.2.0", 2407 | "_view_name": "LayoutView", 2408 | "align_content": null, 2409 | "align_items": null, 2410 | "align_self": null, 2411 | "border": null, 2412 | "bottom": null, 2413 | "display": null, 2414 | "flex": null, 2415 | "flex_flow": null, 2416 | "grid_area": null, 2417 | "grid_auto_columns": null, 2418 | "grid_auto_flow": null, 2419 | "grid_auto_rows": null, 2420 | "grid_column": null, 2421 | "grid_gap": null, 2422 | "grid_row": null, 2423 | "grid_template_areas": null, 2424 | "grid_template_columns": null, 2425 | "grid_template_rows": null, 2426 | "height": null, 2427 | "justify_content": null, 2428 | "justify_items": null, 2429 | "left": null, 2430 | "margin": null, 2431 | "max_height": null, 2432 | "max_width": null, 2433 | "min_height": null, 2434 | "min_width": null, 2435 | "object_fit": null, 2436 | "object_position": null, 2437 | "order": null, 2438 | "overflow": null, 2439 | "overflow_x": null, 2440 | "overflow_y": null, 2441 | "padding": null, 2442 | "right": null, 2443 | "top": null, 2444 | "visibility": null, 2445 | "width": null 2446 | } 2447 | }, 2448 | "cfee62ce88bf4a409ce34337556539cf": { 2449 | "model_module": "@jupyter-widgets/base", 2450 | "model_module_version": "1.2.0", 2451 | "model_name": "LayoutModel", 2452 | "state": { 2453 | "_model_module": "@jupyter-widgets/base", 2454 | "_model_module_version": "1.2.0", 2455 | "_model_name": "LayoutModel", 2456 | "_view_count": null, 2457 | "_view_module": "@jupyter-widgets/base", 2458 | "_view_module_version": "1.2.0", 2459 | "_view_name": "LayoutView", 2460 | "align_content": null, 2461 | "align_items": null, 2462 | "align_self": null, 2463 | "border": null, 2464 | "bottom": null, 2465 | "display": null, 2466 | "flex": null, 2467 | "flex_flow": null, 2468 | "grid_area": null, 2469 | "grid_auto_columns": null, 2470 | "grid_auto_flow": null, 2471 | "grid_auto_rows": null, 2472 | "grid_column": null, 2473 | "grid_gap": null, 2474 | "grid_row": null, 2475 | "grid_template_areas": null, 2476 | "grid_template_columns": null, 2477 | "grid_template_rows": null, 2478 | "height": null, 2479 | "justify_content": null, 2480 | "justify_items": null, 2481 | "left": null, 2482 | "margin": null, 2483 | "max_height": null, 2484 | "max_width": null, 2485 | "min_height": null, 2486 | "min_width": null, 2487 | "object_fit": null, 2488 | "object_position": null, 2489 | "order": null, 2490 | "overflow": null, 2491 | "overflow_x": null, 2492 | "overflow_y": null, 2493 | "padding": null, 2494 | "right": null, 2495 | "top": null, 2496 | "visibility": null, 2497 | "width": null 2498 | } 2499 | }, 2500 | "d160b88d7f00493d81c424be93d7030f": { 2501 | "model_module": "@jupyter-widgets/controls", 2502 | "model_module_version": "1.5.0", 2503 | "model_name": "HTMLModel", 2504 | "state": { 2505 | "_dom_classes": [], 2506 | "_model_module": "@jupyter-widgets/controls", 2507 | "_model_module_version": "1.5.0", 2508 | "_model_name": "HTMLModel", 2509 | "_view_count": null, 2510 | "_view_module": "@jupyter-widgets/controls", 2511 | "_view_module_version": "1.5.0", 2512 | "_view_name": "HTMLView", 2513 | "description": "", 2514 | "description_tooltip": null, 2515 | "layout": "IPY_MODEL_95584181b8fd41eebb322c951c72c1f6", 2516 | "placeholder": "​", 2517 | "style": "IPY_MODEL_88795b72502645b4bdc95beadb4471d5", 2518 | "value": "pytorch_model.bin: 100%" 2519 | } 2520 | }, 2521 | "d3b47705f4e14680b13b21be52d2e604": { 2522 | "model_module": "@jupyter-widgets/controls", 2523 | "model_module_version": "1.5.0", 2524 | "model_name": "ProgressStyleModel", 2525 | "state": { 2526 | "_model_module": "@jupyter-widgets/controls", 2527 | "_model_module_version": "1.5.0", 2528 | "_model_name": "ProgressStyleModel", 2529 | "_view_count": null, 2530 | "_view_module": "@jupyter-widgets/base", 2531 | "_view_module_version": "1.2.0", 2532 | "_view_name": "StyleView", 2533 | "bar_color": null, 2534 | "description_width": "" 2535 | } 2536 | }, 2537 | "dd92cd36ad544a17a22374c0896947da": { 2538 | "model_module": "@jupyter-widgets/controls", 2539 | "model_module_version": "1.5.0", 2540 | "model_name": "HTMLModel", 2541 | "state": { 2542 | "_dom_classes": [], 2543 | "_model_module": "@jupyter-widgets/controls", 2544 | "_model_module_version": "1.5.0", 2545 | "_model_name": "HTMLModel", 2546 | "_view_count": null, 2547 | "_view_module": "@jupyter-widgets/controls", 2548 | "_view_module_version": "1.5.0", 2549 | "_view_name": "HTMLView", 2550 | "description": "", 2551 | "description_tooltip": null, 2552 | "layout": "IPY_MODEL_b568300f99bd4bd48e8f37029282eed6", 2553 | "placeholder": "​", 2554 | "style": "IPY_MODEL_13a9fb93d75e41d180e7673acad43dab", 2555 | "value": "generation_config.json: 100%" 2556 | } 2557 | }, 2558 | "e022c45f119c4eb281d5679b80d6387d": { 2559 | "model_module": "@jupyter-widgets/controls", 2560 | "model_module_version": "1.5.0", 2561 | "model_name": "HBoxModel", 2562 | "state": { 2563 | "_dom_classes": [], 2564 | "_model_module": "@jupyter-widgets/controls", 2565 | "_model_module_version": "1.5.0", 2566 | "_model_name": "HBoxModel", 2567 | "_view_count": null, 2568 | "_view_module": "@jupyter-widgets/controls", 2569 | "_view_module_version": "1.5.0", 2570 | "_view_name": "HBoxView", 2571 | "box_style": "", 2572 | "children": [ 2573 | "IPY_MODEL_c6ab5db12b0647f5989e233d914edd87", 2574 | "IPY_MODEL_6e238cbf7ab84e5b9d6e6d756ccd9487", 2575 | "IPY_MODEL_7f160c6489514b35945179d204c06bc7" 2576 | ], 2577 | "layout": "IPY_MODEL_24ffd51165764526962c2de4c004ae51" 2578 | } 2579 | }, 2580 | "e05657d7cdab4c4082ebc87b6d903969": { 2581 | "model_module": "@jupyter-widgets/controls", 2582 | "model_module_version": "1.5.0", 2583 | "model_name": "HTMLModel", 2584 | "state": { 2585 | "_dom_classes": [], 2586 | "_model_module": "@jupyter-widgets/controls", 2587 | "_model_module_version": "1.5.0", 2588 | "_model_name": "HTMLModel", 2589 | "_view_count": null, 2590 | "_view_module": "@jupyter-widgets/controls", 2591 | "_view_module_version": "1.5.0", 2592 | "_view_name": "HTMLView", 2593 | "description": "", 2594 | "description_tooltip": null, 2595 | "layout": "IPY_MODEL_55690f32de47405da42599803e654fc6", 2596 | "placeholder": "​", 2597 | "style": "IPY_MODEL_6ea49b6d6530403fb9acf7a15493b6e9", 2598 | "value": " 5.38G/5.38G [00:45<00:00, 216MB/s]" 2599 | } 2600 | }, 2601 | "e4685f450cb543f482d41881e2555b34": { 2602 | "model_module": "@jupyter-widgets/base", 2603 | "model_module_version": "1.2.0", 2604 | "model_name": "LayoutModel", 2605 | "state": { 2606 | "_model_module": "@jupyter-widgets/base", 2607 | "_model_module_version": "1.2.0", 2608 | "_model_name": "LayoutModel", 2609 | "_view_count": null, 2610 | "_view_module": "@jupyter-widgets/base", 2611 | "_view_module_version": "1.2.0", 2612 | "_view_name": "LayoutView", 2613 | "align_content": null, 2614 | "align_items": null, 2615 | "align_self": null, 2616 | "border": null, 2617 | "bottom": null, 2618 | "display": null, 2619 | "flex": null, 2620 | "flex_flow": null, 2621 | "grid_area": null, 2622 | "grid_auto_columns": null, 2623 | "grid_auto_flow": null, 2624 | "grid_auto_rows": null, 2625 | "grid_column": null, 2626 | "grid_gap": null, 2627 | "grid_row": null, 2628 | "grid_template_areas": null, 2629 | "grid_template_columns": null, 2630 | "grid_template_rows": null, 2631 | "height": null, 2632 | "justify_content": null, 2633 | "justify_items": null, 2634 | "left": null, 2635 | "margin": null, 2636 | "max_height": null, 2637 | "max_width": null, 2638 | "min_height": null, 2639 | "min_width": null, 2640 | "object_fit": null, 2641 | "object_position": null, 2642 | "order": null, 2643 | "overflow": null, 2644 | "overflow_x": null, 2645 | "overflow_y": null, 2646 | "padding": null, 2647 | "right": null, 2648 | "top": null, 2649 | "visibility": null, 2650 | "width": null 2651 | } 2652 | }, 2653 | "e98ec60115014dc0a5ec44fdec3a0073": { 2654 | "model_module": "@jupyter-widgets/controls", 2655 | "model_module_version": "1.5.0", 2656 | "model_name": "DescriptionStyleModel", 2657 | "state": { 2658 | "_model_module": "@jupyter-widgets/controls", 2659 | "_model_module_version": "1.5.0", 2660 | "_model_name": "DescriptionStyleModel", 2661 | "_view_count": null, 2662 | "_view_module": "@jupyter-widgets/base", 2663 | "_view_module_version": "1.2.0", 2664 | "_view_name": "StyleView", 2665 | "description_width": "" 2666 | } 2667 | }, 2668 | "ed4321893f5f4630bf2dbf1f4a38bd39": { 2669 | "model_module": "@jupyter-widgets/base", 2670 | "model_module_version": "1.2.0", 2671 | "model_name": "LayoutModel", 2672 | "state": { 2673 | "_model_module": "@jupyter-widgets/base", 2674 | "_model_module_version": "1.2.0", 2675 | "_model_name": "LayoutModel", 2676 | "_view_count": null, 2677 | "_view_module": "@jupyter-widgets/base", 2678 | "_view_module_version": "1.2.0", 2679 | "_view_name": "LayoutView", 2680 | "align_content": null, 2681 | "align_items": null, 2682 | "align_self": null, 2683 | "border": null, 2684 | "bottom": null, 2685 | "display": null, 2686 | "flex": null, 2687 | "flex_flow": null, 2688 | "grid_area": null, 2689 | "grid_auto_columns": null, 2690 | "grid_auto_flow": null, 2691 | "grid_auto_rows": null, 2692 | "grid_column": null, 2693 | "grid_gap": null, 2694 | "grid_row": null, 2695 | "grid_template_areas": null, 2696 | "grid_template_columns": null, 2697 | "grid_template_rows": null, 2698 | "height": null, 2699 | "justify_content": null, 2700 | "justify_items": null, 2701 | "left": null, 2702 | "margin": null, 2703 | "max_height": null, 2704 | "max_width": null, 2705 | "min_height": null, 2706 | "min_width": null, 2707 | "object_fit": null, 2708 | "object_position": null, 2709 | "order": null, 2710 | "overflow": null, 2711 | "overflow_x": null, 2712 | "overflow_y": null, 2713 | "padding": null, 2714 | "right": null, 2715 | "top": null, 2716 | "visibility": null, 2717 | "width": null 2718 | } 2719 | }, 2720 | "fa9cafdcab2f439b9196f6f05166d898": { 2721 | "model_module": "@jupyter-widgets/controls", 2722 | "model_module_version": "1.5.0", 2723 | "model_name": "HTMLModel", 2724 | "state": { 2725 | "_dom_classes": [], 2726 | "_model_module": "@jupyter-widgets/controls", 2727 | "_model_module_version": "1.5.0", 2728 | "_model_name": "HTMLModel", 2729 | "_view_count": null, 2730 | "_view_module": "@jupyter-widgets/controls", 2731 | "_view_module_version": "1.5.0", 2732 | "_view_name": "HTMLView", 2733 | "description": "", 2734 | "description_tooltip": null, 2735 | "layout": "IPY_MODEL_1a0d37386914443f983a89cf56d7a900", 2736 | "placeholder": "​", 2737 | "style": "IPY_MODEL_42075dab0a5443b0b65e1ff8b2cb5e89", 2738 | "value": "tokenizer.json: 100%" 2739 | } 2740 | } 2741 | } 2742 | } 2743 | }, 2744 | "nbformat": 4, 2745 | "nbformat_minor": 0 2746 | } 2747 | -------------------------------------------------------------------------------- /processor.py: -------------------------------------------------------------------------------- 1 | from torch.nn.utils.rnn import pad_sequence 2 | import torch.nn.functional as F 3 | import torch 4 | from utils import greedy_packing 5 | 6 | 7 | class PrePackProcessor: 8 | def __init__(self, tokenizer, packing_fn=None): 9 | self.tokenizer = tokenizer 10 | self.pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 11 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | if packing_fn: 13 | self.packing_fn = packing_fn 14 | else: 15 | self.packing_fn = greedy_packing 16 | 17 | def process(self, length_dict, packing_dict, token_dict): 18 | ''' 19 | Takes batch of tokens and packs them according to the bin-packing algorithm. 20 | 21 | Args: 22 | length_dict (dict): maps original batch index to its prompt length 23 | packing_dict (dict): a mapping between prompt length and bin index 24 | token_dict (dict): maps original batch index to its tokenized prompt 25 | 26 | Returns: 27 | new_tokens (Tensor): packed sequence of tokens 28 | new_positions (Tensor): restart positions for the new_tokens 29 | new_mask (Tensor): independent mask for the new_tokens 30 | restart_dict (dict): mapping restart index and original batch index 31 | ''' 32 | new_positions = [] 33 | new_tokens = [] 34 | restart_dict = {0: -1} # -1 is a placeholder 35 | restart_index = 0 36 | 37 | for key in packing_dict: 38 | new_tokens += token_dict[key][:-1] # omit final token for generation 39 | restart_index += length_dict[key] - 1 40 | new_positions += list(range(length_dict[key] - 1)) 41 | restart_dict[restart_index] = key 42 | 43 | restart_indices = list(restart_dict.keys()) 44 | size = len(new_tokens) 45 | new_mask = torch.zeros(size, size, device=self.device) 46 | 47 | for i in range(len(restart_indices) - 1): 48 | start = restart_indices[i] 49 | end = restart_indices[i + 1] 50 | new_mask[start:end, start:end] = torch.tril(torch.ones((end - start, end - start))) 51 | 52 | new_tokens = torch.tensor(new_tokens, device=self.device) 53 | new_positions = torch.tensor(new_positions, device=self.device) 54 | new_mask = new_mask.clone().detach() 55 | 56 | return new_tokens, new_positions, new_mask, restart_dict 57 | 58 | def batch_process(self, sentences): 59 | 60 | original_ids = self.tokenizer(sentences).input_ids 61 | token_dict = dict(enumerate(original_ids)) 62 | length_dict = [len(toks) for toks in original_ids] 63 | length_dict = {index: len(toks) for index, toks in enumerate(original_ids)} 64 | 65 | max_bin_size = max(length_dict.values()) 66 | packing_lst = self.packing_fn(length_dict, max_bin_size) 67 | 68 | batch_new_tokens = [] 69 | batch_new_positions = [] 70 | batch_new_mask = [] 71 | batch_restart_indices = [] 72 | for packing_dict in packing_lst: 73 | new_tokens, new_positions, new_mask, restart_indices = self.process( 74 | length_dict, packing_dict, token_dict 75 | ) 76 | batch_new_tokens.append(new_tokens) 77 | batch_new_positions.append(new_positions) 78 | batch_new_mask.append(new_mask) 79 | batch_restart_indices.append(restart_indices) 80 | 81 | batch_new_tokens = pad_sequence(batch_new_tokens, batch_first=True, padding_value=self.pad_token) 82 | batch_new_positions = pad_sequence(batch_new_positions, batch_first=True, padding_value=1) 83 | 84 | max_size = max(tensor.shape[1:] for tensor in batch_new_mask)[0] 85 | padded_masks = [ 86 | F.pad(tensor, (0, max_size - tensor.size(0), 0, max_size - tensor.size(1))) 87 | for tensor in batch_new_mask 88 | ] 89 | batch_new_mask = torch.stack(padded_masks) 90 | 91 | return batch_new_tokens, batch_new_positions, batch_new_mask, batch_restart_indices, original_ids 92 | -------------------------------------------------------------------------------- /profiling_dataset_level_prepacking.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | import torch 4 | import fire 5 | import random 6 | import time 7 | from tqdm import tqdm 8 | import numpy as np 9 | from prettytable import PrettyTable 10 | from dataset_utils import ( 11 | PackedDataset, 12 | sample_batches, 13 | sample_batches_by_length, 14 | sample_packed_dataset, 15 | unpack_kv, 16 | load_and_evaluate_dataset, 17 | ) 18 | from processor import PrePackProcessor 19 | from utils import integer_program_packing, load_model_and_tokenizer 20 | 21 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 22 | 23 | 24 | def prefill_packed_sentence_output(sentences, model, tokenizer, device, processor): 25 | new_tokens, new_positions, new_mask, restart_dict, original_ids = processor.batch_process(sentences) 26 | with torch.no_grad(): 27 | packed_outputs = model( 28 | input_ids=new_tokens, 29 | attention_mask=new_mask, 30 | position_ids=new_positions, 31 | return_dict=True, 32 | ) 33 | return packed_outputs 34 | 35 | 36 | def TTFT_packed_sentence_output(sentences, model, tokenizer, device, processor): 37 | new_tokens, new_positions, new_mask, restart_dict, original_ids = processor.batch_process(sentences) 38 | with torch.no_grad(): 39 | packed_outputs = model( 40 | input_ids=new_tokens, 41 | attention_mask=new_mask, 42 | position_ids=new_positions, 43 | return_dict=True, 44 | ) 45 | cache, final_tokens, attention_mask = unpack_kv( 46 | packed_outputs["past_key_values"], restart_dict, original_ids, device 47 | ) 48 | _ = model.generate( 49 | input_ids=final_tokens, 50 | attention_mask=attention_mask, 51 | max_new_tokens=1, 52 | use_cache=True, 53 | do_sample=False, 54 | past_key_values=cache, 55 | ) 56 | return 57 | 58 | 59 | def TTFT_packed_dataset_output(batch, model, tokenizer=None, model_device=None, optimized_processor=None): 60 | new_tokens, new_positions, new_mask, restart_dict, original_ids = batch 61 | with torch.no_grad(): 62 | packed_outputs = model( 63 | input_ids=new_tokens, 64 | attention_mask=new_mask, 65 | position_ids=new_positions, 66 | return_dict=True, 67 | ) 68 | cache, final_tokens, attention_mask = unpack_kv( 69 | packed_outputs["past_key_values"], restart_dict, original_ids, model_device 70 | ) 71 | _ = model.generate( 72 | input_ids=final_tokens, 73 | attention_mask=attention_mask, 74 | max_new_tokens=1, 75 | use_cache=True, 76 | do_sample=False, 77 | past_key_values=cache, 78 | ) 79 | return 80 | 81 | 82 | def prefill_packed_dataset_output(batch, model, tokenizer=None, device=None, processor=None): 83 | 84 | new_tokens, new_positions, new_mask, restart_dict, original_ids = batch 85 | with torch.no_grad(): 86 | packed_outputs = model( 87 | input_ids=new_tokens, 88 | attention_mask=new_mask, 89 | position_ids=new_positions, 90 | ) 91 | return packed_outputs 92 | 93 | 94 | def prefill_batch_sentence_output(sentences, model, tokenizer, device, processor=None): 95 | batch_sentences = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True) 96 | 97 | with torch.no_grad(): 98 | batch_sentences_outputs = model( 99 | batch_sentences["input_ids"].to(device), 100 | attention_mask=batch_sentences["attention_mask"].to(device), 101 | ) 102 | return batch_sentences_outputs 103 | 104 | 105 | def TTFT_batch_sentence_output(sentences, model, tokenizer, device, processor=None): 106 | batch_sentences = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device) 107 | 108 | with torch.no_grad(): 109 | _ = model.generate( 110 | **batch_sentences, 111 | max_new_tokens=1, 112 | use_cache=True, 113 | do_sample=False, 114 | ) 115 | return 116 | 117 | 118 | def measure_inference_time( 119 | method, 120 | texts, 121 | batch_size, 122 | num_runs, 123 | total_batches, 124 | model, 125 | tokenizer, 126 | model_device, 127 | metric="TTFT", 128 | binpack_algo="greedy", 129 | ): 130 | 131 | if metric == "TTFT": 132 | method_functions = { 133 | "prepack": TTFT_packed_sentence_output, 134 | "full-batching": TTFT_batch_sentence_output, 135 | "length-ordered": TTFT_batch_sentence_output, 136 | "prepack_dataset": TTFT_packed_dataset_output, 137 | } 138 | elif metric == "prefill": 139 | method_functions = { 140 | "prepack": prefill_packed_sentence_output, 141 | "full-batching": prefill_batch_sentence_output, 142 | "length-ordered": prefill_batch_sentence_output, 143 | "prepack_dataset": prefill_packed_dataset_output, 144 | } 145 | desc = method 146 | method_function = method_functions.get(method) 147 | packing_fn = None if binpack_algo == "greedy" else integer_program_packing 148 | optimized_processor = PrePackProcessor(tokenizer, packing_fn=packing_fn) 149 | total_request_times = [] 150 | for _ in range(num_runs): 151 | if method == "length-ordered": 152 | batches_generator = sample_batches_by_length(texts, batch_size) 153 | elif method == "prepack_dataset": 154 | new_tokens, new_positions, new_mask, restart_indices, original_ids = ( 155 | optimized_processor.batch_process(texts) 156 | ) 157 | dataset = PackedDataset( 158 | new_tokens, new_positions, new_mask, restart_indices, original_ids, batch_size=batch_size 159 | ) 160 | 161 | batches_generator = sample_packed_dataset(dataset) 162 | del new_tokens, new_positions, new_mask, restart_indices, original_ids 163 | else: 164 | batches_generator = sample_batches(texts, batch_size) 165 | start_time = time.time() 166 | for batch in tqdm(batches_generator, total=total_batches, desc=desc): 167 | _ = method_function(batch, model, tokenizer, model_device, optimized_processor) 168 | elapsed = time.time() - start_time 169 | total_request_times.append(elapsed) 170 | 171 | per_request_time = np.mean(total_request_times) / (len(texts) * num_runs) 172 | per_request_time_std = np.std(total_request_times) / (len(texts) * num_runs) 173 | return per_request_time, per_request_time_std 174 | 175 | 176 | def main( 177 | methods: List[str] = ["prepack_dataset", "prepack", "full-batching", "length-ordered"], 178 | metric: str = "prefill", 179 | dataset: str = "mmlu", 180 | model_name: str = "llama1b", 181 | loadbit: int = 8, 182 | num_runs: int = 5, 183 | batch_size: int = 32, 184 | binpack_algo: str = "greedy", 185 | ): 186 | 187 | torch.set_num_threads(5) 188 | seed = 42 189 | random.seed(seed) 190 | np.random.seed(seed) 191 | torch.manual_seed(seed) 192 | torch.cuda.manual_seed_all(seed) 193 | os.environ["PYTHONHASHSEED"] = str(seed) 194 | 195 | if binpack_algo != "greedy": 196 | binpack_algo = "ip" 197 | 198 | # Load the model and tokenizer 199 | model, tokenizer = load_model_and_tokenizer(base_model=model_name, loadbit=loadbit) 200 | 201 | # Load and prepare the dataset 202 | texts = load_and_evaluate_dataset(dataset, tokenizer) 203 | 204 | total_batches = len(texts) // batch_size 205 | if len(texts) % batch_size != 0: 206 | total_batches += 1 207 | table = PrettyTable() 208 | 209 | table.field_names = [ 210 | "Method", 211 | f"Avg Prefill Time per request (s). bs={batch_size}," 212 | f"Bits: {loadbit}, {dataset}, {model_name}," 213 | f"metric: {metric}," 214 | f"binpack_algo: {binpack_algo}", 215 | f"std dev over {num_runs} runs", 216 | ] 217 | results = {} 218 | for method in methods: 219 | avg_time, std = measure_inference_time( 220 | method, 221 | texts, 222 | batch_size, 223 | num_runs, 224 | total_batches, 225 | model, 226 | tokenizer, 227 | model.device, 228 | metric=metric, 229 | binpack_algo=binpack_algo, 230 | ) 231 | table.add_row([method, f"{avg_time:.5f}", f"{std:.5f}"]) 232 | results[method] = { 233 | "Avg Prefill Time per request (s)": avg_time, 234 | "Std Dev": std, 235 | } 236 | print(table) 237 | 238 | 239 | if __name__ == "__main__": 240 | fire.Fire(main) 241 | -------------------------------------------------------------------------------- /profiling_time_and_memory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import threading 4 | import time 5 | import GPUtil 6 | import fire 7 | import random 8 | from tqdm import tqdm 9 | import numpy as np 10 | from typing import List 11 | from prettytable import PrettyTable 12 | from dataset_utils import ( 13 | load_and_evaluate_dataset, 14 | sample_batches, 15 | sample_batches_by_length, 16 | unpack_kv, 17 | ) 18 | from processor import PrePackProcessor 19 | from utils import integer_program_packing, load_model_and_tokenizer 20 | 21 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 22 | 23 | 24 | def monitor_gpu_utilization(stop_event, utilization_stats, device_id=0, interval=0.1): 25 | max_utilization, total_utilization, count = 0, 0, 0 26 | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None) 27 | visible_device_ids = list(map(int, cuda_visible_devices.split(","))) 28 | while not stop_event.is_set(): 29 | gpus = GPUtil.getGPUs() 30 | current_utilization = gpus[visible_device_ids[device_id]].load 31 | max_utilization = max(max_utilization, current_utilization) 32 | total_utilization += current_utilization 33 | count += 1 34 | time.sleep(interval) 35 | utilization_stats["max_util"] = max_utilization * 100 # Convert to percentage 36 | utilization_stats["mean_util"] = (total_utilization / count) * 100 if count > 0 else 0 37 | 38 | 39 | def prefill_with_prepacking(sentences, model, tokenizer, device, processor): 40 | new_tokens, new_positions, new_mask, restart_dict, original_ids = processor.batch_process(sentences) 41 | with torch.no_grad(): 42 | output = model( 43 | input_ids=new_tokens, 44 | attention_mask=new_mask, 45 | position_ids=new_positions, 46 | return_dict=True, 47 | ) 48 | return output 49 | 50 | 51 | def TTFT_with_prepacking(sentences, model, tokenizer, device, processor): 52 | new_tokens, new_positions, new_mask, restart_dict, original_ids = processor.batch_process(sentences) 53 | with torch.no_grad(): 54 | packed_outputs = model( 55 | input_ids=new_tokens, 56 | attention_mask=new_mask, 57 | position_ids=new_positions, 58 | return_dict=True, 59 | ) 60 | cache, final_tokens, attention_mask = unpack_kv( 61 | packed_outputs["past_key_values"], restart_dict, original_ids, device 62 | ) 63 | _ = model.generate( 64 | input_ids=final_tokens, 65 | attention_mask=attention_mask, 66 | max_new_tokens=1, 67 | use_cache=True, 68 | do_sample=False, 69 | past_key_values=cache, 70 | ) 71 | return 72 | 73 | 74 | def prefill_with_baseline(sentences, model, tokenizer, device, processor=None): 75 | batch_sentences = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True) 76 | 77 | with torch.no_grad(): 78 | batch_sentences_outputs = model( 79 | batch_sentences["input_ids"].to(device), 80 | attention_mask=batch_sentences["attention_mask"].to(device), 81 | return_dict=True, 82 | ) 83 | return batch_sentences_outputs 84 | 85 | 86 | def TTFT_with_baseline(sentences, model, tokenizer, device, processor=None): 87 | batch_sentences = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device) 88 | 89 | with torch.no_grad(): 90 | _ = model.generate( 91 | **batch_sentences, 92 | max_new_tokens=1, 93 | use_cache=True, 94 | do_sample=False, 95 | ) 96 | return 97 | 98 | 99 | def get_average_gpu_utilization(): 100 | # get current device id, assuming only 1 GPU is used 101 | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None) 102 | visible_device_ids = list(map(int, cuda_visible_devices.split(","))) 103 | gpus = GPUtil.getGPUs() 104 | return gpus[visible_device_ids[0]].load 105 | 106 | 107 | def measure_inference_resources( 108 | method, 109 | texts, 110 | batch_size, 111 | num_runs, 112 | total_batches, 113 | model, 114 | tokenizer, 115 | model_device, 116 | metric="TTFT", 117 | binpack_algo="greedy", 118 | ): 119 | scenario_times = [] 120 | 121 | if metric == "TTFT": 122 | method_functions = { 123 | "prepacking": TTFT_with_prepacking, 124 | "full-batching": TTFT_with_baseline, 125 | "length-ordered": TTFT_with_baseline, 126 | } 127 | elif metric == "prefill": 128 | method_functions = { 129 | "prepacking": prefill_with_prepacking, 130 | "full-batching": prefill_with_baseline, 131 | "length-ordered": prefill_with_baseline, 132 | } 133 | method_function = method_functions.get(method) 134 | packing_fn = None if binpack_algo == "greedy" else integer_program_packing 135 | optimized_processor = PrePackProcessor(tokenizer, packing_fn=packing_fn) 136 | method_function = method_functions.get(method) 137 | 138 | for _ in range(num_runs): 139 | batches_generator = ( 140 | sample_batches(texts, batch_size) 141 | if method != "length-ordered" 142 | else sample_batches_by_length(texts, batch_size) 143 | ) 144 | 145 | max_gpu_utilization = [] 146 | max_gpu_memory = [] 147 | batch_gpu_memories = [] 148 | batch_gpu_utilizations = [] 149 | mean_gpu_utilizations = [] 150 | for batch in tqdm(batches_generator, total=total_batches, desc=method): 151 | utilization_stats = {} 152 | stop_event = threading.Event() 153 | monitor_thread = threading.Thread( 154 | target=monitor_gpu_utilization, args=(stop_event, utilization_stats), daemon=True 155 | ) 156 | monitor_thread.start() 157 | 158 | torch.cuda.reset_peak_memory_stats(model_device) # Reset memory stats at the start 159 | torch.cuda.empty_cache() 160 | 161 | start_time = time.time() 162 | _ = method_function(batch, model, tokenizer, model_device, optimized_processor) 163 | elapsed = time.time() - start_time 164 | 165 | scenario_times.append(elapsed) 166 | stop_event.set() 167 | monitor_thread.join() 168 | peak_memory = torch.cuda.max_memory_allocated(model_device) / (1024**2) 169 | batch_gpu_memories.append(peak_memory) 170 | 171 | max_util = utilization_stats.get("max_util", 0) 172 | mean_util = utilization_stats.get("mean_util", 0) # Get mean utilization 173 | batch_gpu_utilizations.append(max_util) 174 | mean_gpu_utilizations.append(mean_util) 175 | 176 | max_gpu_memory.append(max(batch_gpu_memories)) 177 | max_gpu_utilization.append(max(batch_gpu_utilizations)) 178 | avg_scenario_time = np.mean(scenario_times) 179 | avg_gpu_utilization = np.mean(max_gpu_utilization) 180 | avg_gpu_memory = np.mean(max_gpu_memory) 181 | avg_mean_gpu_utilization = np.mean(mean_gpu_utilizations) 182 | std_dev_time = np.std(scenario_times) 183 | std_gpu_utilization = np.std(max_gpu_utilization) 184 | std_gpu_memory = np.std(max_gpu_memory) # = 0 185 | std_mean_gpu_utilization = np.std(mean_gpu_utilizations) 186 | return ( 187 | avg_scenario_time, 188 | avg_gpu_utilization, 189 | avg_gpu_memory, 190 | avg_mean_gpu_utilization, 191 | std_dev_time, 192 | std_gpu_utilization, 193 | std_mean_gpu_utilization, 194 | ) 195 | 196 | 197 | def main( 198 | methods: List[str] = [ 199 | "prepacking", 200 | "full-batching", 201 | "length-ordered", 202 | ], 203 | metric: str = "prefill", 204 | dataset: str = "mmlu", 205 | model_name: str = "llama1b", 206 | loadbit: int = 4, 207 | num_runs: int = 5, 208 | batch_size: int = 64, 209 | binpack_algo: str = "greedy", 210 | ): 211 | 212 | torch.set_num_threads(5) 213 | 214 | seed = 42 215 | random.seed(seed) 216 | np.random.seed(seed) 217 | torch.manual_seed(seed) 218 | torch.cuda.manual_seed_all(seed) 219 | os.environ["PYTHONHASHSEED"] = str(seed) 220 | 221 | if binpack_algo != "greedy": 222 | binpack_algo = "ip" 223 | 224 | # Load the model and tokenizer 225 | model, tokenizer = load_model_and_tokenizer(base_model=model_name, loadbit=loadbit) 226 | 227 | # Load and prepare the dataset 228 | texts = load_and_evaluate_dataset(dataset, tokenizer) 229 | 230 | total_batches = len(texts) // batch_size 231 | if len(texts) % batch_size != 0: 232 | total_batches += 1 233 | table = PrettyTable() 234 | 235 | table.field_names = [ 236 | "Method", 237 | f"Avg {metric} Time /batch (s)", 238 | "Max GPU Utilization (%)", 239 | "Max GPU Memory (MB)", 240 | "Mean GPU Utilization (%)", 241 | "Std Dev Time (s)", 242 | "Std Dev Max GPU Util (%)", 243 | "Std Dev Mean GPU Util (%)", 244 | ] 245 | 246 | for method in methods: 247 | try: 248 | ( 249 | avg_scenario_time, 250 | avg_gpu_utilization, 251 | avg_gpu_memory, 252 | avg_mean_gpu_utilization, 253 | std_dev_time, 254 | std_gpu_util, 255 | std_mean_gpu_util, 256 | ) = measure_inference_resources( 257 | method, 258 | texts, 259 | batch_size, 260 | num_runs, 261 | total_batches, 262 | model, 263 | tokenizer, 264 | model.device, 265 | metric=metric, 266 | binpack_algo=binpack_algo, 267 | ) 268 | table.add_row( 269 | [ 270 | method, 271 | f"{avg_scenario_time:.3f}", 272 | f"{avg_gpu_utilization:.3f}", 273 | f"{avg_gpu_memory:.3f}", 274 | f"{avg_mean_gpu_utilization:.3f}", 275 | f"{std_dev_time:.3f}", 276 | f"{std_gpu_util:.3f}", 277 | f"{std_mean_gpu_util:.3f}", 278 | ] 279 | ) 280 | print(table) 281 | except Exception as e: # OOM error 282 | print(f"An error occurred while processing method {method}: {e}") 283 | torch.cuda.empty_cache() 284 | 285 | finally: 286 | torch.cuda.empty_cache() 287 | 288 | 289 | if __name__ == "__main__": 290 | fire.Fire(main) 291 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | import torch 3 | from torch import Tensor 4 | from ortools.linear_solver import pywraplp 5 | import binpacking 6 | from transformers import AutoTokenizer 7 | from model import CustomCausalLlamaModel, CustomCausalMistralModel 8 | 9 | 10 | # As implemented here: 11 | # https://github.com/pytorch/pytorch/issues/10536#issuecomment-1320935162 12 | def left_pad_sequence( 13 | sequences: Union[Tensor, List[Tensor]], 14 | batch_first: bool = True, 15 | padding_value: float = 0.0, 16 | ) -> Tensor: 17 | 18 | sequences = tuple(map(lambda s: s.flip(0), sequences)) 19 | padded_sequence = torch._C._nn.pad_sequence(sequences, batch_first, padding_value) 20 | _seq_dim = padded_sequence.dim() 21 | padded_sequence = padded_sequence.flip(-_seq_dim + batch_first) 22 | return padded_sequence 23 | 24 | 25 | def greedy_packing(length_dict, max_bin_size): 26 | return binpacking.to_constant_volume(length_dict, max_bin_size) 27 | 28 | 29 | # https://developers.google.com/optimization/pack/bin_packing 30 | def integer_program_packing(length_dict, max_bin_size): 31 | data = {} 32 | data["items"] = list(length_dict.keys()) 33 | data["weights"] = list(length_dict.values()) 34 | data["bins"] = data["items"] 35 | data["bin_capacity"] = max_bin_size 36 | 37 | solver = pywraplp.Solver.CreateSolver("SCIP") 38 | 39 | if not solver: 40 | return 41 | x = {} 42 | for i in data["items"]: 43 | for j in data["bins"]: 44 | x[(i, j)] = solver.IntVar(0, 1, "x_%i_%i" % (i, j)) 45 | y = {} 46 | for j in data["bins"]: 47 | y[j] = solver.IntVar(0, 1, "y[%i]" % j) 48 | 49 | for i in data["items"]: 50 | solver.Add(sum(x[i, j] for j in data["bins"]) == 1) 51 | 52 | for j in data["bins"]: 53 | solver.Add(sum(x[(i, j)] * data["weights"][i] for i in data["items"]) <= y[j] * data["bin_capacity"]) 54 | 55 | solver.Minimize(solver.Sum([y[j] for j in data["bins"]])) 56 | 57 | status = solver.Solve() 58 | 59 | if status == pywraplp.Solver.OPTIMAL: 60 | result = [] 61 | for j in data["bins"]: 62 | if y[j].solution_value() == 1: 63 | bin_dict = {} 64 | for i in data["items"]: 65 | if x[i, j].solution_value() > 0: 66 | bin_dict[i] = data["weights"][i] 67 | result.append(bin_dict) 68 | else: 69 | raise ("The problem does not have an optimal solution.") 70 | 71 | return result 72 | 73 | 74 | def load_model_and_tokenizer( 75 | base_model: str = "llama1b", 76 | loadbit: int = 8, 77 | ): 78 | # Load tokenizer and model 79 | if base_model == "llama1b": 80 | path = "princeton-nlp/Sheared-LLaMA-1.3B" 81 | elif base_model == "llama2": 82 | path = "/path/to/llama2" 83 | 84 | tokenizer = AutoTokenizer.from_pretrained(path) 85 | tokenizer.pad_token = "[PAD]" 86 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 87 | load_in_8bit = loadbit == 8 88 | load_in_4bit = loadbit == 4 89 | if "llama" in base_model: 90 | model = CustomCausalLlamaModel.from_pretrained( 91 | path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit 92 | ) 93 | elif "mistral" in base_model: 94 | model = CustomCausalMistralModel.from_pretrained( 95 | path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit 96 | ) 97 | model.eval() 98 | if loadbit != 8 and loadbit != 4: 99 | model.to(device) 100 | 101 | return model, tokenizer 102 | --------------------------------------------------------------------------------