├── .gitignore ├── requirements.txt ├── .gitattributes ├── LICENSE ├── src ├── utils.py ├── run.py └── model.py ├── download.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache/ 3 | models/ 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fire>=0.1.3 2 | regex==2017.4.5 3 | requests==2.21.0 4 | tqdm==4.31.1 5 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # convert to OS line endings on checkout, back to LF on commit 2 | * text=auto 3 | 4 | # ensure anything copied to the container has unix style line endings 5 | *.sh text eol=lf 6 | requirements.txt text eol=lf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Modified MIT License 2 | 3 | Software Copyright (c) 2020 OpenAI 4 | 5 | We don’t claim ownership of the content you create with iGPT, so it is yours to do with as you please. 6 | We only ask that you use iGPT responsibly and clearly indicate your content was created using iGPT. 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 9 | associated documentation files (the "Software"), to deal in the Software without restriction, 10 | including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 11 | and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 12 | subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included 15 | in all copies or substantial portions of the Software. 16 | The above copyright notice and this permission notice 17 | need not be included with content created by the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 20 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 22 | BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 23 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 24 | OR OTHER DEALINGS IN THE SOFTWARE. 25 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import pickle 5 | import subprocess 6 | import math 7 | from tqdm import tqdm 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | def iter_data(*datas, n_batch=128, truncate=False, verbose=False, max_batches=float("inf")): 13 | n = len(datas[0]) 14 | if truncate: 15 | n = (n//n_batch)*n_batch 16 | n = min(n, max_batches*n_batch) 17 | n_batches = 0 18 | for i in tqdm(range(0, n, n_batch), total=n//n_batch, disable=not verbose, ncols=80, leave=False): 19 | if n_batches >= max_batches: raise StopIteration 20 | if len(datas) == 1: 21 | yield datas[0][i:i+n_batch] 22 | else: 23 | yield (d[i:i+n_batch] for d in datas) 24 | n_batches += 1 25 | 26 | def squared_euclidean_distance(a, b): 27 | b = tf.transpose(b) 28 | a2 = tf.reduce_sum(tf.square(a), axis=1, keepdims=True) 29 | b2 = tf.reduce_sum(tf.square(b), axis=0, keepdims=True) 30 | ab = tf.matmul(a, b) 31 | d = a2 - 2*ab + b2 32 | return d 33 | 34 | def color_quantize(x, np_clusters): 35 | clusters = tf.Variable(np_clusters, dtype=tf.float32, trainable=False) 36 | x = tf.reshape(x, [-1, 3]) 37 | d = squared_euclidean_distance(x, clusters) 38 | return tf.argmin(d, 1) 39 | 40 | def count_parameters(): 41 | total_parameters = 0 42 | for variable in tf.trainable_variables(): 43 | shape = variable.get_shape() 44 | variable_parameters = 1 45 | for dim in shape: 46 | variable_parameters *= dim.value 47 | total_parameters += variable_parameters 48 | return total_parameters 49 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | import requests 6 | from tqdm import tqdm 7 | 8 | def parse_arguments(): 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--download_dir", type=str, default="/root/downloads/") 12 | 13 | parser.add_argument("--bert", action="store_true", help="download a bert model (default: ar)") 14 | parser.add_argument("--model", type=str, choices=["s", "m", "l"], help="parameter counts are s:76M, m:455M, l:1362M") 15 | parser.add_argument("--ckpt", type=str, choices=["131000", "262000", "524000", "1000000"]) 16 | parser.add_argument("--clusters", action="store_true", help="download the color clusters file") 17 | parser.add_argument("--dataset", type=str, choices=["imagenet", "cifar10"]) 18 | 19 | args = parser.parse_args() 20 | print("input args:\n", json.dumps(vars(args), indent=4, separators=(",", ":"))) 21 | return args 22 | 23 | 24 | def main(args): 25 | if not os.path.exists(args.download_dir): 26 | os.makedirs(args.download_dir) 27 | 28 | urls = [] 29 | 30 | # download the checkpoint 31 | if args.model and args.ckpt: 32 | base_url = f"https://openaipublic.blob.core.windows.net/image-gpt/checkpoints/igpt-{args.model}{'-bert' if args.bert else ''}/{args.ckpt}" 33 | 34 | size_to_shards = {"s": 32, "m": 32, "l": 64} 35 | shards = size_to_shards[args.model] 36 | 37 | for filename in [f"model.ckpt-{args.ckpt}.data-{i:05d}-of-{shards:05d}" for i in range(shards)]: 38 | urls.append(f"{base_url}/{filename}") 39 | urls.append(f"{base_url}/model.ckpt-{args.ckpt}.index") 40 | urls.append(f"{base_url}/model.ckpt-{args.ckpt}.meta") 41 | 42 | # download the color clusters file 43 | if args.clusters: 44 | urls.append("https://openaipublic.blob.core.windows.net/image-gpt/color-clusters/kmeans_centers.npy") 45 | 46 | # download color clustered dataset 47 | if args.dataset: 48 | for split in ["trX", "trY", "vaX", "vaY", "teX", "teY"]: 49 | urls.append(f"https://openaipublic.blob.core.windows.net/image-gpt/datasets/{args.dataset}_{split}.npy") 50 | 51 | # run the download 52 | for url in urls: 53 | filename = url.split("/")[-1] 54 | r = requests.get(url, stream=True) 55 | with open(f"{args.download_dir}/{filename}", "wb") as f: 56 | file_size = int(r.headers["content-length"]) 57 | chunk_size = 1000 58 | with tqdm(ncols=80, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: 59 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes 60 | for chunk in r.iter_content(chunk_size=chunk_size): 61 | f.write(chunk) 62 | pbar.update(chunk_size) 63 | 64 | if __name__ == "__main__": 65 | args = parse_arguments() 66 | main(args) 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | # image-gpt 4 | 5 | Code and models from the paper ["Generative Pretraining from Pixels"](https://cdn.openai.com/papers/Generative_Pretraining_from_Pixels_V2.pdf). 6 | 7 | Supported Platforms: 8 | 9 | - Ubuntu 16.04 10 | 11 | ## Install 12 | 13 | You can get miniconda from https://docs.conda.io/en/latest/miniconda.html, or install the dependencies shown below manually. 14 | 15 | ``` 16 | conda create --name image-gpt python=3.7.3 17 | conda activate image-gpt 18 | 19 | conda install numpy=1.16.3 20 | conda install tensorflow-gpu=1.13.1 21 | 22 | conda install imageio=2.8.0 23 | conda install requests=2.21.0 24 | conda install tqdm=4.46.0 25 | ``` 26 | 27 | ## Usage 28 | 29 | This repository is meant to be a starting point for researchers and engineers to experiment with image GPT (iGPT). Our code forks GPT-2 to highlight that it can be easily applied across domains. The diff from `gpt-2/src/model.py` to `image-gpt/src/model.py` includes a new activation function, renaming of several variables, and the introduction of a start-of-sequence token, none of which change the model architecture. 30 | 31 | ### Downloading Pre-trained Models 32 | 33 | To download a model checkpoint, run `download.py`. The `--model` argument should be one of "s", "m", or "l", and the `--ckpt` argument should be one of "131000", "262000", "524000", or "1000000". 34 | 35 | ``` 36 | python download.py --model s --ckpt 1000000 37 | ``` 38 | 39 | This command downloads the iGPT-S checkpoint at 1M training iterations. The default download directory is set to `/root/downloads/`, and can be changed using the `--download_dir` argument. 40 | 41 | ### Downloading Datasets 42 | 43 | To download datasets, run `download.py` with the `--dataset` argument set to "imagenet" or "cifar10". 44 | 45 | ``` 46 | python download.py --model s --ckpt 1000000 --dataset imagenet 47 | ``` 48 | 49 | This command additionally downloads 32x32 ImageNet encoded with the 9-bit color palette described in the paper. The datasets we provide are center-cropped images intended for evaluation; random cropped images are required to faithfully replicate training. 50 | 51 | ### Downloading Color Clusters 52 | 53 | To download the color cluster file defining our 9-bit color palette, run `download.py` with the `--clusters` flag set. 54 | 55 | ``` 56 | python download.py --model s --ckpt 1000000 --dataset imagenet --clusters 57 | ``` 58 | 59 | This command additionally downloads the color cluster file. `src/run.py:sample` shows how to decode from 9-bit color to RGB and `src/utils.py:color_quantize` shows how to go the other way around. 60 | 61 | ### Sampling 62 | 63 | Once the desired checkpoint and color cluster file are downloaded, we can run the script in sampling mode. The following commands sample from iGPT-S, iGPT-M, and iGPT-L respectively: 64 | 65 | ``` 66 | python src/run.py --sample --n_embd 512 --n_head 8 --n_layer 24 67 | python src/run.py --sample --n_embd 1024 --n_head 8 --n_layer 36 68 | python src/run.py --sample --n_embd 1536 --n_head 16 --n_layer 48 69 | ``` 70 | 71 | If your data is not in `/root/downloads/`, set `--ckpt_path` and `--color_cluster_path` manually. To run on fewer than 8 GPUs, use a command of the following form: 72 | 73 | ``` 74 | CUDA_VISIBLE_DEVICES=0,1 python src/run.py --sample --n_embd 512 --n_head 8 --n_layer 24 --n_gpu 2 75 | ``` 76 | 77 | ### Evaluating 78 | 79 | Once the desired checkpoint and evaluation dataset are downloaded, we can run the script in evaluation mode. The following commands evaluate iGPT-S, iGPT-M, and iGPT-L on ImageNet respectively: 80 | 81 | ``` 82 | python src/run.py --eval --n_embd 512 --n_head 8 --n_layer 24 83 | python src/run.py --eval --n_embd 1024 --n_head 8 --n_layer 36 84 | python src/run.py --eval --n_embd 1536 --n_head 16 --n_layer 48 85 | ``` 86 | 87 | If your data is not in `/root/downloads/`, set `--ckpt_path` and `--data_path` manually. You should see that the test generative losses are 2.0895, 2.0614, and 2.0466, matching Figure 3 in the paper. 88 | 89 | ### Citation 90 | 91 | Please use the following bibtex entry: 92 | ``` 93 | @article{chen2020generative, 94 | title={Generative Pretraining from Pixels}, 95 | author={Chen, Mark and Radford, Alec and Child, Rewon and Wu, Jeff and Jun, Heewoo and Dhariwal, Prafulla and Luan, David and Sutskever, Ilya}, 96 | year={2020} 97 | } 98 | ``` 99 | 100 | ## License 101 | 102 | [Modified MIT](./LICENSE) 103 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | import random 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | from imageio import imwrite 13 | from scipy.special import softmax 14 | from tensorflow.contrib.training import HParams 15 | from tqdm import tqdm 16 | 17 | from model import model 18 | from utils import iter_data, count_parameters 19 | 20 | 21 | def parse_arguments(): 22 | parser = argparse.ArgumentParser() 23 | 24 | # data and I/O 25 | parser.add_argument("--data_path", type=str, default="/root/downloads/imagenet") 26 | parser.add_argument("--ckpt_path", type=str, default="/root/downloads/model.ckpt-1000000") 27 | parser.add_argument("--color_cluster_path", type=str, default="/root/downloads/kmeans_centers.npy") 28 | parser.add_argument("--save_dir", type=str, default="/root/save/") 29 | 30 | # model 31 | parser.add_argument("--n_embd", type=int, default=512) 32 | parser.add_argument("--n_head", type=int, default=8) 33 | parser.add_argument("--n_layer", type=int, default=24) 34 | parser.add_argument("--n_px", type=int, default=32, help="image height or width in pixels") 35 | parser.add_argument("--n_vocab", type=int, default=512, help="possible values for each pixel") 36 | 37 | parser.add_argument("--bert", action="store_true", help="use the bert objective (defaut: autoregressive)") 38 | parser.add_argument("--bert_mask_prob", type=float, default=0.15) 39 | parser.add_argument("--clf", action="store_true", help="add a learnable classification head") 40 | 41 | # parallelism 42 | parser.add_argument("--n_sub_batch", type=int, default=8, help="per-gpu batch size") 43 | parser.add_argument("--n_gpu", type=int, default=8, help="number of gpus to distribute training across") 44 | 45 | # mode 46 | parser.add_argument("--eval", action="store_true", help="evaluates the model, requires a checkpoint and dataset") 47 | parser.add_argument("--sample", action="store_true", help="samples from the model, requires a checkpoint and clusters") 48 | 49 | # reproducibility 50 | parser.add_argument("--seed", type=int, default=42, help="seed for random, np, tf") 51 | 52 | args = parser.parse_args() 53 | print("input args:\n", json.dumps(vars(args), indent=4, separators=(",", ":"))) 54 | return args 55 | 56 | 57 | def set_seed(seed): 58 | random.seed(seed) 59 | np.random.seed(seed) 60 | tf.set_random_seed(seed) 61 | 62 | 63 | def load_data(data_path): 64 | trX = np.load(f'{data_path}_trX.npy') 65 | trY = np.load(f'{data_path}_trY.npy') 66 | vaX = np.load(f'{data_path}_vaX.npy') 67 | vaY = np.load(f'{data_path}_vaY.npy') 68 | teX = np.load(f'{data_path}_teX.npy') 69 | teY = np.load(f'{data_path}_teY.npy') 70 | return (trX, trY), (vaX, vaY), (teX, teY) 71 | 72 | 73 | def set_hparams(args): 74 | return HParams( 75 | n_ctx=args.n_px*args.n_px, 76 | n_embd=args.n_embd, 77 | n_head=args.n_head, 78 | n_layer=args.n_layer, 79 | n_vocab=args.n_vocab, 80 | bert=args.bert, 81 | bert_mask_prob=args.bert_mask_prob, 82 | clf=args.clf, 83 | ) 84 | 85 | 86 | def create_model(x, y, n_gpu, hparams): 87 | gen_logits = [] 88 | gen_loss = [] 89 | clf_loss = [] 90 | tot_loss = [] 91 | accuracy = [] 92 | 93 | trainable_params = None 94 | for i in range(n_gpu): 95 | with tf.device("/gpu:%d" % i): 96 | results = model(hparams, x[i], y[i], reuse=(i != 0)) 97 | 98 | gen_logits.append(results["gen_logits"]) 99 | gen_loss.append(results["gen_loss"]) 100 | clf_loss.append(results["clf_loss"]) 101 | 102 | if hparams.clf: 103 | tot_loss.append(results["gen_loss"] + results["clf_loss"]) 104 | else: 105 | tot_loss.append(results["gen_loss"]) 106 | 107 | accuracy.append(results["accuracy"]) 108 | 109 | if i == 0: 110 | trainable_params = tf.trainable_variables() 111 | print("trainable parameters:", count_parameters()) 112 | 113 | return trainable_params, gen_logits, gen_loss, clf_loss, tot_loss, accuracy 114 | 115 | 116 | def reduce_mean(gen_loss, clf_loss, tot_loss, accuracy, n_gpu): 117 | with tf.device("/gpu:0"): 118 | for i in range(1, n_gpu): 119 | gen_loss[0] += gen_loss[i] 120 | clf_loss[0] += clf_loss[i] 121 | tot_loss[0] += tot_loss[i] 122 | accuracy[0] += accuracy[i] 123 | gen_loss[0] /= n_gpu 124 | clf_loss[0] /= n_gpu 125 | tot_loss[0] /= n_gpu 126 | accuracy[0] /= n_gpu 127 | 128 | 129 | def evaluate(sess, evX, evY, X, Y, gen_loss, clf_loss, accuracy, n_batch, desc, permute=False): 130 | metrics = [] 131 | for xmb, ymb in iter_data(evX, evY, n_batch=n_batch, truncate=True, verbose=True): 132 | metrics.append(sess.run([gen_loss[0], clf_loss[0], accuracy[0]], {X: xmb, Y: ymb})) 133 | eval_gen_loss, eval_clf_loss, eval_accuracy = [np.mean(m) for m in zip(*metrics)] 134 | print(f"{desc} gen: {eval_gen_loss:.4f} clf: {eval_clf_loss:.4f} acc: {eval_accuracy:.2f}") 135 | 136 | 137 | # naive sampler without caching 138 | def sample(sess, X, gen_logits, n_sub_batch, n_gpu, n_px, n_vocab, clusters, save_dir): 139 | samples = np.zeros([n_gpu * n_sub_batch, n_px * n_px], dtype=np.int32) 140 | 141 | for i in tqdm(range(n_px * n_px), ncols=80, leave=False): 142 | np_gen_logits = sess.run(gen_logits, {X: samples}) 143 | for j in range(n_gpu): 144 | p = softmax(np_gen_logits[j][:, i, :], axis=-1) # logits to probas 145 | for k in range(n_sub_batch): 146 | c = np.random.choice(n_vocab, p=p[k]) # choose based on probas 147 | samples[j * n_sub_batch + k, i] = c 148 | 149 | # dequantize 150 | samples = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [32, 32, 3]).astype(np.uint8) for s in samples] 151 | 152 | # write to png 153 | for i in range(n_gpu * n_sub_batch): 154 | imwrite(f"{args.save_dir}/sample_{i}.png", samples[i]) 155 | 156 | 157 | def main(args): 158 | set_seed(args.seed) 159 | 160 | n_batch = args.n_sub_batch * args.n_gpu 161 | 162 | if args.data_path.endswith("cifar10"): 163 | n_class = 10 164 | elif args.data_path.endswith("imagenet"): 165 | n_class = 1000 166 | else: 167 | raise ValueError("Dataset not supported.") 168 | 169 | X = tf.placeholder(tf.int32, [n_batch, args.n_px * args.n_px]) 170 | Y = tf.placeholder(tf.float32, [n_batch, n_class]) 171 | 172 | x = tf.split(X, args.n_gpu, 0) 173 | y = tf.split(Y, args.n_gpu, 0) 174 | 175 | hparams = set_hparams(args) 176 | trainable_params, gen_logits, gen_loss, clf_loss, tot_loss, accuracy = create_model(x, y, args.n_gpu, hparams) 177 | reduce_mean(gen_loss, clf_loss, tot_loss, accuracy, args.n_gpu) 178 | 179 | saver = tf.train.Saver(var_list=[tp for tp in trainable_params if not 'clf' in tp.name]) 180 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: 181 | sess.run(tf.global_variables_initializer()) 182 | 183 | saver.restore(sess, args.ckpt_path) 184 | 185 | if args.eval: 186 | (trX, trY), (vaX, vaY), (teX, teY) = load_data(args.data_path) 187 | evaluate(sess, trX[:len(vaX)], trY[:len(vaY)], X, Y, gen_loss, clf_loss, accuracy, n_batch, "train") 188 | evaluate(sess, vaX, vaY, X, Y, gen_loss, clf_loss, accuracy, n_batch, "valid") 189 | evaluate(sess, teX, teY, X, Y, gen_loss, clf_loss, accuracy, n_batch, "test") 190 | 191 | if args.sample: 192 | if not os.path.exists(args.save_dir): 193 | os.makedirs(args.save_dir) 194 | clusters = np.load(args.color_cluster_path) 195 | sample(sess, X, gen_logits, args.n_sub_batch, args.n_gpu, args.n_px, args.n_vocab, clusters, args.save_dir) 196 | 197 | 198 | if __name__ == "__main__": 199 | args = parse_arguments() 200 | main(args) 201 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib.training import HParams 4 | 5 | def default_hparams(): 6 | return HParams( 7 | n_vocab=0, 8 | n_ctx=1024, 9 | n_embd=768, 10 | n_head=12, 11 | n_layer=12, 12 | ) 13 | 14 | def shape_list(x): 15 | """Deal with dynamic shape in tensorflow cleanly.""" 16 | static = x.shape.as_list() 17 | dynamic = tf.shape(x) 18 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 19 | 20 | def softmax(x, axis=-1): 21 | x = x - tf.reduce_max(x, axis=axis, keepdims=True) 22 | ex = tf.exp(x) 23 | return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) 24 | 25 | def gelu(x): 26 | return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) 27 | 28 | def gelu2(x): 29 | return x * tf.sigmoid(1.702 * x) 30 | 31 | def norm(x, scope, *, axis=-1, epsilon=1e-5): 32 | """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" 33 | with tf.variable_scope(scope): 34 | n_state = x.shape[axis].value 35 | g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) 36 | s = tf.reduce_mean(tf.square(x), axis=axis, keepdims=True) 37 | x = x * tf.rsqrt(s + epsilon) 38 | x = x*g 39 | return x 40 | 41 | def split_states(x, n): 42 | """Reshape the last dimension of x into [n, x.shape[-1]/n].""" 43 | *start, m = shape_list(x) 44 | return tf.reshape(x, start + [n, m//n]) 45 | 46 | def merge_states(x): 47 | """Smash the last two dimensions of x into a single dimension.""" 48 | *start, a, b = shape_list(x) 49 | return tf.reshape(x, start + [a*b]) 50 | 51 | def conv1d(x, scope, nf, *, w_init_stdev=0.02): 52 | with tf.variable_scope(scope): 53 | *start, nx = shape_list(x) 54 | w = tf.get_variable('w', [nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) 55 | c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])), start+[nf]) 56 | return c 57 | 58 | def attention_mask(nd, ns, *, dtype): 59 | """1's in the lower triangle, counting from the lower right corner. 60 | 61 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. 62 | """ 63 | i = tf.range(nd)[:,None] 64 | j = tf.range(ns) 65 | m = i >= j - ns + nd 66 | return tf.cast(m, dtype) 67 | 68 | 69 | def attn(x, scope, n_state, *, past, hparams): 70 | assert x.shape.ndims == 3 # Should be [batch, sequence, features] 71 | assert n_state % hparams.n_head == 0 72 | if past is not None: 73 | assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] 74 | 75 | def split_heads(x): 76 | # From [batch, sequence, features] to [batch, heads, sequence, features] 77 | return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) 78 | 79 | def merge_heads(x): 80 | # Reverse of split_heads 81 | return merge_states(tf.transpose(x, [0, 2, 1, 3])) 82 | 83 | def mask_attn_weights(w): 84 | # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. 85 | _, _, nd, ns = shape_list(w) 86 | b = attention_mask(nd, ns, dtype=w.dtype) 87 | b = tf.reshape(b, [1, 1, nd, ns]) 88 | w = w*b - tf.cast(1e10, w.dtype)*(1-b) 89 | return w 90 | 91 | def multihead_attn(q, k, v): 92 | # q, k, v have shape [batch, heads, sequence, features] 93 | w = tf.matmul(q, k, transpose_b=True) 94 | w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) 95 | 96 | if not hparams.bert: 97 | w = mask_attn_weights(w) 98 | w = softmax(w) 99 | a = tf.matmul(w, v) 100 | return a 101 | 102 | with tf.variable_scope(scope): 103 | *start, nx = shape_list(x) 104 | 105 | wk = tf.get_variable("k_proj", [hparams.n_head, nx // hparams.n_head, n_state], initializer=tf.random_normal_initializer(stddev=1.0/np.sqrt(n_state))) 106 | wq = tf.get_variable("q_proj", [hparams.n_head, nx // hparams.n_head, n_state], initializer=tf.random_normal_initializer(stddev=1.0/np.sqrt(n_state))) 107 | wv = tf.get_variable("v_proj", [hparams.n_head, nx // hparams.n_head, n_state], initializer=tf.random_normal_initializer(stddev=1.0/np.sqrt(n_state))) 108 | k = tf.einsum("bsf,hef->bhse", x, wk) 109 | q = tf.einsum("bsf,hef->bhse", x, wq) 110 | v = tf.einsum("bsf,hef->bhse", x, wv) 111 | 112 | present = tf.stack([k, v], axis=1) 113 | if past is not None: 114 | pk, pv = tf.unstack(past, axis=1) 115 | k = tf.concat([pk, k], axis=-2) 116 | v = tf.concat([pv, v], axis=-2) 117 | a = multihead_attn(q, k, v) 118 | wc = tf.get_variable("c_proj", [hparams.n_head, nx // hparams.n_head, n_state], initializer=tf.random_normal_initializer(stddev=1.0/np.sqrt(n_state*hparams.n_layer))) 119 | a = tf.einsum("bhse,hef->bsf", a, wc) 120 | return a, present 121 | 122 | 123 | def mlp(x, scope, n_state, *, hparams): 124 | with tf.variable_scope(scope): 125 | nx = x.shape[-1].value 126 | h = gelu2(conv1d(x, 'c_fc', n_state)) 127 | h2 = conv1d(h, 'c_proj', nx) 128 | return h2 129 | 130 | 131 | def block(x, scope, *, past, hparams): 132 | with tf.variable_scope(scope): 133 | nx = x.shape[-1].value 134 | a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams) 135 | x = x + a 136 | m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams) 137 | x = x + m 138 | return x, present 139 | 140 | def past_shape(*, hparams, batch_size=None, sequence=None): 141 | return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head] 142 | 143 | def expand_tile(value, size): 144 | """Add a new axis of given size.""" 145 | value = tf.convert_to_tensor(value, name='value') 146 | ndims = value.shape.ndims 147 | return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims) 148 | 149 | def positions_for(tokens, past_length): 150 | batch_size = tf.shape(tokens)[0] 151 | nsteps = tf.shape(tokens)[1] 152 | return expand_tile(past_length + tf.range(nsteps), batch_size) 153 | 154 | 155 | def model(hparams, X, Y=None, past=None, scope='model', reuse=False): 156 | with tf.variable_scope(scope, reuse=reuse): 157 | results = {} 158 | batch, sequence = shape_list(X) 159 | 160 | if hparams.bert: 161 | M = tf.greater(tf.random.uniform([batch, sequence]), hparams.bert_mask_prob) 162 | M = tf.cast(M, tf.float32) 163 | 164 | wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd], 165 | initializer=tf.random_normal_initializer(stddev=0.01)) 166 | wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd], 167 | initializer=tf.random_normal_initializer(stddev=0.02)) 168 | wtet = tf.get_variable('wtet', [hparams.n_vocab, hparams.n_embd], 169 | initializer=tf.random_normal_initializer(stddev=0.0)) 170 | past_length = 0 if past is None else tf.shape(past)[-2] 171 | 172 | h = tf.gather(wte, X) 173 | 174 | if hparams.bert: 175 | h = h * tf.expand_dims(M, 2) 176 | else: 177 | sos = tf.get_variable('sos', [hparams.n_embd], 178 | initializer=tf.random_normal_initializer(stddev=0.02)) 179 | sos_tok = tf.ones([batch, 1, hparams.n_embd], dtype=tf.float32) * sos 180 | h = tf.concat([sos_tok, h[:,:-1,:]], axis=1) 181 | 182 | h += tf.gather(wpe, positions_for(X, past_length)) 183 | 184 | # Transformer 185 | presents = [] 186 | pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer 187 | assert len(pasts) == hparams.n_layer 188 | for layer, past in enumerate(pasts): 189 | h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) 190 | presents.append(present) 191 | results['present'] = tf.stack(presents, axis=1) 192 | h = norm(h, 'ln_f') 193 | 194 | # Generative loss. Do tokens