├── README.md
├── batching_experiment.py
├── dataset_utils.py
├── environment.yml
├── figures
├── ablation_speed_up_over_batch_size_both_models.pdf
├── alldataset_comparison.pdf
├── batch_size_figure.pdf
├── dataset_level_ablation.pdf
├── prepacking_gif_final.gif
└── speedup_gain16llama1b5000.pdf
├── model.py
├── prepack_generation_demo.ipynb
├── processor.py
├── profiling_dataset_level_prepacking.py
├── profiling_time_and_memory.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Prepacking: A Simple Method for Fast Prefilling and Increased Throughput in Large Language Models
2 |
3 |
4 |
5 |
6 |
7 |
8 | This repository contains the source code of the following paper:
9 |
10 | > **"Prepacking: A Simple Method for Fast Prefilling and Increased Throughput in Large Language Models"**
11 | > Siyan Zhao, Daniel Israel, Guy Van den Broeck, Aditya Grover
12 | >
13 | > **Abstract:** *During inference for transformer-based large language models (LLM), prefilling is the computation of the key-value (KV) cache for input tokens in the prompt prior to autoregressive generation. For longer input prompt lengths, prefilling will incur a significant overhead on decoding time. In this work, we highlight the following pitfall of prefilling: for batches containing high-varying prompt lengths, significant computation is wasted by the standard practice of padding sequences to the maximum length. As LLMs increasingly support longer context lengths, potentially up to 10 million tokens, variations in prompt lengths within a batch become more pronounced. To address this, we propose Prepacking, a simple yet effective method to optimize prefilling computation. To avoid redundant computation on pad tokens, prepacking combines prompts of varying lengths into a sequence and packs multiple sequences into a compact batch using a bin-packing algorithm. It then modifies the attention mask and positional encoding to compute multiple prefilled KV-caches for multiple prompts within a single sequence. On standard curated dataset containing prompts with varying lengths, we obtain a significant speed and memory efficiency improvements as compared to the default padding-based prefilling computation within Huggingface across a range of base model configurations and inference serving scenarios.*
14 |
15 |
16 | [[Paper]](https://arxiv.org/abs/2404.09529)
17 |
18 | ## Setup Environment
19 | ### Clone the repository
20 | ```bash
21 | git clone https://github.com/siyan-zhao/prepacking.git
22 | cd prepacking
23 | ```
24 |
25 | ### Conda Setup
26 | ```
27 | conda env create -f environment.yml
28 | conda activate prepack
29 | ```
30 |
31 | ## Profile Speed and Memory
32 |
33 |
34 | ### Profile Prefill or Time to First Token (TTFT) Time and Compare Peak GPU Memory and Utilization
35 |
36 | ```
37 | CUDA_VISIBLE_DEVICES=0 python profiling_time_and_memory.py --metric=prefill --dataset=mmlu --batch_size=64 --model_name=llama1b --num_runs=5
38 | ```
39 |
40 | Example output when profiled on a single 48GB NVIDIA A6000 GPU:
41 |
42 | | Method | Avg prefill Time /batch (s) | Max GPU Utilization (%) | Max GPU Memory (MB) | Mean GPU Utilization (%) | Std Dev Time (s) | Std Dev Max GPU Util (%) | Std Dev Mean GPU Util (%) |
43 | |----------------|-----------------------------|-------------------------|---------------------|--------------------------|------------------|--------------------------|---------------------------|
44 | | prepacking | 0.441 | 100.000 | 4578.328 | 91.156 | 0.347 | 0.000 | 7.966 |
45 | | full-batching | 2.299 | 100.000 | 34599.695 | 99.719 | 1.741 | 0.000 | 0.223 |
46 | | length-ordered | 0.658 | 100.000 | 22950.019 | 97.865 | 0.815 | 0.000 | 3.236 |
47 |
48 | ### Compare Per Prompt Inference Prefill Time Including Dataset Prepacking
49 |
50 | ```
51 | CUDA_VISIBLE_DEVICES=0 python profiling_dataset_level_prepacking.py --metric=prefill --model_name=llama1b --batch_size=32 --loadbit=8 --dataset=mmlu
52 | ```
53 |
54 | ## Play with Prepacking Generation
55 |
56 | A Colab example of using prepacking for generation. Compare it against default generation yourself.
57 | [](https://colab.research.google.com/github/siyan-zhao/prepacking/blob/main/prepack_generation_demo.ipynb)
58 |
59 | ## Reference
60 | If you find our work useful, please consider citing our [paper](https://arxiv.org/abs/2404.09529):
61 |
62 | ```
63 | @misc{zhao2024prepacking,
64 | title={Prepacking: A Simple Method for Fast Prefilling and Increased Throughput in Large Language Models},
65 | author={Siyan Zhao and Daniel Israel and Guy Van den Broeck and Aditya Grover},
66 | year={2024},
67 | eprint={2404.09529},
68 | archivePrefix={arXiv},
69 | primaryClass={cs.LG}
70 | }
71 | ```
72 |
--------------------------------------------------------------------------------
/batching_experiment.py:
--------------------------------------------------------------------------------
1 | from transformers import LlamaForCausalLM
2 | import torch
3 | import time
4 | import numpy as np
5 | import pickle
6 | import matplotlib.pyplot as plt
7 | from tqdm import tqdm
8 |
9 |
10 | torch.set_num_threads(1)
11 |
12 |
13 | def run_batching_experiment(model, batch_size, length):
14 | model.eval()
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 | scenario_times = []
17 | for _ in range(100):
18 | data = torch.randint(0, 32000, (batch_size, length)).to(device)
19 | with torch.no_grad():
20 | start_time = time.time()
21 | _ = model(input_ids=data)
22 | elapsed = time.time() - start_time
23 | scenario_times.append(elapsed)
24 | avg_scenario_time = np.mean(scenario_times)
25 | std_dev = np.std(scenario_times)
26 | return avg_scenario_time, std_dev
27 |
28 |
29 | if __name__ == "__main__":
30 | model_path = "princeton-nlp/Sheared-LLaMA-1.3B"
31 | model = LlamaForCausalLM.from_pretrained(
32 | model_path,
33 | load_in_4bit=True,
34 | bnb_4bit_compute_dtype=torch.float16,
35 | low_cpu_mem_usage=True,
36 | device_map="auto",
37 | )
38 |
39 | LENGTHS = [50, 100, 200, 400]
40 | BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64]
41 | data = {}
42 | for length in tqdm(LENGTHS):
43 | data[length] = [run_batching_experiment(model, bs, length) for bs in BATCH_SIZES]
44 | print(f"Length: {length}", data[length])
45 |
46 | for length in LENGTHS:
47 | d = data[length]
48 | plt.plot(BATCH_SIZES, [d[0] for d in data[length]], label=f"D={length}")
49 | plt.fill_between(
50 | BATCH_SIZES, [d[0] - d[1] for d in data[length]], [d[0] + d[1] for d in data[length]], alpha=0.2
51 | )
52 | plt.legend()
53 | plt.xlabel("Batch size")
54 | plt.xscale("log", base=2)
55 | plt.ylabel("TTFT (s)")
56 | plt.savefig("batch_size_figure.png")
57 |
--------------------------------------------------------------------------------
/dataset_utils.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from torch.utils.data import DataLoader, Dataset
3 | import random
4 | import torch
5 | from utils import left_pad_sequence
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from transformers.trainer_utils import set_seed
9 |
10 | SEED = 41
11 | set_seed(SEED)
12 |
13 |
14 | class HFDataset(Dataset):
15 | def __init__(self, texts):
16 | self.texts = texts
17 |
18 | def __len__(self):
19 | return len(self.texts)
20 |
21 | def __getitem__(self, idx):
22 | return self.texts[idx]
23 |
24 |
25 | def load_and_prepare_data(
26 | dataset_name: str,
27 | config_name: str,
28 | split: str,
29 | tokenizer,
30 | sample_size: int = 1000,
31 | max_length: int = None,
32 | ):
33 | dataset = load_dataset(dataset_name, config_name, split=split)
34 | if dataset_name == "wikitext":
35 | texts = dataset["text"]
36 | elif "mmlu" in dataset_name:
37 | texts = dataset["question"]
38 | elif "rlhf" in dataset_name:
39 | texts = dataset["chosen"]
40 | elif "alpaca" in dataset_name:
41 | texts = dataset["text"]
42 | elif "samsum" in dataset_name:
43 | texts = dataset["dialogue"]
44 | print(len(texts), "texts loaded", sample_size)
45 | texts = [text for text in texts if len(text.split()) > 1]
46 | texts = random.sample(texts, sample_size)
47 |
48 | if max_length:
49 | tokenized_texts = [
50 | tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt") for text in texts
51 | ]
52 | texts = [
53 | tokenizer.decode(tokens.input_ids[0], skip_special_tokens=True) for tokens in tokenized_texts
54 | ]
55 | return texts
56 |
57 |
58 | def load_and_evaluate_dataset(dataset: str, tokenizer):
59 | dataset_config = {
60 | "mmlu": ("cais/mmlu", "all", None, 1000, "test"),
61 | "wikitext512": ("wikitext", "wikitext-2-raw-v1", 512, 1000, "test"),
62 | "wikitext256": ("wikitext", "wikitext-2-raw-v1", 256, 1000, "test"),
63 | "rlhf": ("Anthropic/hh-rlhf", "default", None, 1000, "train"),
64 | "alpaca": ("tatsu-lab/alpaca", "default", None, 1000, "train"),
65 | "samsum": ("samsum", "samsum", None, 1000, "train"),
66 | }
67 |
68 | if dataset not in dataset_config:
69 | raise ValueError(f"Unsupported dataset: {dataset}")
70 |
71 | dataset_name, config_name, max_length, sample_size, split = dataset_config[dataset]
72 |
73 | # Load dataset
74 | texts = load_and_prepare_data(
75 | dataset_name=dataset_name,
76 | config_name=config_name,
77 | split=split,
78 | tokenizer=tokenizer,
79 | sample_size=sample_size,
80 | max_length=max_length,
81 | )
82 |
83 | # Evaluate sentences from the dataset
84 | print(f"Evaluating {len(texts)} sentences from the dataset {dataset_name}")
85 | lengths = [len(tokenizer(text).input_ids) for text in texts]
86 | print("Mean length:", np.mean(lengths))
87 | print("Max length:", np.max(lengths))
88 | print("Min length:", np.min(lengths))
89 |
90 | # Plotting the length distribution
91 | plt.figure(figsize=(10, 6))
92 | plt.hist(lengths, bins=30, color="skyblue", edgecolor="black")
93 | plt.title(f"Length Distribution of Texts in {dataset_name}")
94 | plt.xlabel("Length of Texts in Tokens")
95 | plt.ylabel("Frequency")
96 | plt.savefig(f"length_distribution_{dataset}.png")
97 | plt.close()
98 |
99 | return texts
100 |
101 |
102 | def sample_batches(texts, batch_size):
103 | random.shuffle(texts)
104 | for i in range(0, len(texts), batch_size):
105 | yield texts[i : i + batch_size]
106 |
107 |
108 | def sample_batches_deterministic(texts, batch_size):
109 | for i in range(0, len(texts), batch_size):
110 | yield texts[i : i + batch_size]
111 |
112 |
113 | def sample_batches_by_length(texts, batch_size):
114 | # Sort texts by their length
115 | texts_sorted = sorted(texts, key=lambda x: len(x))
116 | # Yield batches of similar length
117 | for i in range(0, len(texts_sorted), batch_size):
118 | yield texts_sorted[i : i + batch_size]
119 |
120 |
121 | def sample_packed_dataset(dataset):
122 | # Assumption is dataset is already shuffled and batched
123 | for i in range(len(dataset)):
124 | yield dataset[i]
125 |
126 |
127 | def unpack_kv(packed_outputs, restart_dict, original_ids, device):
128 |
129 | batch_size = sum(map(len, restart_dict)) - len(restart_dict)
130 |
131 | dim1, dim2 = len(packed_outputs), len(packed_outputs[0])
132 |
133 | save_cache = [[None for _ in range(dim2)] for _ in range(dim1)]
134 | batch_length = [len(ids) - 1 for ids in original_ids]
135 | compute = True
136 |
137 | attention_masks = torch.zeros((batch_size, max(batch_length) + 1), dtype=torch.int, device=device)
138 | final_tokens = torch.empty(batch_size, dtype=torch.int, device=device)
139 |
140 | for j in range(dim1): # layer
141 | for k in range(dim2): # k, v
142 | batch_cache = np.empty(batch_size, dtype=object)
143 |
144 | for b, batch in enumerate(restart_dict):
145 | batch_indices = list(batch.keys())
146 | for i in range(len(batch) - 1):
147 | c = packed_outputs[j][k][b, :, batch_indices[i] : batch_indices[i + 1], :].permute(
148 | 1, 0, 2
149 | )
150 | original_index = restart_dict[b][batch_indices[i + 1]]
151 | batch_cache[original_index] = c
152 | if compute:
153 | prompt = original_ids[batch[batch_indices[i + 1]]]
154 | final_tokens[original_index] = prompt[-1]
155 | attention_masks[original_index, -(batch_length[original_index]) - 1 :] = 1
156 | compute = False
157 | padded = left_pad_sequence(batch_cache, batch_first=True, padding_value=0).permute(0, 2, 1, 3)
158 | save_cache[j][k] = padded
159 |
160 | return save_cache, final_tokens.unsqueeze(dim=-1), attention_masks
161 |
162 |
163 | class PackedDataset(Dataset):
164 | def __init__(self, new_tokens, new_positions, new_mask, restart_indices, original_ids, batch_size):
165 | self.tensor_tuple = (new_tokens, new_positions, new_mask)
166 | self.restart_indices = restart_indices
167 | self.original_ids = original_ids
168 | self.initial_processing(batch_size)
169 |
170 | def __len__(self):
171 | return len(self.batch_indices)
172 |
173 | def __getitem__(self, idx):
174 | batch_idx = self.batch_indices[idx]
175 | tensors = tuple(tensor[batch_idx] for tensor in self.tensor_tuple)
176 | restart_idx = self.restart_indices[idx]
177 | original_id = self.original_ids[idx]
178 | return tensors + (restart_idx,) + (original_id,)
179 |
180 | def initial_processing(self, batch_size):
181 |
182 | num_samples = len(self.tensor_tuple[0])
183 | indices = list(range(num_samples))
184 | random.shuffle(indices)
185 | batch_indices = [indices[i : i + batch_size] for i in range(0, num_samples, batch_size)]
186 |
187 | # Purpose of below code is that with batching, restart indices and original ids must be reformatted
188 | # for consistency with existing unpack_kv function
189 | batched_restart_indices = []
190 | for idx in batch_indices:
191 | batch_restart_indices = list(map(self.restart_indices.__getitem__, idx))
192 | batched_restart_indices.append(batch_restart_indices)
193 |
194 | new_original_ids = []
195 | for idx in range(len(batch_indices)):
196 | original_id = []
197 | i = 0
198 | restart_idx = batched_restart_indices[idx]
199 | for d in restart_idx:
200 | for key in d:
201 | value = d[key]
202 | if value != -1:
203 | original_id.append(self.original_ids[value])
204 | d[key] = i
205 | i += 1
206 | new_original_ids.append(original_id)
207 |
208 | self.original_ids = new_original_ids
209 | self.restart_indices = batched_restart_indices
210 | self.batch_indices = batch_indices
211 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: prepack
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2024.3.11=h06a4308_0
8 | - ld_impl_linux-64=2.38=h1181459_1
9 | - libffi=3.3=he6710b0_2
10 | - libgcc-ng=11.2.0=h1234567_1
11 | - libgomp=11.2.0=h1234567_1
12 | - libstdcxx-ng=11.2.0=h1234567_1
13 | - ncurses=6.4=h6a678d5_0
14 | - openssl=1.1.1w=h7f8727e_0
15 | - pip=23.3.1=py39h06a4308_0
16 | - python=3.9.7=h12debd9_1
17 | - readline=8.2=h5eee18b_0
18 | - setuptools=68.2.2=py39h06a4308_0
19 | - sqlite=3.41.2=h5eee18b_0
20 | - tk=8.6.12=h1ccaba5_0
21 | - wheel=0.41.2=py39h06a4308_0
22 | - xz=5.4.6=h5eee18b_0
23 | - zlib=1.2.13=h5eee18b_0
24 | - pip:
25 | - absl-py==2.1.0
26 | - accelerate==0.29.2
27 | - aiohttp==3.9.4
28 | - aiosignal==1.3.1
29 | - async-timeout==4.0.3
30 | - attrs==23.2.0
31 | - binpacking==1.5.2
32 | - bitsandbytes==0.41.1
33 | - certifi==2024.2.2
34 | - charset-normalizer==3.3.2
35 | - cmake==3.29.2
36 | - contourpy==1.2.1
37 | - cycler==0.12.1
38 | - datasets==2.18.0
39 | - dill==0.3.8
40 | - filelock==3.13.4
41 | - fire==0.6.0
42 | - fonttools==4.51.0
43 | - frozenlist==1.4.1
44 | - fsspec==2024.2.0
45 | - future==1.0.0
46 | - gputil==1.4.0
47 | - huggingface-hub==0.22.2
48 | - idna==3.7
49 | - immutabledict==4.2.0
50 | - importlib-resources==6.4.0
51 | - jinja2==3.1.3
52 | - kiwisolver==1.4.5
53 | - lit==18.1.3
54 | - markupsafe==2.1.5
55 | - matplotlib==3.8.3
56 | - mpmath==1.3.0
57 | - multidict==6.0.5
58 | - multiprocess==0.70.16
59 | - networkx==3.2.1
60 | - numpy==1.25.2
61 | - nvidia-cublas-cu11==11.10.3.66
62 | - nvidia-cuda-cupti-cu11==11.7.101
63 | - nvidia-cuda-nvrtc-cu11==11.7.99
64 | - nvidia-cuda-runtime-cu11==11.7.99
65 | - nvidia-cudnn-cu11==8.5.0.96
66 | - nvidia-cufft-cu11==10.9.0.58
67 | - nvidia-curand-cu11==10.2.10.91
68 | - nvidia-cusolver-cu11==11.4.0.1
69 | - nvidia-cusparse-cu11==11.7.4.91
70 | - nvidia-nccl-cu11==2.14.3
71 | - nvidia-nvtx-cu11==11.7.91
72 | - ortools==9.9.3963
73 | - packaging==24.0
74 | - pandas==2.2.2
75 | - pillow==10.3.0
76 | - prettytable==3.9.0
77 | - protobuf==5.26.1
78 | - psutil==5.9.8
79 | - pyarrow==15.0.2
80 | - pyarrow-hotfix==0.6
81 | - pyparsing==3.1.2
82 | - python-dateutil==2.9.0.post0
83 | - pytz==2024.1
84 | - pyyaml==6.0.1
85 | - regex==2023.12.25
86 | - requests==2.31.0
87 | - safetensors==0.4.2
88 | - scipy==1.11.2
89 | - six==1.16.0
90 | - sympy==1.12
91 | - termcolor==2.4.0
92 | - tokenizers==0.15.2
93 | - torch==2.0.1
94 | - torchaudio==2.0.2
95 | - torchvision==0.15.2
96 | - tqdm==4.66.2
97 | - transformers==4.38.2
98 | - triton==2.0.0
99 | - typing-extensions==4.11.0
100 | - tzdata==2024.1
101 | - urllib3==2.2.1
102 | - wcwidth==0.2.13
103 | - xxhash==3.4.1
104 | - yarl==1.9.4
105 | - zipp==3.18.1
106 | prefix: /home/siyanz/miniconda3/envs/prepack1
107 |
--------------------------------------------------------------------------------
/figures/ablation_speed_up_over_batch_size_both_models.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/ablation_speed_up_over_batch_size_both_models.pdf
--------------------------------------------------------------------------------
/figures/alldataset_comparison.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/alldataset_comparison.pdf
--------------------------------------------------------------------------------
/figures/batch_size_figure.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/batch_size_figure.pdf
--------------------------------------------------------------------------------
/figures/dataset_level_ablation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/dataset_level_ablation.pdf
--------------------------------------------------------------------------------
/figures/prepacking_gif_final.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/prepacking_gif_final.gif
--------------------------------------------------------------------------------
/figures/speedup_gain16llama1b5000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siyan-zhao/prepacking/04e723a590415c198473b91c1400f2ee65240b0e/figures/speedup_gain16llama1b5000.pdf
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import transformers
4 | from transformers.modeling_outputs import BaseModelOutputWithPast
5 |
6 | from typing import List, Optional, Tuple, Union
7 | from transformers.generation.utils import GenerationMixin
8 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
9 | from transformers.cache_utils import Cache, DynamicCache
10 | from transformers.utils import logging
11 |
12 | logger = logging.get_logger(__name__)
13 |
14 | # Custom model to allow passing of 2D attention masks
15 | class CustomLlamaModel(transformers.LlamaModel):
16 | def __init__(self, config):
17 | super().__init__(config)
18 |
19 | def _update_causal_mask(self, attention_mask, input_tensor):
20 | if self.config._attn_implementation == "flash_attention_2":
21 | if attention_mask is not None and 0.0 in attention_mask:
22 | return attention_mask
23 | return None
24 |
25 | batch_size, seq_length = input_tensor.shape[:2]
26 | dtype = input_tensor.dtype
27 | device = input_tensor.device
28 |
29 | # support going beyond cached `max_position_embedding`
30 | if seq_length > self.causal_mask.shape[-1]:
31 | causal_mask = torch.full(
32 | (2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]),
33 | fill_value=1,
34 | )
35 | self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
36 |
37 | # We use the current dtype to avoid any overflows
38 | min_dtype = torch.finfo(dtype).min
39 | causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
40 |
41 | causal_mask = causal_mask.to(dtype=dtype, device=device)
42 | if attention_mask is not None and attention_mask.dim() == 2:
43 | mask_length = attention_mask.shape[-1]
44 | padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
45 | causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
46 | padding_mask, min_dtype
47 | )
48 |
49 | if attention_mask is not None and attention_mask.dim() == 3:
50 | attention_mask = attention_mask.to(dtype=dtype, device=device)
51 | mask_length = attention_mask.shape[-1]
52 | padding_mask = causal_mask[..., :mask_length, :mask_length].eq(0.0) * attention_mask[
53 | :, None, :, :
54 | ].eq(0.0)
55 | causal_mask[..., :mask_length, :mask_length] = causal_mask[
56 | ..., :mask_length, :mask_length
57 | ].masked_fill(padding_mask, min_dtype)
58 |
59 | if self.config._attn_implementation == "sdpa" and attention_mask is not None:
60 | # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
61 | is_tracing = (
62 | torch.jit.is_tracing()
63 | or isinstance(input_tensor, torch.fx.Proxy)
64 | or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
65 | )
66 | if not is_tracing and torch.any(attention_mask != 1):
67 | # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
68 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
69 | # Details: https://github.com/pytorch/pytorch/issues/110213
70 | causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(
71 | dtype
72 | )
73 |
74 | return causal_mask
75 |
76 |
77 | class CustomCausalLlamaModel(transformers.LlamaForCausalLM, GenerationMixin):
78 | def __init__(self, config):
79 | super().__init__(config)
80 | self.model = CustomLlamaModel(config)
81 |
82 |
83 | class CustomMistralModel(transformers.MistralModel):
84 | def __init__(self, config):
85 | super().__init__(config)
86 |
87 | # @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
88 | def forward(
89 | self,
90 | input_ids: torch.LongTensor = None,
91 | attention_mask: Optional[torch.Tensor] = None,
92 | position_ids: Optional[torch.LongTensor] = None,
93 | past_key_values: Optional[List[torch.FloatTensor]] = None,
94 | inputs_embeds: Optional[torch.FloatTensor] = None,
95 | use_cache: Optional[bool] = None,
96 | output_attentions: Optional[bool] = None,
97 | output_hidden_states: Optional[bool] = None,
98 | return_dict: Optional[bool] = None,
99 | ) -> Union[Tuple, BaseModelOutputWithPast]:
100 | output_attentions = (
101 | output_attentions if output_attentions is not None else self.config.output_attentions
102 | )
103 | output_hidden_states = (
104 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
105 | )
106 | use_cache = use_cache if use_cache is not None else self.config.use_cache
107 |
108 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
109 |
110 | # retrieve input_ids and inputs_embeds
111 | if input_ids is not None and inputs_embeds is not None:
112 | raise ValueError(
113 | "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
114 | )
115 | elif input_ids is not None:
116 | batch_size, seq_length = input_ids.shape
117 | elif inputs_embeds is not None:
118 | batch_size, seq_length, _ = inputs_embeds.shape
119 | else:
120 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
121 |
122 | if self.gradient_checkpointing and self.training:
123 | if use_cache:
124 | logger.warning_once(
125 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
126 | )
127 | use_cache = False
128 |
129 | past_key_values_length = 0
130 |
131 | if use_cache:
132 | use_legacy_cache = not isinstance(past_key_values, Cache)
133 | if use_legacy_cache:
134 | past_key_values = DynamicCache.from_legacy_cache(past_key_values)
135 | past_key_values_length = past_key_values.get_usable_length(seq_length)
136 |
137 | if position_ids is None:
138 | device = input_ids.device if input_ids is not None else inputs_embeds.device
139 | position_ids = torch.arange(
140 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
141 | )
142 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
143 | else:
144 | position_ids = position_ids.view(-1, seq_length).long()
145 |
146 | if inputs_embeds is None:
147 | inputs_embeds = self.embed_tokens(input_ids)
148 |
149 | if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
150 | is_padding_right = attention_mask[:, -1].sum().item() != batch_size
151 | if is_padding_right:
152 | raise ValueError(
153 | "You are attempting to perform batched generation with padding_side='right'"
154 | " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
155 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
156 | )
157 |
158 | if self._attn_implementation == "flash_attention_2":
159 | # 2d mask is passed through the layers
160 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
161 | elif self._attn_implementation == "sdpa" and not output_attentions:
162 | # output_attentions=True can not be supported when using SDPA, and we fall back on
163 | # the manual implementation that requires a 4D causal mask in all cases.
164 | attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
165 | attention_mask,
166 | (batch_size, seq_length),
167 | inputs_embeds,
168 | past_key_values_length,
169 | )
170 | else:
171 | if attention_mask.shape[1] == 1: # for use cache
172 | attention_mask = attention_mask.view(batch_size, -1)
173 | elif attention_mask.dim() == 3:
174 | attention_mask = attention_mask.view(batch_size, 1, seq_length, seq_length)
175 | # 4d mask is passed through the layers
176 | attention_mask = _prepare_4d_causal_attention_mask(
177 | attention_mask,
178 | (batch_size, seq_length),
179 | inputs_embeds,
180 | past_key_values_length,
181 | sliding_window=self.config.sliding_window,
182 | )
183 | hidden_states = inputs_embeds
184 |
185 | # decoder layers
186 | all_hidden_states = () if output_hidden_states else None
187 | all_self_attns = () if output_attentions else None
188 | next_decoder_cache = None
189 |
190 | for decoder_layer in self.layers:
191 | if output_hidden_states:
192 | all_hidden_states += (hidden_states,)
193 |
194 | if self.gradient_checkpointing and self.training:
195 | layer_outputs = self._gradient_checkpointing_func(
196 | decoder_layer.__call__,
197 | hidden_states,
198 | attention_mask,
199 | position_ids,
200 | past_key_values,
201 | output_attentions,
202 | use_cache,
203 | )
204 | else:
205 | layer_outputs = decoder_layer(
206 | hidden_states,
207 | attention_mask=attention_mask,
208 | position_ids=position_ids,
209 | past_key_value=past_key_values,
210 | output_attentions=output_attentions,
211 | use_cache=use_cache,
212 | )
213 |
214 | hidden_states = layer_outputs[0]
215 |
216 | if use_cache:
217 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
218 |
219 | if output_attentions:
220 | all_self_attns += (layer_outputs[1],)
221 |
222 | hidden_states = self.norm(hidden_states)
223 |
224 | # add hidden states from the last decoder layer
225 | if output_hidden_states:
226 | all_hidden_states += (hidden_states,)
227 |
228 | next_cache = None
229 | if use_cache:
230 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
231 |
232 | if not return_dict:
233 | return tuple(
234 | v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None
235 | )
236 | return BaseModelOutputWithPast(
237 | last_hidden_state=hidden_states,
238 | past_key_values=next_cache,
239 | hidden_states=all_hidden_states,
240 | attentions=all_self_attns,
241 | )
242 |
243 | class CustomCausalMistralModel(transformers.MistralForCausalLM, GenerationMixin):
244 | def __init__(self, config):
245 | super().__init__(config)
246 | self.model = CustomMistralModel(config)
247 |
--------------------------------------------------------------------------------
/prepack_generation_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "**Imports**"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "!git clone https://github.com/siyan-zhao/prepacking.git\n",
17 | "%cd prepacking/\n",
18 | "!pip install ortools==9.9.3963\n",
19 | "!pip install binpacking==1.5.2\n",
20 | "!pip install datasets==2.18.0\n",
21 | "!pip install -i https://pypi.org/simple/ bitsandbytes"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {
28 | "id": "UYWdWPXytfLG"
29 | },
30 | "outputs": [],
31 | "source": [
32 | "import torch\n",
33 | "from processor import PrePackProcessor\n",
34 | "from model import CustomCausalLlamaModel\n",
35 | "from transformers import AutoTokenizer\n",
36 | "from transformers.trainer_utils import set_seed\n",
37 | "from dataset_utils import unpack_kv\n",
38 | "from transformers import BitsAndBytesConfig"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "YW6yOlJxzCVw"
45 | },
46 | "source": [
47 | "**Load Model**"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "colab": {
55 | "base_uri": "https://localhost:8080/",
56 | "height": 833,
57 | "referenced_widgets": [
58 | "27732c5e88314d12a403bfbeea05b71a",
59 | "2bdbfdcdc9a54b8fad0068fda0fecd55",
60 | "48ee78ad568843ddba1619a3e50404ce",
61 | "68025228c36c419e88ce26efcac76456",
62 | "1ff88c3b38de421da1bcff730cccbf52",
63 | "4149a1ba8c174a83abcadc4bcb1ef98a",
64 | "0b9b2e91fe864f25ae09c37bf9145cae",
65 | "782c8fe1763b4977a678d9fb4fd5dadc",
66 | "d3b47705f4e14680b13b21be52d2e604",
67 | "2b5de781f41e4d6dbdc6c7939b405af0",
68 | "1b9ee7656cae4986ba45656fef43d154",
69 | "e022c45f119c4eb281d5679b80d6387d",
70 | "c6ab5db12b0647f5989e233d914edd87",
71 | "6e238cbf7ab84e5b9d6e6d756ccd9487",
72 | "7f160c6489514b35945179d204c06bc7",
73 | "24ffd51165764526962c2de4c004ae51",
74 | "8322ac7b0d8644d78804e6d85ff05baa",
75 | "1fecf11618b244909505323d2968e5fc",
76 | "0f57f44193ef462490f4e1526637c874",
77 | "94f49a2b0fca4fa5a22ea8a07f628b99",
78 | "3f21f1f034744945a9b0657e313a6350",
79 | "e98ec60115014dc0a5ec44fdec3a0073",
80 | "aadb942209e84408a2bf11e55412b574",
81 | "fa9cafdcab2f439b9196f6f05166d898",
82 | "240fed609ad44aafad099550d3c265df",
83 | "4afabea1be0249068ad394e7265abf24",
84 | "91bf2b38bd224cbfabc77aa85d75f48b",
85 | "1a0d37386914443f983a89cf56d7a900",
86 | "42075dab0a5443b0b65e1ff8b2cb5e89",
87 | "b73c8d9ccf494d6193ff8262777a98da",
88 | "01c044c6e5364772b2032197bdca6a37",
89 | "29ee1203a8444b579caeec2f9fe201e6",
90 | "5cbd1a0fe71c434ca0bd655d6347947f",
91 | "a001503024fe44c8acc0d5e1a6d561b7",
92 | "0c0d9cda00e045c2a16eab19347b49d2",
93 | "aa4160a1837f427bb8ae673d31a0f24a",
94 | "7562626ebef743cfb5c234a8babd884d",
95 | "ed4321893f5f4630bf2dbf1f4a38bd39",
96 | "e4685f450cb543f482d41881e2555b34",
97 | "2cdfa2a8e85548ce9259285a04310822",
98 | "bb6f8aa46e3a459aacc6c99f9f76c72c",
99 | "1f057e7c483d41a8b63f8b59480962f1",
100 | "8119e15ecae64227857fb5a927c95877",
101 | "abfc8e73b9f94e4aa684de84a6194e20",
102 | "975d4035cf144139b455fcb06dde003e",
103 | "4c44190d75cb4f21826fee5b25052886",
104 | "48e7aedb9eb8467ba5418860c5d99a07",
105 | "1a65ae611f96427e85b5e5ec21bca22d",
106 | "49536e13224f481bbb4a02a1391ab347",
107 | "9b5632fd42b7452295e7873e7c7a230f",
108 | "b5ca1b962b874a50bd3657a1d74ea998",
109 | "c8219473c99d480081104fe4551036de",
110 | "58524a4c38cf4289801ee2569c07948d",
111 | "b3ffc98ce8fa41ddad546ceed1f3e0fb",
112 | "6d38a5c62d1243bc9c8d68fc0a20393b",
113 | "8c2d4bf32f5e46539cfe78ba705e379a",
114 | "d160b88d7f00493d81c424be93d7030f",
115 | "256bd87c03b2449a9cf25303a84b96d6",
116 | "e05657d7cdab4c4082ebc87b6d903969",
117 | "8246ecc55fbd476ba7fca6a183e4d741",
118 | "95584181b8fd41eebb322c951c72c1f6",
119 | "88795b72502645b4bdc95beadb4471d5",
120 | "a44eefdea1bd4b83a210f226f074e1d0",
121 | "4e3c9dbe0057456683e12c1357925ec5",
122 | "55690f32de47405da42599803e654fc6",
123 | "6ea49b6d6530403fb9acf7a15493b6e9",
124 | "12e2bb57b55c43b2a46ccdde99cb981b",
125 | "dd92cd36ad544a17a22374c0896947da",
126 | "4c7aa0c521cc4e0e8f9bb8789e6ad8f9",
127 | "454099c862aa46149f79467e54b7f7ac",
128 | "10a555507a6246ee9c2ff36a5405d63e",
129 | "b568300f99bd4bd48e8f37029282eed6",
130 | "13a9fb93d75e41d180e7673acad43dab",
131 | "cfee62ce88bf4a409ce34337556539cf",
132 | "a16259ceb15b464fa87355a14aaac331",
133 | "405acde9b0b74f18aaca033b36c28fbd",
134 | "9c5163fdeb61471bbafdfe08393a6d80"
135 | ]
136 | },
137 | "id": "whQESxVbzB4d",
138 | "outputId": "8f3035cc-17cf-483c-9559-90e8055bad01"
139 | },
140 | "outputs": [],
141 | "source": [
142 | "SEED = 42\n",
143 | "set_seed(SEED)\n",
144 | "model_path = \"princeton-nlp/Sheared-LLaMA-1.3B\"\n",
145 | "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
146 | "tokenizer.pad_token = \"[PAD]\"\n",
147 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
148 | "custom_model = CustomCausalLlamaModel.from_pretrained(model_path)\n",
149 | "custom_model.to(device)\n",
150 | "custom_model.eval()"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {
156 | "id": "Dy590zUqzHeV"
157 | },
158 | "source": [
159 | "**Prepacking Generation**"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 7,
165 | "metadata": {
166 | "id": "ZiRs_hLetkCB"
167 | },
168 | "outputs": [],
169 | "source": [
170 | "\n",
171 | "processor = PrePackProcessor(tokenizer)\n",
172 | "\n",
173 | "# Change to any prompts\n",
174 | "sentences = [\n",
175 | " \"Rescuers are searching for multiple people in the water after Baltimore bridge collapse, report says\",\n",
176 | " \"Major bridge in Maryland collapses after being hit by a ship\",\n",
177 | " \"The capital of Germany is\",\n",
178 | " \"The capital of Spain is\",\n",
179 | " \"The capital of Greece is\",\n",
180 | " \"Today I'm going to the\",\n",
181 | " \"Baltimore Police Department told NBC\",\n",
182 | " \"My\",\n",
183 | " \"It\",\n",
184 | "]\n",
185 | "\n",
186 | "packed_tokens, restart_positions, independent_mask, restart_dict, original_ids = processor.batch_process(sentences)\n",
187 | "\n",
188 | "\n",
189 | "with torch.no_grad():\n",
190 | " packed_outputs = custom_model(\n",
191 | " input_ids=packed_tokens.to(device),\n",
192 | " attention_mask=independent_mask.to(device),\n",
193 | " position_ids=restart_positions.to(device),\n",
194 | " return_dict=True,\n",
195 | " output_hidden_states=True,\n",
196 | " )\n",
197 | "\n",
198 | "cache, final_tokens, attention_mask = unpack_kv(\n",
199 | " packed_outputs[\"past_key_values\"], restart_dict, original_ids, device\n",
200 | ")\n",
201 | "\n",
202 | "prepack_generated_output = custom_model.generate(\n",
203 | " input_ids=final_tokens.to(device),\n",
204 | " attention_mask=attention_mask.to(device),\n",
205 | " max_new_tokens=20,\n",
206 | " use_cache=True,\n",
207 | " do_sample=False,\n",
208 | " past_key_values=cache,\n",
209 | " num_return_sequences=1,\n",
210 | " output_scores=True,\n",
211 | " return_dict_in_generate=True,\n",
212 | ")\n",
213 | "\n"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {
219 | "id": "T2Q5miklzaDH"
220 | },
221 | "source": [
222 | "**Default Generation**"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": 8,
228 | "metadata": {
229 | "id": "bdgfwCkbxCos"
230 | },
231 | "outputs": [],
232 | "source": [
233 | "\n",
234 | "with torch.no_grad():\n",
235 | " normal_tokens_id = tokenizer(sentences, return_tensors=\"pt\", padding=True, truncation=False).to(\n",
236 | " device\n",
237 | " )\n",
238 | " normal_outputs = custom_model(**normal_tokens_id, return_dict=True, output_hidden_states=True)\n",
239 | "\n",
240 | "default_generated_output = custom_model.generate(\n",
241 | " **normal_tokens_id,\n",
242 | " max_new_tokens=20,\n",
243 | " use_cache=True,\n",
244 | " do_sample=False,\n",
245 | " num_return_sequences=1,\n",
246 | " output_scores=True,\n",
247 | " return_dict_in_generate=True\n",
248 | ")\n",
249 | "\n",
250 | "attention_mask = normal_tokens_id[\"attention_mask\"]\n",
251 | "\n"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "metadata": {},
257 | "source": [
258 | "**Compare Generations**"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 9,
264 | "metadata": {
265 | "colab": {
266 | "base_uri": "https://localhost:8080/"
267 | },
268 | "id": "4_unN3rS1pNQ",
269 | "outputId": "ecaae29d-5b29-4160-921c-46a009313e93"
270 | },
271 | "outputs": [
272 | {
273 | "name": "stdout",
274 | "output_type": "stream",
275 | "text": [
276 | "Asserting Same Tokens\n",
277 | "--------------- comparing ---------------\n",
278 | "Prepacked 0 : \n",
279 | "The Baltimore Sun reports that the collapse of a bridge in Baltimore on Monday morning has left at least\n",
280 | "Default 0 : \n",
281 | "The Baltimore Sun reports that the collapse of a bridge in Baltimore on Monday morning has left at least\n",
282 | "--------------- comparing ---------------\n",
283 | "Prepacked 1 : \n",
284 | "The bridge was built in 1968 and was the first of its kind in the\n",
285 | "Default 1 : \n",
286 | "The bridge was built in 1968 and was the first of its kind in the\n",
287 | "--------------- comparing ---------------\n",
288 | "Prepacked 2 : Berlin. Berlin is a city of contrasts. Berlin is a city of contrasts. It is\n",
289 | "Default 2 : Berlin. Berlin is a city of contrasts. Berlin is a city of contrasts. It is\n",
290 | "--------------- comparing ---------------\n",
291 | "Prepacked 3 : Madrid, and it is the largest city in the country. The city is located in the central part\n",
292 | "Default 3 : Madrid, and it is the largest city in the country. The city is located in the central part\n",
293 | "--------------- comparing ---------------\n",
294 | "Prepacked 4 : Athens, the largest city in Greece and the second largest in Europe. The city is located on\n",
295 | "Default 4 : Athens, the largest city in Greece and the second largest in Europe. The city is located on\n",
296 | "--------------- comparing ---------------\n",
297 | "Prepacked 5 : gym. I'm going to do a 30 minute workout. I'm\n",
298 | "Default 5 : gym. I'm going to do a 30 minute workout. I'm\n",
299 | "--------------- comparing ---------------\n",
300 | "Prepacked 6 : 10 that the suspect was arrested on a warrant for a probation violation.\n",
301 | "The\n",
302 | "Default 6 : 10 that the suspect was arrested on a warrant for a probation violation.\n",
303 | "The\n",
304 | "--------------- comparing ---------------\n",
305 | "Prepacked 7 : name is Katie and I am a 20 year old student from the UK. I am\n",
306 | "Default 7 : name is Katie and I am a 20 year old student from the UK. I am\n",
307 | "--------------- comparing ---------------\n",
308 | "Prepacked 8 : ’s been a while since I’ve posted anything on this blog. I’ve been busy\n",
309 | "Default 8 : ’s been a while since I’ve posted anything on this blog. I’ve been busy\n"
310 | ]
311 | }
312 | ],
313 | "source": [
314 | "print(\"Asserting Same Tokens\")\n",
315 | "\n",
316 | "# Check tokens\n",
317 | "# Note that it is possible to have different generations due to numerical instability\n",
318 | "for i, (prepack_token, default_token) in enumerate(\n",
319 | " zip(prepack_generated_output.sequences, default_generated_output.sequences)\n",
320 | "):\n",
321 | "\n",
322 | " prepack = tokenizer.decode(prepack_token[1:])\n",
323 | " default = tokenizer.decode(default_token[attention_mask.shape[-1] :])\n",
324 | " print(\"-\" * 15, \"comparing\", \"-\" * 15)\n",
325 | " print(\"Prepacked\", i, \":\", prepack)\n",
326 | " print(\"Default\", i, \":\", default)\n",
327 | "\n",
328 | " assert prepack == default"
329 | ]
330 | }
331 | ],
332 | "metadata": {
333 | "accelerator": "GPU",
334 | "colab": {
335 | "gpuType": "T4",
336 | "provenance": []
337 | },
338 | "kernelspec": {
339 | "display_name": "Python 3",
340 | "name": "python3"
341 | },
342 | "language_info": {
343 | "name": "python"
344 | },
345 | "widgets": {
346 | "application/vnd.jupyter.widget-state+json": {
347 | "01c044c6e5364772b2032197bdca6a37": {
348 | "model_module": "@jupyter-widgets/controls",
349 | "model_module_version": "1.5.0",
350 | "model_name": "ProgressStyleModel",
351 | "state": {
352 | "_model_module": "@jupyter-widgets/controls",
353 | "_model_module_version": "1.5.0",
354 | "_model_name": "ProgressStyleModel",
355 | "_view_count": null,
356 | "_view_module": "@jupyter-widgets/base",
357 | "_view_module_version": "1.2.0",
358 | "_view_name": "StyleView",
359 | "bar_color": null,
360 | "description_width": ""
361 | }
362 | },
363 | "0b9b2e91fe864f25ae09c37bf9145cae": {
364 | "model_module": "@jupyter-widgets/controls",
365 | "model_module_version": "1.5.0",
366 | "model_name": "DescriptionStyleModel",
367 | "state": {
368 | "_model_module": "@jupyter-widgets/controls",
369 | "_model_module_version": "1.5.0",
370 | "_model_name": "DescriptionStyleModel",
371 | "_view_count": null,
372 | "_view_module": "@jupyter-widgets/base",
373 | "_view_module_version": "1.2.0",
374 | "_view_name": "StyleView",
375 | "description_width": ""
376 | }
377 | },
378 | "0c0d9cda00e045c2a16eab19347b49d2": {
379 | "model_module": "@jupyter-widgets/controls",
380 | "model_module_version": "1.5.0",
381 | "model_name": "HTMLModel",
382 | "state": {
383 | "_dom_classes": [],
384 | "_model_module": "@jupyter-widgets/controls",
385 | "_model_module_version": "1.5.0",
386 | "_model_name": "HTMLModel",
387 | "_view_count": null,
388 | "_view_module": "@jupyter-widgets/controls",
389 | "_view_module_version": "1.5.0",
390 | "_view_name": "HTMLView",
391 | "description": "",
392 | "description_tooltip": null,
393 | "layout": "IPY_MODEL_e4685f450cb543f482d41881e2555b34",
394 | "placeholder": "",
395 | "style": "IPY_MODEL_2cdfa2a8e85548ce9259285a04310822",
396 | "value": "special_tokens_map.json: 100%"
397 | }
398 | },
399 | "0f57f44193ef462490f4e1526637c874": {
400 | "model_module": "@jupyter-widgets/base",
401 | "model_module_version": "1.2.0",
402 | "model_name": "LayoutModel",
403 | "state": {
404 | "_model_module": "@jupyter-widgets/base",
405 | "_model_module_version": "1.2.0",
406 | "_model_name": "LayoutModel",
407 | "_view_count": null,
408 | "_view_module": "@jupyter-widgets/base",
409 | "_view_module_version": "1.2.0",
410 | "_view_name": "LayoutView",
411 | "align_content": null,
412 | "align_items": null,
413 | "align_self": null,
414 | "border": null,
415 | "bottom": null,
416 | "display": null,
417 | "flex": null,
418 | "flex_flow": null,
419 | "grid_area": null,
420 | "grid_auto_columns": null,
421 | "grid_auto_flow": null,
422 | "grid_auto_rows": null,
423 | "grid_column": null,
424 | "grid_gap": null,
425 | "grid_row": null,
426 | "grid_template_areas": null,
427 | "grid_template_columns": null,
428 | "grid_template_rows": null,
429 | "height": null,
430 | "justify_content": null,
431 | "justify_items": null,
432 | "left": null,
433 | "margin": null,
434 | "max_height": null,
435 | "max_width": null,
436 | "min_height": null,
437 | "min_width": null,
438 | "object_fit": null,
439 | "object_position": null,
440 | "order": null,
441 | "overflow": null,
442 | "overflow_x": null,
443 | "overflow_y": null,
444 | "padding": null,
445 | "right": null,
446 | "top": null,
447 | "visibility": null,
448 | "width": null
449 | }
450 | },
451 | "10a555507a6246ee9c2ff36a5405d63e": {
452 | "model_module": "@jupyter-widgets/base",
453 | "model_module_version": "1.2.0",
454 | "model_name": "LayoutModel",
455 | "state": {
456 | "_model_module": "@jupyter-widgets/base",
457 | "_model_module_version": "1.2.0",
458 | "_model_name": "LayoutModel",
459 | "_view_count": null,
460 | "_view_module": "@jupyter-widgets/base",
461 | "_view_module_version": "1.2.0",
462 | "_view_name": "LayoutView",
463 | "align_content": null,
464 | "align_items": null,
465 | "align_self": null,
466 | "border": null,
467 | "bottom": null,
468 | "display": null,
469 | "flex": null,
470 | "flex_flow": null,
471 | "grid_area": null,
472 | "grid_auto_columns": null,
473 | "grid_auto_flow": null,
474 | "grid_auto_rows": null,
475 | "grid_column": null,
476 | "grid_gap": null,
477 | "grid_row": null,
478 | "grid_template_areas": null,
479 | "grid_template_columns": null,
480 | "grid_template_rows": null,
481 | "height": null,
482 | "justify_content": null,
483 | "justify_items": null,
484 | "left": null,
485 | "margin": null,
486 | "max_height": null,
487 | "max_width": null,
488 | "min_height": null,
489 | "min_width": null,
490 | "object_fit": null,
491 | "object_position": null,
492 | "order": null,
493 | "overflow": null,
494 | "overflow_x": null,
495 | "overflow_y": null,
496 | "padding": null,
497 | "right": null,
498 | "top": null,
499 | "visibility": null,
500 | "width": null
501 | }
502 | },
503 | "12e2bb57b55c43b2a46ccdde99cb981b": {
504 | "model_module": "@jupyter-widgets/controls",
505 | "model_module_version": "1.5.0",
506 | "model_name": "HBoxModel",
507 | "state": {
508 | "_dom_classes": [],
509 | "_model_module": "@jupyter-widgets/controls",
510 | "_model_module_version": "1.5.0",
511 | "_model_name": "HBoxModel",
512 | "_view_count": null,
513 | "_view_module": "@jupyter-widgets/controls",
514 | "_view_module_version": "1.5.0",
515 | "_view_name": "HBoxView",
516 | "box_style": "",
517 | "children": [
518 | "IPY_MODEL_dd92cd36ad544a17a22374c0896947da",
519 | "IPY_MODEL_4c7aa0c521cc4e0e8f9bb8789e6ad8f9",
520 | "IPY_MODEL_454099c862aa46149f79467e54b7f7ac"
521 | ],
522 | "layout": "IPY_MODEL_10a555507a6246ee9c2ff36a5405d63e"
523 | }
524 | },
525 | "13a9fb93d75e41d180e7673acad43dab": {
526 | "model_module": "@jupyter-widgets/controls",
527 | "model_module_version": "1.5.0",
528 | "model_name": "DescriptionStyleModel",
529 | "state": {
530 | "_model_module": "@jupyter-widgets/controls",
531 | "_model_module_version": "1.5.0",
532 | "_model_name": "DescriptionStyleModel",
533 | "_view_count": null,
534 | "_view_module": "@jupyter-widgets/base",
535 | "_view_module_version": "1.2.0",
536 | "_view_name": "StyleView",
537 | "description_width": ""
538 | }
539 | },
540 | "1a0d37386914443f983a89cf56d7a900": {
541 | "model_module": "@jupyter-widgets/base",
542 | "model_module_version": "1.2.0",
543 | "model_name": "LayoutModel",
544 | "state": {
545 | "_model_module": "@jupyter-widgets/base",
546 | "_model_module_version": "1.2.0",
547 | "_model_name": "LayoutModel",
548 | "_view_count": null,
549 | "_view_module": "@jupyter-widgets/base",
550 | "_view_module_version": "1.2.0",
551 | "_view_name": "LayoutView",
552 | "align_content": null,
553 | "align_items": null,
554 | "align_self": null,
555 | "border": null,
556 | "bottom": null,
557 | "display": null,
558 | "flex": null,
559 | "flex_flow": null,
560 | "grid_area": null,
561 | "grid_auto_columns": null,
562 | "grid_auto_flow": null,
563 | "grid_auto_rows": null,
564 | "grid_column": null,
565 | "grid_gap": null,
566 | "grid_row": null,
567 | "grid_template_areas": null,
568 | "grid_template_columns": null,
569 | "grid_template_rows": null,
570 | "height": null,
571 | "justify_content": null,
572 | "justify_items": null,
573 | "left": null,
574 | "margin": null,
575 | "max_height": null,
576 | "max_width": null,
577 | "min_height": null,
578 | "min_width": null,
579 | "object_fit": null,
580 | "object_position": null,
581 | "order": null,
582 | "overflow": null,
583 | "overflow_x": null,
584 | "overflow_y": null,
585 | "padding": null,
586 | "right": null,
587 | "top": null,
588 | "visibility": null,
589 | "width": null
590 | }
591 | },
592 | "1a65ae611f96427e85b5e5ec21bca22d": {
593 | "model_module": "@jupyter-widgets/controls",
594 | "model_module_version": "1.5.0",
595 | "model_name": "HTMLModel",
596 | "state": {
597 | "_dom_classes": [],
598 | "_model_module": "@jupyter-widgets/controls",
599 | "_model_module_version": "1.5.0",
600 | "_model_name": "HTMLModel",
601 | "_view_count": null,
602 | "_view_module": "@jupyter-widgets/controls",
603 | "_view_module_version": "1.5.0",
604 | "_view_name": "HTMLView",
605 | "description": "",
606 | "description_tooltip": null,
607 | "layout": "IPY_MODEL_b3ffc98ce8fa41ddad546ceed1f3e0fb",
608 | "placeholder": "",
609 | "style": "IPY_MODEL_6d38a5c62d1243bc9c8d68fc0a20393b",
610 | "value": " 632/632 [00:00<00:00, 23.6kB/s]"
611 | }
612 | },
613 | "1b9ee7656cae4986ba45656fef43d154": {
614 | "model_module": "@jupyter-widgets/controls",
615 | "model_module_version": "1.5.0",
616 | "model_name": "DescriptionStyleModel",
617 | "state": {
618 | "_model_module": "@jupyter-widgets/controls",
619 | "_model_module_version": "1.5.0",
620 | "_model_name": "DescriptionStyleModel",
621 | "_view_count": null,
622 | "_view_module": "@jupyter-widgets/base",
623 | "_view_module_version": "1.2.0",
624 | "_view_name": "StyleView",
625 | "description_width": ""
626 | }
627 | },
628 | "1f057e7c483d41a8b63f8b59480962f1": {
629 | "model_module": "@jupyter-widgets/controls",
630 | "model_module_version": "1.5.0",
631 | "model_name": "ProgressStyleModel",
632 | "state": {
633 | "_model_module": "@jupyter-widgets/controls",
634 | "_model_module_version": "1.5.0",
635 | "_model_name": "ProgressStyleModel",
636 | "_view_count": null,
637 | "_view_module": "@jupyter-widgets/base",
638 | "_view_module_version": "1.2.0",
639 | "_view_name": "StyleView",
640 | "bar_color": null,
641 | "description_width": ""
642 | }
643 | },
644 | "1fecf11618b244909505323d2968e5fc": {
645 | "model_module": "@jupyter-widgets/controls",
646 | "model_module_version": "1.5.0",
647 | "model_name": "DescriptionStyleModel",
648 | "state": {
649 | "_model_module": "@jupyter-widgets/controls",
650 | "_model_module_version": "1.5.0",
651 | "_model_name": "DescriptionStyleModel",
652 | "_view_count": null,
653 | "_view_module": "@jupyter-widgets/base",
654 | "_view_module_version": "1.2.0",
655 | "_view_name": "StyleView",
656 | "description_width": ""
657 | }
658 | },
659 | "1ff88c3b38de421da1bcff730cccbf52": {
660 | "model_module": "@jupyter-widgets/base",
661 | "model_module_version": "1.2.0",
662 | "model_name": "LayoutModel",
663 | "state": {
664 | "_model_module": "@jupyter-widgets/base",
665 | "_model_module_version": "1.2.0",
666 | "_model_name": "LayoutModel",
667 | "_view_count": null,
668 | "_view_module": "@jupyter-widgets/base",
669 | "_view_module_version": "1.2.0",
670 | "_view_name": "LayoutView",
671 | "align_content": null,
672 | "align_items": null,
673 | "align_self": null,
674 | "border": null,
675 | "bottom": null,
676 | "display": null,
677 | "flex": null,
678 | "flex_flow": null,
679 | "grid_area": null,
680 | "grid_auto_columns": null,
681 | "grid_auto_flow": null,
682 | "grid_auto_rows": null,
683 | "grid_column": null,
684 | "grid_gap": null,
685 | "grid_row": null,
686 | "grid_template_areas": null,
687 | "grid_template_columns": null,
688 | "grid_template_rows": null,
689 | "height": null,
690 | "justify_content": null,
691 | "justify_items": null,
692 | "left": null,
693 | "margin": null,
694 | "max_height": null,
695 | "max_width": null,
696 | "min_height": null,
697 | "min_width": null,
698 | "object_fit": null,
699 | "object_position": null,
700 | "order": null,
701 | "overflow": null,
702 | "overflow_x": null,
703 | "overflow_y": null,
704 | "padding": null,
705 | "right": null,
706 | "top": null,
707 | "visibility": null,
708 | "width": null
709 | }
710 | },
711 | "240fed609ad44aafad099550d3c265df": {
712 | "model_module": "@jupyter-widgets/controls",
713 | "model_module_version": "1.5.0",
714 | "model_name": "FloatProgressModel",
715 | "state": {
716 | "_dom_classes": [],
717 | "_model_module": "@jupyter-widgets/controls",
718 | "_model_module_version": "1.5.0",
719 | "_model_name": "FloatProgressModel",
720 | "_view_count": null,
721 | "_view_module": "@jupyter-widgets/controls",
722 | "_view_module_version": "1.5.0",
723 | "_view_name": "ProgressView",
724 | "bar_style": "success",
725 | "description": "",
726 | "description_tooltip": null,
727 | "layout": "IPY_MODEL_b73c8d9ccf494d6193ff8262777a98da",
728 | "max": 1842665,
729 | "min": 0,
730 | "orientation": "horizontal",
731 | "style": "IPY_MODEL_01c044c6e5364772b2032197bdca6a37",
732 | "value": 1842665
733 | }
734 | },
735 | "24ffd51165764526962c2de4c004ae51": {
736 | "model_module": "@jupyter-widgets/base",
737 | "model_module_version": "1.2.0",
738 | "model_name": "LayoutModel",
739 | "state": {
740 | "_model_module": "@jupyter-widgets/base",
741 | "_model_module_version": "1.2.0",
742 | "_model_name": "LayoutModel",
743 | "_view_count": null,
744 | "_view_module": "@jupyter-widgets/base",
745 | "_view_module_version": "1.2.0",
746 | "_view_name": "LayoutView",
747 | "align_content": null,
748 | "align_items": null,
749 | "align_self": null,
750 | "border": null,
751 | "bottom": null,
752 | "display": null,
753 | "flex": null,
754 | "flex_flow": null,
755 | "grid_area": null,
756 | "grid_auto_columns": null,
757 | "grid_auto_flow": null,
758 | "grid_auto_rows": null,
759 | "grid_column": null,
760 | "grid_gap": null,
761 | "grid_row": null,
762 | "grid_template_areas": null,
763 | "grid_template_columns": null,
764 | "grid_template_rows": null,
765 | "height": null,
766 | "justify_content": null,
767 | "justify_items": null,
768 | "left": null,
769 | "margin": null,
770 | "max_height": null,
771 | "max_width": null,
772 | "min_height": null,
773 | "min_width": null,
774 | "object_fit": null,
775 | "object_position": null,
776 | "order": null,
777 | "overflow": null,
778 | "overflow_x": null,
779 | "overflow_y": null,
780 | "padding": null,
781 | "right": null,
782 | "top": null,
783 | "visibility": null,
784 | "width": null
785 | }
786 | },
787 | "256bd87c03b2449a9cf25303a84b96d6": {
788 | "model_module": "@jupyter-widgets/controls",
789 | "model_module_version": "1.5.0",
790 | "model_name": "FloatProgressModel",
791 | "state": {
792 | "_dom_classes": [],
793 | "_model_module": "@jupyter-widgets/controls",
794 | "_model_module_version": "1.5.0",
795 | "_model_name": "FloatProgressModel",
796 | "_view_count": null,
797 | "_view_module": "@jupyter-widgets/controls",
798 | "_view_module_version": "1.5.0",
799 | "_view_name": "ProgressView",
800 | "bar_style": "success",
801 | "description": "",
802 | "description_tooltip": null,
803 | "layout": "IPY_MODEL_a44eefdea1bd4b83a210f226f074e1d0",
804 | "max": 5381778489,
805 | "min": 0,
806 | "orientation": "horizontal",
807 | "style": "IPY_MODEL_4e3c9dbe0057456683e12c1357925ec5",
808 | "value": 5381778489
809 | }
810 | },
811 | "27732c5e88314d12a403bfbeea05b71a": {
812 | "model_module": "@jupyter-widgets/controls",
813 | "model_module_version": "1.5.0",
814 | "model_name": "HBoxModel",
815 | "state": {
816 | "_dom_classes": [],
817 | "_model_module": "@jupyter-widgets/controls",
818 | "_model_module_version": "1.5.0",
819 | "_model_name": "HBoxModel",
820 | "_view_count": null,
821 | "_view_module": "@jupyter-widgets/controls",
822 | "_view_module_version": "1.5.0",
823 | "_view_name": "HBoxView",
824 | "box_style": "",
825 | "children": [
826 | "IPY_MODEL_2bdbfdcdc9a54b8fad0068fda0fecd55",
827 | "IPY_MODEL_48ee78ad568843ddba1619a3e50404ce",
828 | "IPY_MODEL_68025228c36c419e88ce26efcac76456"
829 | ],
830 | "layout": "IPY_MODEL_1ff88c3b38de421da1bcff730cccbf52"
831 | }
832 | },
833 | "29ee1203a8444b579caeec2f9fe201e6": {
834 | "model_module": "@jupyter-widgets/base",
835 | "model_module_version": "1.2.0",
836 | "model_name": "LayoutModel",
837 | "state": {
838 | "_model_module": "@jupyter-widgets/base",
839 | "_model_module_version": "1.2.0",
840 | "_model_name": "LayoutModel",
841 | "_view_count": null,
842 | "_view_module": "@jupyter-widgets/base",
843 | "_view_module_version": "1.2.0",
844 | "_view_name": "LayoutView",
845 | "align_content": null,
846 | "align_items": null,
847 | "align_self": null,
848 | "border": null,
849 | "bottom": null,
850 | "display": null,
851 | "flex": null,
852 | "flex_flow": null,
853 | "grid_area": null,
854 | "grid_auto_columns": null,
855 | "grid_auto_flow": null,
856 | "grid_auto_rows": null,
857 | "grid_column": null,
858 | "grid_gap": null,
859 | "grid_row": null,
860 | "grid_template_areas": null,
861 | "grid_template_columns": null,
862 | "grid_template_rows": null,
863 | "height": null,
864 | "justify_content": null,
865 | "justify_items": null,
866 | "left": null,
867 | "margin": null,
868 | "max_height": null,
869 | "max_width": null,
870 | "min_height": null,
871 | "min_width": null,
872 | "object_fit": null,
873 | "object_position": null,
874 | "order": null,
875 | "overflow": null,
876 | "overflow_x": null,
877 | "overflow_y": null,
878 | "padding": null,
879 | "right": null,
880 | "top": null,
881 | "visibility": null,
882 | "width": null
883 | }
884 | },
885 | "2b5de781f41e4d6dbdc6c7939b405af0": {
886 | "model_module": "@jupyter-widgets/base",
887 | "model_module_version": "1.2.0",
888 | "model_name": "LayoutModel",
889 | "state": {
890 | "_model_module": "@jupyter-widgets/base",
891 | "_model_module_version": "1.2.0",
892 | "_model_name": "LayoutModel",
893 | "_view_count": null,
894 | "_view_module": "@jupyter-widgets/base",
895 | "_view_module_version": "1.2.0",
896 | "_view_name": "LayoutView",
897 | "align_content": null,
898 | "align_items": null,
899 | "align_self": null,
900 | "border": null,
901 | "bottom": null,
902 | "display": null,
903 | "flex": null,
904 | "flex_flow": null,
905 | "grid_area": null,
906 | "grid_auto_columns": null,
907 | "grid_auto_flow": null,
908 | "grid_auto_rows": null,
909 | "grid_column": null,
910 | "grid_gap": null,
911 | "grid_row": null,
912 | "grid_template_areas": null,
913 | "grid_template_columns": null,
914 | "grid_template_rows": null,
915 | "height": null,
916 | "justify_content": null,
917 | "justify_items": null,
918 | "left": null,
919 | "margin": null,
920 | "max_height": null,
921 | "max_width": null,
922 | "min_height": null,
923 | "min_width": null,
924 | "object_fit": null,
925 | "object_position": null,
926 | "order": null,
927 | "overflow": null,
928 | "overflow_x": null,
929 | "overflow_y": null,
930 | "padding": null,
931 | "right": null,
932 | "top": null,
933 | "visibility": null,
934 | "width": null
935 | }
936 | },
937 | "2bdbfdcdc9a54b8fad0068fda0fecd55": {
938 | "model_module": "@jupyter-widgets/controls",
939 | "model_module_version": "1.5.0",
940 | "model_name": "HTMLModel",
941 | "state": {
942 | "_dom_classes": [],
943 | "_model_module": "@jupyter-widgets/controls",
944 | "_model_module_version": "1.5.0",
945 | "_model_name": "HTMLModel",
946 | "_view_count": null,
947 | "_view_module": "@jupyter-widgets/controls",
948 | "_view_module_version": "1.5.0",
949 | "_view_name": "HTMLView",
950 | "description": "",
951 | "description_tooltip": null,
952 | "layout": "IPY_MODEL_4149a1ba8c174a83abcadc4bcb1ef98a",
953 | "placeholder": "",
954 | "style": "IPY_MODEL_0b9b2e91fe864f25ae09c37bf9145cae",
955 | "value": "tokenizer_config.json: 100%"
956 | }
957 | },
958 | "2cdfa2a8e85548ce9259285a04310822": {
959 | "model_module": "@jupyter-widgets/controls",
960 | "model_module_version": "1.5.0",
961 | "model_name": "DescriptionStyleModel",
962 | "state": {
963 | "_model_module": "@jupyter-widgets/controls",
964 | "_model_module_version": "1.5.0",
965 | "_model_name": "DescriptionStyleModel",
966 | "_view_count": null,
967 | "_view_module": "@jupyter-widgets/base",
968 | "_view_module_version": "1.2.0",
969 | "_view_name": "StyleView",
970 | "description_width": ""
971 | }
972 | },
973 | "3f21f1f034744945a9b0657e313a6350": {
974 | "model_module": "@jupyter-widgets/base",
975 | "model_module_version": "1.2.0",
976 | "model_name": "LayoutModel",
977 | "state": {
978 | "_model_module": "@jupyter-widgets/base",
979 | "_model_module_version": "1.2.0",
980 | "_model_name": "LayoutModel",
981 | "_view_count": null,
982 | "_view_module": "@jupyter-widgets/base",
983 | "_view_module_version": "1.2.0",
984 | "_view_name": "LayoutView",
985 | "align_content": null,
986 | "align_items": null,
987 | "align_self": null,
988 | "border": null,
989 | "bottom": null,
990 | "display": null,
991 | "flex": null,
992 | "flex_flow": null,
993 | "grid_area": null,
994 | "grid_auto_columns": null,
995 | "grid_auto_flow": null,
996 | "grid_auto_rows": null,
997 | "grid_column": null,
998 | "grid_gap": null,
999 | "grid_row": null,
1000 | "grid_template_areas": null,
1001 | "grid_template_columns": null,
1002 | "grid_template_rows": null,
1003 | "height": null,
1004 | "justify_content": null,
1005 | "justify_items": null,
1006 | "left": null,
1007 | "margin": null,
1008 | "max_height": null,
1009 | "max_width": null,
1010 | "min_height": null,
1011 | "min_width": null,
1012 | "object_fit": null,
1013 | "object_position": null,
1014 | "order": null,
1015 | "overflow": null,
1016 | "overflow_x": null,
1017 | "overflow_y": null,
1018 | "padding": null,
1019 | "right": null,
1020 | "top": null,
1021 | "visibility": null,
1022 | "width": null
1023 | }
1024 | },
1025 | "405acde9b0b74f18aaca033b36c28fbd": {
1026 | "model_module": "@jupyter-widgets/base",
1027 | "model_module_version": "1.2.0",
1028 | "model_name": "LayoutModel",
1029 | "state": {
1030 | "_model_module": "@jupyter-widgets/base",
1031 | "_model_module_version": "1.2.0",
1032 | "_model_name": "LayoutModel",
1033 | "_view_count": null,
1034 | "_view_module": "@jupyter-widgets/base",
1035 | "_view_module_version": "1.2.0",
1036 | "_view_name": "LayoutView",
1037 | "align_content": null,
1038 | "align_items": null,
1039 | "align_self": null,
1040 | "border": null,
1041 | "bottom": null,
1042 | "display": null,
1043 | "flex": null,
1044 | "flex_flow": null,
1045 | "grid_area": null,
1046 | "grid_auto_columns": null,
1047 | "grid_auto_flow": null,
1048 | "grid_auto_rows": null,
1049 | "grid_column": null,
1050 | "grid_gap": null,
1051 | "grid_row": null,
1052 | "grid_template_areas": null,
1053 | "grid_template_columns": null,
1054 | "grid_template_rows": null,
1055 | "height": null,
1056 | "justify_content": null,
1057 | "justify_items": null,
1058 | "left": null,
1059 | "margin": null,
1060 | "max_height": null,
1061 | "max_width": null,
1062 | "min_height": null,
1063 | "min_width": null,
1064 | "object_fit": null,
1065 | "object_position": null,
1066 | "order": null,
1067 | "overflow": null,
1068 | "overflow_x": null,
1069 | "overflow_y": null,
1070 | "padding": null,
1071 | "right": null,
1072 | "top": null,
1073 | "visibility": null,
1074 | "width": null
1075 | }
1076 | },
1077 | "4149a1ba8c174a83abcadc4bcb1ef98a": {
1078 | "model_module": "@jupyter-widgets/base",
1079 | "model_module_version": "1.2.0",
1080 | "model_name": "LayoutModel",
1081 | "state": {
1082 | "_model_module": "@jupyter-widgets/base",
1083 | "_model_module_version": "1.2.0",
1084 | "_model_name": "LayoutModel",
1085 | "_view_count": null,
1086 | "_view_module": "@jupyter-widgets/base",
1087 | "_view_module_version": "1.2.0",
1088 | "_view_name": "LayoutView",
1089 | "align_content": null,
1090 | "align_items": null,
1091 | "align_self": null,
1092 | "border": null,
1093 | "bottom": null,
1094 | "display": null,
1095 | "flex": null,
1096 | "flex_flow": null,
1097 | "grid_area": null,
1098 | "grid_auto_columns": null,
1099 | "grid_auto_flow": null,
1100 | "grid_auto_rows": null,
1101 | "grid_column": null,
1102 | "grid_gap": null,
1103 | "grid_row": null,
1104 | "grid_template_areas": null,
1105 | "grid_template_columns": null,
1106 | "grid_template_rows": null,
1107 | "height": null,
1108 | "justify_content": null,
1109 | "justify_items": null,
1110 | "left": null,
1111 | "margin": null,
1112 | "max_height": null,
1113 | "max_width": null,
1114 | "min_height": null,
1115 | "min_width": null,
1116 | "object_fit": null,
1117 | "object_position": null,
1118 | "order": null,
1119 | "overflow": null,
1120 | "overflow_x": null,
1121 | "overflow_y": null,
1122 | "padding": null,
1123 | "right": null,
1124 | "top": null,
1125 | "visibility": null,
1126 | "width": null
1127 | }
1128 | },
1129 | "42075dab0a5443b0b65e1ff8b2cb5e89": {
1130 | "model_module": "@jupyter-widgets/controls",
1131 | "model_module_version": "1.5.0",
1132 | "model_name": "DescriptionStyleModel",
1133 | "state": {
1134 | "_model_module": "@jupyter-widgets/controls",
1135 | "_model_module_version": "1.5.0",
1136 | "_model_name": "DescriptionStyleModel",
1137 | "_view_count": null,
1138 | "_view_module": "@jupyter-widgets/base",
1139 | "_view_module_version": "1.2.0",
1140 | "_view_name": "StyleView",
1141 | "description_width": ""
1142 | }
1143 | },
1144 | "454099c862aa46149f79467e54b7f7ac": {
1145 | "model_module": "@jupyter-widgets/controls",
1146 | "model_module_version": "1.5.0",
1147 | "model_name": "HTMLModel",
1148 | "state": {
1149 | "_dom_classes": [],
1150 | "_model_module": "@jupyter-widgets/controls",
1151 | "_model_module_version": "1.5.0",
1152 | "_model_name": "HTMLModel",
1153 | "_view_count": null,
1154 | "_view_module": "@jupyter-widgets/controls",
1155 | "_view_module_version": "1.5.0",
1156 | "_view_name": "HTMLView",
1157 | "description": "",
1158 | "description_tooltip": null,
1159 | "layout": "IPY_MODEL_405acde9b0b74f18aaca033b36c28fbd",
1160 | "placeholder": "",
1161 | "style": "IPY_MODEL_9c5163fdeb61471bbafdfe08393a6d80",
1162 | "value": " 132/132 [00:00<00:00, 8.87kB/s]"
1163 | }
1164 | },
1165 | "48e7aedb9eb8467ba5418860c5d99a07": {
1166 | "model_module": "@jupyter-widgets/controls",
1167 | "model_module_version": "1.5.0",
1168 | "model_name": "FloatProgressModel",
1169 | "state": {
1170 | "_dom_classes": [],
1171 | "_model_module": "@jupyter-widgets/controls",
1172 | "_model_module_version": "1.5.0",
1173 | "_model_name": "FloatProgressModel",
1174 | "_view_count": null,
1175 | "_view_module": "@jupyter-widgets/controls",
1176 | "_view_module_version": "1.5.0",
1177 | "_view_name": "ProgressView",
1178 | "bar_style": "success",
1179 | "description": "",
1180 | "description_tooltip": null,
1181 | "layout": "IPY_MODEL_c8219473c99d480081104fe4551036de",
1182 | "max": 632,
1183 | "min": 0,
1184 | "orientation": "horizontal",
1185 | "style": "IPY_MODEL_58524a4c38cf4289801ee2569c07948d",
1186 | "value": 632
1187 | }
1188 | },
1189 | "48ee78ad568843ddba1619a3e50404ce": {
1190 | "model_module": "@jupyter-widgets/controls",
1191 | "model_module_version": "1.5.0",
1192 | "model_name": "FloatProgressModel",
1193 | "state": {
1194 | "_dom_classes": [],
1195 | "_model_module": "@jupyter-widgets/controls",
1196 | "_model_module_version": "1.5.0",
1197 | "_model_name": "FloatProgressModel",
1198 | "_view_count": null,
1199 | "_view_module": "@jupyter-widgets/controls",
1200 | "_view_module_version": "1.5.0",
1201 | "_view_name": "ProgressView",
1202 | "bar_style": "success",
1203 | "description": "",
1204 | "description_tooltip": null,
1205 | "layout": "IPY_MODEL_782c8fe1763b4977a678d9fb4fd5dadc",
1206 | "max": 727,
1207 | "min": 0,
1208 | "orientation": "horizontal",
1209 | "style": "IPY_MODEL_d3b47705f4e14680b13b21be52d2e604",
1210 | "value": 727
1211 | }
1212 | },
1213 | "49536e13224f481bbb4a02a1391ab347": {
1214 | "model_module": "@jupyter-widgets/base",
1215 | "model_module_version": "1.2.0",
1216 | "model_name": "LayoutModel",
1217 | "state": {
1218 | "_model_module": "@jupyter-widgets/base",
1219 | "_model_module_version": "1.2.0",
1220 | "_model_name": "LayoutModel",
1221 | "_view_count": null,
1222 | "_view_module": "@jupyter-widgets/base",
1223 | "_view_module_version": "1.2.0",
1224 | "_view_name": "LayoutView",
1225 | "align_content": null,
1226 | "align_items": null,
1227 | "align_self": null,
1228 | "border": null,
1229 | "bottom": null,
1230 | "display": null,
1231 | "flex": null,
1232 | "flex_flow": null,
1233 | "grid_area": null,
1234 | "grid_auto_columns": null,
1235 | "grid_auto_flow": null,
1236 | "grid_auto_rows": null,
1237 | "grid_column": null,
1238 | "grid_gap": null,
1239 | "grid_row": null,
1240 | "grid_template_areas": null,
1241 | "grid_template_columns": null,
1242 | "grid_template_rows": null,
1243 | "height": null,
1244 | "justify_content": null,
1245 | "justify_items": null,
1246 | "left": null,
1247 | "margin": null,
1248 | "max_height": null,
1249 | "max_width": null,
1250 | "min_height": null,
1251 | "min_width": null,
1252 | "object_fit": null,
1253 | "object_position": null,
1254 | "order": null,
1255 | "overflow": null,
1256 | "overflow_x": null,
1257 | "overflow_y": null,
1258 | "padding": null,
1259 | "right": null,
1260 | "top": null,
1261 | "visibility": null,
1262 | "width": null
1263 | }
1264 | },
1265 | "4afabea1be0249068ad394e7265abf24": {
1266 | "model_module": "@jupyter-widgets/controls",
1267 | "model_module_version": "1.5.0",
1268 | "model_name": "HTMLModel",
1269 | "state": {
1270 | "_dom_classes": [],
1271 | "_model_module": "@jupyter-widgets/controls",
1272 | "_model_module_version": "1.5.0",
1273 | "_model_name": "HTMLModel",
1274 | "_view_count": null,
1275 | "_view_module": "@jupyter-widgets/controls",
1276 | "_view_module_version": "1.5.0",
1277 | "_view_name": "HTMLView",
1278 | "description": "",
1279 | "description_tooltip": null,
1280 | "layout": "IPY_MODEL_29ee1203a8444b579caeec2f9fe201e6",
1281 | "placeholder": "",
1282 | "style": "IPY_MODEL_5cbd1a0fe71c434ca0bd655d6347947f",
1283 | "value": " 1.84M/1.84M [00:00<00:00, 7.76MB/s]"
1284 | }
1285 | },
1286 | "4c44190d75cb4f21826fee5b25052886": {
1287 | "model_module": "@jupyter-widgets/controls",
1288 | "model_module_version": "1.5.0",
1289 | "model_name": "HTMLModel",
1290 | "state": {
1291 | "_dom_classes": [],
1292 | "_model_module": "@jupyter-widgets/controls",
1293 | "_model_module_version": "1.5.0",
1294 | "_model_name": "HTMLModel",
1295 | "_view_count": null,
1296 | "_view_module": "@jupyter-widgets/controls",
1297 | "_view_module_version": "1.5.0",
1298 | "_view_name": "HTMLView",
1299 | "description": "",
1300 | "description_tooltip": null,
1301 | "layout": "IPY_MODEL_9b5632fd42b7452295e7873e7c7a230f",
1302 | "placeholder": "",
1303 | "style": "IPY_MODEL_b5ca1b962b874a50bd3657a1d74ea998",
1304 | "value": "config.json: 100%"
1305 | }
1306 | },
1307 | "4c7aa0c521cc4e0e8f9bb8789e6ad8f9": {
1308 | "model_module": "@jupyter-widgets/controls",
1309 | "model_module_version": "1.5.0",
1310 | "model_name": "FloatProgressModel",
1311 | "state": {
1312 | "_dom_classes": [],
1313 | "_model_module": "@jupyter-widgets/controls",
1314 | "_model_module_version": "1.5.0",
1315 | "_model_name": "FloatProgressModel",
1316 | "_view_count": null,
1317 | "_view_module": "@jupyter-widgets/controls",
1318 | "_view_module_version": "1.5.0",
1319 | "_view_name": "ProgressView",
1320 | "bar_style": "success",
1321 | "description": "",
1322 | "description_tooltip": null,
1323 | "layout": "IPY_MODEL_cfee62ce88bf4a409ce34337556539cf",
1324 | "max": 132,
1325 | "min": 0,
1326 | "orientation": "horizontal",
1327 | "style": "IPY_MODEL_a16259ceb15b464fa87355a14aaac331",
1328 | "value": 132
1329 | }
1330 | },
1331 | "4e3c9dbe0057456683e12c1357925ec5": {
1332 | "model_module": "@jupyter-widgets/controls",
1333 | "model_module_version": "1.5.0",
1334 | "model_name": "ProgressStyleModel",
1335 | "state": {
1336 | "_model_module": "@jupyter-widgets/controls",
1337 | "_model_module_version": "1.5.0",
1338 | "_model_name": "ProgressStyleModel",
1339 | "_view_count": null,
1340 | "_view_module": "@jupyter-widgets/base",
1341 | "_view_module_version": "1.2.0",
1342 | "_view_name": "StyleView",
1343 | "bar_color": null,
1344 | "description_width": ""
1345 | }
1346 | },
1347 | "55690f32de47405da42599803e654fc6": {
1348 | "model_module": "@jupyter-widgets/base",
1349 | "model_module_version": "1.2.0",
1350 | "model_name": "LayoutModel",
1351 | "state": {
1352 | "_model_module": "@jupyter-widgets/base",
1353 | "_model_module_version": "1.2.0",
1354 | "_model_name": "LayoutModel",
1355 | "_view_count": null,
1356 | "_view_module": "@jupyter-widgets/base",
1357 | "_view_module_version": "1.2.0",
1358 | "_view_name": "LayoutView",
1359 | "align_content": null,
1360 | "align_items": null,
1361 | "align_self": null,
1362 | "border": null,
1363 | "bottom": null,
1364 | "display": null,
1365 | "flex": null,
1366 | "flex_flow": null,
1367 | "grid_area": null,
1368 | "grid_auto_columns": null,
1369 | "grid_auto_flow": null,
1370 | "grid_auto_rows": null,
1371 | "grid_column": null,
1372 | "grid_gap": null,
1373 | "grid_row": null,
1374 | "grid_template_areas": null,
1375 | "grid_template_columns": null,
1376 | "grid_template_rows": null,
1377 | "height": null,
1378 | "justify_content": null,
1379 | "justify_items": null,
1380 | "left": null,
1381 | "margin": null,
1382 | "max_height": null,
1383 | "max_width": null,
1384 | "min_height": null,
1385 | "min_width": null,
1386 | "object_fit": null,
1387 | "object_position": null,
1388 | "order": null,
1389 | "overflow": null,
1390 | "overflow_x": null,
1391 | "overflow_y": null,
1392 | "padding": null,
1393 | "right": null,
1394 | "top": null,
1395 | "visibility": null,
1396 | "width": null
1397 | }
1398 | },
1399 | "58524a4c38cf4289801ee2569c07948d": {
1400 | "model_module": "@jupyter-widgets/controls",
1401 | "model_module_version": "1.5.0",
1402 | "model_name": "ProgressStyleModel",
1403 | "state": {
1404 | "_model_module": "@jupyter-widgets/controls",
1405 | "_model_module_version": "1.5.0",
1406 | "_model_name": "ProgressStyleModel",
1407 | "_view_count": null,
1408 | "_view_module": "@jupyter-widgets/base",
1409 | "_view_module_version": "1.2.0",
1410 | "_view_name": "StyleView",
1411 | "bar_color": null,
1412 | "description_width": ""
1413 | }
1414 | },
1415 | "5cbd1a0fe71c434ca0bd655d6347947f": {
1416 | "model_module": "@jupyter-widgets/controls",
1417 | "model_module_version": "1.5.0",
1418 | "model_name": "DescriptionStyleModel",
1419 | "state": {
1420 | "_model_module": "@jupyter-widgets/controls",
1421 | "_model_module_version": "1.5.0",
1422 | "_model_name": "DescriptionStyleModel",
1423 | "_view_count": null,
1424 | "_view_module": "@jupyter-widgets/base",
1425 | "_view_module_version": "1.2.0",
1426 | "_view_name": "StyleView",
1427 | "description_width": ""
1428 | }
1429 | },
1430 | "68025228c36c419e88ce26efcac76456": {
1431 | "model_module": "@jupyter-widgets/controls",
1432 | "model_module_version": "1.5.0",
1433 | "model_name": "HTMLModel",
1434 | "state": {
1435 | "_dom_classes": [],
1436 | "_model_module": "@jupyter-widgets/controls",
1437 | "_model_module_version": "1.5.0",
1438 | "_model_name": "HTMLModel",
1439 | "_view_count": null,
1440 | "_view_module": "@jupyter-widgets/controls",
1441 | "_view_module_version": "1.5.0",
1442 | "_view_name": "HTMLView",
1443 | "description": "",
1444 | "description_tooltip": null,
1445 | "layout": "IPY_MODEL_2b5de781f41e4d6dbdc6c7939b405af0",
1446 | "placeholder": "",
1447 | "style": "IPY_MODEL_1b9ee7656cae4986ba45656fef43d154",
1448 | "value": " 727/727 [00:00<00:00, 48.7kB/s]"
1449 | }
1450 | },
1451 | "6d38a5c62d1243bc9c8d68fc0a20393b": {
1452 | "model_module": "@jupyter-widgets/controls",
1453 | "model_module_version": "1.5.0",
1454 | "model_name": "DescriptionStyleModel",
1455 | "state": {
1456 | "_model_module": "@jupyter-widgets/controls",
1457 | "_model_module_version": "1.5.0",
1458 | "_model_name": "DescriptionStyleModel",
1459 | "_view_count": null,
1460 | "_view_module": "@jupyter-widgets/base",
1461 | "_view_module_version": "1.2.0",
1462 | "_view_name": "StyleView",
1463 | "description_width": ""
1464 | }
1465 | },
1466 | "6e238cbf7ab84e5b9d6e6d756ccd9487": {
1467 | "model_module": "@jupyter-widgets/controls",
1468 | "model_module_version": "1.5.0",
1469 | "model_name": "FloatProgressModel",
1470 | "state": {
1471 | "_dom_classes": [],
1472 | "_model_module": "@jupyter-widgets/controls",
1473 | "_model_module_version": "1.5.0",
1474 | "_model_name": "FloatProgressModel",
1475 | "_view_count": null,
1476 | "_view_module": "@jupyter-widgets/controls",
1477 | "_view_module_version": "1.5.0",
1478 | "_view_name": "ProgressView",
1479 | "bar_style": "success",
1480 | "description": "",
1481 | "description_tooltip": null,
1482 | "layout": "IPY_MODEL_0f57f44193ef462490f4e1526637c874",
1483 | "max": 499723,
1484 | "min": 0,
1485 | "orientation": "horizontal",
1486 | "style": "IPY_MODEL_94f49a2b0fca4fa5a22ea8a07f628b99",
1487 | "value": 499723
1488 | }
1489 | },
1490 | "6ea49b6d6530403fb9acf7a15493b6e9": {
1491 | "model_module": "@jupyter-widgets/controls",
1492 | "model_module_version": "1.5.0",
1493 | "model_name": "DescriptionStyleModel",
1494 | "state": {
1495 | "_model_module": "@jupyter-widgets/controls",
1496 | "_model_module_version": "1.5.0",
1497 | "_model_name": "DescriptionStyleModel",
1498 | "_view_count": null,
1499 | "_view_module": "@jupyter-widgets/base",
1500 | "_view_module_version": "1.2.0",
1501 | "_view_name": "StyleView",
1502 | "description_width": ""
1503 | }
1504 | },
1505 | "7562626ebef743cfb5c234a8babd884d": {
1506 | "model_module": "@jupyter-widgets/controls",
1507 | "model_module_version": "1.5.0",
1508 | "model_name": "HTMLModel",
1509 | "state": {
1510 | "_dom_classes": [],
1511 | "_model_module": "@jupyter-widgets/controls",
1512 | "_model_module_version": "1.5.0",
1513 | "_model_name": "HTMLModel",
1514 | "_view_count": null,
1515 | "_view_module": "@jupyter-widgets/controls",
1516 | "_view_module_version": "1.5.0",
1517 | "_view_name": "HTMLView",
1518 | "description": "",
1519 | "description_tooltip": null,
1520 | "layout": "IPY_MODEL_8119e15ecae64227857fb5a927c95877",
1521 | "placeholder": "",
1522 | "style": "IPY_MODEL_abfc8e73b9f94e4aa684de84a6194e20",
1523 | "value": " 411/411 [00:00<00:00, 12.4kB/s]"
1524 | }
1525 | },
1526 | "782c8fe1763b4977a678d9fb4fd5dadc": {
1527 | "model_module": "@jupyter-widgets/base",
1528 | "model_module_version": "1.2.0",
1529 | "model_name": "LayoutModel",
1530 | "state": {
1531 | "_model_module": "@jupyter-widgets/base",
1532 | "_model_module_version": "1.2.0",
1533 | "_model_name": "LayoutModel",
1534 | "_view_count": null,
1535 | "_view_module": "@jupyter-widgets/base",
1536 | "_view_module_version": "1.2.0",
1537 | "_view_name": "LayoutView",
1538 | "align_content": null,
1539 | "align_items": null,
1540 | "align_self": null,
1541 | "border": null,
1542 | "bottom": null,
1543 | "display": null,
1544 | "flex": null,
1545 | "flex_flow": null,
1546 | "grid_area": null,
1547 | "grid_auto_columns": null,
1548 | "grid_auto_flow": null,
1549 | "grid_auto_rows": null,
1550 | "grid_column": null,
1551 | "grid_gap": null,
1552 | "grid_row": null,
1553 | "grid_template_areas": null,
1554 | "grid_template_columns": null,
1555 | "grid_template_rows": null,
1556 | "height": null,
1557 | "justify_content": null,
1558 | "justify_items": null,
1559 | "left": null,
1560 | "margin": null,
1561 | "max_height": null,
1562 | "max_width": null,
1563 | "min_height": null,
1564 | "min_width": null,
1565 | "object_fit": null,
1566 | "object_position": null,
1567 | "order": null,
1568 | "overflow": null,
1569 | "overflow_x": null,
1570 | "overflow_y": null,
1571 | "padding": null,
1572 | "right": null,
1573 | "top": null,
1574 | "visibility": null,
1575 | "width": null
1576 | }
1577 | },
1578 | "7f160c6489514b35945179d204c06bc7": {
1579 | "model_module": "@jupyter-widgets/controls",
1580 | "model_module_version": "1.5.0",
1581 | "model_name": "HTMLModel",
1582 | "state": {
1583 | "_dom_classes": [],
1584 | "_model_module": "@jupyter-widgets/controls",
1585 | "_model_module_version": "1.5.0",
1586 | "_model_name": "HTMLModel",
1587 | "_view_count": null,
1588 | "_view_module": "@jupyter-widgets/controls",
1589 | "_view_module_version": "1.5.0",
1590 | "_view_name": "HTMLView",
1591 | "description": "",
1592 | "description_tooltip": null,
1593 | "layout": "IPY_MODEL_3f21f1f034744945a9b0657e313a6350",
1594 | "placeholder": "",
1595 | "style": "IPY_MODEL_e98ec60115014dc0a5ec44fdec3a0073",
1596 | "value": " 500k/500k [00:00<00:00, 9.94MB/s]"
1597 | }
1598 | },
1599 | "8119e15ecae64227857fb5a927c95877": {
1600 | "model_module": "@jupyter-widgets/base",
1601 | "model_module_version": "1.2.0",
1602 | "model_name": "LayoutModel",
1603 | "state": {
1604 | "_model_module": "@jupyter-widgets/base",
1605 | "_model_module_version": "1.2.0",
1606 | "_model_name": "LayoutModel",
1607 | "_view_count": null,
1608 | "_view_module": "@jupyter-widgets/base",
1609 | "_view_module_version": "1.2.0",
1610 | "_view_name": "LayoutView",
1611 | "align_content": null,
1612 | "align_items": null,
1613 | "align_self": null,
1614 | "border": null,
1615 | "bottom": null,
1616 | "display": null,
1617 | "flex": null,
1618 | "flex_flow": null,
1619 | "grid_area": null,
1620 | "grid_auto_columns": null,
1621 | "grid_auto_flow": null,
1622 | "grid_auto_rows": null,
1623 | "grid_column": null,
1624 | "grid_gap": null,
1625 | "grid_row": null,
1626 | "grid_template_areas": null,
1627 | "grid_template_columns": null,
1628 | "grid_template_rows": null,
1629 | "height": null,
1630 | "justify_content": null,
1631 | "justify_items": null,
1632 | "left": null,
1633 | "margin": null,
1634 | "max_height": null,
1635 | "max_width": null,
1636 | "min_height": null,
1637 | "min_width": null,
1638 | "object_fit": null,
1639 | "object_position": null,
1640 | "order": null,
1641 | "overflow": null,
1642 | "overflow_x": null,
1643 | "overflow_y": null,
1644 | "padding": null,
1645 | "right": null,
1646 | "top": null,
1647 | "visibility": null,
1648 | "width": null
1649 | }
1650 | },
1651 | "8246ecc55fbd476ba7fca6a183e4d741": {
1652 | "model_module": "@jupyter-widgets/base",
1653 | "model_module_version": "1.2.0",
1654 | "model_name": "LayoutModel",
1655 | "state": {
1656 | "_model_module": "@jupyter-widgets/base",
1657 | "_model_module_version": "1.2.0",
1658 | "_model_name": "LayoutModel",
1659 | "_view_count": null,
1660 | "_view_module": "@jupyter-widgets/base",
1661 | "_view_module_version": "1.2.0",
1662 | "_view_name": "LayoutView",
1663 | "align_content": null,
1664 | "align_items": null,
1665 | "align_self": null,
1666 | "border": null,
1667 | "bottom": null,
1668 | "display": null,
1669 | "flex": null,
1670 | "flex_flow": null,
1671 | "grid_area": null,
1672 | "grid_auto_columns": null,
1673 | "grid_auto_flow": null,
1674 | "grid_auto_rows": null,
1675 | "grid_column": null,
1676 | "grid_gap": null,
1677 | "grid_row": null,
1678 | "grid_template_areas": null,
1679 | "grid_template_columns": null,
1680 | "grid_template_rows": null,
1681 | "height": null,
1682 | "justify_content": null,
1683 | "justify_items": null,
1684 | "left": null,
1685 | "margin": null,
1686 | "max_height": null,
1687 | "max_width": null,
1688 | "min_height": null,
1689 | "min_width": null,
1690 | "object_fit": null,
1691 | "object_position": null,
1692 | "order": null,
1693 | "overflow": null,
1694 | "overflow_x": null,
1695 | "overflow_y": null,
1696 | "padding": null,
1697 | "right": null,
1698 | "top": null,
1699 | "visibility": null,
1700 | "width": null
1701 | }
1702 | },
1703 | "8322ac7b0d8644d78804e6d85ff05baa": {
1704 | "model_module": "@jupyter-widgets/base",
1705 | "model_module_version": "1.2.0",
1706 | "model_name": "LayoutModel",
1707 | "state": {
1708 | "_model_module": "@jupyter-widgets/base",
1709 | "_model_module_version": "1.2.0",
1710 | "_model_name": "LayoutModel",
1711 | "_view_count": null,
1712 | "_view_module": "@jupyter-widgets/base",
1713 | "_view_module_version": "1.2.0",
1714 | "_view_name": "LayoutView",
1715 | "align_content": null,
1716 | "align_items": null,
1717 | "align_self": null,
1718 | "border": null,
1719 | "bottom": null,
1720 | "display": null,
1721 | "flex": null,
1722 | "flex_flow": null,
1723 | "grid_area": null,
1724 | "grid_auto_columns": null,
1725 | "grid_auto_flow": null,
1726 | "grid_auto_rows": null,
1727 | "grid_column": null,
1728 | "grid_gap": null,
1729 | "grid_row": null,
1730 | "grid_template_areas": null,
1731 | "grid_template_columns": null,
1732 | "grid_template_rows": null,
1733 | "height": null,
1734 | "justify_content": null,
1735 | "justify_items": null,
1736 | "left": null,
1737 | "margin": null,
1738 | "max_height": null,
1739 | "max_width": null,
1740 | "min_height": null,
1741 | "min_width": null,
1742 | "object_fit": null,
1743 | "object_position": null,
1744 | "order": null,
1745 | "overflow": null,
1746 | "overflow_x": null,
1747 | "overflow_y": null,
1748 | "padding": null,
1749 | "right": null,
1750 | "top": null,
1751 | "visibility": null,
1752 | "width": null
1753 | }
1754 | },
1755 | "88795b72502645b4bdc95beadb4471d5": {
1756 | "model_module": "@jupyter-widgets/controls",
1757 | "model_module_version": "1.5.0",
1758 | "model_name": "DescriptionStyleModel",
1759 | "state": {
1760 | "_model_module": "@jupyter-widgets/controls",
1761 | "_model_module_version": "1.5.0",
1762 | "_model_name": "DescriptionStyleModel",
1763 | "_view_count": null,
1764 | "_view_module": "@jupyter-widgets/base",
1765 | "_view_module_version": "1.2.0",
1766 | "_view_name": "StyleView",
1767 | "description_width": ""
1768 | }
1769 | },
1770 | "8c2d4bf32f5e46539cfe78ba705e379a": {
1771 | "model_module": "@jupyter-widgets/controls",
1772 | "model_module_version": "1.5.0",
1773 | "model_name": "HBoxModel",
1774 | "state": {
1775 | "_dom_classes": [],
1776 | "_model_module": "@jupyter-widgets/controls",
1777 | "_model_module_version": "1.5.0",
1778 | "_model_name": "HBoxModel",
1779 | "_view_count": null,
1780 | "_view_module": "@jupyter-widgets/controls",
1781 | "_view_module_version": "1.5.0",
1782 | "_view_name": "HBoxView",
1783 | "box_style": "",
1784 | "children": [
1785 | "IPY_MODEL_d160b88d7f00493d81c424be93d7030f",
1786 | "IPY_MODEL_256bd87c03b2449a9cf25303a84b96d6",
1787 | "IPY_MODEL_e05657d7cdab4c4082ebc87b6d903969"
1788 | ],
1789 | "layout": "IPY_MODEL_8246ecc55fbd476ba7fca6a183e4d741"
1790 | }
1791 | },
1792 | "91bf2b38bd224cbfabc77aa85d75f48b": {
1793 | "model_module": "@jupyter-widgets/base",
1794 | "model_module_version": "1.2.0",
1795 | "model_name": "LayoutModel",
1796 | "state": {
1797 | "_model_module": "@jupyter-widgets/base",
1798 | "_model_module_version": "1.2.0",
1799 | "_model_name": "LayoutModel",
1800 | "_view_count": null,
1801 | "_view_module": "@jupyter-widgets/base",
1802 | "_view_module_version": "1.2.0",
1803 | "_view_name": "LayoutView",
1804 | "align_content": null,
1805 | "align_items": null,
1806 | "align_self": null,
1807 | "border": null,
1808 | "bottom": null,
1809 | "display": null,
1810 | "flex": null,
1811 | "flex_flow": null,
1812 | "grid_area": null,
1813 | "grid_auto_columns": null,
1814 | "grid_auto_flow": null,
1815 | "grid_auto_rows": null,
1816 | "grid_column": null,
1817 | "grid_gap": null,
1818 | "grid_row": null,
1819 | "grid_template_areas": null,
1820 | "grid_template_columns": null,
1821 | "grid_template_rows": null,
1822 | "height": null,
1823 | "justify_content": null,
1824 | "justify_items": null,
1825 | "left": null,
1826 | "margin": null,
1827 | "max_height": null,
1828 | "max_width": null,
1829 | "min_height": null,
1830 | "min_width": null,
1831 | "object_fit": null,
1832 | "object_position": null,
1833 | "order": null,
1834 | "overflow": null,
1835 | "overflow_x": null,
1836 | "overflow_y": null,
1837 | "padding": null,
1838 | "right": null,
1839 | "top": null,
1840 | "visibility": null,
1841 | "width": null
1842 | }
1843 | },
1844 | "94f49a2b0fca4fa5a22ea8a07f628b99": {
1845 | "model_module": "@jupyter-widgets/controls",
1846 | "model_module_version": "1.5.0",
1847 | "model_name": "ProgressStyleModel",
1848 | "state": {
1849 | "_model_module": "@jupyter-widgets/controls",
1850 | "_model_module_version": "1.5.0",
1851 | "_model_name": "ProgressStyleModel",
1852 | "_view_count": null,
1853 | "_view_module": "@jupyter-widgets/base",
1854 | "_view_module_version": "1.2.0",
1855 | "_view_name": "StyleView",
1856 | "bar_color": null,
1857 | "description_width": ""
1858 | }
1859 | },
1860 | "95584181b8fd41eebb322c951c72c1f6": {
1861 | "model_module": "@jupyter-widgets/base",
1862 | "model_module_version": "1.2.0",
1863 | "model_name": "LayoutModel",
1864 | "state": {
1865 | "_model_module": "@jupyter-widgets/base",
1866 | "_model_module_version": "1.2.0",
1867 | "_model_name": "LayoutModel",
1868 | "_view_count": null,
1869 | "_view_module": "@jupyter-widgets/base",
1870 | "_view_module_version": "1.2.0",
1871 | "_view_name": "LayoutView",
1872 | "align_content": null,
1873 | "align_items": null,
1874 | "align_self": null,
1875 | "border": null,
1876 | "bottom": null,
1877 | "display": null,
1878 | "flex": null,
1879 | "flex_flow": null,
1880 | "grid_area": null,
1881 | "grid_auto_columns": null,
1882 | "grid_auto_flow": null,
1883 | "grid_auto_rows": null,
1884 | "grid_column": null,
1885 | "grid_gap": null,
1886 | "grid_row": null,
1887 | "grid_template_areas": null,
1888 | "grid_template_columns": null,
1889 | "grid_template_rows": null,
1890 | "height": null,
1891 | "justify_content": null,
1892 | "justify_items": null,
1893 | "left": null,
1894 | "margin": null,
1895 | "max_height": null,
1896 | "max_width": null,
1897 | "min_height": null,
1898 | "min_width": null,
1899 | "object_fit": null,
1900 | "object_position": null,
1901 | "order": null,
1902 | "overflow": null,
1903 | "overflow_x": null,
1904 | "overflow_y": null,
1905 | "padding": null,
1906 | "right": null,
1907 | "top": null,
1908 | "visibility": null,
1909 | "width": null
1910 | }
1911 | },
1912 | "975d4035cf144139b455fcb06dde003e": {
1913 | "model_module": "@jupyter-widgets/controls",
1914 | "model_module_version": "1.5.0",
1915 | "model_name": "HBoxModel",
1916 | "state": {
1917 | "_dom_classes": [],
1918 | "_model_module": "@jupyter-widgets/controls",
1919 | "_model_module_version": "1.5.0",
1920 | "_model_name": "HBoxModel",
1921 | "_view_count": null,
1922 | "_view_module": "@jupyter-widgets/controls",
1923 | "_view_module_version": "1.5.0",
1924 | "_view_name": "HBoxView",
1925 | "box_style": "",
1926 | "children": [
1927 | "IPY_MODEL_4c44190d75cb4f21826fee5b25052886",
1928 | "IPY_MODEL_48e7aedb9eb8467ba5418860c5d99a07",
1929 | "IPY_MODEL_1a65ae611f96427e85b5e5ec21bca22d"
1930 | ],
1931 | "layout": "IPY_MODEL_49536e13224f481bbb4a02a1391ab347"
1932 | }
1933 | },
1934 | "9b5632fd42b7452295e7873e7c7a230f": {
1935 | "model_module": "@jupyter-widgets/base",
1936 | "model_module_version": "1.2.0",
1937 | "model_name": "LayoutModel",
1938 | "state": {
1939 | "_model_module": "@jupyter-widgets/base",
1940 | "_model_module_version": "1.2.0",
1941 | "_model_name": "LayoutModel",
1942 | "_view_count": null,
1943 | "_view_module": "@jupyter-widgets/base",
1944 | "_view_module_version": "1.2.0",
1945 | "_view_name": "LayoutView",
1946 | "align_content": null,
1947 | "align_items": null,
1948 | "align_self": null,
1949 | "border": null,
1950 | "bottom": null,
1951 | "display": null,
1952 | "flex": null,
1953 | "flex_flow": null,
1954 | "grid_area": null,
1955 | "grid_auto_columns": null,
1956 | "grid_auto_flow": null,
1957 | "grid_auto_rows": null,
1958 | "grid_column": null,
1959 | "grid_gap": null,
1960 | "grid_row": null,
1961 | "grid_template_areas": null,
1962 | "grid_template_columns": null,
1963 | "grid_template_rows": null,
1964 | "height": null,
1965 | "justify_content": null,
1966 | "justify_items": null,
1967 | "left": null,
1968 | "margin": null,
1969 | "max_height": null,
1970 | "max_width": null,
1971 | "min_height": null,
1972 | "min_width": null,
1973 | "object_fit": null,
1974 | "object_position": null,
1975 | "order": null,
1976 | "overflow": null,
1977 | "overflow_x": null,
1978 | "overflow_y": null,
1979 | "padding": null,
1980 | "right": null,
1981 | "top": null,
1982 | "visibility": null,
1983 | "width": null
1984 | }
1985 | },
1986 | "9c5163fdeb61471bbafdfe08393a6d80": {
1987 | "model_module": "@jupyter-widgets/controls",
1988 | "model_module_version": "1.5.0",
1989 | "model_name": "DescriptionStyleModel",
1990 | "state": {
1991 | "_model_module": "@jupyter-widgets/controls",
1992 | "_model_module_version": "1.5.0",
1993 | "_model_name": "DescriptionStyleModel",
1994 | "_view_count": null,
1995 | "_view_module": "@jupyter-widgets/base",
1996 | "_view_module_version": "1.2.0",
1997 | "_view_name": "StyleView",
1998 | "description_width": ""
1999 | }
2000 | },
2001 | "a001503024fe44c8acc0d5e1a6d561b7": {
2002 | "model_module": "@jupyter-widgets/controls",
2003 | "model_module_version": "1.5.0",
2004 | "model_name": "HBoxModel",
2005 | "state": {
2006 | "_dom_classes": [],
2007 | "_model_module": "@jupyter-widgets/controls",
2008 | "_model_module_version": "1.5.0",
2009 | "_model_name": "HBoxModel",
2010 | "_view_count": null,
2011 | "_view_module": "@jupyter-widgets/controls",
2012 | "_view_module_version": "1.5.0",
2013 | "_view_name": "HBoxView",
2014 | "box_style": "",
2015 | "children": [
2016 | "IPY_MODEL_0c0d9cda00e045c2a16eab19347b49d2",
2017 | "IPY_MODEL_aa4160a1837f427bb8ae673d31a0f24a",
2018 | "IPY_MODEL_7562626ebef743cfb5c234a8babd884d"
2019 | ],
2020 | "layout": "IPY_MODEL_ed4321893f5f4630bf2dbf1f4a38bd39"
2021 | }
2022 | },
2023 | "a16259ceb15b464fa87355a14aaac331": {
2024 | "model_module": "@jupyter-widgets/controls",
2025 | "model_module_version": "1.5.0",
2026 | "model_name": "ProgressStyleModel",
2027 | "state": {
2028 | "_model_module": "@jupyter-widgets/controls",
2029 | "_model_module_version": "1.5.0",
2030 | "_model_name": "ProgressStyleModel",
2031 | "_view_count": null,
2032 | "_view_module": "@jupyter-widgets/base",
2033 | "_view_module_version": "1.2.0",
2034 | "_view_name": "StyleView",
2035 | "bar_color": null,
2036 | "description_width": ""
2037 | }
2038 | },
2039 | "a44eefdea1bd4b83a210f226f074e1d0": {
2040 | "model_module": "@jupyter-widgets/base",
2041 | "model_module_version": "1.2.0",
2042 | "model_name": "LayoutModel",
2043 | "state": {
2044 | "_model_module": "@jupyter-widgets/base",
2045 | "_model_module_version": "1.2.0",
2046 | "_model_name": "LayoutModel",
2047 | "_view_count": null,
2048 | "_view_module": "@jupyter-widgets/base",
2049 | "_view_module_version": "1.2.0",
2050 | "_view_name": "LayoutView",
2051 | "align_content": null,
2052 | "align_items": null,
2053 | "align_self": null,
2054 | "border": null,
2055 | "bottom": null,
2056 | "display": null,
2057 | "flex": null,
2058 | "flex_flow": null,
2059 | "grid_area": null,
2060 | "grid_auto_columns": null,
2061 | "grid_auto_flow": null,
2062 | "grid_auto_rows": null,
2063 | "grid_column": null,
2064 | "grid_gap": null,
2065 | "grid_row": null,
2066 | "grid_template_areas": null,
2067 | "grid_template_columns": null,
2068 | "grid_template_rows": null,
2069 | "height": null,
2070 | "justify_content": null,
2071 | "justify_items": null,
2072 | "left": null,
2073 | "margin": null,
2074 | "max_height": null,
2075 | "max_width": null,
2076 | "min_height": null,
2077 | "min_width": null,
2078 | "object_fit": null,
2079 | "object_position": null,
2080 | "order": null,
2081 | "overflow": null,
2082 | "overflow_x": null,
2083 | "overflow_y": null,
2084 | "padding": null,
2085 | "right": null,
2086 | "top": null,
2087 | "visibility": null,
2088 | "width": null
2089 | }
2090 | },
2091 | "aa4160a1837f427bb8ae673d31a0f24a": {
2092 | "model_module": "@jupyter-widgets/controls",
2093 | "model_module_version": "1.5.0",
2094 | "model_name": "FloatProgressModel",
2095 | "state": {
2096 | "_dom_classes": [],
2097 | "_model_module": "@jupyter-widgets/controls",
2098 | "_model_module_version": "1.5.0",
2099 | "_model_name": "FloatProgressModel",
2100 | "_view_count": null,
2101 | "_view_module": "@jupyter-widgets/controls",
2102 | "_view_module_version": "1.5.0",
2103 | "_view_name": "ProgressView",
2104 | "bar_style": "success",
2105 | "description": "",
2106 | "description_tooltip": null,
2107 | "layout": "IPY_MODEL_bb6f8aa46e3a459aacc6c99f9f76c72c",
2108 | "max": 411,
2109 | "min": 0,
2110 | "orientation": "horizontal",
2111 | "style": "IPY_MODEL_1f057e7c483d41a8b63f8b59480962f1",
2112 | "value": 411
2113 | }
2114 | },
2115 | "aadb942209e84408a2bf11e55412b574": {
2116 | "model_module": "@jupyter-widgets/controls",
2117 | "model_module_version": "1.5.0",
2118 | "model_name": "HBoxModel",
2119 | "state": {
2120 | "_dom_classes": [],
2121 | "_model_module": "@jupyter-widgets/controls",
2122 | "_model_module_version": "1.5.0",
2123 | "_model_name": "HBoxModel",
2124 | "_view_count": null,
2125 | "_view_module": "@jupyter-widgets/controls",
2126 | "_view_module_version": "1.5.0",
2127 | "_view_name": "HBoxView",
2128 | "box_style": "",
2129 | "children": [
2130 | "IPY_MODEL_fa9cafdcab2f439b9196f6f05166d898",
2131 | "IPY_MODEL_240fed609ad44aafad099550d3c265df",
2132 | "IPY_MODEL_4afabea1be0249068ad394e7265abf24"
2133 | ],
2134 | "layout": "IPY_MODEL_91bf2b38bd224cbfabc77aa85d75f48b"
2135 | }
2136 | },
2137 | "abfc8e73b9f94e4aa684de84a6194e20": {
2138 | "model_module": "@jupyter-widgets/controls",
2139 | "model_module_version": "1.5.0",
2140 | "model_name": "DescriptionStyleModel",
2141 | "state": {
2142 | "_model_module": "@jupyter-widgets/controls",
2143 | "_model_module_version": "1.5.0",
2144 | "_model_name": "DescriptionStyleModel",
2145 | "_view_count": null,
2146 | "_view_module": "@jupyter-widgets/base",
2147 | "_view_module_version": "1.2.0",
2148 | "_view_name": "StyleView",
2149 | "description_width": ""
2150 | }
2151 | },
2152 | "b3ffc98ce8fa41ddad546ceed1f3e0fb": {
2153 | "model_module": "@jupyter-widgets/base",
2154 | "model_module_version": "1.2.0",
2155 | "model_name": "LayoutModel",
2156 | "state": {
2157 | "_model_module": "@jupyter-widgets/base",
2158 | "_model_module_version": "1.2.0",
2159 | "_model_name": "LayoutModel",
2160 | "_view_count": null,
2161 | "_view_module": "@jupyter-widgets/base",
2162 | "_view_module_version": "1.2.0",
2163 | "_view_name": "LayoutView",
2164 | "align_content": null,
2165 | "align_items": null,
2166 | "align_self": null,
2167 | "border": null,
2168 | "bottom": null,
2169 | "display": null,
2170 | "flex": null,
2171 | "flex_flow": null,
2172 | "grid_area": null,
2173 | "grid_auto_columns": null,
2174 | "grid_auto_flow": null,
2175 | "grid_auto_rows": null,
2176 | "grid_column": null,
2177 | "grid_gap": null,
2178 | "grid_row": null,
2179 | "grid_template_areas": null,
2180 | "grid_template_columns": null,
2181 | "grid_template_rows": null,
2182 | "height": null,
2183 | "justify_content": null,
2184 | "justify_items": null,
2185 | "left": null,
2186 | "margin": null,
2187 | "max_height": null,
2188 | "max_width": null,
2189 | "min_height": null,
2190 | "min_width": null,
2191 | "object_fit": null,
2192 | "object_position": null,
2193 | "order": null,
2194 | "overflow": null,
2195 | "overflow_x": null,
2196 | "overflow_y": null,
2197 | "padding": null,
2198 | "right": null,
2199 | "top": null,
2200 | "visibility": null,
2201 | "width": null
2202 | }
2203 | },
2204 | "b568300f99bd4bd48e8f37029282eed6": {
2205 | "model_module": "@jupyter-widgets/base",
2206 | "model_module_version": "1.2.0",
2207 | "model_name": "LayoutModel",
2208 | "state": {
2209 | "_model_module": "@jupyter-widgets/base",
2210 | "_model_module_version": "1.2.0",
2211 | "_model_name": "LayoutModel",
2212 | "_view_count": null,
2213 | "_view_module": "@jupyter-widgets/base",
2214 | "_view_module_version": "1.2.0",
2215 | "_view_name": "LayoutView",
2216 | "align_content": null,
2217 | "align_items": null,
2218 | "align_self": null,
2219 | "border": null,
2220 | "bottom": null,
2221 | "display": null,
2222 | "flex": null,
2223 | "flex_flow": null,
2224 | "grid_area": null,
2225 | "grid_auto_columns": null,
2226 | "grid_auto_flow": null,
2227 | "grid_auto_rows": null,
2228 | "grid_column": null,
2229 | "grid_gap": null,
2230 | "grid_row": null,
2231 | "grid_template_areas": null,
2232 | "grid_template_columns": null,
2233 | "grid_template_rows": null,
2234 | "height": null,
2235 | "justify_content": null,
2236 | "justify_items": null,
2237 | "left": null,
2238 | "margin": null,
2239 | "max_height": null,
2240 | "max_width": null,
2241 | "min_height": null,
2242 | "min_width": null,
2243 | "object_fit": null,
2244 | "object_position": null,
2245 | "order": null,
2246 | "overflow": null,
2247 | "overflow_x": null,
2248 | "overflow_y": null,
2249 | "padding": null,
2250 | "right": null,
2251 | "top": null,
2252 | "visibility": null,
2253 | "width": null
2254 | }
2255 | },
2256 | "b5ca1b962b874a50bd3657a1d74ea998": {
2257 | "model_module": "@jupyter-widgets/controls",
2258 | "model_module_version": "1.5.0",
2259 | "model_name": "DescriptionStyleModel",
2260 | "state": {
2261 | "_model_module": "@jupyter-widgets/controls",
2262 | "_model_module_version": "1.5.0",
2263 | "_model_name": "DescriptionStyleModel",
2264 | "_view_count": null,
2265 | "_view_module": "@jupyter-widgets/base",
2266 | "_view_module_version": "1.2.0",
2267 | "_view_name": "StyleView",
2268 | "description_width": ""
2269 | }
2270 | },
2271 | "b73c8d9ccf494d6193ff8262777a98da": {
2272 | "model_module": "@jupyter-widgets/base",
2273 | "model_module_version": "1.2.0",
2274 | "model_name": "LayoutModel",
2275 | "state": {
2276 | "_model_module": "@jupyter-widgets/base",
2277 | "_model_module_version": "1.2.0",
2278 | "_model_name": "LayoutModel",
2279 | "_view_count": null,
2280 | "_view_module": "@jupyter-widgets/base",
2281 | "_view_module_version": "1.2.0",
2282 | "_view_name": "LayoutView",
2283 | "align_content": null,
2284 | "align_items": null,
2285 | "align_self": null,
2286 | "border": null,
2287 | "bottom": null,
2288 | "display": null,
2289 | "flex": null,
2290 | "flex_flow": null,
2291 | "grid_area": null,
2292 | "grid_auto_columns": null,
2293 | "grid_auto_flow": null,
2294 | "grid_auto_rows": null,
2295 | "grid_column": null,
2296 | "grid_gap": null,
2297 | "grid_row": null,
2298 | "grid_template_areas": null,
2299 | "grid_template_columns": null,
2300 | "grid_template_rows": null,
2301 | "height": null,
2302 | "justify_content": null,
2303 | "justify_items": null,
2304 | "left": null,
2305 | "margin": null,
2306 | "max_height": null,
2307 | "max_width": null,
2308 | "min_height": null,
2309 | "min_width": null,
2310 | "object_fit": null,
2311 | "object_position": null,
2312 | "order": null,
2313 | "overflow": null,
2314 | "overflow_x": null,
2315 | "overflow_y": null,
2316 | "padding": null,
2317 | "right": null,
2318 | "top": null,
2319 | "visibility": null,
2320 | "width": null
2321 | }
2322 | },
2323 | "bb6f8aa46e3a459aacc6c99f9f76c72c": {
2324 | "model_module": "@jupyter-widgets/base",
2325 | "model_module_version": "1.2.0",
2326 | "model_name": "LayoutModel",
2327 | "state": {
2328 | "_model_module": "@jupyter-widgets/base",
2329 | "_model_module_version": "1.2.0",
2330 | "_model_name": "LayoutModel",
2331 | "_view_count": null,
2332 | "_view_module": "@jupyter-widgets/base",
2333 | "_view_module_version": "1.2.0",
2334 | "_view_name": "LayoutView",
2335 | "align_content": null,
2336 | "align_items": null,
2337 | "align_self": null,
2338 | "border": null,
2339 | "bottom": null,
2340 | "display": null,
2341 | "flex": null,
2342 | "flex_flow": null,
2343 | "grid_area": null,
2344 | "grid_auto_columns": null,
2345 | "grid_auto_flow": null,
2346 | "grid_auto_rows": null,
2347 | "grid_column": null,
2348 | "grid_gap": null,
2349 | "grid_row": null,
2350 | "grid_template_areas": null,
2351 | "grid_template_columns": null,
2352 | "grid_template_rows": null,
2353 | "height": null,
2354 | "justify_content": null,
2355 | "justify_items": null,
2356 | "left": null,
2357 | "margin": null,
2358 | "max_height": null,
2359 | "max_width": null,
2360 | "min_height": null,
2361 | "min_width": null,
2362 | "object_fit": null,
2363 | "object_position": null,
2364 | "order": null,
2365 | "overflow": null,
2366 | "overflow_x": null,
2367 | "overflow_y": null,
2368 | "padding": null,
2369 | "right": null,
2370 | "top": null,
2371 | "visibility": null,
2372 | "width": null
2373 | }
2374 | },
2375 | "c6ab5db12b0647f5989e233d914edd87": {
2376 | "model_module": "@jupyter-widgets/controls",
2377 | "model_module_version": "1.5.0",
2378 | "model_name": "HTMLModel",
2379 | "state": {
2380 | "_dom_classes": [],
2381 | "_model_module": "@jupyter-widgets/controls",
2382 | "_model_module_version": "1.5.0",
2383 | "_model_name": "HTMLModel",
2384 | "_view_count": null,
2385 | "_view_module": "@jupyter-widgets/controls",
2386 | "_view_module_version": "1.5.0",
2387 | "_view_name": "HTMLView",
2388 | "description": "",
2389 | "description_tooltip": null,
2390 | "layout": "IPY_MODEL_8322ac7b0d8644d78804e6d85ff05baa",
2391 | "placeholder": "",
2392 | "style": "IPY_MODEL_1fecf11618b244909505323d2968e5fc",
2393 | "value": "tokenizer.model: 100%"
2394 | }
2395 | },
2396 | "c8219473c99d480081104fe4551036de": {
2397 | "model_module": "@jupyter-widgets/base",
2398 | "model_module_version": "1.2.0",
2399 | "model_name": "LayoutModel",
2400 | "state": {
2401 | "_model_module": "@jupyter-widgets/base",
2402 | "_model_module_version": "1.2.0",
2403 | "_model_name": "LayoutModel",
2404 | "_view_count": null,
2405 | "_view_module": "@jupyter-widgets/base",
2406 | "_view_module_version": "1.2.0",
2407 | "_view_name": "LayoutView",
2408 | "align_content": null,
2409 | "align_items": null,
2410 | "align_self": null,
2411 | "border": null,
2412 | "bottom": null,
2413 | "display": null,
2414 | "flex": null,
2415 | "flex_flow": null,
2416 | "grid_area": null,
2417 | "grid_auto_columns": null,
2418 | "grid_auto_flow": null,
2419 | "grid_auto_rows": null,
2420 | "grid_column": null,
2421 | "grid_gap": null,
2422 | "grid_row": null,
2423 | "grid_template_areas": null,
2424 | "grid_template_columns": null,
2425 | "grid_template_rows": null,
2426 | "height": null,
2427 | "justify_content": null,
2428 | "justify_items": null,
2429 | "left": null,
2430 | "margin": null,
2431 | "max_height": null,
2432 | "max_width": null,
2433 | "min_height": null,
2434 | "min_width": null,
2435 | "object_fit": null,
2436 | "object_position": null,
2437 | "order": null,
2438 | "overflow": null,
2439 | "overflow_x": null,
2440 | "overflow_y": null,
2441 | "padding": null,
2442 | "right": null,
2443 | "top": null,
2444 | "visibility": null,
2445 | "width": null
2446 | }
2447 | },
2448 | "cfee62ce88bf4a409ce34337556539cf": {
2449 | "model_module": "@jupyter-widgets/base",
2450 | "model_module_version": "1.2.0",
2451 | "model_name": "LayoutModel",
2452 | "state": {
2453 | "_model_module": "@jupyter-widgets/base",
2454 | "_model_module_version": "1.2.0",
2455 | "_model_name": "LayoutModel",
2456 | "_view_count": null,
2457 | "_view_module": "@jupyter-widgets/base",
2458 | "_view_module_version": "1.2.0",
2459 | "_view_name": "LayoutView",
2460 | "align_content": null,
2461 | "align_items": null,
2462 | "align_self": null,
2463 | "border": null,
2464 | "bottom": null,
2465 | "display": null,
2466 | "flex": null,
2467 | "flex_flow": null,
2468 | "grid_area": null,
2469 | "grid_auto_columns": null,
2470 | "grid_auto_flow": null,
2471 | "grid_auto_rows": null,
2472 | "grid_column": null,
2473 | "grid_gap": null,
2474 | "grid_row": null,
2475 | "grid_template_areas": null,
2476 | "grid_template_columns": null,
2477 | "grid_template_rows": null,
2478 | "height": null,
2479 | "justify_content": null,
2480 | "justify_items": null,
2481 | "left": null,
2482 | "margin": null,
2483 | "max_height": null,
2484 | "max_width": null,
2485 | "min_height": null,
2486 | "min_width": null,
2487 | "object_fit": null,
2488 | "object_position": null,
2489 | "order": null,
2490 | "overflow": null,
2491 | "overflow_x": null,
2492 | "overflow_y": null,
2493 | "padding": null,
2494 | "right": null,
2495 | "top": null,
2496 | "visibility": null,
2497 | "width": null
2498 | }
2499 | },
2500 | "d160b88d7f00493d81c424be93d7030f": {
2501 | "model_module": "@jupyter-widgets/controls",
2502 | "model_module_version": "1.5.0",
2503 | "model_name": "HTMLModel",
2504 | "state": {
2505 | "_dom_classes": [],
2506 | "_model_module": "@jupyter-widgets/controls",
2507 | "_model_module_version": "1.5.0",
2508 | "_model_name": "HTMLModel",
2509 | "_view_count": null,
2510 | "_view_module": "@jupyter-widgets/controls",
2511 | "_view_module_version": "1.5.0",
2512 | "_view_name": "HTMLView",
2513 | "description": "",
2514 | "description_tooltip": null,
2515 | "layout": "IPY_MODEL_95584181b8fd41eebb322c951c72c1f6",
2516 | "placeholder": "",
2517 | "style": "IPY_MODEL_88795b72502645b4bdc95beadb4471d5",
2518 | "value": "pytorch_model.bin: 100%"
2519 | }
2520 | },
2521 | "d3b47705f4e14680b13b21be52d2e604": {
2522 | "model_module": "@jupyter-widgets/controls",
2523 | "model_module_version": "1.5.0",
2524 | "model_name": "ProgressStyleModel",
2525 | "state": {
2526 | "_model_module": "@jupyter-widgets/controls",
2527 | "_model_module_version": "1.5.0",
2528 | "_model_name": "ProgressStyleModel",
2529 | "_view_count": null,
2530 | "_view_module": "@jupyter-widgets/base",
2531 | "_view_module_version": "1.2.0",
2532 | "_view_name": "StyleView",
2533 | "bar_color": null,
2534 | "description_width": ""
2535 | }
2536 | },
2537 | "dd92cd36ad544a17a22374c0896947da": {
2538 | "model_module": "@jupyter-widgets/controls",
2539 | "model_module_version": "1.5.0",
2540 | "model_name": "HTMLModel",
2541 | "state": {
2542 | "_dom_classes": [],
2543 | "_model_module": "@jupyter-widgets/controls",
2544 | "_model_module_version": "1.5.0",
2545 | "_model_name": "HTMLModel",
2546 | "_view_count": null,
2547 | "_view_module": "@jupyter-widgets/controls",
2548 | "_view_module_version": "1.5.0",
2549 | "_view_name": "HTMLView",
2550 | "description": "",
2551 | "description_tooltip": null,
2552 | "layout": "IPY_MODEL_b568300f99bd4bd48e8f37029282eed6",
2553 | "placeholder": "",
2554 | "style": "IPY_MODEL_13a9fb93d75e41d180e7673acad43dab",
2555 | "value": "generation_config.json: 100%"
2556 | }
2557 | },
2558 | "e022c45f119c4eb281d5679b80d6387d": {
2559 | "model_module": "@jupyter-widgets/controls",
2560 | "model_module_version": "1.5.0",
2561 | "model_name": "HBoxModel",
2562 | "state": {
2563 | "_dom_classes": [],
2564 | "_model_module": "@jupyter-widgets/controls",
2565 | "_model_module_version": "1.5.0",
2566 | "_model_name": "HBoxModel",
2567 | "_view_count": null,
2568 | "_view_module": "@jupyter-widgets/controls",
2569 | "_view_module_version": "1.5.0",
2570 | "_view_name": "HBoxView",
2571 | "box_style": "",
2572 | "children": [
2573 | "IPY_MODEL_c6ab5db12b0647f5989e233d914edd87",
2574 | "IPY_MODEL_6e238cbf7ab84e5b9d6e6d756ccd9487",
2575 | "IPY_MODEL_7f160c6489514b35945179d204c06bc7"
2576 | ],
2577 | "layout": "IPY_MODEL_24ffd51165764526962c2de4c004ae51"
2578 | }
2579 | },
2580 | "e05657d7cdab4c4082ebc87b6d903969": {
2581 | "model_module": "@jupyter-widgets/controls",
2582 | "model_module_version": "1.5.0",
2583 | "model_name": "HTMLModel",
2584 | "state": {
2585 | "_dom_classes": [],
2586 | "_model_module": "@jupyter-widgets/controls",
2587 | "_model_module_version": "1.5.0",
2588 | "_model_name": "HTMLModel",
2589 | "_view_count": null,
2590 | "_view_module": "@jupyter-widgets/controls",
2591 | "_view_module_version": "1.5.0",
2592 | "_view_name": "HTMLView",
2593 | "description": "",
2594 | "description_tooltip": null,
2595 | "layout": "IPY_MODEL_55690f32de47405da42599803e654fc6",
2596 | "placeholder": "",
2597 | "style": "IPY_MODEL_6ea49b6d6530403fb9acf7a15493b6e9",
2598 | "value": " 5.38G/5.38G [00:45<00:00, 216MB/s]"
2599 | }
2600 | },
2601 | "e4685f450cb543f482d41881e2555b34": {
2602 | "model_module": "@jupyter-widgets/base",
2603 | "model_module_version": "1.2.0",
2604 | "model_name": "LayoutModel",
2605 | "state": {
2606 | "_model_module": "@jupyter-widgets/base",
2607 | "_model_module_version": "1.2.0",
2608 | "_model_name": "LayoutModel",
2609 | "_view_count": null,
2610 | "_view_module": "@jupyter-widgets/base",
2611 | "_view_module_version": "1.2.0",
2612 | "_view_name": "LayoutView",
2613 | "align_content": null,
2614 | "align_items": null,
2615 | "align_self": null,
2616 | "border": null,
2617 | "bottom": null,
2618 | "display": null,
2619 | "flex": null,
2620 | "flex_flow": null,
2621 | "grid_area": null,
2622 | "grid_auto_columns": null,
2623 | "grid_auto_flow": null,
2624 | "grid_auto_rows": null,
2625 | "grid_column": null,
2626 | "grid_gap": null,
2627 | "grid_row": null,
2628 | "grid_template_areas": null,
2629 | "grid_template_columns": null,
2630 | "grid_template_rows": null,
2631 | "height": null,
2632 | "justify_content": null,
2633 | "justify_items": null,
2634 | "left": null,
2635 | "margin": null,
2636 | "max_height": null,
2637 | "max_width": null,
2638 | "min_height": null,
2639 | "min_width": null,
2640 | "object_fit": null,
2641 | "object_position": null,
2642 | "order": null,
2643 | "overflow": null,
2644 | "overflow_x": null,
2645 | "overflow_y": null,
2646 | "padding": null,
2647 | "right": null,
2648 | "top": null,
2649 | "visibility": null,
2650 | "width": null
2651 | }
2652 | },
2653 | "e98ec60115014dc0a5ec44fdec3a0073": {
2654 | "model_module": "@jupyter-widgets/controls",
2655 | "model_module_version": "1.5.0",
2656 | "model_name": "DescriptionStyleModel",
2657 | "state": {
2658 | "_model_module": "@jupyter-widgets/controls",
2659 | "_model_module_version": "1.5.0",
2660 | "_model_name": "DescriptionStyleModel",
2661 | "_view_count": null,
2662 | "_view_module": "@jupyter-widgets/base",
2663 | "_view_module_version": "1.2.0",
2664 | "_view_name": "StyleView",
2665 | "description_width": ""
2666 | }
2667 | },
2668 | "ed4321893f5f4630bf2dbf1f4a38bd39": {
2669 | "model_module": "@jupyter-widgets/base",
2670 | "model_module_version": "1.2.0",
2671 | "model_name": "LayoutModel",
2672 | "state": {
2673 | "_model_module": "@jupyter-widgets/base",
2674 | "_model_module_version": "1.2.0",
2675 | "_model_name": "LayoutModel",
2676 | "_view_count": null,
2677 | "_view_module": "@jupyter-widgets/base",
2678 | "_view_module_version": "1.2.0",
2679 | "_view_name": "LayoutView",
2680 | "align_content": null,
2681 | "align_items": null,
2682 | "align_self": null,
2683 | "border": null,
2684 | "bottom": null,
2685 | "display": null,
2686 | "flex": null,
2687 | "flex_flow": null,
2688 | "grid_area": null,
2689 | "grid_auto_columns": null,
2690 | "grid_auto_flow": null,
2691 | "grid_auto_rows": null,
2692 | "grid_column": null,
2693 | "grid_gap": null,
2694 | "grid_row": null,
2695 | "grid_template_areas": null,
2696 | "grid_template_columns": null,
2697 | "grid_template_rows": null,
2698 | "height": null,
2699 | "justify_content": null,
2700 | "justify_items": null,
2701 | "left": null,
2702 | "margin": null,
2703 | "max_height": null,
2704 | "max_width": null,
2705 | "min_height": null,
2706 | "min_width": null,
2707 | "object_fit": null,
2708 | "object_position": null,
2709 | "order": null,
2710 | "overflow": null,
2711 | "overflow_x": null,
2712 | "overflow_y": null,
2713 | "padding": null,
2714 | "right": null,
2715 | "top": null,
2716 | "visibility": null,
2717 | "width": null
2718 | }
2719 | },
2720 | "fa9cafdcab2f439b9196f6f05166d898": {
2721 | "model_module": "@jupyter-widgets/controls",
2722 | "model_module_version": "1.5.0",
2723 | "model_name": "HTMLModel",
2724 | "state": {
2725 | "_dom_classes": [],
2726 | "_model_module": "@jupyter-widgets/controls",
2727 | "_model_module_version": "1.5.0",
2728 | "_model_name": "HTMLModel",
2729 | "_view_count": null,
2730 | "_view_module": "@jupyter-widgets/controls",
2731 | "_view_module_version": "1.5.0",
2732 | "_view_name": "HTMLView",
2733 | "description": "",
2734 | "description_tooltip": null,
2735 | "layout": "IPY_MODEL_1a0d37386914443f983a89cf56d7a900",
2736 | "placeholder": "",
2737 | "style": "IPY_MODEL_42075dab0a5443b0b65e1ff8b2cb5e89",
2738 | "value": "tokenizer.json: 100%"
2739 | }
2740 | }
2741 | }
2742 | }
2743 | },
2744 | "nbformat": 4,
2745 | "nbformat_minor": 0
2746 | }
2747 |
--------------------------------------------------------------------------------
/processor.py:
--------------------------------------------------------------------------------
1 | from torch.nn.utils.rnn import pad_sequence
2 | import torch.nn.functional as F
3 | import torch
4 | from utils import greedy_packing
5 |
6 |
7 | class PrePackProcessor:
8 | def __init__(self, tokenizer, packing_fn=None):
9 | self.tokenizer = tokenizer
10 | self.pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
11 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 | if packing_fn:
13 | self.packing_fn = packing_fn
14 | else:
15 | self.packing_fn = greedy_packing
16 |
17 | def process(self, length_dict, packing_dict, token_dict):
18 | '''
19 | Takes batch of tokens and packs them according to the bin-packing algorithm.
20 |
21 | Args:
22 | length_dict (dict): maps original batch index to its prompt length
23 | packing_dict (dict): a mapping between prompt length and bin index
24 | token_dict (dict): maps original batch index to its tokenized prompt
25 |
26 | Returns:
27 | new_tokens (Tensor): packed sequence of tokens
28 | new_positions (Tensor): restart positions for the new_tokens
29 | new_mask (Tensor): independent mask for the new_tokens
30 | restart_dict (dict): mapping restart index and original batch index
31 | '''
32 | new_positions = []
33 | new_tokens = []
34 | restart_dict = {0: -1} # -1 is a placeholder
35 | restart_index = 0
36 |
37 | for key in packing_dict:
38 | new_tokens += token_dict[key][:-1] # omit final token for generation
39 | restart_index += length_dict[key] - 1
40 | new_positions += list(range(length_dict[key] - 1))
41 | restart_dict[restart_index] = key
42 |
43 | restart_indices = list(restart_dict.keys())
44 | size = len(new_tokens)
45 | new_mask = torch.zeros(size, size, device=self.device)
46 |
47 | for i in range(len(restart_indices) - 1):
48 | start = restart_indices[i]
49 | end = restart_indices[i + 1]
50 | new_mask[start:end, start:end] = torch.tril(torch.ones((end - start, end - start)))
51 |
52 | new_tokens = torch.tensor(new_tokens, device=self.device)
53 | new_positions = torch.tensor(new_positions, device=self.device)
54 | new_mask = new_mask.clone().detach()
55 |
56 | return new_tokens, new_positions, new_mask, restart_dict
57 |
58 | def batch_process(self, sentences):
59 |
60 | original_ids = self.tokenizer(sentences).input_ids
61 | token_dict = dict(enumerate(original_ids))
62 | length_dict = [len(toks) for toks in original_ids]
63 | length_dict = {index: len(toks) for index, toks in enumerate(original_ids)}
64 |
65 | max_bin_size = max(length_dict.values())
66 | packing_lst = self.packing_fn(length_dict, max_bin_size)
67 |
68 | batch_new_tokens = []
69 | batch_new_positions = []
70 | batch_new_mask = []
71 | batch_restart_indices = []
72 | for packing_dict in packing_lst:
73 | new_tokens, new_positions, new_mask, restart_indices = self.process(
74 | length_dict, packing_dict, token_dict
75 | )
76 | batch_new_tokens.append(new_tokens)
77 | batch_new_positions.append(new_positions)
78 | batch_new_mask.append(new_mask)
79 | batch_restart_indices.append(restart_indices)
80 |
81 | batch_new_tokens = pad_sequence(batch_new_tokens, batch_first=True, padding_value=self.pad_token)
82 | batch_new_positions = pad_sequence(batch_new_positions, batch_first=True, padding_value=1)
83 |
84 | max_size = max(tensor.shape[1:] for tensor in batch_new_mask)[0]
85 | padded_masks = [
86 | F.pad(tensor, (0, max_size - tensor.size(0), 0, max_size - tensor.size(1)))
87 | for tensor in batch_new_mask
88 | ]
89 | batch_new_mask = torch.stack(padded_masks)
90 |
91 | return batch_new_tokens, batch_new_positions, batch_new_mask, batch_restart_indices, original_ids
92 |
--------------------------------------------------------------------------------
/profiling_dataset_level_prepacking.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 | import torch
4 | import fire
5 | import random
6 | import time
7 | from tqdm import tqdm
8 | import numpy as np
9 | from prettytable import PrettyTable
10 | from dataset_utils import (
11 | PackedDataset,
12 | sample_batches,
13 | sample_batches_by_length,
14 | sample_packed_dataset,
15 | unpack_kv,
16 | load_and_evaluate_dataset,
17 | )
18 | from processor import PrePackProcessor
19 | from utils import integer_program_packing, load_model_and_tokenizer
20 |
21 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
22 |
23 |
24 | def prefill_packed_sentence_output(sentences, model, tokenizer, device, processor):
25 | new_tokens, new_positions, new_mask, restart_dict, original_ids = processor.batch_process(sentences)
26 | with torch.no_grad():
27 | packed_outputs = model(
28 | input_ids=new_tokens,
29 | attention_mask=new_mask,
30 | position_ids=new_positions,
31 | return_dict=True,
32 | )
33 | return packed_outputs
34 |
35 |
36 | def TTFT_packed_sentence_output(sentences, model, tokenizer, device, processor):
37 | new_tokens, new_positions, new_mask, restart_dict, original_ids = processor.batch_process(sentences)
38 | with torch.no_grad():
39 | packed_outputs = model(
40 | input_ids=new_tokens,
41 | attention_mask=new_mask,
42 | position_ids=new_positions,
43 | return_dict=True,
44 | )
45 | cache, final_tokens, attention_mask = unpack_kv(
46 | packed_outputs["past_key_values"], restart_dict, original_ids, device
47 | )
48 | _ = model.generate(
49 | input_ids=final_tokens,
50 | attention_mask=attention_mask,
51 | max_new_tokens=1,
52 | use_cache=True,
53 | do_sample=False,
54 | past_key_values=cache,
55 | )
56 | return
57 |
58 |
59 | def TTFT_packed_dataset_output(batch, model, tokenizer=None, model_device=None, optimized_processor=None):
60 | new_tokens, new_positions, new_mask, restart_dict, original_ids = batch
61 | with torch.no_grad():
62 | packed_outputs = model(
63 | input_ids=new_tokens,
64 | attention_mask=new_mask,
65 | position_ids=new_positions,
66 | return_dict=True,
67 | )
68 | cache, final_tokens, attention_mask = unpack_kv(
69 | packed_outputs["past_key_values"], restart_dict, original_ids, model_device
70 | )
71 | _ = model.generate(
72 | input_ids=final_tokens,
73 | attention_mask=attention_mask,
74 | max_new_tokens=1,
75 | use_cache=True,
76 | do_sample=False,
77 | past_key_values=cache,
78 | )
79 | return
80 |
81 |
82 | def prefill_packed_dataset_output(batch, model, tokenizer=None, device=None, processor=None):
83 |
84 | new_tokens, new_positions, new_mask, restart_dict, original_ids = batch
85 | with torch.no_grad():
86 | packed_outputs = model(
87 | input_ids=new_tokens,
88 | attention_mask=new_mask,
89 | position_ids=new_positions,
90 | )
91 | return packed_outputs
92 |
93 |
94 | def prefill_batch_sentence_output(sentences, model, tokenizer, device, processor=None):
95 | batch_sentences = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
96 |
97 | with torch.no_grad():
98 | batch_sentences_outputs = model(
99 | batch_sentences["input_ids"].to(device),
100 | attention_mask=batch_sentences["attention_mask"].to(device),
101 | )
102 | return batch_sentences_outputs
103 |
104 |
105 | def TTFT_batch_sentence_output(sentences, model, tokenizer, device, processor=None):
106 | batch_sentences = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
107 |
108 | with torch.no_grad():
109 | _ = model.generate(
110 | **batch_sentences,
111 | max_new_tokens=1,
112 | use_cache=True,
113 | do_sample=False,
114 | )
115 | return
116 |
117 |
118 | def measure_inference_time(
119 | method,
120 | texts,
121 | batch_size,
122 | num_runs,
123 | total_batches,
124 | model,
125 | tokenizer,
126 | model_device,
127 | metric="TTFT",
128 | binpack_algo="greedy",
129 | ):
130 |
131 | if metric == "TTFT":
132 | method_functions = {
133 | "prepack": TTFT_packed_sentence_output,
134 | "full-batching": TTFT_batch_sentence_output,
135 | "length-ordered": TTFT_batch_sentence_output,
136 | "prepack_dataset": TTFT_packed_dataset_output,
137 | }
138 | elif metric == "prefill":
139 | method_functions = {
140 | "prepack": prefill_packed_sentence_output,
141 | "full-batching": prefill_batch_sentence_output,
142 | "length-ordered": prefill_batch_sentence_output,
143 | "prepack_dataset": prefill_packed_dataset_output,
144 | }
145 | desc = method
146 | method_function = method_functions.get(method)
147 | packing_fn = None if binpack_algo == "greedy" else integer_program_packing
148 | optimized_processor = PrePackProcessor(tokenizer, packing_fn=packing_fn)
149 | total_request_times = []
150 | for _ in range(num_runs):
151 | if method == "length-ordered":
152 | batches_generator = sample_batches_by_length(texts, batch_size)
153 | elif method == "prepack_dataset":
154 | new_tokens, new_positions, new_mask, restart_indices, original_ids = (
155 | optimized_processor.batch_process(texts)
156 | )
157 | dataset = PackedDataset(
158 | new_tokens, new_positions, new_mask, restart_indices, original_ids, batch_size=batch_size
159 | )
160 |
161 | batches_generator = sample_packed_dataset(dataset)
162 | del new_tokens, new_positions, new_mask, restart_indices, original_ids
163 | else:
164 | batches_generator = sample_batches(texts, batch_size)
165 | start_time = time.time()
166 | for batch in tqdm(batches_generator, total=total_batches, desc=desc):
167 | _ = method_function(batch, model, tokenizer, model_device, optimized_processor)
168 | elapsed = time.time() - start_time
169 | total_request_times.append(elapsed)
170 |
171 | per_request_time = np.mean(total_request_times) / (len(texts) * num_runs)
172 | per_request_time_std = np.std(total_request_times) / (len(texts) * num_runs)
173 | return per_request_time, per_request_time_std
174 |
175 |
176 | def main(
177 | methods: List[str] = ["prepack_dataset", "prepack", "full-batching", "length-ordered"],
178 | metric: str = "prefill",
179 | dataset: str = "mmlu",
180 | model_name: str = "llama1b",
181 | loadbit: int = 8,
182 | num_runs: int = 5,
183 | batch_size: int = 32,
184 | binpack_algo: str = "greedy",
185 | ):
186 |
187 | torch.set_num_threads(5)
188 | seed = 42
189 | random.seed(seed)
190 | np.random.seed(seed)
191 | torch.manual_seed(seed)
192 | torch.cuda.manual_seed_all(seed)
193 | os.environ["PYTHONHASHSEED"] = str(seed)
194 |
195 | if binpack_algo != "greedy":
196 | binpack_algo = "ip"
197 |
198 | # Load the model and tokenizer
199 | model, tokenizer = load_model_and_tokenizer(base_model=model_name, loadbit=loadbit)
200 |
201 | # Load and prepare the dataset
202 | texts = load_and_evaluate_dataset(dataset, tokenizer)
203 |
204 | total_batches = len(texts) // batch_size
205 | if len(texts) % batch_size != 0:
206 | total_batches += 1
207 | table = PrettyTable()
208 |
209 | table.field_names = [
210 | "Method",
211 | f"Avg Prefill Time per request (s). bs={batch_size},"
212 | f"Bits: {loadbit}, {dataset}, {model_name},"
213 | f"metric: {metric},"
214 | f"binpack_algo: {binpack_algo}",
215 | f"std dev over {num_runs} runs",
216 | ]
217 | results = {}
218 | for method in methods:
219 | avg_time, std = measure_inference_time(
220 | method,
221 | texts,
222 | batch_size,
223 | num_runs,
224 | total_batches,
225 | model,
226 | tokenizer,
227 | model.device,
228 | metric=metric,
229 | binpack_algo=binpack_algo,
230 | )
231 | table.add_row([method, f"{avg_time:.5f}", f"{std:.5f}"])
232 | results[method] = {
233 | "Avg Prefill Time per request (s)": avg_time,
234 | "Std Dev": std,
235 | }
236 | print(table)
237 |
238 |
239 | if __name__ == "__main__":
240 | fire.Fire(main)
241 |
--------------------------------------------------------------------------------
/profiling_time_and_memory.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import threading
4 | import time
5 | import GPUtil
6 | import fire
7 | import random
8 | from tqdm import tqdm
9 | import numpy as np
10 | from typing import List
11 | from prettytable import PrettyTable
12 | from dataset_utils import (
13 | load_and_evaluate_dataset,
14 | sample_batches,
15 | sample_batches_by_length,
16 | unpack_kv,
17 | )
18 | from processor import PrePackProcessor
19 | from utils import integer_program_packing, load_model_and_tokenizer
20 |
21 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
22 |
23 |
24 | def monitor_gpu_utilization(stop_event, utilization_stats, device_id=0, interval=0.1):
25 | max_utilization, total_utilization, count = 0, 0, 0
26 | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None)
27 | visible_device_ids = list(map(int, cuda_visible_devices.split(",")))
28 | while not stop_event.is_set():
29 | gpus = GPUtil.getGPUs()
30 | current_utilization = gpus[visible_device_ids[device_id]].load
31 | max_utilization = max(max_utilization, current_utilization)
32 | total_utilization += current_utilization
33 | count += 1
34 | time.sleep(interval)
35 | utilization_stats["max_util"] = max_utilization * 100 # Convert to percentage
36 | utilization_stats["mean_util"] = (total_utilization / count) * 100 if count > 0 else 0
37 |
38 |
39 | def prefill_with_prepacking(sentences, model, tokenizer, device, processor):
40 | new_tokens, new_positions, new_mask, restart_dict, original_ids = processor.batch_process(sentences)
41 | with torch.no_grad():
42 | output = model(
43 | input_ids=new_tokens,
44 | attention_mask=new_mask,
45 | position_ids=new_positions,
46 | return_dict=True,
47 | )
48 | return output
49 |
50 |
51 | def TTFT_with_prepacking(sentences, model, tokenizer, device, processor):
52 | new_tokens, new_positions, new_mask, restart_dict, original_ids = processor.batch_process(sentences)
53 | with torch.no_grad():
54 | packed_outputs = model(
55 | input_ids=new_tokens,
56 | attention_mask=new_mask,
57 | position_ids=new_positions,
58 | return_dict=True,
59 | )
60 | cache, final_tokens, attention_mask = unpack_kv(
61 | packed_outputs["past_key_values"], restart_dict, original_ids, device
62 | )
63 | _ = model.generate(
64 | input_ids=final_tokens,
65 | attention_mask=attention_mask,
66 | max_new_tokens=1,
67 | use_cache=True,
68 | do_sample=False,
69 | past_key_values=cache,
70 | )
71 | return
72 |
73 |
74 | def prefill_with_baseline(sentences, model, tokenizer, device, processor=None):
75 | batch_sentences = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
76 |
77 | with torch.no_grad():
78 | batch_sentences_outputs = model(
79 | batch_sentences["input_ids"].to(device),
80 | attention_mask=batch_sentences["attention_mask"].to(device),
81 | return_dict=True,
82 | )
83 | return batch_sentences_outputs
84 |
85 |
86 | def TTFT_with_baseline(sentences, model, tokenizer, device, processor=None):
87 | batch_sentences = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
88 |
89 | with torch.no_grad():
90 | _ = model.generate(
91 | **batch_sentences,
92 | max_new_tokens=1,
93 | use_cache=True,
94 | do_sample=False,
95 | )
96 | return
97 |
98 |
99 | def get_average_gpu_utilization():
100 | # get current device id, assuming only 1 GPU is used
101 | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None)
102 | visible_device_ids = list(map(int, cuda_visible_devices.split(",")))
103 | gpus = GPUtil.getGPUs()
104 | return gpus[visible_device_ids[0]].load
105 |
106 |
107 | def measure_inference_resources(
108 | method,
109 | texts,
110 | batch_size,
111 | num_runs,
112 | total_batches,
113 | model,
114 | tokenizer,
115 | model_device,
116 | metric="TTFT",
117 | binpack_algo="greedy",
118 | ):
119 | scenario_times = []
120 |
121 | if metric == "TTFT":
122 | method_functions = {
123 | "prepacking": TTFT_with_prepacking,
124 | "full-batching": TTFT_with_baseline,
125 | "length-ordered": TTFT_with_baseline,
126 | }
127 | elif metric == "prefill":
128 | method_functions = {
129 | "prepacking": prefill_with_prepacking,
130 | "full-batching": prefill_with_baseline,
131 | "length-ordered": prefill_with_baseline,
132 | }
133 | method_function = method_functions.get(method)
134 | packing_fn = None if binpack_algo == "greedy" else integer_program_packing
135 | optimized_processor = PrePackProcessor(tokenizer, packing_fn=packing_fn)
136 | method_function = method_functions.get(method)
137 |
138 | for _ in range(num_runs):
139 | batches_generator = (
140 | sample_batches(texts, batch_size)
141 | if method != "length-ordered"
142 | else sample_batches_by_length(texts, batch_size)
143 | )
144 |
145 | max_gpu_utilization = []
146 | max_gpu_memory = []
147 | batch_gpu_memories = []
148 | batch_gpu_utilizations = []
149 | mean_gpu_utilizations = []
150 | for batch in tqdm(batches_generator, total=total_batches, desc=method):
151 | utilization_stats = {}
152 | stop_event = threading.Event()
153 | monitor_thread = threading.Thread(
154 | target=monitor_gpu_utilization, args=(stop_event, utilization_stats), daemon=True
155 | )
156 | monitor_thread.start()
157 |
158 | torch.cuda.reset_peak_memory_stats(model_device) # Reset memory stats at the start
159 | torch.cuda.empty_cache()
160 |
161 | start_time = time.time()
162 | _ = method_function(batch, model, tokenizer, model_device, optimized_processor)
163 | elapsed = time.time() - start_time
164 |
165 | scenario_times.append(elapsed)
166 | stop_event.set()
167 | monitor_thread.join()
168 | peak_memory = torch.cuda.max_memory_allocated(model_device) / (1024**2)
169 | batch_gpu_memories.append(peak_memory)
170 |
171 | max_util = utilization_stats.get("max_util", 0)
172 | mean_util = utilization_stats.get("mean_util", 0) # Get mean utilization
173 | batch_gpu_utilizations.append(max_util)
174 | mean_gpu_utilizations.append(mean_util)
175 |
176 | max_gpu_memory.append(max(batch_gpu_memories))
177 | max_gpu_utilization.append(max(batch_gpu_utilizations))
178 | avg_scenario_time = np.mean(scenario_times)
179 | avg_gpu_utilization = np.mean(max_gpu_utilization)
180 | avg_gpu_memory = np.mean(max_gpu_memory)
181 | avg_mean_gpu_utilization = np.mean(mean_gpu_utilizations)
182 | std_dev_time = np.std(scenario_times)
183 | std_gpu_utilization = np.std(max_gpu_utilization)
184 | std_gpu_memory = np.std(max_gpu_memory) # = 0
185 | std_mean_gpu_utilization = np.std(mean_gpu_utilizations)
186 | return (
187 | avg_scenario_time,
188 | avg_gpu_utilization,
189 | avg_gpu_memory,
190 | avg_mean_gpu_utilization,
191 | std_dev_time,
192 | std_gpu_utilization,
193 | std_mean_gpu_utilization,
194 | )
195 |
196 |
197 | def main(
198 | methods: List[str] = [
199 | "prepacking",
200 | "full-batching",
201 | "length-ordered",
202 | ],
203 | metric: str = "prefill",
204 | dataset: str = "mmlu",
205 | model_name: str = "llama1b",
206 | loadbit: int = 4,
207 | num_runs: int = 5,
208 | batch_size: int = 64,
209 | binpack_algo: str = "greedy",
210 | ):
211 |
212 | torch.set_num_threads(5)
213 |
214 | seed = 42
215 | random.seed(seed)
216 | np.random.seed(seed)
217 | torch.manual_seed(seed)
218 | torch.cuda.manual_seed_all(seed)
219 | os.environ["PYTHONHASHSEED"] = str(seed)
220 |
221 | if binpack_algo != "greedy":
222 | binpack_algo = "ip"
223 |
224 | # Load the model and tokenizer
225 | model, tokenizer = load_model_and_tokenizer(base_model=model_name, loadbit=loadbit)
226 |
227 | # Load and prepare the dataset
228 | texts = load_and_evaluate_dataset(dataset, tokenizer)
229 |
230 | total_batches = len(texts) // batch_size
231 | if len(texts) % batch_size != 0:
232 | total_batches += 1
233 | table = PrettyTable()
234 |
235 | table.field_names = [
236 | "Method",
237 | f"Avg {metric} Time /batch (s)",
238 | "Max GPU Utilization (%)",
239 | "Max GPU Memory (MB)",
240 | "Mean GPU Utilization (%)",
241 | "Std Dev Time (s)",
242 | "Std Dev Max GPU Util (%)",
243 | "Std Dev Mean GPU Util (%)",
244 | ]
245 |
246 | for method in methods:
247 | try:
248 | (
249 | avg_scenario_time,
250 | avg_gpu_utilization,
251 | avg_gpu_memory,
252 | avg_mean_gpu_utilization,
253 | std_dev_time,
254 | std_gpu_util,
255 | std_mean_gpu_util,
256 | ) = measure_inference_resources(
257 | method,
258 | texts,
259 | batch_size,
260 | num_runs,
261 | total_batches,
262 | model,
263 | tokenizer,
264 | model.device,
265 | metric=metric,
266 | binpack_algo=binpack_algo,
267 | )
268 | table.add_row(
269 | [
270 | method,
271 | f"{avg_scenario_time:.3f}",
272 | f"{avg_gpu_utilization:.3f}",
273 | f"{avg_gpu_memory:.3f}",
274 | f"{avg_mean_gpu_utilization:.3f}",
275 | f"{std_dev_time:.3f}",
276 | f"{std_gpu_util:.3f}",
277 | f"{std_mean_gpu_util:.3f}",
278 | ]
279 | )
280 | print(table)
281 | except Exception as e: # OOM error
282 | print(f"An error occurred while processing method {method}: {e}")
283 | torch.cuda.empty_cache()
284 |
285 | finally:
286 | torch.cuda.empty_cache()
287 |
288 |
289 | if __name__ == "__main__":
290 | fire.Fire(main)
291 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 | import torch
3 | from torch import Tensor
4 | from ortools.linear_solver import pywraplp
5 | import binpacking
6 | from transformers import AutoTokenizer
7 | from model import CustomCausalLlamaModel, CustomCausalMistralModel
8 |
9 |
10 | # As implemented here:
11 | # https://github.com/pytorch/pytorch/issues/10536#issuecomment-1320935162
12 | def left_pad_sequence(
13 | sequences: Union[Tensor, List[Tensor]],
14 | batch_first: bool = True,
15 | padding_value: float = 0.0,
16 | ) -> Tensor:
17 |
18 | sequences = tuple(map(lambda s: s.flip(0), sequences))
19 | padded_sequence = torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
20 | _seq_dim = padded_sequence.dim()
21 | padded_sequence = padded_sequence.flip(-_seq_dim + batch_first)
22 | return padded_sequence
23 |
24 |
25 | def greedy_packing(length_dict, max_bin_size):
26 | return binpacking.to_constant_volume(length_dict, max_bin_size)
27 |
28 |
29 | # https://developers.google.com/optimization/pack/bin_packing
30 | def integer_program_packing(length_dict, max_bin_size):
31 | data = {}
32 | data["items"] = list(length_dict.keys())
33 | data["weights"] = list(length_dict.values())
34 | data["bins"] = data["items"]
35 | data["bin_capacity"] = max_bin_size
36 |
37 | solver = pywraplp.Solver.CreateSolver("SCIP")
38 |
39 | if not solver:
40 | return
41 | x = {}
42 | for i in data["items"]:
43 | for j in data["bins"]:
44 | x[(i, j)] = solver.IntVar(0, 1, "x_%i_%i" % (i, j))
45 | y = {}
46 | for j in data["bins"]:
47 | y[j] = solver.IntVar(0, 1, "y[%i]" % j)
48 |
49 | for i in data["items"]:
50 | solver.Add(sum(x[i, j] for j in data["bins"]) == 1)
51 |
52 | for j in data["bins"]:
53 | solver.Add(sum(x[(i, j)] * data["weights"][i] for i in data["items"]) <= y[j] * data["bin_capacity"])
54 |
55 | solver.Minimize(solver.Sum([y[j] for j in data["bins"]]))
56 |
57 | status = solver.Solve()
58 |
59 | if status == pywraplp.Solver.OPTIMAL:
60 | result = []
61 | for j in data["bins"]:
62 | if y[j].solution_value() == 1:
63 | bin_dict = {}
64 | for i in data["items"]:
65 | if x[i, j].solution_value() > 0:
66 | bin_dict[i] = data["weights"][i]
67 | result.append(bin_dict)
68 | else:
69 | raise ("The problem does not have an optimal solution.")
70 |
71 | return result
72 |
73 |
74 | def load_model_and_tokenizer(
75 | base_model: str = "llama1b",
76 | loadbit: int = 8,
77 | ):
78 | # Load tokenizer and model
79 | if base_model == "llama1b":
80 | path = "princeton-nlp/Sheared-LLaMA-1.3B"
81 | elif base_model == "llama2":
82 | path = "/path/to/llama2"
83 |
84 | tokenizer = AutoTokenizer.from_pretrained(path)
85 | tokenizer.pad_token = "[PAD]"
86 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87 | load_in_8bit = loadbit == 8
88 | load_in_4bit = loadbit == 4
89 | if "llama" in base_model:
90 | model = CustomCausalLlamaModel.from_pretrained(
91 | path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
92 | )
93 | elif "mistral" in base_model:
94 | model = CustomCausalMistralModel.from_pretrained(
95 | path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
96 | )
97 | model.eval()
98 | if loadbit != 8 and loadbit != 4:
99 | model.to(device)
100 |
101 | return model, tokenizer
102 |
--------------------------------------------------------------------------------