├── .gitignore ├── LICENSE ├── arguments.py ├── assets ├── coco_new.png ├── cogviewcase.png └── logo.png ├── data └── placeholder ├── data_utils ├── __init__.py ├── configure_data.py ├── datasets.py ├── samplers.py ├── sp_tokenizer.py ├── templates.py ├── unified_tokenizer.py └── vqvae_tokenizer.py ├── env ├── dockerfile ├── ip_list.txt ├── ip_list.txt.example ├── setup_connection.py └── start_docker.sh ├── eval_utils ├── dataset.py ├── fid_score.py ├── inception.py └── inception_score.py ├── finetune └── __init__.py ├── fp16 ├── __init__.py ├── fp16.py ├── fp16util.py └── loss_scaler.py ├── generate_samples.py ├── generation ├── __init__.py ├── magnify.py └── sampling.py ├── learning_rates.py ├── model ├── __init__.py ├── distributed.py └── gpt2_modeling.py ├── mpu ├── __init__.py ├── cross_entropy.py ├── data.py ├── grads.py ├── initialize.py ├── layers.py ├── mappings.py ├── random.py ├── sparse_transformer.py └── utils.py ├── preprocess ├── __init__.py ├── preprocess_text_image_data.py ├── preprocess_text_jsonformat_data.py ├── pretokenized_data.py ├── raw_datasets.py └── utils.py ├── preprocess_entry.py ├── pretrain_gpt2.py ├── pretrained ├── chinese_sentencepiece │ ├── cog-pretrain.model │ └── cog-pretrain.vocab ├── cogview │ └── placeholder └── vqvae │ └── placeholder ├── readme.md ├── requirements.txt ├── scripts ├── ds_config.json ├── ds_config_zero.json ├── image2text.sh ├── low_level_super_resolution.sh ├── post_selection.sh ├── pretrain_multiple_nodes.sh ├── pretrain_single_node.sh ├── super_resolution.sh └── text2image.sh ├── test_lmdb.py ├── utils.py └── vqvae ├── LICENSE ├── README.md ├── __init__.py ├── api.py ├── distributed ├── __init__.py ├── distributed.py └── launch.py └── vqvae_zc.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *_jax* 3 | events.out* 4 | __pycache__/ 5 | *.pt 6 | data 7 | core.* 8 | _cache* 9 | .vscode/ 10 | samples/ 11 | hostfile 12 | pretrained/checkpoints 13 | *.png 14 | *.jpg 15 | *.jpeg 16 | input*.txt 17 | samples* -------------------------------------------------------------------------------- /assets/coco_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView/189a12bbe87227d3d127fa74abe968dd795cad39/assets/coco_new.png -------------------------------------------------------------------------------- /assets/cogviewcase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView/189a12bbe87227d3d127fa74abe968dd795cad39/assets/cogviewcase.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView/189a12bbe87227d3d127fa74abe968dd795cad39/assets/logo.png -------------------------------------------------------------------------------- /data/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView/189a12bbe87227d3d127fa74abe968dd795cad39/data/placeholder -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : __init__.py 4 | @Time : 2021/01/11 16:35:24 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | 11 | from .unified_tokenizer import get_tokenizer 12 | 13 | from .templates import * 14 | from .configure_data import make_loaders, detect_new_datasets -------------------------------------------------------------------------------- /data_utils/configure_data.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : configure_data.py 4 | @Time : 2021/01/11 23:28:38 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | import copy 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn.functional as F 20 | from bisect import bisect_right 21 | 22 | from .unified_tokenizer import get_tokenizer 23 | from .datasets import get_dataset_by_type 24 | from torch.utils import data 25 | from .samplers import DistributedBatchSampler 26 | 27 | import mpu 28 | 29 | 30 | def make_data_loader(dataset, batch_size, num_iters, args): 31 | world_size = torch.distributed.get_world_size( 32 | group=mpu.get_data_parallel_group()) 33 | rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) 34 | distributed = world_size > 1 35 | 36 | sampler = torch.utils.data.SequentialSampler(dataset) 37 | drop_last = distributed 38 | # the GPUs in the same model parallel group receive the same data 39 | if distributed: 40 | batch_sampler = DistributedBatchSampler(sampler, 41 | batch_size, 42 | drop_last, 43 | rank, 44 | world_size, 45 | gradient_accumulation_steps=args.gradient_accumulation_steps) 46 | else: 47 | batch_sampler = torch.utils.data.BatchSampler(sampler, 48 | batch_size, 49 | drop_last) 50 | data_loader = torch.utils.data.DataLoader(dataset, 51 | batch_sampler=batch_sampler, 52 | num_workers=args.num_workers, 53 | pin_memory=True) 54 | return data_loader 55 | 56 | 57 | def make_dataset(dataset_type, path, split, args, **kwargs): 58 | """function to create datasets+tokenizers for common options""" 59 | print('make dataset ...', path) 60 | if split is None: 61 | split = [1.] 62 | 63 | assert isinstance(path, list) 64 | # TODO other dsclass, e.g. odps 65 | # ds = [get_dataset_by_type(dataset_type, p, args) for p in path] 66 | # dataset object can be copied N times 67 | ds = [] 68 | for p in path: 69 | d = get_dataset_by_type(dataset_type, p, args) 70 | if p.find('t2i') >= 0: 71 | ds.extend([d] * 4) 72 | print(f'Enlarge {p} 4 times...') 73 | elif p.find('i2t') >= 0: 74 | ds.extend([d] * 2) 75 | print(f'Enlarge {p} 2 times...') 76 | else: 77 | ds.append(d) 78 | 79 | ds = RandomMappingDataset(ConcatDataset(ds)) 80 | 81 | if should_split(split): 82 | ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping 83 | # FIXME this will merge valid set and train set. 84 | return ds 85 | 86 | def make_loaders(args): 87 | """makes training/val/test""" 88 | 89 | world_size = torch.distributed.get_world_size( 90 | group=mpu.get_data_parallel_group()) 91 | batch_size = args.batch_size * world_size 92 | eval_batch_size = batch_size 93 | if args.eval_batch_size is not None: 94 | eval_batch_size = args.eval_batch_size * world_size 95 | 96 | split = get_split(args) 97 | 98 | data_set_args = { 99 | 'path': args.train_data, 100 | 'dataset_type': args.dataset_type, 101 | 'split': split, 102 | } 103 | 104 | eval_set_args = copy.copy(data_set_args) 105 | eval_set_args['split'] = [1.] 106 | 107 | # make datasets splits and tokenizer 108 | train = None 109 | valid = None 110 | test = None 111 | 112 | if args.train_data is not None: 113 | train = make_dataset(**data_set_args, args=args) 114 | if should_split(split): 115 | train, valid, test = train 116 | 117 | # make training and val dataset if necessary 118 | if valid is None and args.valid_data is not None: 119 | eval_set_args['path'] = args.valid_data 120 | valid = make_dataset(**eval_set_args, args=args) 121 | if test is None and args.test_data is not None: 122 | eval_set_args['path'] = args.test_data 123 | test = make_dataset(**eval_set_args, args=args) 124 | 125 | # wrap datasets with data loader 126 | if train is not None and args.batch_size > 0: 127 | train = make_data_loader(train, batch_size, args.train_iters, args) 128 | args.do_train = True 129 | else: 130 | args.do_train = False 131 | eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size 132 | if valid is not None: 133 | valid = make_data_loader(valid, eval_batch_size, args.train_iters, args) 134 | args.do_valid = True 135 | else: 136 | args.do_valid = False 137 | if test is not None: 138 | test = make_data_loader(test, eval_batch_size, len(test) // eval_batch_size + 1, args) 139 | args.do_test = True 140 | else: 141 | args.do_test = False 142 | 143 | return train, valid, test 144 | 145 | 146 | 147 | def get_split(args): 148 | """ 149 | Get dataset splits from comma separated string list 150 | """ 151 | splits = [] 152 | if args.split.find(',') != -1: 153 | splits = [float(s) for s in args.split.split(',')] 154 | elif args.split.find('/') != -1: 155 | splits = [float(s) for s in args.split.split('/')] 156 | else: 157 | splits = [float(args.split)] 158 | split_total = sum(splits) 159 | if split_total < 1.: 160 | splits.append(1-split_total) 161 | while len(splits) < 3: 162 | splits.append(0.) 163 | splits = splits[:3] 164 | if args.valid_data is not None: 165 | splits[1] = 0. 166 | if args.test_data is not None: 167 | splits[2] = 0. 168 | final_sum = sum(splits) 169 | return [s/final_sum for s in splits] 170 | 171 | def should_split(split): 172 | """ 173 | given split proportions checks if should split 174 | Examples: 175 | >>> should_split([10,0,0]) 176 | False 177 | >>> should_split([1,.1,.2]) 178 | True 179 | """ 180 | return max(split) / sum(split) != 1. 181 | 182 | def split_ds(ds, split=[.8,.2,.0]): 183 | """ 184 | Split a dataset into subsets given proportions of how 185 | much to allocate per split. If a split is 0% returns None for that split. 186 | Purpose: Useful for creating train/val/test splits 187 | Arguments: 188 | ds (Dataset or array-like): Data to be split. 189 | split (1D array-like): proportions to split `ds`. `sum(splits) != 0` 190 | shuffle (boolean): Randomly split dataset. Default: True 191 | """ 192 | split_sum = sum(split) 193 | if split_sum == 0: 194 | raise Exception('Split cannot sum to 0.') 195 | split = np.array(split) 196 | split /= split_sum 197 | ds_len = len(ds) 198 | 199 | start_idx = 0 200 | residual_idx = 0 201 | rtn_ds = [None]*len(split) 202 | for i, f in enumerate(split): 203 | if f != 0: 204 | proportion = ds_len*split[i] 205 | residual_idx += proportion % 1 206 | split_ = int(int(proportion) + residual_idx) 207 | split_range = (start_idx, start_idx+max(split_, 1)) 208 | rtn_ds[i] = SplitDataset(ds, split_range) 209 | start_idx += split_ 210 | residual_idx %= 1 211 | return rtn_ds 212 | 213 | class ConcatDataset(data.Dataset): 214 | """ 215 | Dataset to concatenate multiple datasets. 216 | Purpose: useful to assemble different existing datasets, possibly 217 | large-scale datasets as the concatenation operation is done in an 218 | on-the-fly manner. 219 | Arguments: 220 | datasets (sequence): List of datasets to be concatenated. 221 | """ 222 | 223 | @staticmethod 224 | def cumsum(sequence): 225 | r, s = [], 0 226 | for e in sequence: 227 | l = len(e) 228 | r.append(l + s) 229 | s += l 230 | return r 231 | 232 | def __init__(self, datasets, **kwargs): 233 | super(ConcatDataset, self).__init__() 234 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 235 | self.datasets = list(datasets) 236 | self.cumulative_sizes = self.cumsum(self.datasets) 237 | 238 | def __len__(self): 239 | return self.cumulative_sizes[-1] 240 | 241 | def __getitem__(self, idx): 242 | dataset_idx = bisect_right(self.cumulative_sizes, idx) 243 | if dataset_idx == 0: 244 | sample_idx = idx 245 | else: 246 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 247 | return self.datasets[dataset_idx][sample_idx] 248 | 249 | 250 | class SplitDataset(data.Dataset): 251 | """ 252 | Dataset wrapper to access a subset of another dataset. 253 | Purpose: useful to index into existing datasets, possibly 254 | large-scale datasets as the subindexing operation is done in an 255 | on-the-fly manner. 256 | Arguments: 257 | ds (Dataset or array-like): List of datasets to be subindexed 258 | split_range (Tuple): (Left, Right) 259 | """ 260 | def __init__(self, ds, split_range, **kwargs): 261 | self.split_range = split_range 262 | self.wrapped_data = ds 263 | 264 | def __len__(self): 265 | return self.split_range[1] - self.split_range[0] 266 | 267 | def __getitem__(self, index): 268 | index += self.split_range[0] 269 | assert index < self.split_range[1] 270 | return self.wrapped_data[index] 271 | 272 | def __iter__(self): 273 | for idx in range(*self.split_range): 274 | yield self.wrapped_data[idx] 275 | 276 | class RandomMappingDataset(data.Dataset): 277 | ''' 278 | Dataset wrapper to randomly mapping indices to original order. 279 | Will also enlarge the length 280 | ''' 281 | def __init__(self, ds, **kwargs): 282 | self.wrapped_data = ds 283 | 284 | def __len__(self): 285 | return len(self.wrapped_data) * 200 286 | 287 | def __getitem__(self, index): 288 | rng = random.Random(index) 289 | rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) 290 | index = rng.randint(len(self.wrapped_data)) 291 | return self.wrapped_data[index] 292 | 293 | def detect_new_datasets(args): 294 | if args.new_dataset_path is None: 295 | return None 296 | if not os.path.exists(args.new_dataset_path): 297 | print('Warning: new_dataset_path not exists... skip detection.') 298 | return None 299 | current_datasets = [str(os.path.abspath(path)) for path in args.train_data] 300 | 301 | found = [] 302 | for _p in os.listdir(args.new_dataset_path): 303 | p = os.path.join(args.new_dataset_path, _p) 304 | if (str(p).endswith('lmdb') or str(p).endswith('bin')) and not str(os.path.abspath(p)) in current_datasets: 305 | found.append(p) 306 | if len(found) == 0: 307 | return None 308 | else: 309 | args.train_data = args.train_data + found 310 | return make_loaders(args) 311 | -------------------------------------------------------------------------------- /data_utils/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : datasets.py 4 | @Time : 2021/01/11 21:01:51 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | import logging 16 | 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn.functional as F 21 | from torchvision import datasets, transforms 22 | import pickle 23 | from collections import namedtuple 24 | 25 | from torch.utils.data import Dataset 26 | import lmdb 27 | 28 | from .unified_tokenizer import get_tokenizer 29 | from .templates import TextCodeTemplate 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class LMDBDataset(Dataset): 35 | def __init__(self, path, process_fn): 36 | self.env = lmdb.open( 37 | path, 38 | max_readers=32, 39 | readonly=True, 40 | lock=False, 41 | readahead=False, 42 | meminit=False, 43 | ) 44 | self.process_fn = process_fn 45 | if not self.env: 46 | raise IOError('Cannot open lmdb dataset', path) 47 | 48 | with self.env.begin(write=False) as txn: 49 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 50 | 51 | def __len__(self): 52 | return self.length 53 | 54 | def __getitem__(self, idx): 55 | 56 | with self.env.begin(write=False) as txn: 57 | key = str(idx).encode('utf-8') 58 | 59 | row = pickle.loads(txn.get(key)) 60 | 61 | return self.process_fn(row) 62 | 63 | class BinaryDataset(Dataset): 64 | def __init__(self, path, process_fn, length_per_sample=64+1024, dtype='int32', preload=False, **kwargs): 65 | assert length_per_sample is not None 66 | self.length_per_sample = length_per_sample 67 | self.dtype = np.dtype(dtype) 68 | self.process_fn = process_fn 69 | if preload: 70 | self.bin = np.fromfile(path, dtype=self.dtype).reshape(-1, length_per_sample) 71 | else: 72 | with open(path, 'r') as fid: 73 | nbytes = fid.seek(0, 2) 74 | flen = fid.tell() // self.dtype.itemsize 75 | self.bin = np.memmap(path, dtype=self.dtype, shape=(flen // length_per_sample, length_per_sample)) 76 | 77 | def __len__(self): 78 | return self.bin.shape[0] 79 | 80 | def __getitem__(self, index): 81 | return self.process_fn(self.bin[index]) 82 | 83 | def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): 84 | 85 | tokenizer = get_tokenizer() 86 | if args.finetune and args.max_position_embeddings_finetune > args.max_position_embeddings: 87 | ml = args.max_position_embeddings_finetune 88 | else: 89 | ml = args.max_position_embeddings 90 | 91 | def pad_to_len(ret): 92 | 93 | if len(ret) < ml: # pad 94 | return np.concatenate((ret, 95 | np.array([tokenizer['[PAD]']] * (ml - len(ret)))), 96 | axis=0), len(ret) 97 | else: 98 | if len(ret) > ml: 99 | logger.warning('Out of max len, truncated.') 100 | return ret[:ml], ml 101 | 102 | if dataset_type == 'TokenizedDataset': 103 | # already tokenized when saved 104 | def process_fn(row): 105 | ret, attention_mask_sep = pad_to_len(row.flatten()) 106 | return {'text': ret, 107 | 'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep)) 108 | } 109 | 110 | elif dataset_type == 'TextCodeDataset': 111 | def process_fn(row): 112 | text, code = row[0], row[1].flatten() 113 | ret = TextCodeTemplate(text, code) 114 | ret, attention_mask_sep = pad_to_len(ret) 115 | return {'text': ret, 116 | 'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep)) 117 | } 118 | 119 | elif dataset_type == 'CompactBinaryDataset': 120 | DS_CLASS = BinaryDataset 121 | def process_fn(row): 122 | text, code = row[:64].astype(np.int64), row[64:].astype(np.int64) # must 64 + 1024 123 | text = text[text>-1] 124 | ret = TextCodeTemplate(text, code) 125 | ret, attention_mask_sep = pad_to_len(ret) 126 | return {'text': ret, 127 | 'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep)) 128 | } 129 | 130 | return DS_CLASS(path, process_fn) 131 | 132 | -------------------------------------------------------------------------------- /data_utils/samplers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """batch samplers that work with either random or sequential data samplers""" 16 | import math 17 | import os 18 | import sys 19 | 20 | import torch 21 | from torch.utils import data 22 | import numpy as np 23 | 24 | class RandomSampler(data.sampler.Sampler): 25 | r""" 26 | Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, 27 | but this class lets the user set an epoch like DistributedSampler 28 | Samples elements randomly. If without replacement, then sample from a shuffled dataset. 29 | If with replacement, then user can specify ``num_samples`` to draw. 30 | Arguments: 31 | data_source (Dataset): dataset to sample from 32 | num_samples (int): number of samples to draw, default=len(dataset) 33 | replacement (bool): samples are drawn with replacement if ``True``, default=False 34 | """ 35 | 36 | def __init__(self, data_source, replacement=False, num_samples=None): 37 | self.data_source = data_source 38 | self.replacement = replacement 39 | self._num_samples = num_samples 40 | self.epoch = -1 41 | 42 | if self._num_samples is not None and replacement is False: 43 | raise ValueError("With replacement=False, num_samples should not be specified, " 44 | "since a random permute will be performed.") 45 | 46 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 47 | raise ValueError("num_samples should be a positive integer " 48 | "value, but got num_samples={}".format(self.num_samples)) 49 | if not isinstance(self.replacement, bool): 50 | raise ValueError("replacement should be a boolean value, but got " 51 | "replacement={}".format(self.replacement)) 52 | 53 | @property 54 | def num_samples(self): 55 | # dataset size might change at runtime 56 | if self._num_samples is None: 57 | return len(self.data_source) 58 | return self._num_samples 59 | 60 | def __iter__(self): 61 | n = len(self.data_source) 62 | g = torch.Generator() 63 | if self.epoch >= 0: 64 | g.manual_seed(self.epoch) 65 | if self.replacement: 66 | return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist()) 67 | return iter(torch.randperm(n, generator=g).tolist()) 68 | 69 | def __len__(self): 70 | return self.num_samples 71 | 72 | def set_epoch(self, epoch): 73 | self.epoch = epoch 74 | 75 | 76 | class DistributedSequentialSampler(data.sampler.Sampler): 77 | def __init__(self, num_samples, train_iters, batch_size, rank=-1, world_size=2): 78 | super().__init__(num_samples) 79 | if rank == -1: 80 | rank = 0 81 | world_size = 1 82 | self.num_samples = num_samples 83 | self.rank = rank 84 | self.world_size = world_size 85 | self.start_iter = 0 86 | self.train_iters = train_iters 87 | self.batch_size = batch_size 88 | self.batch_bias = [i * (num_samples // batch_size) for i in range(batch_size)] 89 | 90 | def __iter__(self): 91 | for idx in range(self.start_iter, self.train_iters * 10): 92 | batch = [(idx + bias) % self.num_samples for bias in self.batch_bias] 93 | tbatch = self._batch(batch) 94 | yield tbatch 95 | 96 | def __len__(self): 97 | return self.train_iters 98 | 99 | def _batch(self, batch): 100 | """extracts samples only pertaining to this worker's batch""" 101 | start = self.rank*self.batch_size//self.world_size 102 | end = (self.rank+1)*self.batch_size//self.world_size 103 | return batch[start:end] 104 | 105 | 106 | class DistributedBatchSampler(data.sampler.BatchSampler): 107 | """ 108 | similar to normal implementation of distributed sampler, except implementation is at the 109 | batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary 110 | data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. 111 | """ 112 | def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False, gradient_accumulation_steps=None): 113 | super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) 114 | if rank == -1: 115 | assert False, 'should not be here' 116 | self.rank = rank 117 | self.world_size = world_size 118 | self.sampler.wrap_around = 0 119 | self.wrap_around = 0 120 | self.wrap_last = wrap_last 121 | self.start_iter = 0 122 | self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps 123 | 124 | def __iter__(self): 125 | batch = [] 126 | i = 0 127 | for idx in self.data_iterator(self.sampler, wrap_around=False): 128 | batch.append(idx) 129 | if len(batch) == self.batch_size: 130 | tbatch = self._batch(batch) 131 | if i >= self.start_iter * self.effective_batch_size: 132 | yield tbatch 133 | self.start_iter = 0 134 | i += len(batch) 135 | batch = [] 136 | batch_len = len(batch) 137 | if batch_len > 0 and not self.drop_last: 138 | if self.wrap_last: 139 | self.sampler.wrap_around -= (self.batch_size) 140 | self.wrap_around += (len(batch)) 141 | self.wrap_around %= self.batch_size 142 | if isinstance(self.sampler, TransposedSampler): 143 | for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)): 144 | if i == 0: 145 | continue 146 | batch.append(idx) 147 | new_batch_len = len(batch) 148 | if len(batch) == self.batch_size: 149 | break 150 | yield self._batch(batch) 151 | if self.wrap_last: 152 | self.sampler.wrap_around += self.batch_size 153 | 154 | def data_iterator(self, _iter, wrap_around=False): 155 | """iterates through data and handles wrap around""" 156 | for i, idx in enumerate(_iter): 157 | if i < self.wrap_around%self.batch_size: 158 | continue 159 | if wrap_around: 160 | self.wrap_around += 1 161 | self.wrap_around %= self.batch_size 162 | yield idx 163 | 164 | def _batch(self, batch): 165 | """extracts samples only pertaining to this worker's batch""" 166 | start = self.rank*self.batch_size//self.world_size 167 | end = (self.rank+1)*self.batch_size//self.world_size 168 | return batch[start:end] 169 | -------------------------------------------------------------------------------- /data_utils/sp_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/openai/gpt-2/, changed for chinese 3 | """ 4 | import json 5 | import os 6 | import sentencepiece as spm 7 | 8 | """ 9 | SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation 10 | systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements 11 | subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the 12 | extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end 13 | system that does not depend on language-specific pre/postprocessing. 14 | https://github.com/google/sentencepiece 15 | 16 | pip install sentencepiece 17 | 18 | or git clone https://github.com/google/sentencepiece.git 19 | python setup.py install 20 | 21 | """ 22 | PRETRAINED_MODEL_FILE = "pretrained/chinese_sentencepiece/cog-pretrain.model" 23 | 24 | 25 | def get_pairs(word): 26 | pairs = set() 27 | prev_char = word[0] 28 | for char in word[1:]: 29 | pairs.add((prev_char, char)) 30 | prev_char = char 31 | return pairs 32 | 33 | 34 | class Encoder: 35 | def __init__(self, encoder, bpe_merges): 36 | self.encoder = encoder 37 | self.decoder = {v: k for k, v in self.encoder.items()} 38 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 39 | self.cache = {} 40 | self.max_len = 0 41 | 42 | def bpe(self, token): 43 | if token in self.cache: 44 | return self.cache[token] 45 | word = tuple(token) 46 | pairs = get_pairs(word) 47 | if not pairs: 48 | return token 49 | 50 | while True: 51 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 52 | if bigram not in self.bpe_ranks: 53 | break 54 | first, second = bigram 55 | new_word = [] 56 | i = 0 57 | while i < len(word): 58 | try: 59 | j = word.index(first, i) 60 | new_word.extend(word[i:j]) 61 | i = j 62 | except: 63 | new_word.extend(word[i:]) 64 | break 65 | 66 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 67 | new_word.append(first + second) 68 | i += 2 69 | else: 70 | new_word.append(word[i]) 71 | i += 1 72 | new_word = tuple(new_word) 73 | word = new_word 74 | if len(word) == 1: 75 | break 76 | else: 77 | pairs = get_pairs(word) 78 | word = ' '.join(word) 79 | self.cache[token] = word 80 | return word 81 | 82 | def encode(self, text): 83 | return [self.encoder.get(token, 1) for token in self.tokenize(text)] 84 | 85 | def decode(self, tokens): 86 | text = ''.join([self.decoder[token] for token in tokens]) 87 | return text 88 | 89 | def tokenize(self, text): 90 | bpe_tokens = [] 91 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' ')) 92 | return bpe_tokens 93 | 94 | def convert_tokens_to_ids(self, tokens): 95 | return [self.encoder.get(token, 1) for token in tokens] 96 | 97 | 98 | class Encoder_SP: 99 | def __init__(self, model_path): 100 | self.sp = spm.SentencePieceProcessor() 101 | self.sp.Load(model_path) 102 | self.num_tokens = self.sp.vocab_size() 103 | 104 | def encode(self, text): 105 | """ 106 | text="...." 107 | """ 108 | return self.sp.EncodeAsIds(text) 109 | 110 | def decode(self, tokens): 111 | """ 112 | tokens=[x1,x2,...] 113 | """ 114 | text = [int(token) for token in tokens] 115 | return self.sp.DecodeIds(text) 116 | 117 | def tokenize(self, text): 118 | return self.sp.EncodeAsPieces(text) 119 | 120 | def convert_tokens_to_ids(self, tokens): 121 | return [self.sp.PieceToId(token) for token in tokens] 122 | 123 | def convert_token_to_id(self, token): 124 | return self.sp.PieceToId(token) 125 | 126 | def convert_id_to_token(self, idx): 127 | return self.sp.IdToPiece(idx) 128 | 129 | 130 | def get_encoder(encoder_file, bpe_file): 131 | # 以下是为了同一个函数入兼容sentencepiece 132 | filepath, filename = os.path.split(encoder_file) 133 | shotname, extension = os.path.splitext(filename) 134 | 135 | if (".model" == extension) and (bpe_file == ""): 136 | return Encoder_SP(encoder_file) 137 | else: 138 | with open(encoder_file, 'r', encoding="utf-8") as f: 139 | encoder = json.load(f) 140 | with open(bpe_file, 'r', encoding="utf-8") as f: 141 | bpe_data = f.read() 142 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 143 | return Encoder( 144 | encoder=encoder, 145 | bpe_merges=bpe_merges, 146 | ) 147 | 148 | 149 | def from_pretrained(): 150 | return get_encoder(PRETRAINED_MODEL_FILE, "") -------------------------------------------------------------------------------- /data_utils/templates.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : templates.py 4 | @Time : 2021/01/11 22:28:57 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | from .unified_tokenizer import get_tokenizer 21 | from .vqvae_tokenizer import sqrt_int 22 | 23 | def concat_codes(*codes): 24 | is_numpy = is_tensor = False 25 | for code in codes: 26 | if isinstance(code, np.ndarray): 27 | is_numpy = True 28 | if isinstance(code, torch.Tensor): 29 | is_tensor = True 30 | device = code.device 31 | if is_tensor: 32 | return torch.cat( 33 | [ 34 | torch.tensor(code, device=device) 35 | for code in codes 36 | ] 37 | ) 38 | elif is_numpy: 39 | return np.concatenate( 40 | [ 41 | np.array(code) 42 | for code in codes 43 | ], 44 | axis=0 45 | ) 46 | else: 47 | ret = [] 48 | for code in codes: 49 | ret = ret + code 50 | return ret 51 | 52 | def TextCodeTemplate(text, code): 53 | tokenizer = get_tokenizer() 54 | if isinstance(text, str): 55 | text_ids = [tokenizer['[ROI1]']] + tokenizer(text) 56 | else: 57 | text_ids = np.concatenate( 58 | ( 59 | np.array([tokenizer['[ROI1]']]), 60 | text, 61 | ), 62 | axis=0 63 | ) 64 | code = tokenizer.wrap_code(code) 65 | return concat_codes(text_ids, code) 66 | 67 | def Code2CodeTemplate(text, code0, code1): 68 | tokenizer = get_tokenizer() 69 | text_ids = tokenizer.parse_query(text) if isinstance(text, str) else text 70 | code0 = tokenizer.wrap_code(code0) 71 | code1 = tokenizer.wrap_code(code1, idx=2) 72 | return concat_codes(text_ids, code0, code1) 73 | 74 | def PureTextTemplate(text): 75 | tokenizer = get_tokenizer() 76 | return tokenizer(text) + [tokenizer['[SEP]']] 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /data_utils/unified_tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : unified_tokenizer.py 4 | @Time : 2021/01/11 16:36:33 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | from .sp_tokenizer import from_pretrained 21 | from .vqvae_tokenizer import VQVAETokenizer, sqrt_int 22 | 23 | class UnifiedTokenizer(object): 24 | def __init__(self, img_tokenizer_path, device, img_tokenizer_num_tokens=None): 25 | self.device = device 26 | if img_tokenizer_path is None and img_tokenizer_num_tokens is not None: 27 | # pretraining but only know the vocab size of VQVAE, which is developing fast 28 | self.img_tokenizer = FakeTokenizer(img_tokenizer_num_tokens) 29 | else: 30 | self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device) 31 | self.txt_tokenizer = from_pretrained() 32 | self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens 33 | self.raw_command_tokens = [ 34 | ('[PAD]', 0), 35 | ('[BOI1]', 1), # Begin 36 | ('[BOI2]', 2), 37 | ('[BOI3]', 3), 38 | ('[EOI1]', 4), # End 39 | ('[EOI2]', 5), 40 | ('[EOI3]', 6), 41 | ('[ROI1]', 7), # Reference 42 | ('[ROI2]', 8), 43 | ('[ROI3]', 9), 44 | ('[SEP]', 10), 45 | ('[MASK]', 11), 46 | ('[CLS]', 12), 47 | ('[ENC]', 13), 48 | ('[TINY]', 14), # 8 * 8 49 | ('[SMALL]', 15), # 16 * 16 50 | ('[BASE]', 16), # 32 * 32 51 | ('[BIG]', 17), # 64 * 64 52 | ('[POS0]', 18), # 58210 53 | ('[POS1]', 19), 54 | ('[POS2]', 20), 55 | ('[POS3]', 21), 56 | ('[POS4]', 22), 57 | ('[POS5]', 23), 58 | ('[POS6]', 24), 59 | ('[POS7]', 25), 60 | ('[POS8]', 26) 61 | # Please leave the ``size tokens'' at the back of command tokens 62 | ] 63 | self.command_tokens = { 64 | k: v + self.num_tokens 65 | for k, v in self.raw_command_tokens 66 | } 67 | self.num_tokens += len(self.raw_command_tokens) 68 | 69 | def __getitem__(self, command_token): 70 | return self.command_tokens[command_token] 71 | 72 | def __len__(self): 73 | """total number of tokens""" 74 | return self.num_tokens 75 | 76 | def __call__(self, inputs, process_fn=None): 77 | """run preprocessing and encode inputs as Ids 78 | CANNOT contain command tokens""" 79 | if isinstance(inputs, torch.Tensor): # image 80 | if len(inputs.shape) == 3: 81 | inputs = inputs.unsqueeze(0) 82 | return self.img_tokenizer.EncodeAsIds(inputs) 83 | return self.EncodeAsIds(inputs, process_fn=process_fn) 84 | 85 | def EncodeAsIds(self, text, process_fn=None): 86 | processed_text = text 87 | if process_fn is not None: 88 | processed_text = process_fn(processed_text) 89 | ids = self.txt_tokenizer.encode(processed_text) 90 | return [x + self.img_tokenizer.num_tokens for x in ids] 91 | 92 | def DecodeIds(self, ids): 93 | ret, img_buffer, txt_buffer, ret_imgs = [], [], [], [] 94 | try: 95 | for x in ids: 96 | if self.num_tokens - len(self.raw_command_tokens) <= x: 97 | # command tokens 98 | token = self.raw_command_tokens[x - (self.num_tokens - len(self.raw_command_tokens))][0] 99 | if token.startswith('[EOI') and len(img_buffer) > 0: 100 | # dump image 101 | ret_imgs.append(self.img_tokenizer.DecodeIds(img_buffer)) 102 | img_buffer = [] 103 | if len(txt_buffer) > 0: 104 | # dump text 105 | ret.append(self.txt_tokenizer.decode(txt_buffer)) 106 | txt_buffer = [] 107 | ret.append(token) 108 | elif x < self.img_tokenizer.num_tokens: 109 | img_buffer.append(x) 110 | else: 111 | txt_buffer.append(x - self.img_tokenizer.num_tokens) 112 | 113 | if len(img_buffer) > 0: 114 | # dump image 115 | ret_imgs.append(self.img_tokenizer.DecodeIds(img_buffer)) 116 | img_buffer = [] 117 | if len(txt_buffer) > 0: 118 | # dump text 119 | ret.append(self.txt_tokenizer.decode(txt_buffer)) 120 | txt_buffer = [] 121 | except ValueError: 122 | print('Value error in tokenization, skipping...') 123 | return ret, ret_imgs 124 | 125 | def wrap_code(self, code, idx=1): 126 | s = sqrt_int(len(code)) 127 | prefix = {8:'[TINY]', 16:'[SMALL]', 32:'[BASE]', 64:'[BIG]'}[s] 128 | boi = {1:'[BOI1]', 2: '[BOI2]', 3:'[BOI3]'}[idx] 129 | eoi = {1:'[EOI1]', 2: '[EOI2]', 3:'[EOI3]'}[idx] 130 | 131 | if isinstance(code, list): 132 | return [self.command_tokens[prefix], self.command_tokens[boi]] + \ 133 | code + [self.command_tokens[eoi]] 134 | elif isinstance(code, np.ndarray): 135 | return np.concatenate( 136 | ( 137 | np.array([self.command_tokens[prefix], self.command_tokens[boi]]), 138 | code, 139 | np.array([self.command_tokens[eoi]]) 140 | ), 141 | axis=0 142 | ) 143 | elif isinstance(code, torch.Tensor): 144 | return torch.cat( 145 | ( 146 | torch.tensor([self.command_tokens[prefix], self.command_tokens[boi]]), 147 | code, 148 | np.array([self.command_tokens[eoi]]) 149 | ) 150 | ) 151 | else: 152 | raise ValueError('') 153 | 154 | def parse_query(self, query, img_size=256): 155 | text_buffer = [] 156 | ret = [] 157 | for part in query.split(' '): 158 | if part in self.command_tokens: 159 | if len(text_buffer) > 0: 160 | # dump text ids 161 | ret.extend(self.EncodeAsIds(' '.join(text_buffer))) 162 | text_buffer = [] 163 | if part == '[MASK]': 164 | ret.append(-1) 165 | else: 166 | ret.append(self.command_tokens[part]) 167 | elif part.startswith('[MASK]*'): # special lang *N 168 | c = int(part[7:]) 169 | assert c > 0 170 | if len(text_buffer) > 0: 171 | # dump text ids 172 | ret.extend(self.EncodeAsIds(' '.join(text_buffer))) 173 | text_buffer = [] 174 | ret.extend([-1] * c) 175 | elif part.startswith('[Image'): # [Image*N]path 176 | c = part[6:] 177 | assert len(c) > 0 178 | num_codes, img_path = c.split(']') 179 | if num_codes == '': 180 | num_codes = 1024 181 | else: 182 | num_codes = int(num_codes) 183 | 184 | raw_img = self.img_tokenizer.read_img(img_path, img_size=img_size) 185 | img_codes = self.img_tokenizer.EncodeAsIds(raw_img) # [1, 32*32] 186 | img_codes[0, num_codes:] = -1 187 | img_codes = img_codes[0].tolist() 188 | ret.extend(img_codes) 189 | else: 190 | text_buffer.append(part) 191 | 192 | if len(text_buffer) > 0: 193 | # dump text ids 194 | ret.extend(self.EncodeAsIds(' '.join(text_buffer))) 195 | text_buffer = [] 196 | return ret 197 | 198 | def get_tokenizer(args=None): 199 | if not hasattr(get_tokenizer, 'tokenizer'): 200 | # the first time to load the tokenizer, specify img_tokenizer_path 201 | get_tokenizer.tokenizer = UnifiedTokenizer( 202 | args.img_tokenizer_path, 203 | device=torch.cuda.current_device(), 204 | img_tokenizer_num_tokens=args.img_tokenizer_num_tokens 205 | ) 206 | return get_tokenizer.tokenizer 207 | 208 | class FakeTokenizer(object): 209 | def __init__(self, num_tokens): 210 | self.num_tokens = num_tokens 211 | def __len__(self): 212 | return self.num_tokens -------------------------------------------------------------------------------- /data_utils/vqvae_tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : vqvae_tokenizer.py 4 | @Time : 2021/01/11 17:57:43 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | 21 | from vqvae import new_model, img2code, code2img 22 | from torchvision import transforms 23 | from PIL import Image 24 | 25 | def is_exp2(x): 26 | t = math.log2(x) 27 | return abs(t - int(t)) < 1e-4 28 | def sqrt_int(x): 29 | r = int(math.sqrt(x) + 1e-4) 30 | assert r * r == x 31 | return r 32 | 33 | class VQVAETokenizer(object): 34 | def __init__(self, 35 | model_path, 36 | device='cuda' 37 | ): 38 | ckpt = torch.load(model_path, map_location=torch.device(device)) 39 | 40 | model = new_model() 41 | 42 | if list(ckpt.keys())[0].startswith('module.'): 43 | ckpt = {k[7:]: v for k, v in ckpt.items()} 44 | 45 | model.load_state_dict(ckpt) 46 | model = model.to(device) 47 | model.eval() 48 | 49 | self.model = model 50 | self.device = device 51 | self.image_tokens = model.quantize_t.n_embed 52 | self.num_tokens = model.quantize_t.n_embed 53 | 54 | def __len__(self): 55 | return self.num_tokens 56 | 57 | def EncodeAsIds(self, img): 58 | assert len(img.shape) == 4 # [b, c, h, w] 59 | return img2code(self.model, img) 60 | 61 | def DecodeIds(self, code, shape=None): 62 | if shape is None: 63 | if isinstance(code, list): 64 | code = torch.tensor(code, device=self.device) 65 | s = sqrt_int(len(code.view(-1))) 66 | assert s * s == len(code.view(-1)) 67 | shape = (1, s, s) 68 | code = code.view(shape) 69 | out = code2img(self.model, code) 70 | return out 71 | 72 | def read_img(self, path, img_size=256): 73 | tr = transforms.Compose([ 74 | transforms.Resize(img_size), 75 | transforms.CenterCrop(img_size), 76 | transforms.ToTensor(), 77 | ]) 78 | img = tr(Image.open(path)) 79 | if img.shape[0] == 4: 80 | img = img[:-1] 81 | tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]) 82 | img = tr_normalize(img) 83 | img = img.unsqueeze(0).float().to(self.device) # size [1, 3, h, w] 84 | return img -------------------------------------------------------------------------------- /env/dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.1-devel-ubuntu18.04 2 | 3 | ############################################################################## 4 | # Temporary Installation Directory 5 | ############################################################################## 6 | ENV STAGE_DIR=/tmp 7 | RUN mkdir -p ${STAGE_DIR} 8 | 9 | ############################################################################## 10 | # Installation/Basic Utilities 11 | ############################################################################## 12 | RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list 13 | RUN sed -i s@/security.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list 14 | RUN rm /etc/apt/sources.list.d/nvidia-ml.list && rm /etc/apt/sources.list.d/cuda.list && apt-get clean 15 | RUN apt-get update && \ 16 | apt-get install -y --no-install-recommends \ 17 | software-properties-common build-essential autotools-dev \ 18 | nfs-common pdsh \ 19 | cmake g++ gcc \ 20 | curl wget vim tmux emacs less unzip \ 21 | htop iftop iotop ca-certificates openssh-client openssh-server \ 22 | rsync iputils-ping net-tools sudo \ 23 | llvm-9-dev libsndfile-dev \ 24 | libcupti-dev \ 25 | libjpeg-dev \ 26 | libpng-dev \ 27 | screen jq psmisc dnsutils lsof musl-dev systemd 28 | 29 | ############################################################################## 30 | # Installation Latest Git 31 | ############################################################################## 32 | RUN add-apt-repository ppa:git-core/ppa -y && \ 33 | apt-get update && \ 34 | apt-get install -y git && \ 35 | git --version 36 | 37 | ############################################################################## 38 | # Client Liveness & Uncomment Port 22 for SSH Daemon 39 | ############################################################################## 40 | # Keep SSH client alive froGm server side 41 | RUN echo "ClientAliveInterval 30" >> /etc/ssh/sshd_config 42 | RUN cp /etc/ssh/sshd_config ${STAGE_DIR}/sshd_config && \ 43 | sed "0,/^#Port 22/s//Port 22/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config 44 | 45 | ############################################################################## 46 | # Mellanox OFED 47 | ############################################################################## 48 | ENV MLNX_OFED_VERSION=5.1-0.6.6.0 49 | #ENV MLNX_OFED_VERSION=4.6-1.0.1.1 50 | RUN apt-get install -y libnuma-dev 51 | RUN cd ${STAGE_DIR} && \ 52 | wget -q -O - http://www.mellanox.com/downloads/ofed/MLNX_OFED-${MLNX_OFED_VERSION}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64.tgz | tar xzf - && \ 53 | cd MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64 && \ 54 | ./mlnxofedinstall --user-space-only --without-fw-update --umad-dev-rw --all -q && \ 55 | cd ${STAGE_DIR} && \ 56 | rm -rf ${STAGE_DIR}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64* 57 | 58 | ############################################################################## 59 | # nv_peer_mem 60 | ############################################################################## 61 | ENV NV_PEER_MEM_VERSION=1.1 62 | ENV NV_PEER_MEM_TAG=1.1-0 63 | RUN mkdir -p ${STAGE_DIR} && \ 64 | git clone https://github.com/Mellanox/nv_peer_memory.git --branch ${NV_PEER_MEM_TAG} ${STAGE_DIR}/nv_peer_memory && \ 65 | cd ${STAGE_DIR}/nv_peer_memory && \ 66 | ./build_module.sh && \ 67 | cd ${STAGE_DIR} && \ 68 | tar xzf ${STAGE_DIR}/nvidia-peer-memory_${NV_PEER_MEM_VERSION}.orig.tar.gz && \ 69 | cd ${STAGE_DIR}/nvidia-peer-memory-${NV_PEER_MEM_VERSION} && \ 70 | apt-get update && \ 71 | apt-get install -y dkms && \ 72 | dpkg-buildpackage -us -uc && \ 73 | dpkg -i ${STAGE_DIR}/nvidia-peer-memory_${NV_PEER_MEM_TAG}_all.deb 74 | 75 | ############################################################################## 76 | # OPENMPI 77 | ############################################################################## 78 | ENV OPENMPI_BASEVERSION=4.0 79 | ENV OPENMPI_VERSION=${OPENMPI_BASEVERSION}.5 80 | #ENV OPENMPI_VERSION=${OPENMPI_BASEVERSION}.1 81 | RUN cd ${STAGE_DIR} && \ 82 | wget -q -O - https://download.open-mpi.org/release/open-mpi/v${OPENMPI_BASEVERSION}/openmpi-${OPENMPI_VERSION}.tar.gz | tar xzf - && \ 83 | cd openmpi-${OPENMPI_VERSION} && \ 84 | ./configure --prefix=/usr/local/openmpi-${OPENMPI_VERSION} && \ 85 | make -j"$(nproc)" install && \ 86 | ln -s /usr/local/openmpi-${OPENMPI_VERSION} /usr/local/mpi && \ 87 | # Sanity check: 88 | test -f /usr/local/mpi/bin/mpic++ && \ 89 | cd ${STAGE_DIR} && \ 90 | rm -r ${STAGE_DIR}/openmpi-${OPENMPI_VERSION} 91 | ENV PATH=/usr/local/mpi/bin:${PATH} \ 92 | LD_LIBRARY_PATH=/usr/local/lib:/usr/local/mpi/lib:/usr/local/mpi/lib64:${LD_LIBRARY_PATH} 93 | # Create a wrapper for OpenMPI to allow running as root by default 94 | RUN mv /usr/local/mpi/bin/mpirun /usr/local/mpi/bin/mpirun.real && \ 95 | echo '#!/bin/bash' > /usr/local/mpi/bin/mpirun && \ 96 | echo 'mpirun.real --allow-run-as-root --prefix /usr/local/mpi "$@"' >> /usr/local/mpi/bin/mpirun && \ 97 | chmod a+x /usr/local/mpi/bin/mpirun 98 | 99 | ############################################################################## 100 | # Python 101 | ############################################################################## 102 | ARG PYTHON_VERSION=3.8 103 | RUN curl -o ~/miniconda.sh https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 104 | chmod +x ~/miniconda.sh && \ 105 | ~/miniconda.sh -b -p /opt/conda && \ 106 | rm ~/miniconda.sh && \ 107 | /opt/conda/bin/conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ && \ 108 | /opt/conda/bin/conda config --set show_channel_urls yes && \ 109 | /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include ninja cython typing && \ 110 | /opt/conda/bin/conda clean -ya 111 | 112 | ENV PATH /opt/conda/bin:$PATH 113 | RUN wget https://tuna.moe/oh-my-tuna/oh-my-tuna.py && python oh-my-tuna.py 114 | RUN pip install --upgrade pip setuptools 115 | 116 | ############################################################################## 117 | # Some Packages 118 | ############################################################################## 119 | RUN pip install psutil \ 120 | yappi \ 121 | cffi \ 122 | ipdb \ 123 | h5py \ 124 | pandas \ 125 | matplotlib \ 126 | py3nvml \ 127 | pyarrow \ 128 | graphviz \ 129 | astor \ 130 | boto3 \ 131 | tqdm \ 132 | sentencepiece \ 133 | msgpack \ 134 | requests \ 135 | pandas \ 136 | sphinx \ 137 | sphinx_rtd_theme \ 138 | nvidia-ml-py3 \ 139 | mpi4py \ 140 | filelock \ 141 | lmdb \ 142 | cupy-cuda111 && \ 143 | pip cache purge 144 | 145 | 146 | ############################################################################## 147 | # PyTorch 148 | # The default NCCL from pytorch will be slower, but to download pytorch source code is too slow in China, so we gave up. 149 | ############################################################################## 150 | # RUN git clone --branch v1.8.1 --recursive https://github.com/pytorch/pytorch /opt/pytorch 151 | # RUN cd /opt/pytorch && \ 152 | # git submodule sync && git submodule update --init --recursive 153 | 154 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX 8.0 8.6" 155 | # ENV NCCL_LIBRARY=/usr/lib/x86_64-linux-gnu 156 | # ENV NCCL_INCLUDE_DIR=/usr/include 157 | # RUN conda install -c pytorch magma-cuda111 && \ 158 | # cd /opt/pytorch && TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \ 159 | # CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" USE_SYSTEM_NCCL=1 \ 160 | # pip install -v . && rm -rf /opt/pytorch 161 | 162 | ENV TENSORBOARDX_VERSION=1.8 163 | RUN pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html && \ 164 | pip install tensorboardX==${TENSORBOARDX_VERSION} && \ 165 | pip cache purge 166 | 167 | ############################################################################## 168 | # apex 169 | ############################################################################## 170 | # RUN git clone https://github.com/NVIDIA/apex ${STAGE_DIR}/apex 171 | COPY apex-master ${STAGE_DIR}/apex 172 | RUN cd ${STAGE_DIR}/apex && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \ 173 | && rm -rf ${STAGE_DIR}/apex 174 | 175 | ############################################################################## 176 | # PyYAML build issue 177 | # https://stackoverflow.com/a/53926898 178 | ############################################################################## 179 | RUN rm -rf /usr/lib/python3/dist-packages/yaml && \ 180 | rm -rf /usr/lib/python3/dist-packages/PyYAML-* 181 | 182 | 183 | ############################################################################## 184 | # DeepSpeed 185 | ############################################################################## 186 | # RUN git clone https://github.com/microsoft/DeepSpeed.git ${STAGE_DIR}/DeepSpeed 187 | # COPY DeepSpeed ${STAGE_DIR}/DeepSpeed 188 | # RUN cd ${STAGE_DIR}/DeepSpeed && \ 189 | # git checkout . && \ 190 | # DS_BUILD_OPS=1 ./install.sh -r 191 | # RUN rm -rf ${STAGE_DIR}/DeepSpeed 192 | # RUN python -c "import deepspeed; print(deepspeed.__version__)" 193 | RUN pip install triton==0.2.3 && \ 194 | DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_FUSED_LAMB=1 DS_BUILD_SPARSE_ATTN=1 DS_BUILD_UTILS=1 pip install deepspeed --global-option="build_ext" --global-option="-j8" && \ 195 | pip cache purge && \ 196 | ds_report 197 | 198 | ############################################################################## 199 | ## SSH daemon port inside container cannot conflict with host OS port 200 | ############################################################################### 201 | ARG SSH_PORT=2222 202 | RUN cat /etc/ssh/sshd_config > ${STAGE_DIR}/sshd_config && \ 203 | echo "PasswordAuthentication no" >> ${STAGE_DIR}/sshd_config && \ 204 | sed "0,/^Port 22/s//Port ${SSH_PORT}/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config 205 | EXPOSE ${SSH_PORT} 206 | 207 | # Set SSH KEY 208 | RUN echo "StrictHostKeyChecking no \nUserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \ 209 | ssh-keygen -t rsa -f ~/.ssh/id_rsa -N "" && cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys && \ 210 | chmod og-wx ~/.ssh/authorized_keys 211 | -------------------------------------------------------------------------------- /env/ip_list.txt: -------------------------------------------------------------------------------- 1 | 127.0.0.1 -------------------------------------------------------------------------------- /env/ip_list.txt.example: -------------------------------------------------------------------------------- 1 | 172.30.0.214 172.30.0.215 172.30.0.209 2 | -------------------------------------------------------------------------------- /env/setup_connection.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : setup_connection.py 4 | @Time : 2021/01/16 16:50:36 5 | @Author : Ming Ding 6 | @Contact : dm18@mail.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | import base64 15 | 16 | if __name__ == "__main__": 17 | ssh_config = '' 18 | line_format = 'Host node{}\n\tUser root\n\tPort 2222\n\tHostname {}\n' 19 | for i, ip in enumerate(sys.argv[1:]): 20 | ssh_config += line_format.format(i, ip) 21 | 22 | ret = os.system(f'echo \"{ssh_config}\" > ~/.ssh/config && chmod 600 ~/.ssh/config') 23 | assert ret == 0 24 | 25 | hostfile_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'hostfile') 26 | with open(hostfile_path, 'w') as fout: 27 | for i, ip in enumerate(sys.argv[1:]): 28 | fout.write(f'node{i} slots=8\n') 29 | print(f'Successfully generating hostfile \'{hostfile_path}\'!') 30 | 31 | 32 | -------------------------------------------------------------------------------- /env/start_docker.sh: -------------------------------------------------------------------------------- 1 | script_path=$(realpath $0) 2 | script_dir=$(dirname $script_path) 3 | main_dir=$(dirname $script_dir) 4 | ip_list=$(cat $script_dir/ip_list.txt) 5 | docker run --gpus all -d --ipc=host --cap-add=IPC_LOCK -v /sys/class/net/:/sys/class/net/ --device=/dev/ --privileged --network=host -v $main_dir:/root/cogview --name bg-cogview cogview/cuda111_torch181_deepspeed040:base bash -c "/etc/init.d/ssh start && python /root/cogview/env/setup_connection.py $ip_list && sleep 365d" 6 | -------------------------------------------------------------------------------- /eval_utils/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import IterableDataset 2 | import PIL 3 | import csv 4 | import torch 5 | from io import BytesIO 6 | import base64 7 | 8 | 9 | class TsvDataset(IterableDataset): 10 | def __init__(self, path, transform=None, caption_only=False): 11 | self.f = open(path, "r") 12 | self.tsvreader = csv.reader(self.f, delimiter='\t') 13 | self.transform = transform 14 | self.caption_only = caption_only 15 | def callback_fn(image_base64, id, caption): 16 | try: 17 | img = Image.open(BytesIO(base64.urlsafe_b64decode(image_base64))).convert('RGB') 18 | if self.transform is not None: 19 | img = self.transform(img) 20 | return img, id, caption 21 | except (PIL.UnidentifiedImageError, PIL.Image.DecompressionBombError): 22 | print("UnidentifiedImageError") 23 | return torch.zeros((3, 256, 256)), "not_a_image", "not_a_caption" 24 | self.callback_fn = callback_fn 25 | def __iter__(self): 26 | def get_next(): 27 | if self.caption_only: 28 | for line in self.tsvreader: 29 | yield self.callback_fn(torch.zeros((3, 256, 256)), line[0], line[1]) 30 | else: 31 | for line in self.tsvreader: 32 | yield self.callback_fn(line[3], line[0], line[2]) 33 | return iter(get_next()) 34 | def __del__(self): 35 | self.f.close() -------------------------------------------------------------------------------- /eval_utils/fid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectivly. 15 | 16 | See --help to see further details. 17 | 18 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 19 | of Tensorflow 20 | 21 | Copyright 2018 Institute of Bioinformatics, JKU Linz 22 | 23 | Licensed under the Apache License, Version 2.0 (the "License"); 24 | you may not use this file except in compliance with the License. 25 | You may obtain a copy of the License at 26 | 27 | http://www.apache.org/licenses/LICENSE-2.0 28 | 29 | Unless required by applicable law or agreed to in writing, software 30 | distributed under the License is distributed on an "AS IS" BASIS, 31 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | See the License for the specific language governing permissions and 33 | limitations under the License. 34 | """ 35 | import os 36 | import pathlib 37 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 38 | from torchvision.models.inception import inception_v3 39 | 40 | 41 | import torch 42 | import numpy as np 43 | from scipy.misc import imread 44 | from scipy import linalg 45 | from torch.autograd import Variable 46 | from torch.nn.functional import adaptive_avg_pool2d 47 | import torchvision.transforms as transforms 48 | from inception import InceptionV3 49 | import torch.utils.data 50 | from PIL import Image 51 | from torch.utils import data 52 | import img_data as img_data 53 | 54 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 55 | #parser.add_argument('path', type=str, nargs=2, 56 | # help=('Path to the generated images or ' 57 | # 'to .npz statistic files')) 58 | parser.add_argument('--batch-size', type=int, default=64, 59 | help='Batch size to use') 60 | parser.add_argument('--dims', type=int, default=2048, 61 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 62 | help=('Dimensionality of Inception features to use. ' 63 | 'By default, uses pool3 features')) 64 | parser.add_argument('-c', '--gpu', default='', type=str, 65 | help='GPU to use (leave blank for CPU only)') 66 | parser.add_argument('--path1', type=str, default=64) 67 | parser.add_argument('--path2', type=str, default=64) 68 | 69 | def get_activations(images, model, batch_size=64, dims=2048, cuda=False, verbose=True): 70 | """Calculates the activations of the pool_3 layer for all images. 71 | 72 | Params: 73 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 74 | must lie between 0 and 1. 75 | -- model : Instance of inception model 76 | -- batch_size : the images numpy array is split into batches with 77 | batch size batch_size. A reasonable batch size depends 78 | on the hardware. 79 | -- dims : Dimensionality of features returned by Inception 80 | -- cuda : If set to True, use GPU 81 | -- verbose : If set to True and parameter out_step is given, the number 82 | of calculated batches is reported. 83 | Returns: 84 | -- A numpy array of dimension (num images, dims) that contains the 85 | activations of the given tensor when feeding inception with the 86 | query tensor. 87 | """ 88 | model.eval() 89 | 90 | #d0 = images.shape[0] 91 | 92 | d0 = images.__len__() * batch_size 93 | if batch_size > d0: 94 | print(('Warning: batch size is bigger than the data size. ' 95 | 'Setting batch size to data size')) 96 | batch_size = d0 97 | 98 | n_batches = d0 // batch_size 99 | n_used_imgs = n_batches * batch_size 100 | 101 | pred_arr = np.empty((n_used_imgs, dims)) 102 | #for i in range(n_batches): 103 | for i, batch in enumerate(images): 104 | #batch = batch[0] 105 | #if verbose: 106 | #print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True) 107 | #import ipdb 108 | #ipdb.set_trace() 109 | start = i * batch_size 110 | end = start + batch_size 111 | 112 | #batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 113 | #batch = Variable(batch, volatile=True) 114 | 115 | if cuda: 116 | batch = batch.cuda() 117 | 118 | pred = model(batch)[0] 119 | 120 | # If model output is not scalar, apply global spatial average pooling. 121 | # This happens if you choose a dimensionality not equal 2048. 122 | if pred.shape[2] != 1 or pred.shape[3] != 1: 123 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 124 | 125 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 126 | 127 | if verbose: 128 | print(' done') 129 | 130 | return pred_arr 131 | 132 | 133 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 134 | """Numpy implementation of the Frechet Distance. 135 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 136 | and X_2 ~ N(mu_2, C_2) is 137 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 138 | 139 | Stable version by Dougal J. Sutherland. 140 | 141 | Params: 142 | -- mu1 : Numpy array containing the activations of a layer of the 143 | inception net (like returned by the function 'get_predictions') 144 | for generated samples. 145 | -- mu2 : The sample mean over activations, precalculated on an 146 | representive data set. 147 | -- sigma1: The covariance matrix over activations for generated samples. 148 | -- sigma2: The covariance matrix over activations, precalculated on an 149 | representive data set. 150 | 151 | Returns: 152 | -- : The Frechet Distance. 153 | """ 154 | 155 | mu1 = np.atleast_1d(mu1) 156 | mu2 = np.atleast_1d(mu2) 157 | 158 | sigma1 = np.atleast_2d(sigma1) 159 | sigma2 = np.atleast_2d(sigma2) 160 | 161 | assert mu1.shape == mu2.shape, \ 162 | 'Training and test mean vectors have different lengths' 163 | assert sigma1.shape == sigma2.shape, \ 164 | 'Training and test covariances have different dimensions' 165 | 166 | diff = mu1 - mu2 167 | 168 | # Product might be almost singular 169 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 170 | if not np.isfinite(covmean).all(): 171 | msg = ('fid calculation produces singular product; ' 172 | 'adding %s to diagonal of cov estimates') % eps 173 | print(msg) 174 | offset = np.eye(sigma1.shape[0]) * eps 175 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 176 | 177 | # Numerical error might give slight imaginary component 178 | if np.iscomplexobj(covmean): 179 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 180 | m = np.max(np.abs(covmean.imag)) 181 | raise ValueError('Imaginary component {}'.format(m)) 182 | covmean = covmean.real 183 | 184 | tr_covmean = np.trace(covmean) 185 | 186 | return (diff.dot(diff) + np.trace(sigma1) + 187 | np.trace(sigma2) - 2 * tr_covmean) 188 | 189 | 190 | def calculate_activation_statistics(images, model, batch_size=64, 191 | dims=2048, cuda=False, verbose=True): 192 | """Calculation of the statistics used by the FID. 193 | Params: 194 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 195 | must lie between 0 and 1. 196 | -- model : Instance of inception model 197 | -- batch_size : The images numpy array is split into batches with 198 | batch size batch_size. A reasonable batch size 199 | depends on the hardware. 200 | -- dims : Dimensionality of features returned by Inception 201 | -- cuda : If set to True, use GPU 202 | -- verbose : If set to True and parameter out_step is given, the 203 | number of calculated batches is reported. 204 | Returns: 205 | -- mu : The mean over samples of the activations of the pool_3 layer of 206 | the inception model. 207 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 208 | the inception model. 209 | """ 210 | act = get_activations(images, model, batch_size, dims, cuda, verbose) 211 | mu = np.mean(act, axis=0) 212 | sigma = np.cov(act, rowvar=False) 213 | return mu, sigma 214 | 215 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda): 216 | if path.endswith('.npz'): 217 | f = np.load(path) 218 | m, s = f['mu'][:], f['sigma'][:] 219 | f.close() 220 | 221 | else: 222 | dataset = img_data.Dataset(path, transforms.Compose([ 223 | transforms.Resize((299, 299)), 224 | transforms.ToTensor(), 225 | ])) 226 | print(dataset.__len__()) 227 | dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8) 228 | m, s = calculate_activation_statistics(dataloader, model, batch_size, dims, cuda) 229 | return m, s 230 | 231 | def calculate_fid_given_dataset(dataset1, dataset2, batch_size, cuda=True, dims=2048): 232 | """Calculates the FID of two dataset""" 233 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 234 | model = InceptionV3([block_idx]) 235 | if cuda: 236 | model.cuda() 237 | 238 | loader1 = torch.utils.data.DataLoader(dataset=dataset1, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8) 239 | m1, s1 = calculate_activation_statistics(loader1, model, batch_size, dims, cuda) 240 | loader2 = torch.utils.data.DataLoader(dataset=dataset2, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8) 241 | m2, s2 = calculate_activation_statistics(loader2, model, batch_size, dims, cuda) 242 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 243 | return fid_value 244 | 245 | 246 | def calculate_fid_given_paths(paths, batch_size, cuda, dims): 247 | """Calculates the FID of two paths""" 248 | for p in paths: 249 | if not os.path.exists(p): 250 | raise RuntimeError('Invalid path: %s' % p) 251 | 252 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 253 | model = InceptionV3([block_idx]) 254 | if cuda: 255 | model.cuda() 256 | 257 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, cuda) 258 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, cuda) 259 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 260 | return fid_value 261 | 262 | if __name__ == '__main__': 263 | args = parser.parse_args() 264 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 265 | paths = ["",""] 266 | paths[0] = args.path1 267 | paths[1] = args.path2 268 | print(paths) 269 | fid_value = calculate_fid_given_paths(paths, args.batch_size,args.gpu,args.dims) 270 | print('FID: ', fid_value) 271 | -------------------------------------------------------------------------------- /eval_utils/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=True, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | 28 | Parameters 29 | ---------- 30 | output_blocks : list of int 31 | Indices of blocks to return features of. Possible values are: 32 | - 0: corresponds to output of first max pooling 33 | - 1: corresponds to output of second max pooling 34 | - 2: corresponds to output which is fed to aux classifier 35 | - 3: corresponds to output of final average pooling 36 | resize_input : bool 37 | If true, bilinearly resizes input to width and height 299 before 38 | feeding input to model. As the network without fully connected 39 | layers is fully convolutional, it should be able to handle inputs 40 | of arbitrary size, so resizing might not be strictly needed 41 | normalize_input : bool 42 | If true, normalizes the input to the statistics the pretrained 43 | Inception network expects 44 | requires_grad : bool 45 | If true, parameters of the model require gradient. Possibly useful 46 | for finetuning the network 47 | """ 48 | super(InceptionV3, self).__init__() 49 | 50 | self.resize_input = resize_input 51 | self.normalize_input = normalize_input 52 | self.output_blocks = sorted(output_blocks) 53 | self.last_needed_block = max(output_blocks) 54 | 55 | assert self.last_needed_block <= 3, \ 56 | 'Last possible output block index is 3' 57 | 58 | self.blocks = nn.ModuleList() 59 | 60 | inception = models.inception_v3(pretrained=True) 61 | 62 | # Block 0: input to maxpool1 63 | block0 = [ 64 | inception.Conv2d_1a_3x3, 65 | inception.Conv2d_2a_3x3, 66 | inception.Conv2d_2b_3x3, 67 | nn.MaxPool2d(kernel_size=3, stride=2) 68 | ] 69 | self.blocks.append(nn.Sequential(*block0)) 70 | 71 | # Block 1: maxpool1 to maxpool2 72 | if self.last_needed_block >= 1: 73 | block1 = [ 74 | inception.Conv2d_3b_1x1, 75 | inception.Conv2d_4a_3x3, 76 | nn.MaxPool2d(kernel_size=3, stride=2) 77 | ] 78 | self.blocks.append(nn.Sequential(*block1)) 79 | 80 | # Block 2: maxpool2 to aux classifier 81 | if self.last_needed_block >= 2: 82 | block2 = [ 83 | inception.Mixed_5b, 84 | inception.Mixed_5c, 85 | inception.Mixed_5d, 86 | inception.Mixed_6a, 87 | inception.Mixed_6b, 88 | inception.Mixed_6c, 89 | inception.Mixed_6d, 90 | inception.Mixed_6e, 91 | ] 92 | self.blocks.append(nn.Sequential(*block2)) 93 | 94 | # Block 3: aux classifier to final avgpool 95 | if self.last_needed_block >= 3: 96 | block3 = [ 97 | inception.Mixed_7a, 98 | inception.Mixed_7b, 99 | inception.Mixed_7c, 100 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 101 | ] 102 | self.blocks.append(nn.Sequential(*block3)) 103 | 104 | for param in self.parameters(): 105 | param.requires_grad = requires_grad 106 | 107 | def forward(self, inp): 108 | """Get Inception feature maps 109 | 110 | Parameters 111 | ---------- 112 | inp : torch.autograd.Variable 113 | Input tensor of shape Bx3xHxW. Values are expected to be in 114 | range (0, 1) 115 | 116 | Returns 117 | ------- 118 | List of torch.autograd.Variable, corresponding to the selected output 119 | block, sorted ascending by index 120 | """ 121 | outp = [] 122 | x = inp 123 | 124 | if self.resize_input: 125 | x = F.upsample(x, size=(299, 299), mode='bilinear', align_corners=True) 126 | 127 | if self.normalize_input: 128 | x = x.clone() 129 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 130 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 131 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 132 | 133 | for idx, block in enumerate(self.blocks): 134 | x = block(x) 135 | if idx in self.output_blocks: 136 | outp.append(x) 137 | 138 | if idx == self.last_needed_block: 139 | break 140 | 141 | return outp 142 | -------------------------------------------------------------------------------- /eval_utils/inception_score.py: -------------------------------------------------------------------------------- 1 | # 参考:https://github.com/sbarratt/inception-score-pytorch/blob/master/inception_score.py 2 | import torch 3 | from torch import nn 4 | from torchvision.models.inception import inception_v3 5 | from torch.nn import functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from scipy.stats import entropy 9 | 10 | def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1): 11 | """Computes the inception score of the generated images imgs 12 | imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] 13 | cuda -- whether or not to run on GPU 14 | batch_size -- batch size for feeding into Inception v3 15 | splits -- number of splits 16 | """ 17 | N = len(imgs) 18 | 19 | assert batch_size > 0 20 | assert N > batch_size 21 | 22 | # Set up dtype 23 | if cuda: 24 | dtype = torch.cuda.FloatTensor 25 | else: 26 | if torch.cuda.is_available(): 27 | print("WARNING: You have a CUDA device, so you should probably set cuda=True") 28 | dtype = torch.FloatTensor 29 | 30 | # Set up dataloader 31 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) 32 | 33 | # Load inception model 34 | inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype) 35 | inception_model.eval() 36 | up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) 37 | def get_pred(x): 38 | if resize: 39 | x = up(x) 40 | x = inception_model(x) 41 | return F.softmax(x).data.cpu().numpy() 42 | 43 | # Get predictions 44 | preds = np.zeros((N, 1000)) 45 | 46 | for i, batch in enumerate(dataloader, 0): 47 | batch = batch.type(dtype) 48 | batchv = Variable(batch) 49 | batch_size_i = batch.size()[0] 50 | 51 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) 52 | 53 | # Now compute the mean kl-div 54 | split_scores = [] 55 | 56 | for k in range(splits): 57 | part = preds[k * (N // splits): (k+1) * (N // splits), :] 58 | py = np.mean(part, axis=0) 59 | scores = [] 60 | for i in range(part.shape[0]): 61 | pyx = part[i, :] 62 | scores.append(entropy(pyx, py)) 63 | split_scores.append(np.exp(np.mean(scores))) 64 | 65 | return np.mean(split_scores), np.std(split_scores) 66 | 67 | -------------------------------------------------------------------------------- /finetune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView/189a12bbe87227d3d127fa74abe968dd795cad39/finetune/__init__.py -------------------------------------------------------------------------------- /fp16/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from .fp16util import ( 16 | BN_convert_float, 17 | network_to_half, 18 | prep_param_lists, 19 | model_grads_to_master_grads, 20 | master_params_to_model_params, 21 | tofp16, 22 | to_python_float, 23 | clip_grad_norm, 24 | convert_module, 25 | convert_network, 26 | FP16Model, 27 | ) 28 | 29 | from .fp16 import * 30 | from .loss_scaler import * 31 | -------------------------------------------------------------------------------- /fp16/fp16util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch.autograd import Variable 19 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 20 | 21 | import mpu 22 | 23 | 24 | class tofp16(nn.Module): 25 | """ 26 | Utility module that implements:: 27 | 28 | def forward(self, input): 29 | return input.half() 30 | """ 31 | 32 | def __init__(self): 33 | super(tofp16, self).__init__() 34 | 35 | def forward(self, input): 36 | return input.half() 37 | 38 | 39 | def BN_convert_float(module): 40 | """ 41 | Utility function for network_to_half(). 42 | 43 | Retained for legacy purposes. 44 | """ 45 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: 46 | module.float() 47 | for child in module.children(): 48 | BN_convert_float(child) 49 | return module 50 | 51 | 52 | def network_to_half(network): 53 | """ 54 | Convert model to half precision in a batchnorm-safe way. 55 | 56 | Retained for legacy purposes. It is recommended to use FP16Model. 57 | """ 58 | return nn.Sequential(tofp16(), BN_convert_float(network.half())) 59 | 60 | 61 | def convert_module(module, dtype): 62 | """ 63 | Converts a module's immediate parameters and buffers to dtype. 64 | """ 65 | for param in module.parameters(recurse=False): 66 | if param is not None: 67 | if param.data.dtype.is_floating_point: 68 | param.data = param.data.to(dtype=dtype) 69 | if param._grad is not None and param._grad.data.dtype.is_floating_point: 70 | param._grad.data = param._grad.data.to(dtype=dtype) 71 | 72 | for buf in module.buffers(recurse=False): 73 | if buf is not None and buf.data.dtype.is_floating_point: 74 | buf.data = buf.data.to(dtype=dtype) 75 | 76 | 77 | def convert_network(network, dtype): 78 | """ 79 | Converts a network's parameters and buffers to dtype. 80 | """ 81 | for module in network.modules(): 82 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: 83 | continue 84 | convert_module(module, dtype) 85 | return network 86 | 87 | 88 | class FP16Model(nn.Module): 89 | """ 90 | Convert model to half precision in a batchnorm-safe way. 91 | """ 92 | 93 | def __init__(self, network): 94 | super(FP16Model, self).__init__() 95 | self.network = convert_network(network, dtype=torch.half) 96 | 97 | def forward(self, *inputs): 98 | inputs = tuple(t.half() for t in inputs) 99 | return self.network(*inputs) 100 | 101 | 102 | def backwards_debug_hook(grad): 103 | raise RuntimeError("master_params recieved a gradient in the backward pass!") 104 | 105 | def prep_param_lists(model, flat_master=False): 106 | """ 107 | Creates a list of FP32 master parameters for a given model, as in 108 | `Training Neural Networks with Mixed Precision: Real Examples`_. 109 | 110 | Args: 111 | model (torch.nn.Module): Existing Pytorch model 112 | flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. 113 | Returns: 114 | A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. 115 | 116 | Example:: 117 | 118 | model_params, master_params = prep_param_lists(model) 119 | 120 | .. warning:: 121 | Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. 122 | 123 | .. _`Training Neural Networks with Mixed Precision: Real Examples`: 124 | http://on-demand.gputechconf.com/gtc/2018/video/S81012/ 125 | """ 126 | model_params = [param for param in model.parameters() if param.requires_grad] 127 | 128 | if flat_master: 129 | # Give the user some more useful error messages 130 | try: 131 | # flatten_dense_tensors returns a contiguous flat array. 132 | # http://pytorch.org/docs/master/_modules/torch/_utils.html 133 | master_params = _flatten_dense_tensors([param.data for param in model_params]).float() 134 | except: 135 | print("Error in prep_param_lists: model may contain a mixture of parameters " 136 | "of different types. Use flat_master=False, or use F16_Optimizer.") 137 | raise 138 | master_params = torch.nn.Parameter(master_params) 139 | master_params.requires_grad = True 140 | # master_params.register_hook(backwards_debug_hook) 141 | if master_params.grad is None: 142 | master_params.grad = master_params.new(*master_params.size()) 143 | return model_params, [master_params] 144 | else: 145 | master_params = [param.clone().float().detach() for param in model_params] 146 | for param in master_params: 147 | param.requires_grad = True 148 | return model_params, master_params 149 | 150 | 151 | def model_grads_to_master_grads(model_params, master_params, flat_master=False): 152 | """ 153 | Copy model gradients to master gradients. 154 | 155 | Args: 156 | model_params: List of model parameters created by :func:`prep_param_lists`. 157 | master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. 158 | """ 159 | if flat_master: 160 | # The flattening may incur one more deep copy than is necessary. 161 | master_params[0].grad.data.copy_( 162 | _flatten_dense_tensors([p.grad.data for p in model_params])) 163 | else: 164 | for model, master in zip(model_params, master_params): 165 | if model.grad is not None: 166 | if master.grad is None: 167 | master.grad = Variable(master.data.new(*master.data.size())) 168 | master.grad.data.copy_(model.grad.data) 169 | else: 170 | master.grad = None 171 | 172 | 173 | def master_params_to_model_params(model_params, master_params, flat_master=False): 174 | """ 175 | Copy master parameters to model parameters. 176 | 177 | Args: 178 | model_params: List of model parameters created by :func:`prep_param_lists`. 179 | master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. 180 | """ 181 | if flat_master: 182 | for model, master in zip(model_params, 183 | _unflatten_dense_tensors(master_params[0].data, model_params)): 184 | model.data.copy_(master) 185 | else: 186 | for model, master in zip(model_params, master_params): 187 | model.data.copy_(master.data) 188 | 189 | # Backward compatibility fixes 190 | 191 | def to_python_float(t): 192 | if hasattr(t, 'item'): 193 | return t.item() 194 | else: 195 | return t[0] 196 | 197 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 198 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 199 | 200 | clip_grad_norm = mpu.clip_grad_norm 201 | #elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4: 202 | # clip_grad_norm = torch.nn.utils.clip_grad_norm 203 | #else: 204 | # clip_grad_norm = torch.nn.utils.clip_grad_norm_ 205 | -------------------------------------------------------------------------------- /fp16/loss_scaler.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import mpu 18 | 19 | # item() is a recent addition, so this helps with backward compatibility. 20 | def to_python_float(t): 21 | if hasattr(t, 'item'): 22 | return t.item() 23 | else: 24 | return t[0] 25 | 26 | class LossScaler: 27 | """ 28 | Class that manages a static loss scale. This class is intended to interact with 29 | :class:`FP16_Optimizer`, and should not be directly manipulated by the user. 30 | 31 | Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to 32 | :class:`FP16_Optimizer`'s constructor. 33 | 34 | Args: 35 | scale (float, optional, default=1.0): The loss scale. 36 | """ 37 | 38 | def __init__(self, scale=1): 39 | self.cur_scale = scale 40 | 41 | # `params` is a list / generator of torch.Variable 42 | def has_overflow(self, params): 43 | return False 44 | 45 | # `x` is a torch.Tensor 46 | def _has_inf_or_nan(x): 47 | return False 48 | 49 | def update_scale(self, overflow): 50 | pass 51 | 52 | @property 53 | def loss_scale(self): 54 | return self.cur_scale 55 | 56 | def scale_gradient(self, module, grad_in, grad_out): 57 | return tuple(self.loss_scale * g for g in grad_in) 58 | 59 | def backward(self, loss, retain_graph=False): 60 | scaled_loss = loss*self.loss_scale 61 | scaled_loss.backward(retain_graph=retain_graph) 62 | 63 | class DynamicLossScaler: 64 | """ 65 | Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` 66 | indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of 67 | :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` 68 | operates, because the default options can be changed using the 69 | the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. 70 | 71 | Loss scaling is designed to combat the problem of underflowing gradients encountered at long 72 | times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss 73 | scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are 74 | encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has 75 | occurred. 76 | :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, 77 | and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. 78 | If a certain number of iterations occur without overflowing gradients detected, 79 | :class:`DynamicLossScaler` increases the loss scale once more. 80 | In this way :class:`DynamicLossScaler` attempts to "ride the edge" of 81 | always using the highest loss scale possible without incurring overflow. 82 | 83 | Args: 84 | init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` 85 | scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. 86 | scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. 87 | """ 88 | 89 | def __init__(self, 90 | init_scale=2**32, 91 | scale_factor=2., 92 | scale_window=1000, 93 | min_scale=1, 94 | delayed_shift=1, 95 | consecutive_hysteresis=False): 96 | self.cur_scale = init_scale 97 | self.cur_iter = 0 98 | self.last_overflow_iter = -1 99 | self.scale_factor = scale_factor 100 | self.scale_window = scale_window 101 | self.min_scale = min_scale 102 | self.delayed_shift = delayed_shift 103 | self.cur_hysteresis = delayed_shift 104 | self.consecutive_hysteresis = consecutive_hysteresis 105 | 106 | # `params` is a list / generator of torch.Variable 107 | def has_overflow_serial(self, params): 108 | for p in params: 109 | if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): 110 | return True 111 | 112 | return False 113 | 114 | def has_overflow(self, params): 115 | overflow = self.has_overflow_serial(params) 116 | # Since each model parallel GPU carries only part of the model, 117 | # make sure overflow flag is synced across all the model parallel GPUs 118 | overflow_gpu = torch.cuda.ByteTensor([overflow]) 119 | torch.distributed.all_reduce(overflow_gpu, 120 | op=torch.distributed.ReduceOp.MAX, 121 | group=mpu.get_model_parallel_group()) 122 | overflow = overflow_gpu[0].item() 123 | return bool(overflow) 124 | 125 | 126 | # `x` is a torch.Tensor 127 | def _has_inf_or_nan(x): 128 | try: 129 | # if x is half, the .float() incurs an additional deep copy, but it's necessary if 130 | # Pytorch's .sum() creates a one-element tensor of the same type as x 131 | # (which is true for some recent version of pytorch). 132 | cpu_sum = float(x.float().sum()) 133 | # More efficient version that can be used if .sum() returns a Python scalar 134 | # cpu_sum = float(x.sum()) 135 | except RuntimeError as instance: 136 | # We want to check if inst is actually an overflow exception. 137 | # RuntimeError could come from a different error. 138 | # If so, we still want the exception to propagate. 139 | if "value cannot be converted" not in instance.args[0]: 140 | raise 141 | return True 142 | else: 143 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 144 | return True 145 | return False 146 | 147 | # `overflow` is boolean indicating whether the gradient overflowed 148 | def update_scale(self, overflow): 149 | 150 | if not hasattr(self, 'min_scale'): 151 | self.min_scale = 1 152 | if not hasattr(self, 'delayed_shift'): 153 | self.delayed_shift = 1 154 | if not hasattr(self, 'cur_hysteresis'): 155 | self.cur_hysteresis = 1 156 | if not hasattr(self, 'consecutive_hysteresis'): 157 | self.consecutive_hysteresis = True 158 | if overflow: 159 | # self.cur_scale /= self.scale_factor 160 | if self.delayed_shift == 1 or self.cur_hysteresis == 1: 161 | self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale) 162 | else: 163 | self.cur_hysteresis -= 1 164 | self.last_overflow_iter = self.cur_iter 165 | else: 166 | if self.consecutive_hysteresis: 167 | self.cur_hysteresis = self.delayed_shift 168 | if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 169 | if not self.consecutive_hysteresis: 170 | self.cur_hysteresis = self.delayed_shift 171 | self.cur_scale *= self.scale_factor 172 | self.cur_iter += 1 173 | 174 | @property 175 | def loss_scale(self): 176 | return self.cur_scale 177 | 178 | def scale_gradient(self, module, grad_in, grad_out): 179 | return tuple(self.loss_scale * g for g in grad_in) 180 | 181 | def backward(self, loss, retain_graph=False): 182 | scaled_loss = loss*self.loss_scale 183 | scaled_loss.backward(retain_graph=retain_graph) 184 | 185 | ############################################################## 186 | # Example usage below here -- assuming it's in a separate file 187 | ############################################################## 188 | """ 189 | TO-DO separate out into an example. 190 | if __name__ == "__main__": 191 | import torch 192 | from torch.autograd import Variable 193 | from dynamic_loss_scaler import DynamicLossScaler 194 | 195 | # N is batch size; D_in is input dimension; 196 | # H is hidden dimension; D_out is output dimension. 197 | N, D_in, H, D_out = 64, 1000, 100, 10 198 | 199 | # Create random Tensors to hold inputs and outputs, and wrap them in Variables. 200 | x = Variable(torch.randn(N, D_in), requires_grad=False) 201 | y = Variable(torch.randn(N, D_out), requires_grad=False) 202 | 203 | w1 = Variable(torch.randn(D_in, H), requires_grad=True) 204 | w2 = Variable(torch.randn(H, D_out), requires_grad=True) 205 | parameters = [w1, w2] 206 | 207 | learning_rate = 1e-6 208 | optimizer = torch.optim.SGD(parameters, lr=learning_rate) 209 | loss_scaler = DynamicLossScaler() 210 | 211 | for t in range(500): 212 | y_pred = x.mm(w1).clamp(min=0).mm(w2) 213 | loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale 214 | print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) 215 | print('Iter {} scaled loss: {}'.format(t, loss.data[0])) 216 | print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) 217 | 218 | # Run backprop 219 | optimizer.zero_grad() 220 | loss.backward() 221 | 222 | # Check for overflow 223 | has_overflow = DynamicLossScaler.has_overflow(parameters) 224 | 225 | # If no overflow, unscale grad and update as usual 226 | if not has_overflow: 227 | for param in parameters: 228 | param.grad.data.mul_(1. / loss_scaler.loss_scale) 229 | optimizer.step() 230 | # Otherwise, don't do anything -- ie, skip iteration 231 | else: 232 | print('OVERFLOW!') 233 | 234 | # Update loss scale for next iteration 235 | loss_scaler.update_scale(has_overflow) 236 | 237 | """ 238 | -------------------------------------------------------------------------------- /generation/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score 2 | from .magnify import magnify -------------------------------------------------------------------------------- /generation/magnify.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : magnify.py 4 | @Time : 2021/01/14 00:41:40 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from .sampling import filling_sequence 20 | 21 | 22 | def magnify(model, tokenizer, tokens_list, text_token_list, args): 23 | # 32 * 32 to 4 16 * 16 24 | s = int(math.sqrt(len(tokens_list)+ 1e-6)) 25 | assert s == 32 26 | code = tokens_list.view(s, s) 27 | 28 | midfix = torch.tensor([tokenizer['[EOI1]'], tokenizer['[ROI2]'], tokenizer['[POS0]'], tokenizer['[BASE]'], tokenizer['[BOI2]']], device=code.device) 29 | 30 | magnified_code = code.new_zeros((s * 2, s * 2), dtype=torch.long) - 1 31 | 32 | windows = [(0, 0, 18), (0, 1, 30), (0, 2, 30), (1, 1, 30), (1, 0, 30), (1, 2, 30), (2, 0, 32), (2, 1, 32), (2, 2, 32)] 33 | for i, j, line in tqdm(windows): 34 | code_part = code[8 * i: 8 * (i+2), 8 * j: 8 * (j+2)].reshape(-1) 35 | 36 | magnified_code_part = magnified_code[16 * i: 16 * i + line, 16 * j: 16 * (j+2)].reshape(-1) 37 | context_tokens_tensor = torch.cat([text_token_list, code_part, midfix], dim=0) 38 | context_len = len(context_tokens_tensor) 39 | seq = torch.cat([context_tokens_tensor, magnified_code_part], dim=0) 40 | 41 | magnified_code_part_completed = filling_sequence(model, seq, args, invalid_slices=[slice(tokenizer.img_tokenizer.num_tokens, None)]) 42 | magnified_code[16 * i: 16 * i + line, 16 * j: 16 * (j+2)] = magnified_code_part_completed[0, context_len:].view(line, 32) 43 | return magnified_code.view(1, s * s * 4) 44 | -------------------------------------------------------------------------------- /generation/sampling.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : sampling.py 4 | @Time : 2021/01/13 19:52:12 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | from pretrain_gpt2 import get_masks_and_position_ids 21 | from data_utils import get_tokenizer 22 | 23 | 24 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 25 | # This function has been mostly taken from huggingface conversational ai code at 26 | # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 27 | 28 | if top_k > 0: 29 | # Remove all tokens with a probability less than the last token of the top-k 30 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 31 | logits[indices_to_remove] = filter_value 32 | 33 | if top_p > 0.0: 34 | # convert to 1D 35 | logits = logits.view(logits.size()[1]).contiguous() 36 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 37 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 38 | 39 | # Remove tokens with cumulative probability above the threshold 40 | sorted_indices_to_remove = cumulative_probs > top_p 41 | # Shift the indices to the right to keep also the first token above the threshold 42 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 43 | sorted_indices_to_remove[..., 0] = 0 44 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 45 | logits[indices_to_remove] = filter_value 46 | # going back to 2D 47 | logits = logits.view(1, -1).contiguous() 48 | 49 | return logits 50 | 51 | def get_batch(context_tokens, device, args): 52 | tokens = context_tokens 53 | if len(tokens.shape) == 1: 54 | tokens = tokens.unsqueeze(0).contiguous() 55 | else: 56 | tokens = tokens.view(tokens.shape[0], -1).contiguous() 57 | tokens = tokens.to(device) 58 | 59 | # Get the masks and postition ids. 60 | attention_mask, loss_mask, position_ids = get_masks_and_position_ids( 61 | tokens) 62 | return tokens, attention_mask, position_ids 63 | 64 | def filling_sequence( 65 | model, 66 | seq, 67 | args, 68 | mems=None, 69 | invalid_slices=[], 70 | **kwargs): 71 | ''' 72 | seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1] 73 | context_length: first non(-1)s 74 | ''' 75 | tokenizer = get_tokenizer() 76 | device = seq.device 77 | assert len(seq.shape) == 1 78 | out_seq_length = len(seq) 79 | # building the initial tokens, attention_mask, and position_ids 80 | context_length = 0 81 | offset = 100000 82 | 83 | invalid_slices = [slice(0, tokenizer.img_tokenizer.num_tokens)] 84 | 85 | while seq[context_length] >= 0: 86 | # change what to generate 87 | if seq[context_length] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]: 88 | invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] 89 | elif seq[context_length] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]: 90 | invalid_slices = [ 91 | slice(0, tokenizer.img_tokenizer.num_tokens), 92 | slice(tokenizer.img_tokenizer.num_tokens + tokenizer.txt_tokenizer.num_tokens, None)] 93 | 94 | if seq[context_length] == tokenizer['[ROI2]']: 95 | offset = context_length 96 | context_length += 1 97 | tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args) 98 | 99 | counter = context_length - 1 # == len(tokens) - 1 100 | index = 0 # len(mems) 101 | if mems is None: 102 | mems = [] 103 | score = [0] # sum log likelihood for beams 104 | 105 | if args.is_sparse == 2: 106 | tokenizer = get_tokenizer() 107 | img_txt_sep = tokenizer.img_tokenizer.num_tokens 108 | img_indices_bool = (tokens < img_txt_sep) 109 | txt_indices_bool = (~img_indices_bool) 110 | elif args.is_sparse == 0: 111 | txt_indices_bool = img_indices_bool = None 112 | else: 113 | raise ValueError('set is_sparse==2 for inference.') 114 | 115 | while counter < (out_seq_length - 1): 116 | # Now, we want to generate seq[counter + 1] 117 | # token[:, index: counter+1] are just added. 118 | 119 | if seq[counter + 1] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]: 120 | invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] 121 | elif seq[counter + 1] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]: 122 | invalid_slices = [ 123 | slice(0, tokenizer.img_tokenizer.num_tokens), 124 | slice(tokenizer.img_tokenizer.num_tokens + tokenizer.txt_tokenizer.num_tokens, None)] 125 | 126 | if index == 0: # first 127 | position_ids[position_ids > offset] -= offset 128 | logits, *mems = model(tokens, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse=args.is_sparse, *mems) 129 | index = counter 130 | elif seq[counter + 1] >= 0: # provided 131 | if seq[counter + 1] == tokenizer['[ROI2]']: 132 | offset = counter + 1 133 | tokens, mems, score = shrink_beams(tokens, mems, 1, score) 134 | nb = 1 135 | counter += 1 136 | tokens = torch.cat((tokens, seq[counter: counter+1].expand(tokens.shape[0], 1)), dim=1) 137 | if args.is_sparse == 2: 138 | img_indices_bool = (tokens < img_txt_sep) 139 | txt_indices_bool = (~img_indices_bool) 140 | continue 141 | else: 142 | assert tokens.shape[1] == counter + 1 143 | position_ids = torch.arange(index, counter + 1, dtype=torch.long, device=tokens.device).unsqueeze(0) 144 | position_ids[position_ids > offset] -= offset 145 | # TODO each time, the feed input cannot be too long (window size), or it will have a discrepcy from sparse training, but this is not very important. 146 | tokens, mems, score = shrink_beams(tokens, mems, -seq[counter + 1], score) 147 | logits, *mems = model(tokens[:, index: ], 148 | position_ids, 149 | 0, # rebuild in transformers (sep version) 150 | txt_indices_bool, img_indices_bool, args.is_sparse, 151 | *mems) 152 | index = counter 153 | nb = -seq[counter + 1] 154 | counter += 1 155 | index += 1 156 | 157 | logits = logits[:, -1] # [batch size, vocab size] 158 | 159 | temp = args.temperature 160 | # TODO since the temperature is crucial, how can we find a good setting? 161 | logits /= temp 162 | for invalid_slice in invalid_slices: # forbide to generate other tokens 163 | logits[..., invalid_slice] = -float('Inf') 164 | logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) 165 | log_probs = F.softmax(logits, dim=-1) 166 | 167 | # expand beams 168 | if nb > 1 and tokens.shape[0] == 1: # 1->nb 169 | tokens = tokens.expand(nb, -1).contiguous() 170 | mems = [mem.expand(nb, -1, -1) for mem in mems] 171 | prev = torch.multinomial(log_probs, num_samples=nb, replacement=True) 172 | score = torch.log(torch.gather(log_probs, dim=1, index=prev)[0]).tolist() 173 | else: # nb -> nb 174 | assert tokens.shape[0] == nb 175 | prev = torch.multinomial(log_probs, num_samples=1) 176 | score_plus = torch.log(torch.gather(log_probs, dim=1, index=prev)[:, 0]) 177 | for idx in range(nb): 178 | score[idx] += score_plus[idx] 179 | 180 | tokens = torch.cat((tokens, prev.view(tokens.shape[0], 1)), dim=1) 181 | if args.is_sparse == 2: # update indices 182 | img_indices_bool = (tokens < img_txt_sep) 183 | txt_indices_bool = (~img_indices_bool) 184 | 185 | output_tokens_list = tokens.view(tokens.shape[0], -1).contiguous() 186 | return output_tokens_list 187 | 188 | def shrink_beams(tokens, mems, nb, score): 189 | # beam search is a failed attempt, will be removed soon... 190 | if tokens.shape[0] == nb: 191 | return tokens, mems, score 192 | # shrink 193 | maximum = max(score) 194 | max_idx = score.index(maximum) 195 | tokens = tokens[max_idx].unsqueeze(0) 196 | score = [0] 197 | new_mems = [mem[max_idx: max_idx + 1] for mem in mems] 198 | return tokens, new_mems, score 199 | 200 | def add_interlacing_beam_marks(seq, nb=12, period=3000): 201 | assert isinstance(seq, list) or len(seq.shape) == 1 202 | blk_cnt = 0 203 | for i in range(len(seq)): 204 | if seq[i] == -1: 205 | blk_cnt += 1 206 | seq[i] = -nb 207 | if blk_cnt == period: 208 | nb += (nb % 2) * 2 - 1 209 | blk_cnt = 0 210 | else: 211 | blk_cnt = 0 212 | 213 | 214 | def inverse_prompt_score(model, seq, args): 215 | tokenizer = get_tokenizer() 216 | device = seq.device 217 | assert len(seq.shape) == 2 218 | 219 | botext = 2 + 1024 + 1 220 | assert tokenizer['[ROI1]'] == seq[0][botext] 221 | 222 | tokens, attention_mask, position_ids = get_batch(seq, device, args) 223 | logits, *mems = model(tokens, position_ids, attention_mask, None, None, is_sparse=args.is_sparse) 224 | logits[..., :tokenizer.img_tokenizer.num_tokens] = -float('Inf') 225 | log_probs = torch.log(F.softmax(logits, dim=-1)) 226 | 227 | pred = log_probs[:, botext:-1, :] 228 | target = tokens[:, botext+1:].unsqueeze(-1) 229 | scores = torch.gather(pred, dim=2, index=target).squeeze(-1).sum(dim=-1) 230 | return scores 231 | -------------------------------------------------------------------------------- /learning_rates.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch DataLoader for TFRecords""" 16 | 17 | import torch 18 | from torch.optim.lr_scheduler import _LRScheduler 19 | import math 20 | 21 | class AnnealingLR(_LRScheduler): 22 | """Anneals the learning rate from start to zero along a cosine curve.""" 23 | 24 | DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None'] 25 | 26 | def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5): 27 | assert warmup_iter <= num_iters 28 | self.optimizer = optimizer 29 | self.start_lr = start_lr 30 | self.warmup_iter = warmup_iter 31 | self.num_iters = last_iter + 1 32 | self.end_iter = num_iters 33 | self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None 34 | self.decay_ratio = 1 / decay_ratio 35 | self.step(self.num_iters) 36 | if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: 37 | print(f'learning rate decaying style {self.decay_style}, ratio {self.decay_ratio}') 38 | 39 | def get_lr(self): 40 | # https://openreview.net/pdf?id=BJYwwY9ll pg. 4 41 | if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: 42 | return float(self.start_lr) * self.num_iters / self.warmup_iter 43 | else: 44 | if self.decay_style == self.DECAY_STYLES[0]: 45 | return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter) 46 | elif self.decay_style == self.DECAY_STYLES[1]: 47 | decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / self.end_iter) 48 | return self.start_lr / self.decay_ratio * ( 49 | (math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1) 50 | elif self.decay_style == self.DECAY_STYLES[2]: 51 | #TODO: implement exponential decay 52 | return self.start_lr 53 | else: 54 | return self.start_lr 55 | 56 | def step(self, step_num=None): 57 | if step_num is None: 58 | step_num = self.num_iters + 1 59 | self.num_iters = step_num 60 | new_lr = self.get_lr() 61 | for group in self.optimizer.param_groups: 62 | group['lr'] = new_lr 63 | 64 | def state_dict(self): 65 | sd = { 66 | # 'start_lr': self.start_lr, 67 | 'warmup_iter': self.warmup_iter, 68 | 'num_iters': self.num_iters, 69 | 'decay_style': self.decay_style, 70 | 'end_iter': self.end_iter, 71 | 'decay_ratio': self.decay_ratio 72 | } 73 | return sd 74 | 75 | def load_state_dict(self, sd): 76 | # self.start_lr = sd['start_lr'] 77 | self.warmup_iter = sd['warmup_iter'] 78 | self.num_iters = sd['num_iters'] 79 | # self.end_iter = sd['end_iter'] 80 | self.decay_style = sd['decay_style'] 81 | if 'decay_ratio' in sd: 82 | self.decay_ratio = sd['decay_ratio'] 83 | self.step(self.num_iters) 84 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .distributed import * 17 | from .gpt2_modeling import gpt2_get_params_for_weight_decay_optimization 18 | from .gpt2_modeling import GPT2Model 19 | 20 | -------------------------------------------------------------------------------- /model/distributed.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 18 | import torch.distributed as dist 19 | from torch.nn.modules import Module 20 | from torch.autograd import Variable 21 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 22 | 23 | import mpu 24 | 25 | 26 | class PyTorchDistributedDataParallel(DDP): 27 | def state_dict(self, destination=None, prefix='', keep_vars=False): 28 | sd = self.module.state_dict(destination, prefix, keep_vars) 29 | return sd 30 | 31 | def load_state_dict(self, state_dict, strict=True): 32 | self.module.load_state_dict(state_dict, strict=strict) 33 | 34 | 35 | class DistributedDataParallel(Module): 36 | 37 | def __init__(self, module): 38 | super(DistributedDataParallel, self).__init__() 39 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 40 | 41 | self.module = module 42 | self.data_parallel_group = mpu.get_data_parallel_group() 43 | src_rank = mpu.get_model_parallel_rank() 44 | for p in self.module.parameters(): 45 | if torch.is_tensor(p): 46 | dist.broadcast(p, src_rank, group=self.data_parallel_group) 47 | 48 | def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False): 49 | if(self.needs_reduction): 50 | self.needs_reduction = False 51 | buckets = {} 52 | for name, param in self.module.named_parameters(): 53 | if param.requires_grad and param.grad is not None: 54 | tp = (param.data.type()) 55 | if tp not in buckets: 56 | buckets[tp] = [] 57 | buckets[tp].append(param) 58 | if self.warn_on_half: 59 | if torch.cuda.HalfTensor in buckets: 60 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 61 | " It is recommended to use the NCCL backend in this case.") 62 | self.warn_on_half = False 63 | for tp in buckets: 64 | bucket = buckets[tp] 65 | grads = [param.grad.data for param in bucket] 66 | coalesced = _flatten_dense_tensors(grads) 67 | if fp32_allreduce: 68 | coalesced = coalesced.float() 69 | if not no_scale and not reduce_after: 70 | coalesced /= dist.get_world_size(group=self.data_parallel_group) 71 | dist.all_reduce(coalesced, group=self.data_parallel_group) 72 | torch.cuda.synchronize() 73 | if not no_scale and reduce_after: 74 | coalesced /= dist.get_world_size(group=self.data_parallel_group) 75 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 76 | buf.copy_(synced) 77 | self.hook_handles = [] 78 | self.hooks = [] 79 | for param in list(self.module.parameters()): 80 | def allreduce_hook(*unused): 81 | Variable._execution_engine.queue_callback(allreduce_params) 82 | # handle = param.register_hook(allreduce_hook) 83 | #self.hooks.append(allreduce_hook) 84 | #self.hook_handles.append(handle) 85 | self.allreduce_params = allreduce_params 86 | 87 | def forward(self, *inputs, **kwargs): 88 | self.needs_reduction = True 89 | return self.module(*inputs, **kwargs) 90 | 91 | def state_dict(self, destination=None, prefix='', keep_vars=False): 92 | #[h.remove() for h in self.hook_handles] 93 | sd = self.module.state_dict(destination, prefix, keep_vars) 94 | # for handle, hook in zip(self.hook_handles, self.hooks): 95 | # d = handle.hooks_dict_ref() 96 | # d[handle.id] = hook 97 | 98 | return sd 99 | 100 | def load_state_dict(self, state_dict, strict=True): 101 | self.module.load_state_dict(state_dict, strict=strict) 102 | 103 | ''' 104 | def _sync_buffers(self): 105 | buffers = list(self.module._all_buffers()) 106 | if len(buffers) > 0: 107 | # cross-node buffer sync 108 | flat_buffers = _flatten_dense_tensors(buffers) 109 | dist.broadcast(flat_buffers, 0) 110 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): 111 | buf.copy_(synced) 112 | def train(self, mode=True): 113 | # Clear NCCL communicator and CUDA event cache of the default group ID, 114 | # These cache will be recreated at the later call. This is currently a 115 | # work-around for a potential NCCL deadlock. 116 | if dist._backend == dist.dist_backend.NCCL: 117 | dist._clear_group_cache() 118 | super(DistributedDataParallel, self).train(mode) 119 | self.module.train(mode) 120 | ''' 121 | 122 | -------------------------------------------------------------------------------- /model/gpt2_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """GPT-2 model.""" 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | import mpu 22 | 23 | 24 | def init_method_normal(std=0.02): 25 | """Init method based on normal distribution. 26 | 27 | This is only used for embeddings. The transformer has its 28 | own initializer. 29 | """ 30 | def init_(tensor): 31 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 32 | return init_ 33 | 34 | 35 | def gpt2_get_params_for_weight_decay_optimization(module): 36 | 37 | weight_decay_params = {'params': []} 38 | no_weight_decay_params = {'params': [], 'weight_decay': 0.0} 39 | for module_ in module.modules(): 40 | if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)): 41 | no_weight_decay_params['params'].extend( 42 | [p for p in list(module_._parameters.values()) 43 | if p is not None]) 44 | else: 45 | weight_decay_params['params'].extend( 46 | [p for n, p in list(module_._parameters.items()) 47 | if p is not None and n != 'bias']) 48 | no_weight_decay_params['params'].extend( 49 | [p for n, p in list(module_._parameters.items()) 50 | if p is not None and n == 'bias']) 51 | 52 | return weight_decay_params, no_weight_decay_params 53 | 54 | 55 | class GPT2Model(torch.nn.Module): 56 | """GPT-2 Language model. 57 | 58 | The output of the forward method are the logits (parallel or 59 | serial depending on the `parallel_output` flag. 60 | """ 61 | 62 | def __init__(self, 63 | num_layers, 64 | vocab_size, 65 | hidden_size, 66 | num_attention_heads, 67 | embedding_dropout_prob, 68 | attention_dropout_prob, 69 | output_dropout_prob, 70 | max_sequence_length, 71 | max_memory_length, 72 | checkpoint_activations, 73 | checkpoint_num_layers=1, 74 | parallel_output=True, 75 | query_window=128, 76 | key_window_times=6, 77 | num_pivot=768 78 | ): 79 | 80 | super(GPT2Model, self).__init__() 81 | 82 | self.parallel_output = parallel_output 83 | 84 | init_method = init_method_normal(std=0.02) 85 | 86 | # Word embeddings (parallel). 87 | self.word_embeddings = mpu.VocabParallelEmbedding( 88 | vocab_size, hidden_size, init_method=init_method) 89 | 90 | # Transformer 91 | self.transformer = mpu.GPT2ParallelTransformer(num_layers, 92 | hidden_size, 93 | num_attention_heads, 94 | max_sequence_length, 95 | max_memory_length, 96 | embedding_dropout_prob, 97 | attention_dropout_prob, 98 | output_dropout_prob, 99 | checkpoint_activations, 100 | checkpoint_num_layers, 101 | query_window=query_window, 102 | key_window_times=key_window_times, 103 | num_pivot=num_pivot 104 | ) 105 | 106 | def forward(self, input_ids, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems): 107 | # Embeddings. 108 | words_embeddings = self.word_embeddings(input_ids) 109 | embeddings = words_embeddings 110 | 111 | # Transformer. 112 | transformer_output = self.transformer(embeddings, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems) 113 | logits, *hidden_layers = transformer_output 114 | # Parallel logits. 115 | logits_parallel = mpu.copy_to_model_parallel_region( 116 | logits) 117 | logits_parallel = F.linear(logits_parallel, 118 | self.word_embeddings.weight) 119 | 120 | if self.parallel_output: 121 | return (logits_parallel, *hidden_layers) 122 | 123 | return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers) 124 | -------------------------------------------------------------------------------- /mpu/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Model parallel utility interface.""" 17 | 18 | from .cross_entropy import vocab_parallel_cross_entropy 19 | 20 | from .data import broadcast_data 21 | 22 | from .grads import clip_grad_norm 23 | 24 | from .initialize import destroy_model_parallel 25 | from .initialize import get_data_parallel_group 26 | from .initialize import get_data_parallel_rank 27 | from .initialize import get_data_parallel_world_size 28 | from .initialize import get_model_parallel_group 29 | from .initialize import get_model_parallel_rank 30 | from .initialize import get_model_parallel_src_rank 31 | from .initialize import get_model_parallel_world_size 32 | from .initialize import initialize_model_parallel 33 | from .initialize import model_parallel_is_initialized 34 | 35 | from .layers import ColumnParallelLinear 36 | from .layers import ParallelEmbedding 37 | from .layers import RowParallelLinear 38 | from .layers import VocabParallelEmbedding 39 | 40 | from .mappings import copy_to_model_parallel_region 41 | from .mappings import gather_from_model_parallel_region 42 | from .mappings import reduce_from_model_parallel_region 43 | from .mappings import scatter_to_model_parallel_region 44 | 45 | from .random import checkpoint 46 | from .random import partition_activations_in_checkpoint 47 | from .random import get_cuda_rng_tracker 48 | from .random import model_parallel_cuda_manual_seed 49 | 50 | from .sparse_transformer import GPT2ParallelTransformer 51 | from .sparse_transformer import LayerNorm 52 | -------------------------------------------------------------------------------- /mpu/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | from .initialize import get_model_parallel_group 20 | from .initialize import get_model_parallel_rank 21 | from .initialize import get_model_parallel_world_size 22 | from .utils import VocabUtility 23 | 24 | 25 | class _VocabParallelCrossEntropy(torch.autograd.Function): 26 | 27 | @staticmethod 28 | def forward(ctx, vocab_parallel_logits, target): 29 | 30 | # Copy so the input remains unchanged. 31 | logits = vocab_parallel_logits.clone() 32 | # Maximum value along vocab dimension across all GPUs. 33 | logits_max = torch.max(logits, dim=-1)[0] 34 | torch.distributed.all_reduce(logits_max, 35 | op=torch.distributed.ReduceOp.MAX, 36 | group=get_model_parallel_group()) 37 | # Subtract the maximum value. 38 | logits.sub_(logits_max.unsqueeze(dim=-1)) 39 | # Sum of exponential of logits along vocab dimension across all GPUs. 40 | exp_logits = logits.exp() 41 | sum_exp_logits = exp_logits.sum(dim=-1) 42 | torch.distributed.all_reduce(sum_exp_logits, 43 | op=torch.distributed.ReduceOp.SUM, 44 | group=get_model_parallel_group()) 45 | 46 | # Get the partition's vocab indecies 47 | get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size 48 | partition_vocab_size = vocab_parallel_logits.size()[-1] 49 | rank = get_model_parallel_rank() 50 | world_size = get_model_parallel_world_size() 51 | vocab_start_index, vocab_end_index = get_vocab_range( 52 | partition_vocab_size, rank, world_size) 53 | 54 | # Create a mask of valid vocab ids (1 means it needs to be masked). 55 | target_mask = (target < vocab_start_index) | (target >= vocab_end_index) 56 | masked_target = target.clone() - vocab_start_index 57 | masked_target[target_mask] = 0 58 | 59 | # Get predicted-logits = logits[target]. 60 | # For Simplicity, we convert logits to a 2-D tensor with size 61 | # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. 62 | logits_2d = logits.view(-1, partition_vocab_size) 63 | masked_target_1d = masked_target.view(-1) 64 | arange_1d = torch.arange(start=0, end=logits_2d.size()[0], 65 | device=logits_2d.device) 66 | predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] 67 | predicted_logits = predicted_logits_1d.view_as(target) 68 | predicted_logits[target_mask] = 0.0 69 | # All reduce is needed to get the chunks from other GPUs. 70 | torch.distributed.all_reduce(predicted_logits, 71 | op=torch.distributed.ReduceOp.SUM, 72 | group=get_model_parallel_group()) 73 | 74 | # Loss = log(sum(exp(logits))) - predicted-logit. 75 | loss = torch.log(sum_exp_logits) - predicted_logits 76 | 77 | # Store softmax, target-mask and masked-target for backward pass. 78 | exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) 79 | ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) 80 | 81 | return loss 82 | 83 | @staticmethod 84 | def backward(ctx, grad_output): 85 | 86 | # Retreive tensors from the forward path. 87 | softmax, target_mask, masked_target_1d = ctx.saved_tensors 88 | 89 | # All the inputs have softmax as thier gradient. 90 | grad_input = softmax 91 | # For simplicity, work with the 2D gradient. 92 | partition_vocab_size = softmax.size()[-1] 93 | grad_2d = grad_input.view(-1, partition_vocab_size) 94 | 95 | # Add the gradient from matching classes. 96 | arange_1d = torch.arange(start=0, end=grad_2d.size()[0], 97 | device=grad_2d.device) 98 | grad_2d[arange_1d, masked_target_1d] -= ( 99 | 1.0 - target_mask.view(-1).float()) 100 | 101 | # Finally elementwise multiplication with the output gradients. 102 | grad_input.mul_(grad_output.unsqueeze(dim=-1)) 103 | 104 | return grad_input, None 105 | 106 | 107 | def vocab_parallel_cross_entropy(vocab_parallel_logits, target): 108 | """Helper function for the cross entropy.""" 109 | return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) 110 | -------------------------------------------------------------------------------- /mpu/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from .initialize import get_model_parallel_group 19 | from .initialize import get_model_parallel_rank 20 | from .initialize import get_model_parallel_src_rank 21 | 22 | 23 | _MAX_DATA_DIM = 5 24 | 25 | 26 | def _check_data_types(keys, data, target_dtype): 27 | """Check that all the keys have the same target data type.""" 28 | for key in keys: 29 | assert data[key].dtype == target_dtype, '{} has data type {} which '\ 30 | 'is different than {}'.format(key, data[key].dtype, target_dtype) 31 | 32 | 33 | def _build_key_size_numel_dictionaries(keys, data): 34 | """Build the size on rank 0 and broadcast.""" 35 | max_dim = _MAX_DATA_DIM 36 | sizes = [0 for _ in range(max_dim) for _ in keys] 37 | 38 | # Pack the sizes on rank zero. 39 | if get_model_parallel_rank() == 0: 40 | offset = 0 41 | for key in keys: 42 | assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' 43 | size = data[key].size() 44 | for i, s in enumerate(size): 45 | sizes[i + offset] = s 46 | offset += max_dim 47 | 48 | # Move to GPU and broadcast. 49 | sizes_cuda = torch.cuda.LongTensor(sizes) 50 | torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(), 51 | group=get_model_parallel_group()) 52 | 53 | # Move back to cpu and unpack. 54 | sizes_cpu = sizes_cuda.cpu() 55 | key_size = {} 56 | key_numel = {} 57 | total_numel = 0 58 | offset = 0 59 | for key in keys: 60 | i = 0 61 | size = [] 62 | numel = 1 63 | while sizes_cpu[offset + i] > 0: 64 | this_size = sizes_cpu[offset + i] 65 | size.append(this_size) 66 | numel *= this_size 67 | i += 1 68 | key_size[key] = size 69 | key_numel[key] = numel 70 | total_numel += numel 71 | offset += max_dim 72 | 73 | return key_size, key_numel, total_numel 74 | 75 | 76 | def broadcast_data(keys, data, datatype): 77 | """Broadcast data from rank zero of each model parallel group to the 78 | members of the same model parallel group. 79 | 80 | Arguments: 81 | keys: list of keys in the data disctionary to be broadcasted 82 | data: data dictionary of string keys and cpu tensor values. 83 | datatype: torch data type of all tensors in data associated 84 | with keys. 85 | """ 86 | # Build (key, size) and (key, number of elements) dictionaries along 87 | # with the total number of elements on all ranks. 88 | key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, 89 | data) 90 | 91 | # Pack on rank zero. 92 | if get_model_parallel_rank() == 0: 93 | # Check that all keys have the same data type. 94 | _check_data_types(keys, data, datatype) 95 | # Flatten the data associated with the keys 96 | flatten_data = torch.cat( 97 | [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() 98 | else: 99 | flatten_data = torch.empty(total_numel, 100 | device=torch.cuda.current_device(), 101 | dtype=datatype) 102 | 103 | # Boradcast 104 | torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(), 105 | group=get_model_parallel_group()) 106 | 107 | # Unpack 108 | output = {} 109 | offset = 0 110 | for key in keys: 111 | size = key_size[key] 112 | numel = key_numel[key] 113 | output[key] = flatten_data.narrow(0, offset, numel).view(size) 114 | offset += numel 115 | 116 | return output 117 | -------------------------------------------------------------------------------- /mpu/grads.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Parts of the code here are adapted from PyTorch 18 | # repo: https://github.com/pytorch/pytorch 19 | 20 | 21 | import torch 22 | from torch._six import inf 23 | 24 | from .initialize import get_model_parallel_group 25 | from .initialize import get_model_parallel_rank 26 | 27 | 28 | def clip_grad_norm(parameters, max_norm, norm_type=2): 29 | """Clips gradient norm of an iterable of parameters. 30 | 31 | This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and 32 | added functionality to handle model parallel parameters. Note that 33 | the gradients are modified in place. 34 | 35 | Arguments: 36 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 37 | single Tensor that will have gradients normalized 38 | max_norm (float or int): max norm of the gradients 39 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 40 | infinity norm. 41 | 42 | Returns: 43 | Total norm of the parameters (viewed as a single vector). 44 | """ 45 | if isinstance(parameters, torch.Tensor): 46 | parameters = [parameters] 47 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 48 | max_norm = float(max_norm) 49 | norm_type = float(norm_type) 50 | if norm_type == inf: 51 | total_norm = max(p.grad.data.abs().max() for p in parameters) 52 | total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) 53 | # Take max across all GPUs. 54 | torch.distributed.all_reduce(total_norm_cuda, 55 | op=torch.distributed.ReduceOp.MAX, 56 | group=get_model_parallel_group()) 57 | total_norm = total_norm_cuda[0].item() 58 | else: 59 | total_norm = 0 60 | for p in parameters: 61 | if p.model_parallel or (get_model_parallel_rank() == 0): 62 | param_norm = p.grad.data.norm(norm_type) 63 | total_norm += param_norm.item() ** norm_type 64 | # Sum across all model parallel GPUs. 65 | total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) 66 | torch.distributed.all_reduce(total_norm_cuda, 67 | op=torch.distributed.ReduceOp.SUM, 68 | group=get_model_parallel_group()) 69 | total_norm = total_norm_cuda[0].item() ** (1. / norm_type) 70 | clip_coef = max_norm / (total_norm + 1e-6) 71 | if clip_coef < 1: 72 | for p in parameters: 73 | p.grad.data.mul_(clip_coef) 74 | return total_norm 75 | -------------------------------------------------------------------------------- /mpu/initialize.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """Model and data parallel groups.""" 18 | 19 | import torch 20 | 21 | from .utils import ensure_divisibility 22 | 23 | 24 | # Model parallel group that the current rank belongs to. 25 | _MODEL_PARALLEL_GROUP = None 26 | # Data parallel group that the current rank belongs to. 27 | _DATA_PARALLEL_GROUP = None 28 | 29 | 30 | def initialize_model_parallel(model_parallel_size_): 31 | """ 32 | Initialize model data parallel groups. 33 | 34 | Arguments: 35 | model_parallel_size: number of GPUs used to parallelize model. 36 | 37 | Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we 38 | use 2 GPUs to parallelize the model. The present function will 39 | create 4 model parallel groups and 2 data parallel grous as: 40 | 4 model parallel groups: 41 | [g0, g1], [g2, g3], [g4, g5], [g6, g7] 42 | 2 data parallel groups: 43 | [g0, g2, g4, g6], [g1, g3, g5, g7] 44 | Note that for efficiency, the caller should make sure adjacent ranks 45 | are on the same DGX box. For example if we are using 2 DGX-1 boxes 46 | with a total of 16 GPUs, rank 0 to 7 belong to the first box and 47 | ranks 8 to 15 belong to the second box. 48 | """ 49 | if torch.distributed.get_rank() == 0: 50 | print('> initializing model parallel with size {}'.format( 51 | model_parallel_size_)) 52 | # Get world size and rank. Ensure some consistencies. 53 | assert torch.distributed.is_initialized() 54 | world_size = torch.distributed.get_world_size() 55 | model_parallel_size = min(model_parallel_size_, world_size) 56 | ensure_divisibility(world_size, model_parallel_size) 57 | rank = torch.distributed.get_rank() 58 | 59 | # Build the data parallel groups. 60 | global _DATA_PARALLEL_GROUP 61 | assert _DATA_PARALLEL_GROUP is None, \ 62 | 'data parallel group is already initialized' 63 | for i in range(model_parallel_size): 64 | ranks = range(i, world_size, model_parallel_size) 65 | group = torch.distributed.new_group(ranks) 66 | if i == (rank % model_parallel_size): 67 | _DATA_PARALLEL_GROUP = group 68 | 69 | # Build the model parallel groups. 70 | global _MODEL_PARALLEL_GROUP 71 | assert _MODEL_PARALLEL_GROUP is None, \ 72 | 'model parallel group is already initialized' 73 | for i in range(world_size // model_parallel_size): 74 | ranks = range(i * model_parallel_size, 75 | (i + 1) * model_parallel_size) 76 | group = torch.distributed.new_group(ranks) 77 | if i == (rank // model_parallel_size): 78 | _MODEL_PARALLEL_GROUP = group 79 | 80 | 81 | def model_parallel_is_initialized(): 82 | """Check if model and data parallel groups are initialized.""" 83 | if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: 84 | return False 85 | return True 86 | 87 | 88 | def get_model_parallel_group(): 89 | """Get the model parallel group the caller rank belongs to.""" 90 | assert _MODEL_PARALLEL_GROUP is not None, \ 91 | 'model parallel group is not initialized' 92 | return _MODEL_PARALLEL_GROUP 93 | 94 | 95 | def get_data_parallel_group(): 96 | """Get the data parallel group the caller rank belongs to.""" 97 | assert _DATA_PARALLEL_GROUP is not None, \ 98 | 'data parallel group is not initialized' 99 | return _DATA_PARALLEL_GROUP 100 | 101 | 102 | def get_model_parallel_world_size(): 103 | """Return world size for the model parallel group.""" 104 | return torch.distributed.get_world_size(group=get_model_parallel_group()) 105 | 106 | 107 | def get_model_parallel_rank(): 108 | """Return my rank for the model parallel group.""" 109 | return torch.distributed.get_rank(group=get_model_parallel_group()) 110 | 111 | 112 | def get_model_parallel_src_rank(): 113 | """Calculate the global rank corresponding to a local rank zeor 114 | in the model parallel group.""" 115 | global_rank = torch.distributed.get_rank() 116 | local_world_size = get_model_parallel_world_size() 117 | return (global_rank // local_world_size) * local_world_size 118 | 119 | 120 | def get_data_parallel_world_size(): 121 | """Return world size for the data parallel group.""" 122 | return torch.distributed.get_world_size(group=get_data_parallel_group()) 123 | 124 | 125 | def get_data_parallel_rank(): 126 | """Return my rank for the data parallel group.""" 127 | return torch.distributed.get_rank(group=get_data_parallel_group()) 128 | 129 | 130 | def destroy_model_parallel(): 131 | """Set the groups to none.""" 132 | global _MODEL_PARALLEL_GROUP 133 | _MODEL_PARALLEL_GROUP = None 134 | global _DATA_PARALLEL_GROUP 135 | _DATA_PARALLEL_GROUP = None 136 | -------------------------------------------------------------------------------- /mpu/mappings.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from .initialize import get_model_parallel_group 19 | from .utils import split_tensor_along_last_dim 20 | 21 | 22 | def _reduce(input_): 23 | """All-reduce the the input tensor across model parallel group.""" 24 | group = get_model_parallel_group() 25 | 26 | # Bypass the function if we are using only 1 GPU. 27 | if torch.distributed.get_world_size(group=group) == 1: 28 | return input_ 29 | 30 | # All-reduce. 31 | torch.distributed.all_reduce(input_, group=group) 32 | 33 | return input_ 34 | 35 | 36 | def _split(input_): 37 | """Split the tensor along its last dimension and keep the 38 | corresponding slice.""" 39 | group = get_model_parallel_group() 40 | 41 | # Bypass the function if we are using only 1 GPU. 42 | if torch.distributed.get_world_size(group=group) == 1: 43 | return input_ 44 | 45 | # Split along last dimension. 46 | world_size = torch.distributed.get_world_size(group=group) 47 | input_list = split_tensor_along_last_dim(input_, world_size) 48 | 49 | # Note: torch.split does not create contiguous tensors by default. 50 | rank = torch.distributed.get_rank(group=group) 51 | output = input_list[rank].contiguous() 52 | 53 | return output 54 | 55 | 56 | def _gather(input_): 57 | """Gather tensors and concatinate along the last dimension.""" 58 | group = get_model_parallel_group() 59 | 60 | # Bypass the function if we are using only 1 GPU. 61 | if torch.distributed.get_world_size(group=group) == 1: 62 | return input_ 63 | 64 | # Size and dimension. 65 | last_dim = input_.dim() - 1 66 | rank = torch.distributed.get_rank(group=group) 67 | world_size = torch.distributed.get_world_size(group=group) 68 | 69 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)] 70 | tensor_list[rank] = input_ 71 | torch.distributed.all_gather(tensor_list, input_, group=group) 72 | 73 | # Note: torch.cat already creates a contiguous tensor. 74 | output = torch.cat(tensor_list, dim=last_dim).contiguous() 75 | 76 | return output 77 | 78 | 79 | class _CopyToModelParallelRegion(torch.autograd.Function): 80 | """Pass the input to the model parallel region.""" 81 | 82 | @staticmethod 83 | def forward(ctx, input_): 84 | return input_ 85 | 86 | @staticmethod 87 | def backward(ctx, grad_output): 88 | return _reduce(grad_output) 89 | 90 | 91 | class _ReduceFromModelParallelRegion(torch.autograd.Function): 92 | """All-redcue the input from the model parallel region.""" 93 | 94 | @staticmethod 95 | def forward(ctx, input_): 96 | return _reduce(input_) 97 | 98 | @staticmethod 99 | def backward(ctx, grad_output): 100 | return grad_output 101 | 102 | 103 | class _ScatterToModelParallelRegion(torch.autograd.Function): 104 | """Split the input and keep only the corresponding chuck to the rank.""" 105 | 106 | @staticmethod 107 | def forward(ctx, input_): 108 | return _split(input_) 109 | 110 | @staticmethod 111 | def backward(ctx, grad_output): 112 | return _gather(grad_output) 113 | 114 | 115 | class _GatherFromModelParallelRegion(torch.autograd.Function): 116 | """Gather the input from model parallel region and concatinate.""" 117 | 118 | @staticmethod 119 | def forward(ctx, input_): 120 | return _gather(input_) 121 | 122 | @staticmethod 123 | def backward(ctx, grad_output): 124 | return _split(grad_output) 125 | 126 | 127 | # ----------------- 128 | # Helper functions. 129 | # ----------------- 130 | 131 | def copy_to_model_parallel_region(input_): 132 | return _CopyToModelParallelRegion.apply(input_) 133 | 134 | def reduce_from_model_parallel_region(input_): 135 | return _ReduceFromModelParallelRegion.apply(input_) 136 | 137 | def scatter_to_model_parallel_region(input_): 138 | return _ScatterToModelParallelRegion.apply(input_) 139 | 140 | def gather_from_model_parallel_region(input_): 141 | return _GatherFromModelParallelRegion.apply(input_) 142 | -------------------------------------------------------------------------------- /mpu/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | 20 | def ensure_divisibility(numerator, denominator): 21 | """Ensure that numerator is divisible by the denominator.""" 22 | assert numerator % denominator == 0, '{} is not divisible by {}'.format( 23 | numerator, denominator) 24 | 25 | 26 | def divide(numerator, denominator): 27 | """Ensure that numerator is divisible by the denominator and return 28 | the division value.""" 29 | ensure_divisibility(numerator, denominator) 30 | return numerator // denominator 31 | 32 | 33 | def split_tensor_along_last_dim(tensor, num_partitions, 34 | contiguous_split_chunks=False): 35 | """Split a tensor along its last dimension. 36 | Arguments: 37 | tensor: input tensor. 38 | num_partitions: number of partitions to split the tensor 39 | contiguous_split_chunks: If True, make each chunk contiguous 40 | in memory. 41 | """ 42 | # Get the size and dimension. 43 | last_dim = tensor.dim() - 1 44 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 45 | # Split. 46 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 47 | # Note: torch.split does not create contiguous tensors by default. 48 | if contiguous_split_chunks: 49 | return tuple(chunk.contiguous() for chunk in tensor_list) 50 | 51 | return tensor_list 52 | 53 | 54 | class VocabUtility: 55 | """Split the vocabulary into `world_size` chunks amd return the 56 | first and last index of the vocabulary belonging to the `rank` 57 | partition: Note that indecies in [fist, last)""" 58 | 59 | @staticmethod 60 | def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, 61 | rank, world_size): 62 | index_f = rank * per_partition_vocab_size 63 | index_l = index_f + per_partition_vocab_size 64 | return index_f, index_l 65 | 66 | @staticmethod 67 | def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): 68 | per_partition_vocab_size = divide(global_vocab_size, world_size) 69 | return VocabUtility.vocab_range_from_per_partition_vocab_size( 70 | per_partition_vocab_size, rank, world_size) 71 | 72 | def split_out_sums(x, BLOCK_SIZE=32, all_ret=False): 73 | b, L = x.shape[:2] 74 | rs = x.shape[2:] 75 | x = x.view(b, L // BLOCK_SIZE, BLOCK_SIZE, *rs) 76 | oris, sums = x.split([BLOCK_SIZE-1, 1], dim=2) 77 | if all_ret: 78 | return oris.reshape(b, -1, *rs), sums.reshape(b, -1, *rs) 79 | else: 80 | return sums.reshape(b, -1, *rs) 81 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import show_recover_results 2 | -------------------------------------------------------------------------------- /preprocess/preprocess_text_image_data.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : preprocess_text_image_data.py 4 | @Time : 2021/01/24 15:38:44 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | import pickle 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from torch.utils.data import DataLoader 20 | from torchvision import transforms 21 | import lmdb 22 | from .pretokenized_data import make_text_image_batch, make_tuple_text_image_batch, make_super_resolution_batch 23 | import PIL 24 | import timeit 25 | 26 | 27 | 28 | @torch.no_grad() 29 | def extract_code(model, datasets, text_dict, name, device, txt_type): 30 | index = 0 31 | map_size = 1024 * 1024 * 1024 * 1024 32 | lmdb_env = lmdb.open(f'/root/mnt/lmdb/{name}', map_size=map_size, writemap=True) 33 | print(f'/root/mnt/lmdb/{name}') 34 | with lmdb_env.begin(write=True) as txn: 35 | for dataset in datasets: 36 | loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=1) 37 | print(dataset) 38 | pbar = tqdm(loader) 39 | for raw_imgs, raw_filenames in pbar: 40 | imgs = [] 41 | filenames = [] 42 | for i, filename in enumerate(raw_filenames): 43 | if filename != "not_a_image" and text_dict.__contains__(filename): 44 | imgs.append(raw_imgs[i]) 45 | filenames.append(filename) 46 | else: 47 | print("warning: deleted damaged image") 48 | imgs = torch.stack(imgs) 49 | imgs = imgs.to(device) 50 | try: 51 | if txt_type == "h5": 52 | filenames = filenames.numpy() 53 | txts = [text_dict[filename] for filename in filenames] 54 | if txt_type != "h5": 55 | codes = make_text_image_batch(model, txts, imgs) 56 | else: 57 | codes = make_tuple_text_image_batch(model, txts, imgs) 58 | for code in codes: 59 | txn.put(str(index).encode('utf-8'), pickle.dumps(code)) 60 | index += 1 61 | except KeyError: 62 | print("warning: KeyError. The text cannot be find") 63 | pass 64 | txn.put('length'.encode('utf-8'), str(index).encode('utf-8')) 65 | 66 | 67 | @torch.no_grad() 68 | def extract_code_super_resolution_patches(model, datasets, text_dict, name, device, txt_type): 69 | index = 0 70 | map_size = 1024 * 1024 * 1024 * 1024 71 | lmdb_env = lmdb.open(f'/root/mnt/lmdb/{name}_super_resolution', map_size=map_size, writemap=True) 72 | print(f'/root/mnt/lmdb/{name}_super_resolution') 73 | with lmdb_env.begin(write=True) as txn: 74 | for dataset in datasets: 75 | loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=1) 76 | print(dataset) 77 | pbar = tqdm(loader) 78 | for raw_imgs, raw_filenames in pbar: 79 | imgs = [] 80 | filenames = [] 81 | for i, filename in enumerate(raw_filenames): 82 | if filename != "not_a_image" and text_dict.__contains__(filename): 83 | imgs.append(raw_imgs[i]) 84 | filenames.append(filename) 85 | else: 86 | print("warning: deleted damaged image") 87 | imgs = torch.stack(imgs) 88 | imgs = imgs.to(device) 89 | try: 90 | if txt_type == "h5": 91 | filenames = filenames.numpy() 92 | txts = [text_dict[filename] for filename in filenames] 93 | if txt_type != "h5": 94 | codes = make_super_resolution_batch(model, txts, imgs) 95 | else: 96 | codes = make_tuple_text_image_batch(model, txts, imgs) 97 | for code in codes: 98 | txn.put(str(index).encode('utf-8'), pickle.dumps(code)) 99 | index += 1 100 | except KeyError: 101 | print("warning: KeyError. The text cannot be find") 102 | pass 103 | txn.put('length'.encode('utf-8'), str(index).encode('utf-8')) 104 | 105 | -------------------------------------------------------------------------------- /preprocess/preprocess_text_jsonformat_data.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : preprocess_text_jsonformat_data.py 4 | @Time : 2021/03/14 20:56:28 5 | @Author : Ming Ding 6 | @Contact : dm18@mail.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | import pickle 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from torch.utils.data import DataLoader 20 | from torchvision import transforms 21 | import lmdb 22 | from .pretokenized_data import make_cut_text_batch 23 | import timeit 24 | import ujson as json 25 | 26 | def extract_code(datasets, name, seq_len): 27 | ''' 28 | datasets: [json_name1, json_name2, ...] 29 | ''' 30 | index = 0 31 | map_size = 1024 * 1024 * 1024 * 1024 32 | lmdb_env = lmdb.open(f'/root/mnt/lmdb/{name}', map_size=map_size, writemap=True) 33 | with lmdb_env.begin(write=True) as txn: 34 | for dataset in datasets: 35 | with open(dataset, 'r') as fin: 36 | print(f'Loading {dataset}...') 37 | raw_json = json.load(fin)["RECORDS"] 38 | bs = 512 39 | for i in tqdm(range(0, len(raw_json), bs)): 40 | txts = [t["content"] for t in raw_json[i: i + bs]] 41 | txts = make_cut_text_batch(txts, seq_len) 42 | for code in txts: 43 | txn.put(str(index).encode('utf-8'), pickle.dumps(code)) 44 | index += 1 45 | txn.put('length'.encode('utf-8'), str(index).encode('utf-8')) 46 | print(f'/root/mnt/lmdb/{name}, length={index}') 47 | -------------------------------------------------------------------------------- /preprocess/pretokenized_data.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : pretokenized_data.py 4 | @Time : 2021/01/20 15:39:10 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from vqvae import * 20 | from data_utils import Code2CodeTemplate, concat_codes 21 | from torchvision.transforms.functional import resize 22 | from torchvision import transforms 23 | from data_utils import get_tokenizer 24 | 25 | # def make_hierarchical_batch(model, txts, imgs): 26 | # ''' 27 | # model: VQVAE 28 | # txts: ['text1', 'text2', ...] 29 | # imgs: [b, 3, s, s] 30 | # ''' 31 | # s = img.shape[-1] 32 | # assert img.shape[-2] == s # square 33 | # codes_base = img2code(model, img) 34 | # img_tiny = resize(img, size=s//4).numpy() 35 | # codes_tiny = img2code(model, img_tiny).numpy() 36 | # ret = [] 37 | # for i in range(len(txts)): 38 | # text = '[ROI1] ' + txts[i] 39 | # ret.append( 40 | # Code2CodeTemplate(text, codes_tiny[i], codes_base[i]) 41 | # ) 42 | # return ret 43 | 44 | 45 | def make_super_resolution_batch(model, txts, imgs): 46 | ''' 47 | [text...small_img...base_img] 48 | ''' 49 | tokenizer = get_tokenizer() 50 | 51 | if not hasattr(make_super_resolution_batch, 'pos'): 52 | pos = ['左上', '正上', '右上', '左侧', '中间', '右侧', '左下', '正下', '右下'] 53 | pos = [ 54 | tokenizer.parse_query('[ROI1] 是{}部分图'.format(p)) 55 | for p in pos 56 | ] # [[23, 354...], [232, ...]] 57 | pw = [0, 64, 128] * 3 58 | ph = [0, 0, 0, 64, 64, 64, 128, 128, 128] 59 | make_super_resolution_batch.pos = list(zip(pos, ph, pw)) 60 | make_super_resolution_batch.weights = [1] * 9 61 | make_super_resolution_batch.prefix = tokenizer.parse_query('[ROI2] 是 [ROI1] 的放大图') 62 | 63 | s = imgs.shape[-1] 64 | assert s == imgs.shape[-2] == 256 65 | # Crop 128 * 128 patch 66 | selected_poses = random.choices(range(9), weights=make_super_resolution_batch.weights) 67 | pos = make_super_resolution_batch.pos 68 | patches = [ 69 | imgs[i, :, pos[p][1]:pos[p][1] + 128, pos[p][2]: pos[p][2]+128] 70 | for i, p in enumerate(selected_poses) 71 | ] 72 | patches = torch.stack(patches) 73 | small_patches = resize(patches, size=64) 74 | 75 | codes_base = img2code(model, patches).cpu().numpy() 76 | codes_small = img2code(model, small_patches).cpu().numpy() 77 | 78 | ret = [] 79 | for i in range(len(txts)): 80 | code_text = tokenizer(txts[i]) 81 | ret.append( 82 | concat_codes(code_text + make_super_resolution_batch.prefix, 83 | codes_small[i], 84 | pos[selected_poses[i]][0], 85 | codes_base[i]) 86 | ) 87 | return ret 88 | 89 | def make_super_resolution_batch(model, txts, imgs, img_size=512, sampling_num=4): 90 | ''' 91 | [text...small_img...base_img] 92 | ''' 93 | tokenizer = get_tokenizer() 94 | t0, t1 = img_size // 4, img_size // 2 95 | if img_size == 512: 96 | size_tk = tokenizer['[BASE]'] 97 | else: 98 | raise NotImplementedError 99 | 100 | pw = [0, t0, t1] * 3 101 | ph = [0, 0, 0, t0, t0, t0, t1, t1, t1] 102 | ptk = [[tokenizer['[EOI1]'], tokenizer['[ROI2]'], tokenizer[f'[POS{i}]'], size_tk, tokenizer['[BOI2]']] 103 | for i in range(9) 104 | ] 105 | pos = list(zip(ptk, ph, pw)) 106 | weights = [1] * 9 107 | 108 | 109 | s = imgs.shape[-1] 110 | assert s == imgs.shape[-2] == img_size 111 | # Crop img_size/2 * img_size/2 patch 112 | selected_poses = random.choices(range(9), weights=weights, k=sampling_num) 113 | pos = pos 114 | patches = [ 115 | imgs[i, :, pos[p][1]:pos[p][1] + t1, pos[p][2]: pos[p][2]+t1] 116 | for i in range(imgs.shape[0]) 117 | for p in selected_poses 118 | ] 119 | patch_prefix = [ 120 | pos[p][0] 121 | for p in selected_poses 122 | ] * imgs.shape[0] 123 | patches = torch.stack(patches) 124 | overviews = torch.nn.functional.interpolate(imgs, size=(t1, t1), mode='bilinear') 125 | 126 | codes_patches = img2code(model, patches).cpu().numpy() 127 | codes_overviews = img2code(model, overviews).cpu().numpy() 128 | ret = [] 129 | for i in range(len(txts)): 130 | code_text = [tokenizer['[ROI1]']] + tokenizer(txts[i]) + [size_tk, tokenizer['[BOI1]']] 131 | for j in range(sampling_num): 132 | ret.append( 133 | concat_codes(code_text, 134 | codes_overviews[i], 135 | patch_prefix[i* sampling_num + j], 136 | codes_patches[i * sampling_num + j], 137 | [tokenizer['[EOI2]']] 138 | ) 139 | ) 140 | return ret 141 | 142 | def make_text_image_batch(model, txts, imgs): 143 | from data_utils import TextCodeTemplate 144 | s = imgs.shape[-1] 145 | assert s == imgs.shape[-2] == 256 146 | tokenizer = get_tokenizer() 147 | codes = img2code(model, imgs).cpu().numpy() 148 | ret = [] 149 | for i in range(len(txts)): 150 | ret.append( 151 | TextCodeTemplate(txts[i], codes[i]) 152 | ) 153 | return ret 154 | 155 | def make_tuple_text_image_batch(model, txts, imgs): 156 | s = imgs.shape[-1] 157 | assert s == imgs.shape[-2] == 256 158 | codes = img2code(model, imgs).cpu().numpy() 159 | ret = [] 160 | for i in range(len(txts)): 161 | ret.append( 162 | (txts[i], codes[i]) 163 | ) 164 | return codes 165 | 166 | import itertools 167 | def make_cut_text_batch(txts, seq_len): 168 | from data_utils import PureTextTemplate 169 | tmp_list = np.array(list( 170 | itertools.chain(*(PureTextTemplate(txt) for txt in txts)) 171 | )) 172 | ret = [ 173 | tmp_list[en - seq_len: en] 174 | for en in range(seq_len, len(tmp_list), seq_len) 175 | ] 176 | return ret 177 | 178 | -------------------------------------------------------------------------------- /preprocess/raw_datasets.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : raw_datasets.py 4 | @Time : 2021/01/24 15:31:34 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | import ctypes 16 | import io 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn.functional as F 21 | 22 | from torch.utils.data import Dataset, IterableDataset 23 | from torchvision import datasets 24 | import unrar 25 | from PIL import Image 26 | import timeit 27 | from collections import Iterable 28 | 29 | 30 | class ImageFileDataset(datasets.ImageFolder): 31 | def __getitem__(self, index): 32 | sample, target = super().__getitem__(index) 33 | path, _ = self.samples[index] 34 | dirs, filename = os.path.split(path) 35 | filename = filename.split('.')[0] 36 | return sample, filename 37 | 38 | class RarDataset(Dataset): 39 | def __init__(self, path, transform=None): 40 | from unrar import rarfile 41 | self.rar = rarfile.RarFile(path) 42 | self.infos = self.rar.infolist() 43 | self.transform = transform 44 | def __len__(self): 45 | return len(self.infos) 46 | def __getitem__(self, idx): 47 | target_info = self.infos[idx] 48 | img = Image.open(self.rar.open(target_info)) 49 | dirs, filename = os.path.split(self.infos[idx].filename) 50 | filename = filename.split('.')[0] 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | return img, filename 54 | 55 | from unrar import rarfile 56 | from unrar import unrarlib 57 | from unrar import constants 58 | from unrar.rarfile import _ReadIntoMemory, BadRarFile 59 | import zipfile 60 | import PIL 61 | 62 | class ZipDataset(Dataset): 63 | def __init__(self, path, transform=None): 64 | self.zip = zipfile.ZipFile(path) 65 | worker_info = torch.utils.data.get_worker_info() 66 | if worker_info is None: 67 | self.members = [info for info in self.zip.infolist() if info.filename[-1] != os.sep] 68 | else: 69 | all_members = [info for info in self.zip.infolist() if info.filename[-1] != os.sep] 70 | num_workers = worker_info.num_workers 71 | worker_id = worker_info.id 72 | self.members = [x for i, x in enumerate(all_members) if i % num_workers == worker_id] 73 | 74 | self.transform = transform 75 | def __len__(self): 76 | return len(self.members) 77 | def __getitem__(self, idx): 78 | target_info = self.members[idx] 79 | img = Image.open(self.zip.open(target_info)) 80 | dirs, filename = os.path.split(self.members[idx].filename) 81 | filename = filename.split('.')[0] 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | return img, filename 85 | 86 | import h5py 87 | 88 | class H5Dataset(Dataset): 89 | def __init__(self, path, transform=None): 90 | self.h5 = h5py.File(path, "r") 91 | self.images = self.h5["input_image"] 92 | self.members = None 93 | self.transform = transform 94 | 95 | def create_members(self): 96 | worker_info = torch.utils.data.get_worker_info() 97 | if worker_info is None: 98 | self.members = self.h5['index'][:] 99 | else: 100 | all_members = self.h5['index'][:] 101 | num_workers = worker_info.num_workers 102 | worker_id = worker_info.id 103 | self.members = [x for i, x in enumerate(all_members) if i % num_workers == worker_id] 104 | 105 | def __len__(self): 106 | if self.members is None: 107 | self.create_members() 108 | return len(self.members) 109 | 110 | def __getitem__(self, idx): 111 | if self.members is None: 112 | self.create_members() 113 | target_info = self.members[idx] 114 | try: 115 | img = Image.fromarray(self.images[target_info][0]) 116 | if self.transform is not None: 117 | img = self.transform(img) 118 | return img, int(target_info) 119 | except(OSError, IndexError): 120 | print("warning: OSError or IndexError") 121 | return Image.new('RGB', (256, 256), (255, 255, 255)), -1 122 | 123 | # class StreamingZipDataset(IterableDataset): 124 | # def __init__(self, path, transform=None): 125 | # self.zip = zipfile.ZipFile(path, "r") 126 | # self.transform = transform 127 | # def __len__(self): 128 | # return len(self.zip.filelist) 129 | # def __next__(self): 130 | # img = Image.open(self.rar.open(target_info)) 131 | # 132 | # pass 133 | # def __iter__(self): 134 | # worker_info = torch.utils.data.get_worker_info() 135 | # if worker_info is None: 136 | # self.members = self.zip.namelist() 137 | # else: 138 | # all_members = self.zip.namelist() 139 | # num_workers = worker_info.num_workers 140 | # worker_id = worker_info.id 141 | # self.members = [x for i, x in enumerate(all_members) if i % num_workers == worker_id] 142 | # self.pointer = 0 143 | # return self 144 | # def __del__(self): 145 | # self.zip.close() 146 | 147 | class StreamingRarDataset(IterableDataset): 148 | def __init__(self, path, transform=None, default_size=256): 149 | from PIL import ImageFile 150 | ImageFile.LOAD_TRUNCATED_IMAGES = True 151 | print("begin open rar") 152 | self.rar = rarfile.RarFile(path) 153 | print("finish open rar") 154 | self.transform = transform 155 | def callback_fn(file_buffer, filename): 156 | try: 157 | img = Image.open(file_buffer.get_bytes()).convert('RGB') 158 | dirs, filename = os.path.split(filename) 159 | filename = filename.split('.')[0] 160 | if self.transform is not None: 161 | img = self.transform(img) 162 | return img, filename 163 | except PIL.UnidentifiedImageError: 164 | print("UnidentifiedImageError") 165 | return torch.zeros((3, default_size, default_size)), "not_a_image" 166 | self.callback_fn = callback_fn 167 | # new handle 168 | self.handle = None 169 | self.callback_fn = callback_fn 170 | 171 | def __len__(self): 172 | return len(self.rar.filelist) 173 | def __next__(self): 174 | if self.pointer >= len(self.members): 175 | raise StopIteration() 176 | if self.handle == None: 177 | archive = unrarlib.RAROpenArchiveDataEx( 178 | self.rar.filename, mode=constants.RAR_OM_EXTRACT) 179 | self.handle = self.rar._open(archive) 180 | # callback to memory 181 | self.data_storage = _ReadIntoMemory() 182 | c_callback = unrarlib.UNRARCALLBACK(self.data_storage._callback) 183 | unrarlib.RARSetCallback(self.handle, c_callback, 0) 184 | handle = self.handle 185 | try: 186 | rarinfo = self.rar._read_header(handle) 187 | while rarinfo is not None: 188 | if rarinfo.filename == self.members[self.pointer]: 189 | self.rar._process_current(handle, constants.RAR_TEST) 190 | break 191 | else: 192 | self.rar._process_current(handle, constants.RAR_SKIP) 193 | rarinfo = self.rar._read_header(handle) 194 | 195 | if rarinfo is None: 196 | self.data_storage = None 197 | 198 | except unrarlib.UnrarException: 199 | raise BadRarFile("Bad RAR archive data.") 200 | 201 | if self.data_storage is None: 202 | raise KeyError('There is no item named %r in the archive' % self.members[self.pointer]) 203 | 204 | # return file-like object 205 | ret = self.data_storage 206 | if self.callback_fn is not None: 207 | ret = self.callback_fn(ret, self.members[self.pointer]) 208 | self.pointer += 1 209 | return ret 210 | 211 | def __iter__(self): 212 | worker_info = torch.utils.data.get_worker_info() 213 | if worker_info is None: 214 | self.members = self.rar.namelist() 215 | else: 216 | all_members = self.rar.namelist() 217 | num_workers = worker_info.num_workers 218 | worker_id = worker_info.id 219 | self.members = [x for i, x in enumerate(all_members) if i % num_workers == worker_id] 220 | self.pointer = 0 221 | return self 222 | 223 | def __del__(self): 224 | self.rar._close(self.handle) 225 | -------------------------------------------------------------------------------- /preprocess/utils.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : utils.py 4 | @Time : 2021/01/24 16:35:43 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | # here put the import lib 10 | import os 11 | import sys 12 | import math 13 | import random 14 | from tqdm import tqdm 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from vqvae import code2img, img2code 20 | from torchvision.utils import save_image 21 | 22 | 23 | def show_recover_results(model, imgs): 24 | codes = img2code(model, imgs) 25 | recovered = code2img(model, codes) 26 | mean = torch.tensor([0.79093, 0.76271, 0.75340], device=recovered.device).view(-1, 1, 1) 27 | std = torch.tensor([0.30379, 0.32279, 0.32800], device=recovered.device).view(-1, 1, 1) 28 | recovered = (recovered * std + mean).clamp(0, 1) 29 | imgs = (imgs * std + mean).clamp(0, 1) 30 | out = torch.cat([imgs, recovered], dim=0) 31 | save_image(out, 'samples/show_recover_results.jpg', normalize=False, nrow=len(imgs)) 32 | -------------------------------------------------------------------------------- /preprocess_entry.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import random 5 | from tqdm import tqdm 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torchvision import transforms 12 | import argparse 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser(description="preprocess args") 16 | parser.add_argument("--dataset", type=str, required=True) 17 | parser.add_argument("--img_tokenizer_path", type=str, default='vqvae_hard_biggerset_011.pt') 18 | parser.add_argument("--encode_size", type=int, default=32) 19 | parser.add_argument("--device", type=int, default=0) 20 | args = parser.parse_args() 21 | print(args) 22 | img_size = args.encode_size * 8 23 | 24 | # args = argparse.Namespace() 25 | # args.img_tokenizer_path = 'pretrained/vqvae/vqvae_hard_018.pt'#old path 26 | # args.img_tokenizer_path = 'pretrained/vqvae/vqvae_hard_biggerset_011.pt' 27 | # args.img_tokenizer_path = '/root/mnt/vqvae_1epoch_64x64.pt' 28 | args.img_tokenizer_num_tokens = None 29 | 30 | device = f'cuda:{args.device}' 31 | torch.cuda.set_device(device) 32 | name = args.dataset + "_" + args.img_tokenizer_path.split(".")[0] + ".lmdb" 33 | args.img_tokenizer_path = f"pretrained/vqvae/{args.img_tokenizer_path}" 34 | 35 | datasets = {} 36 | datasets["ali"] = [ 37 | ['/root/mnt/sq_gouhou_white_pict_title_word_256_fulltitle.tsv'], 38 | ['/root/mnt/dingming/ali_white_picts_256.zip'], 39 | "tsv" 40 | ] 41 | datasets["ks3"] = [ 42 | ['/root/mnt/KS3/a_baidu_image_msg_data.json'], 43 | ['/root/mnt/KS3/downloadImages.rar'], 44 | "json_ks" 45 | ] 46 | datasets["zijian"] = [ 47 | ['/root/mnt/zijian/zj_duomotai_clean_done_data_new.json', 48 | '/root/mnt/zijian/zj_duomotai_local_server_last_surplus_120w.json'], 49 | ['/root/mnt/imageFolder_part01.rar', 50 | '/root/mnt/zijian/imagesFolder_last_surplus_120w.rar'], 51 | "json" 52 | ] 53 | datasets["google"] = [ 54 | ['/root/mnt/google/google_image_message_data.json'], 55 | ['/root/mnt/google/downloadImage_2020_12_16.rar'], 56 | "json_ks" 57 | ] 58 | datasets["zijian1"] = [ 59 | ['/root/mnt/zijian/zj_duomotai_clean_done_data_new.json'], 60 | ['/root/cogview2/data/imageFolder_part01.rar'], 61 | "json" 62 | ] 63 | datasets["zijian2"] = [ 64 | ['/root/mnt/zijian/zj_duomotai_local_server_last_surplus_120w.json'], 65 | ['/root/mnt/zijian/imagesFolder_last_surplus_120w.rar'], 66 | "json" 67 | ] 68 | txt_files, img_folders, txt_type = datasets[args.dataset] 69 | 70 | os.environ['UNRAR_LIB_PATH'] = '/usr/local/lib/libunrar.so' 71 | 72 | 73 | from data_utils import get_tokenizer 74 | tokenizer = get_tokenizer(args) 75 | model = tokenizer.img_tokenizer.model 76 | 77 | print("finish init vqvae_model") 78 | 79 | from preprocess.preprocess_text_image_data import extract_code,extract_code_super_resolution_patches 80 | 81 | # ===================== Define Imgs ======================== # 82 | from preprocess.raw_datasets import H5Dataset, StreamingRarDataset, ZipDataset 83 | 84 | datasets = [] 85 | for img_folder in img_folders: 86 | if img_folder[-3:] == "rar": 87 | dataset = StreamingRarDataset(path=img_folder, transform=transforms.Compose([ 88 | transforms.Resize(img_size), 89 | transforms.CenterCrop(img_size), 90 | transforms.ToTensor(), 91 | transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]), 92 | ]), 93 | default_size=img_size) 94 | elif img_folder[-3:] == "zip": 95 | dataset = ZipDataset(path=img_folder, transform=transforms.Compose([ 96 | transforms.Resize(img_size), 97 | transforms.CenterCrop(img_size), 98 | transforms.ToTensor(), 99 | transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]), 100 | ])) 101 | else: 102 | dataset = H5Dataset(path=img_folder, transform=transforms.Compose([ 103 | transforms.Resize(img_size), 104 | transforms.CenterCrop(img_size), 105 | transforms.ToTensor(), 106 | transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]), 107 | ])) 108 | datasets.append(dataset) 109 | print('Finish reading meta-data of dataset.') 110 | # ===================== END OF BLOCK ======================= # 111 | 112 | # from preprocess import show_recover_results 113 | 114 | 115 | # loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8) 116 | # loader = iter(loader) 117 | # samples = [] 118 | # for k in range(8): 119 | # x = next(loader) 120 | # print(x[1]) 121 | # x = x[0].to(device) 122 | # samples.append(x) 123 | # samples = torch.cat(samples, dim=0) 124 | # show_recover_results(model, samples) 125 | 126 | # ===================== Load Text ======================== # 127 | if txt_type == "json": 128 | import json 129 | txt_list = [] 130 | for txt in txt_files: 131 | with open(txt, 'r') as fin: 132 | t = json.load(fin) 133 | txt_list.extend(list(t.items())) 134 | tmp = [] 135 | for k, v in tqdm(txt_list): 136 | tmp.append((v['uniqueKey'], v['cnShortText'])) 137 | text_dict = dict(tmp) 138 | elif txt_type == "json_ks": 139 | import json 140 | txt_list = [] 141 | for txt in txt_files: 142 | with open(txt, 'r') as fin: 143 | t = json.load(fin) 144 | txt_list.extend(t["RECORDS"]) 145 | tmp = [] 146 | for v in tqdm(txt_list): 147 | tmp.append((v['uniqueKey'], v['cnShortText'])) 148 | text_dict = dict(tmp) 149 | elif txt_type == "tsv": 150 | import pandas as pd 151 | txt_list = [] 152 | for txt in txt_files: 153 | t = pd.read_csv(txt, sep='\t') 154 | txt_list.extend(list(t.values)) 155 | tmp = [] 156 | for k, v in tqdm(txt_list): 157 | tmp.append((str(k), v)) 158 | text_dict = dict(tmp) 159 | else: 160 | des = dataset.h5["input_concat_description"] 161 | txt_name = dataset.h5["input_name"] 162 | tmp = [] 163 | for i in tqdm(range(len(des))): 164 | tmp.append((i, des[i][0].decode("latin-1")+txt_name[i][0].decode("latin-1"))) 165 | text_dict = dict(tmp) 166 | print('Finish reading texts of dataset.') 167 | # ===================== END OF BLOCK ======================= # 168 | 169 | # extract_code(model, datasets, text_dict, name, device, txt_type) 170 | extract_code_super_resolution_patches(model, datasets, text_dict, name, device, txt_type) -------------------------------------------------------------------------------- /pretrained/chinese_sentencepiece/cog-pretrain.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView/189a12bbe87227d3d127fa74abe968dd795cad39/pretrained/chinese_sentencepiece/cog-pretrain.model -------------------------------------------------------------------------------- /pretrained/cogview/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView/189a12bbe87227d3d127fa74abe968dd795cad39/pretrained/cogview/placeholder -------------------------------------------------------------------------------- /pretrained/vqvae/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView/189a12bbe87227d3d127fa74abe968dd795cad39/pretrained/vqvae/placeholder -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 |

5 | Generate vivid Images for Any (Chinese) text 6 |

7 | 8 | ![teaser](assets/cogviewcase.png) 9 | 10 | **News!** The paper of ImageReward is accepted by NeurIPS 2023! 11 | 12 | **News!** The codes of ImageReward ([paper link](https://arxiv.org/abs/2304.05977.pdf)) have been released at https://github.com/THUDM/ImageReward! ImageReward is the first general-purpose text-to-image human preference RM. 13 | 14 | **News!** The codes of CogView2 ([paper link](https://arxiv.org/pdf/2105.13290.pdf)) have been released at https://github.com/THUDM/CogView2! 15 | 16 | **News!** The [demo](https://agc.platform.baai.ac.cn/CogView/index.html) for a better and faster CogView2 (formal version, March 2022) is available! The lastest model also supports English input, but to translate them into Chinese often could be better. 17 | 18 | **News!** The [demo](https://agc.platform.baai.ac.cn/CogView/index.html) for a better and faster CogView2 (new version) is available! 19 | 20 | **News!** The paper of CogView is accepted by NeurIPS 2021! 21 | 22 | CogView is a pretrained (4B-param) transformer for text-to-image generation in general domain. 23 | 24 | * **Read** our paper [CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/pdf/2105.13290.pdf) on ArXiv for a formal introduction. The *PB-relax* and *Sandwich-LN* can also help you train large and deep transformers stably (e.g. eliminating NaN losses). 25 | * **Visit** our demo at [Github Page](https://thudm.github.io/CogView/index.html) or [Wudao](https://wudao.aminer.cn/CogView/)! (Without post-selection or super-resolution, currently only supports simplified Chinese input, but one can translate text from other languages into Chinese for input. Note: *Wudao* provides faster access for users from China mainland.) 26 | * **Download** our pretrained models from [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/df21f6d4109b4285bfd9/?dl=1). 27 | * **Cite** our paper if you find our work is helpful~ 28 | ``` 29 | @article{ding2021cogview, 30 | title={CogView: Mastering Text-to-Image Generation via Transformers}, 31 | author={Ding, Ming and Yang, Zhuoyi and Hong, Wenyi and Zheng, Wendi and Zhou, Chang and Yin, Da and Lin, Junyang and Zou, Xu and Shao, Zhou and Yang, Hongxia and Tang, Jie}, 32 | journal={arXiv preprint arXiv:2105.13290}, 33 | year={2021} 34 | ``` 35 | * **Google Colab** Two contributors successfully setup up CogView on Colab [![Links to Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/THUDM/CogView/issues/10)! 36 | ## Getting Started 37 | ### Setup 38 | * Hardware: Linux servers with Nvidia V100s or A100s are recommended, but it is also okay to run the pretrained models with smaller `--max-inference-batch-size` or training smaller models on less powerful GPUs. 39 | * Environment (Option 1): Please first install PyTorch (>=1.7.0) and [apex](https://github.com/NVIDIA/apex), and then install other dependencies via `pip install -r requirements.txt`. 40 | 41 | * Environment (Option 2): We prepare a docker image in case that you fail to handle the environments. Pull the image, create a (background) container and get into it via: 42 | ``` 43 | docker pull cogview/cuda111_torch181_deepspeed040 44 | ./env/start_docker.sh && docker exec -it bg-cogview bash 45 | 46 | cd /root/cogview # in the container 47 | ``` 48 | 49 | ### Download 50 | 0. Download the image tokenizer `vqvae_hard_biggerset_011.pt` from [BAAI website](https://resource.wudaoai.cn/home?ind=2&name=WuDao%20WenHui&id=1399364355975327744) or Tsinghua Cloud. Place the file under `pretrained/vqvae`. 51 | ``` 52 | wget 'https://cloud.tsinghua.edu.cn/f/71607a5dca69417baa8c/?dl=1' -O pretrained/vqvae/vqvae_hard_biggerset_011.pt 53 | ``` 54 | 1. Download models from [Project Wudao-Wenhui](https://resource.wudaoai.cn/home?ind=2&name=WuDao%20WenHui&id=1399364355975327744). 55 | | FileName | Discription | 56 | | ---- | ---- | 57 | | cogview-base.tar | The pretrained text-to-image model. | 58 | | cogview-caption.tar | Finetuned image-to-text model, also used for reranking. | 59 | | cogview-sr.tar | Finetuned super-resolution model. (warning: it runs slow.) | 60 | 61 | Uncompress them into `pretrained/cogview/`. The following command should be modified based on the model name. 62 | ``` 63 | tar -xvf cogview-{base, sr, caption}.tar -C pretrained/cogview/ 64 | ``` 65 | 2. (Only for training tutorial, skip it for inference.) Download a small "bird-and-animal" example dataset from our link at Tsinghua Cloud. 66 | ``` 67 | wget https://cloud.tsinghua.edu.cn/f/1e4963ec8ac84941ba68/?dl=1 -O data/bird_animal.bin 68 | ``` 69 | 70 | ### Run CogView! (Model Inference) 71 | We encapsulate the generation functions into scripts. See `generate_samples.py` and `arguments.py` for details. 72 | 73 | #### Text-to-Image Generation 74 | Write text queries (one per line) into `input.txt` and run: 75 | ``` 76 | ./scripts/text2image.sh --debug 77 | ``` 78 | The results will in a new folder `samples_text2image/`. 79 | 80 | Arguments useful in inference are mainly: 81 | * `--input-source [path or "interactive"]`. The path of the input file, can also be "interactive", which will launch a CLI. 82 | * `--output-path [path]`. The folder containing the results. 83 | * `--batch-size [int]`. The number of samples will be generated per query. 84 | * `--max-inference-batch-size [int]`. Maximum batch size per forward. Reduce it if OOM. 85 | * `--debug`. Only save concatenated images for all generated samples, and name them by input text and date. 86 | * `--with-id`. When it toggled, you must specify an "id" before each input, e.g. `001\t一个漂亮的女孩`, \t denoting TAB (**NOT space**). It will generate `batch-size` split images in a folder named "id" for each input. Confict with `--debug`. 87 | * `--device [int]`. Running on which GPU. 88 | 89 | #### Super-resolution 90 | Run the following script and input `text\t{image_path}`, where `{image_path}` means the path of a previously generated image. 91 | ``` 92 | ./scripts/super_resolution.sh 93 | ``` 94 | Note: *It is only effective for generated images from our Image Tokenizer (due to the token distribution).* 95 | 96 | #### Image-to-Text 97 | The input is "one image path per line", and will print the results to stdout. 98 | ``` 99 | ./scripts/image2text.sh 100 | ``` 101 | Note: *Not optimized for this task, so it might not very competitive (but okay). We will consider to release a version funetuning for a longer period on this task in the future.* (*TODO*) 102 | 103 | #### Post-selection 104 | This application only takes file inputs, where each line is `{text}\t{image_path1}\t{image_path2}\t{image_path3}...`. 105 | The output is `{output_path}/scores.txt`, a line of a list of scores, following a line from inputs. 106 | ``` 107 | ./scripts/post_selection.sh 108 | ``` 109 | 110 | Note: *In the released codes, for simplicity, we did not expose the raw API , which supports some advanced generation modes, e.g. text and part of image.* 111 | 112 | ## Training 113 | Here we use a subset of our dataset from bird-and-animal for tutorial. The binary dataset is generated by our [cogdata toolkit](https://github.com/Sleepychord/cogdata). Please wait for a formal release with tutorials of cogdata (although it is available now). 114 | ### Single Node 115 | After downloading the dataset, directly run 116 | ``` 117 | ./scripts/pretrain_single_node.sh 118 | ``` 119 | ### Multiple Nodes 120 | If you want to train the models on multiple servers inter-connected by infiniband without a shared file system (you may need `pdsh` to accelerate this process): 121 | 1. On **each** server, use `git clone` to download this repo, and make sure the data (LMDB format) are moved into the `data` subfolder. 122 | 2. On **each** server, `echo "ip1 ip2 " > ./docker/ip_list.txt`, and then start the docker by `./env/start_docker.sh`. 123 | 3. Get into **the docker on the first node** container via `docker exec -it bg-cogview bash`. 124 | 4. Get into `/root/cogview` and run `./scripts/pretrain_multiple_nodes.sh`. You may need to change the config (especially `OPTIONS_NCCL`) in the shell script. 125 | 126 | See the `arguments.py` for advanced functions for training. 127 | *TODO* 128 | 129 | ## Gallery 130 | ![more_samples](assets/coco_new.png) 131 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | deepspeed 3 | tqdm 4 | lmdb 5 | filelock 6 | sentencepiece 7 | mpi4py 8 | tensorboardX==1.8 -------------------------------------------------------------------------------- /scripts/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 12, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "fp16": { 7 | "enabled": true, 8 | "loss_scale": 0, 9 | "loss_scale_window": 1000, 10 | "hysteresis": 2, 11 | "min_loss_scale": 0.01 12 | }, 13 | "optimizer": { 14 | "type": "Adam", 15 | "params": { 16 | "lr": 0.0004, 17 | "weight_decay": 1e-2 18 | } 19 | }, 20 | "activation_checkpointing": { 21 | "partition_activations": false, 22 | "contiguous_memory_optimization": false 23 | }, 24 | "wall_clock_breakdown": false 25 | } -------------------------------------------------------------------------------- /scripts/ds_config_zero.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 4, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 1, 5 | "gradient_clipping": 0.1, 6 | "zero_optimization": { 7 | "stage": 1, 8 | "cpu_offload": false, 9 | "contiguous_gradients": false, 10 | "overlap_comm": true, 11 | "reduce_scatter": true, 12 | "reduce_bucket_size": 100000000, 13 | "allgather_bucket_size": 1000000000 14 | }, 15 | "zero_allow_untested_optimizer": true, 16 | "fp16": { 17 | "enabled": true, 18 | "loss_scale": 0, 19 | "loss_scale_window": 400, 20 | "hysteresis": 2, 21 | "min_loss_scale": 1 22 | }, 23 | "optimizer": { 24 | "type": "Adam", 25 | "params": { 26 | "lr": 0.00005, 27 | "betas": [ 28 | 0.9, 29 | 0.95 30 | ], 31 | "eps": 1e-8, 32 | "weight_decay": 4e-2 33 | } 34 | }, 35 | "activation_checkpointing": { 36 | "partition_activations": false, 37 | "contiguous_memory_optimization": false 38 | }, 39 | "wall_clock_breakdown": false 40 | } 41 | -------------------------------------------------------------------------------- /scripts/image2text.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHECKPOINT_PATH=pretrained/cogview/cogview-caption 4 | NLAYERS=48 5 | NHIDDEN=2560 6 | NATT=40 7 | MAXSEQLEN=1089 8 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 9 | MPSIZE=1 10 | 11 | #SAMPLING ARGS 12 | TEMP=1. 13 | #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p 14 | TOPK=200 15 | TOPP=0 16 | 17 | script_path=$(realpath $0) 18 | script_dir=$(dirname $script_path) 19 | 20 | MASTER_PORT=${MASTER_PORT} python generate_samples.py \ 21 | --deepspeed \ 22 | --model-parallel-size $MPSIZE \ 23 | --num-layers $NLAYERS \ 24 | --hidden-size $NHIDDEN \ 25 | --load $CHECKPOINT_PATH \ 26 | --num-attention-heads $NATT \ 27 | --max-position-embeddings 1089 \ 28 | --fp16 \ 29 | --temperature $TEMP \ 30 | --top_k $TOPK \ 31 | --top_p $TOPP \ 32 | --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ 33 | --query-window 64 \ 34 | --key-window-times 4 \ 35 | --num-pivot 256 \ 36 | --is-sparse 0 \ 37 | --max-position-embeddings-finetune $MAXSEQLEN \ 38 | --generation-task image2text \ 39 | --input-source interactive \ 40 | --output-path samples_image2text \ 41 | --batch-size 8 \ 42 | --debug \ 43 | --device 1 \ 44 | $@ 45 | 46 | 47 | -------------------------------------------------------------------------------- /scripts/low_level_super_resolution.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHECKPOINT_PATH=pretrained/cogview/cogview-sr 4 | NLAYERS=48 5 | NHIDDEN=2560 6 | NATT=40 7 | MAXSEQLEN=1345 8 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 9 | MPSIZE=1 10 | 11 | #SAMPLING ARGS 12 | TEMP=1.02 13 | #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p 14 | TOPK=200 15 | TOPP=0 16 | 17 | script_path=$(realpath $0) 18 | script_dir=$(dirname $script_path) 19 | 20 | MASTER_PORT=${MASTER_PORT} python generate_samples.py \ 21 | --deepspeed \ 22 | --model-parallel-size $MPSIZE \ 23 | --num-layers $NLAYERS \ 24 | --hidden-size $NHIDDEN \ 25 | --load $CHECKPOINT_PATH \ 26 | --num-attention-heads $NATT \ 27 | --max-position-embeddings 1089 \ 28 | --fp16 \ 29 | --temperature $TEMP \ 30 | --top_k $TOPK \ 31 | --top_p $TOPP \ 32 | --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ 33 | --query-window 64 \ 34 | --key-window-times 4 \ 35 | --num-pivot 256 \ 36 | --is-sparse 0 \ 37 | --max-position-embeddings-finetune $MAXSEQLEN \ 38 | --generation-task "low-level super-resolution" \ 39 | --input-source interactive \ 40 | --output-path samples_low_level_sr \ 41 | --batch-size 4 \ 42 | --device 6 \ 43 | $@ 44 | 45 | 46 | -------------------------------------------------------------------------------- /scripts/post_selection.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHECKPOINT_PATH=pretrained/cogview/cogview-caption 4 | NLAYERS=48 5 | NHIDDEN=2560 6 | NATT=40 7 | MAXSEQLEN=1089 8 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 9 | MPSIZE=1 10 | 11 | #SAMPLING ARGS 12 | TEMP=1. 13 | #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p 14 | TOPK=200 15 | TOPP=0 16 | 17 | script_path=$(realpath $0) 18 | script_dir=$(dirname $script_path) 19 | 20 | MASTER_PORT=${MASTER_PORT} python generate_samples.py \ 21 | --deepspeed \ 22 | --model-parallel-size $MPSIZE \ 23 | --num-layers $NLAYERS \ 24 | --hidden-size $NHIDDEN \ 25 | --load $CHECKPOINT_PATH \ 26 | --num-attention-heads $NATT \ 27 | --max-position-embeddings 1089 \ 28 | --fp16 \ 29 | --temperature $TEMP \ 30 | --top_k $TOPK \ 31 | --top_p $TOPP \ 32 | --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ 33 | --query-window 64 \ 34 | --key-window-times 4 \ 35 | --num-pivot 256 \ 36 | --is-sparse 0 \ 37 | --max-position-embeddings-finetune $MAXSEQLEN \ 38 | --generation-task post-selection \ 39 | --input-source input_select.txt \ 40 | --output-path samples_post_selection \ 41 | --debug \ 42 | --device 2 \ 43 | $@ 44 | # input-source is split by \t, instead of 4 spaces 45 | 46 | -------------------------------------------------------------------------------- /scripts/pretrain_multiple_nodes.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Change for multinode config 4 | 5 | NUM_WORKERS=2 6 | NUM_GPUS_PER_WORKER=8 7 | MP_SIZE=1 8 | 9 | script_path=$(realpath $0) 10 | script_dir=$(dirname $script_path) 11 | main_dir=$(dirname $script_dir) 12 | 13 | # OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0" 14 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=ib0 NCCL_NET_GDR_LEVEL=2" 15 | HOST_FILE_PATH="hostfile" 16 | # OPTIONS_NCCL="" 17 | # HOST_FILE_PATH="hostfile_single" 18 | 19 | 20 | config_json="$script_dir/ds_config_zero.json" 21 | gpt_options=" \ 22 | --experiment-name cogview-ali_fashion_tutorial-12-1024-16 \ 23 | --img-tokenizer-num-tokens 8192 \ 24 | --dataset-type TokenizedDataset \ 25 | --model-parallel-size ${MP_SIZE} \ 26 | --num-layers 12 \ 27 | --hidden-size 1024 \ 28 | --num-attention-heads 16 \ 29 | --save $main_dir/data/checkpoints \ 30 | --train-iters 40000 \ 31 | --resume-dataloader \ 32 | --train-data ./data/ali_vqvae_hard_biggerset_011.lmdb \ 33 | --split 949,50,1 \ 34 | --distributed-backend nccl \ 35 | --lr-decay-style cosine \ 36 | --warmup .1 \ 37 | --checkpoint-activations \ 38 | --deepspeed-activation-checkpointing \ 39 | --max-position-embeddings 1089 \ 40 | --max-memory-length 0 \ 41 | --fp16 \ 42 | " 43 | 44 | 45 | gpt_options="${gpt_options} 46 | --deepspeed \ 47 | --deepspeed_config ${config_json} \ 48 | " 49 | 50 | 51 | run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py $@ ${gpt_options}" 52 | echo ${run_cmd} 53 | eval ${run_cmd} 54 | 55 | set +x 56 | -------------------------------------------------------------------------------- /scripts/pretrain_single_node.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Change for multinode config 4 | 5 | NUM_WORKERS=1 6 | NUM_GPUS_PER_WORKER=8 7 | MP_SIZE=1 8 | 9 | script_path=$(realpath $0) 10 | script_dir=$(dirname $script_path) 11 | main_dir=$(dirname $script_dir) 12 | 13 | # OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0" 14 | OPTIONS_NCCL="NCCL_DEBUG=info" 15 | HOST_FILE_PATH="hostfile_single" 16 | 17 | 18 | config_json="$script_dir/ds_config.json" 19 | gpt_options=" \ 20 | --experiment-name cogview-bird_animal_tutorial-12-1024-16 \ 21 | --img-tokenizer-num-tokens 8192 \ 22 | --dataset-type CompactBinaryDataset \ 23 | --model-parallel-size ${MP_SIZE} \ 24 | --num-layers 12 \ 25 | --hidden-size 1024 \ 26 | --num-attention-heads 16 \ 27 | --save $main_dir/data/checkpoints \ 28 | --train-iters 20000 \ 29 | --resume-dataloader \ 30 | --train-data ./data/bird_animal.bin \ 31 | --split 949,50,1 \ 32 | --distributed-backend nccl \ 33 | --lr-decay-style cosine \ 34 | --warmup .1 \ 35 | --checkpoint-activations \ 36 | --deepspeed-activation-checkpointing \ 37 | --max-position-embeddings 1089 \ 38 | --max-memory-length 0 \ 39 | --fp16 \ 40 | --txt-loss-scale 5 \ 41 | " 42 | 43 | gpt_options="${gpt_options} 44 | --deepspeed \ 45 | --deepspeed_config ${config_json} \ 46 | " 47 | 48 | 49 | run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py $@ ${gpt_options}" 50 | echo ${run_cmd} 51 | eval ${run_cmd} 52 | 53 | set +x 54 | -------------------------------------------------------------------------------- /scripts/super_resolution.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHECKPOINT_PATH=pretrained/cogview/cogview-sr 4 | NLAYERS=48 5 | NHIDDEN=2560 6 | NATT=40 7 | MAXSEQLEN=1345 8 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 9 | MPSIZE=1 10 | 11 | #SAMPLING ARGS 12 | TEMP=1.02 13 | #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p 14 | TOPK=200 15 | TOPP=0 16 | 17 | script_path=$(realpath $0) 18 | script_dir=$(dirname $script_path) 19 | 20 | MASTER_PORT=${MASTER_PORT} python generate_samples.py \ 21 | --deepspeed \ 22 | --model-parallel-size $MPSIZE \ 23 | --num-layers $NLAYERS \ 24 | --hidden-size $NHIDDEN \ 25 | --load $CHECKPOINT_PATH \ 26 | --num-attention-heads $NATT \ 27 | --max-position-embeddings 1089 \ 28 | --fp16 \ 29 | --temperature $TEMP \ 30 | --top_k $TOPK \ 31 | --top_p $TOPP \ 32 | --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ 33 | --query-window 64 \ 34 | --key-window-times 4 \ 35 | --num-pivot 256 \ 36 | --is-sparse 0 \ 37 | --max-position-embeddings-finetune $MAXSEQLEN \ 38 | --generation-task "super-resolution" \ 39 | --input-source interactive \ 40 | --output-path samples_sr \ 41 | --debug \ 42 | --device 0 \ 43 | $@ 44 | 45 | 46 | -------------------------------------------------------------------------------- /scripts/text2image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ==== tutorial settings: ===== 4 | # CHECKPOINT_PATH=data/checkpoints/cogview-bird_animal_tutorial-12-1024-1608-10-09-38 5 | # NLAYERS=12 6 | # NHIDDEN=1024 7 | # NATT=16 8 | 9 | CHECKPOINT_PATH=pretrained/cogview/cogview-base 10 | NLAYERS=48 11 | NHIDDEN=2560 12 | NATT=40 13 | MAXSEQLEN=1089 14 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 15 | MPSIZE=1 16 | 17 | #SAMPLING ARGS 18 | TEMP=1. 19 | #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p 20 | TOPK=200 21 | TOPP=0 22 | 23 | script_path=$(realpath $0) 24 | script_dir=$(dirname $script_path) 25 | 26 | MASTER_PORT=${MASTER_PORT} python generate_samples.py \ 27 | --deepspeed \ 28 | --model-parallel-size $MPSIZE \ 29 | --num-layers $NLAYERS \ 30 | --hidden-size $NHIDDEN \ 31 | --load $CHECKPOINT_PATH \ 32 | --num-attention-heads $NATT \ 33 | --max-position-embeddings 1089 \ 34 | --fp16 \ 35 | --temperature $TEMP \ 36 | --top_k $TOPK \ 37 | --top_p $TOPP \ 38 | --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ 39 | --query-window 64 \ 40 | --key-window-times 4 \ 41 | --num-pivot 256 \ 42 | --is-sparse 0 \ 43 | --max-position-embeddings-finetune $MAXSEQLEN \ 44 | --generation-task text2image \ 45 | --input-source ./input.txt \ 46 | --output-path samples_text2image \ 47 | --batch-size 4 \ 48 | --max-inference-batch-size 4 \ 49 | --device 0 \ 50 | $@ 51 | 52 | 53 | -------------------------------------------------------------------------------- /test_lmdb.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import os, sys 3 | from data_utils import get_tokenizer 4 | 5 | def initialize(file_name): 6 | env = lmdb.open(file_name, "r") 7 | return env 8 | 9 | def insert(env, sid, name): 10 | txn = env.begin(write=True) 11 | txn.put(str(sid).encode('utf-8'), name.encode('utf-8')) 12 | txn.commit() 13 | 14 | def delete(env, sid): 15 | txn = env.begin(write=True) 16 | txn.delete(str(sid).encode('utf-8')) 17 | txn.commit() 18 | 19 | def update(env, sid, name): 20 | txn = env.begin(write=True) 21 | txn.put(str(sid).encode('utf-8'), name.encode('utf-8')) 22 | txn.commit() 23 | 24 | 25 | import pickle 26 | def search(env, sid): 27 | txn = env.begin() 28 | data = pickle.loads(txn.get(str(sid).encode('utf-8'))) 29 | return data 30 | 31 | import argparse 32 | import torch 33 | from torchvision.utils import save_image 34 | 35 | if __name__ == "__main__": 36 | # settings 37 | lmdb_path = "data/ali_vqvae_hard_biggerset_011.lmdb" 38 | output_path = f"test_lmdb_{lmdb_path.split('/')[-1]}.jpg" 39 | args = argparse.Namespace() 40 | args.img_tokenizer_path = 'pretrained/vqvae/vqvae_hard_biggerset_011.pt' 41 | args.img_tokenizer_num_tokens = None 42 | device = 'cuda:0' 43 | 44 | torch.cuda.set_device(device) 45 | tokenizer = get_tokenizer(args) 46 | with lmdb.open(lmdb_path, readonly=True, lock=False) as env: 47 | imgs = [] 48 | txts = [] 49 | for i in range(20,50): 50 | txt, images = tokenizer.DecodeIds(search(env, i)) 51 | txts.append(txt) 52 | imgs.append(images[0]) 53 | print(txts) 54 | imgs = torch.cat(imgs, dim=0) 55 | save_image(imgs, output_path, normalize=True, range=None) -------------------------------------------------------------------------------- /vqvae/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | ============================================================================== 25 | Learning rate scheduler and VQ-VAE 26 | ============================================================================== 27 | 28 | Apache License, Version 2.0 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ 29 | 30 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 31 | 32 | 1. Definitions. 33 | 34 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 35 | 36 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 37 | 38 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 39 | 40 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 41 | 42 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 43 | 44 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 45 | 46 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 47 | 48 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 51 | 52 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 53 | 54 | 2. Grant of Copyright License. 55 | 56 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 57 | 58 | 3. Grant of Patent License. 59 | 60 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 61 | 62 | 4. Redistribution. 63 | 64 | You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 65 | 66 | You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 67 | 68 | 5. Submission of Contributions. 69 | 70 | Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 71 | 72 | 6. Trademarks. 73 | 74 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 75 | 76 | 7. Disclaimer of Warranty. 77 | 78 | Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 79 | 80 | 8. Limitation of Liability. 81 | 82 | In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 83 | 84 | 9. Accepting Warranty or Additional Liability. 85 | 86 | While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 87 | 88 | -------------------------------------------------------------------------------- /vqvae/README.md: -------------------------------------------------------------------------------- 1 | # vq-vae-2-pytorch 2 | Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch 3 | 4 | ## Update 5 | 6 | * 2020-06-01 7 | 8 | train_vqvae.py and vqvae.py now supports distributed training. You can use --n_gpu [NUM_GPUS] arguments for train_vqvae.py to use [NUM_GPUS] during training. 9 | 10 | ## Requisite 11 | 12 | * Python >= 3.6 13 | * PyTorch >= 1.1 14 | * lmdb (for storing extracted codes) 15 | 16 | [Checkpoint of VQ-VAE pretrained on FFHQ](vqvae_560.pt) 17 | 18 | ## Usage 19 | 20 | Currently supports 256px (top/bottom hierarchical prior) 21 | 22 | 1. Stage 1 (VQ-VAE) 23 | 24 | > python train_vqvae.py [DATASET PATH] 25 | 26 | If you use FFHQ, I highly recommends to preprocess images. (resize and convert to jpeg) 27 | 28 | 2. Extract codes for stage 2 training 29 | 30 | > python extract_code.py --ckpt checkpoint/[VQ-VAE CHECKPOINT] --name [LMDB NAME] [DATASET PATH] 31 | 32 | 3. Stage 2 (PixelSNAIL) 33 | 34 | > python train_pixelsnail.py [LMDB NAME] 35 | 36 | Maybe it is better to use larger PixelSNAIL model. Currently model size is reduced due to GPU constraints. 37 | 38 | ## Sample 39 | 40 | ### Stage 1 41 | 42 | Note: This is a training sample 43 | 44 | ![Sample from Stage 1 (VQ-VAE)](stage1_sample.png) 45 | -------------------------------------------------------------------------------- /vqvae/__init__.py: -------------------------------------------------------------------------------- 1 | from .api import new_model, img2code, code2img -------------------------------------------------------------------------------- /vqvae/api.py: -------------------------------------------------------------------------------- 1 | # This is an API file to export an VQVAE/... for tokenization 2 | # Can rewrite the APIs for VQGAN. 3 | # Don't forget to freeze the relavant .py files. 4 | 5 | import torch 6 | import math 7 | 8 | # production APIs 9 | 10 | from .vqvae_zc import VQVAE 11 | 12 | def new_model(): 13 | '''Return a New Instance of VQVAE, the same parameters with the pretrained model. 14 | This is for torch.load(). 15 | ''' 16 | return VQVAE( 17 | channel=512, n_res_block=0, 18 | n_res_channel=32, embed_dim=256, 19 | n_embed=8192, stride=6 20 | ) 21 | 22 | def img2code(model, img): 23 | '''Convert a batch of img to code 24 | Args: 25 | model: The tokenizer model. 26 | img: [b, c, h, w] 27 | ''' 28 | with torch.no_grad(): 29 | quant_t1, _, id_t1 = model.encode(img) 30 | return id_t1.view(img.shape[0], -1) 31 | 32 | def code2img(model, code): 33 | '''Convert a batch of code to imgs 34 | Args: 35 | model: ... 36 | code: [b, h, w] or [b, h*w] LongTensor 37 | ''' 38 | if len(code.shape) == 2: 39 | s = int(math.sqrt(len(code.view(-1))) + 1e-5) 40 | code = code.view(code.shape[0], s, s) 41 | with torch.no_grad(): 42 | out = model.decode_code(code) 43 | out = out * torch.tensor([0.30379, 0.32279, 0.32800], device=out.device).view(1, -1, 1, 1) + torch.tensor([0.79093, 0.76271, 0.75340], device=out.device).view(1, -1, 1, 1) 44 | return out 45 | 46 | -------------------------------------------------------------------------------- /vqvae/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed import ( 2 | get_rank, 3 | get_local_rank, 4 | is_primary, 5 | synchronize, 6 | get_world_size, 7 | all_reduce, 8 | all_gather, 9 | reduce_dict, 10 | data_sampler, 11 | LOCAL_PROCESS_GROUP, 12 | ) 13 | from .launch import launch 14 | -------------------------------------------------------------------------------- /vqvae/distributed/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils import data 7 | 8 | 9 | LOCAL_PROCESS_GROUP = None 10 | 11 | 12 | def is_primary(): 13 | return get_rank() == 0 14 | 15 | 16 | def get_rank(): 17 | if not dist.is_available(): 18 | return 0 19 | 20 | if not dist.is_initialized(): 21 | return 0 22 | 23 | return dist.get_rank() 24 | 25 | 26 | def get_local_rank(): 27 | if not dist.is_available(): 28 | return 0 29 | 30 | if not dist.is_initialized(): 31 | return 0 32 | 33 | if LOCAL_PROCESS_GROUP is None: 34 | raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None") 35 | 36 | return dist.get_rank(group=LOCAL_PROCESS_GROUP) 37 | 38 | 39 | def synchronize(): 40 | if not dist.is_available(): 41 | return 42 | 43 | if not dist.is_initialized(): 44 | return 45 | 46 | world_size = dist.get_world_size() 47 | 48 | if world_size == 1: 49 | return 50 | 51 | dist.barrier() 52 | 53 | 54 | def get_world_size(): 55 | if not dist.is_available(): 56 | return 1 57 | 58 | if not dist.is_initialized(): 59 | return 1 60 | 61 | return dist.get_world_size() 62 | 63 | 64 | def all_reduce(tensor, op=dist.ReduceOp.SUM): 65 | world_size = get_world_size() 66 | 67 | if world_size == 1: 68 | return tensor 69 | 70 | dist.all_reduce(tensor, op=op) 71 | 72 | return tensor 73 | 74 | 75 | def all_gather(data): 76 | world_size = get_world_size() 77 | 78 | if world_size == 1: 79 | return [data] 80 | 81 | buffer = pickle.dumps(data) 82 | storage = torch.ByteStorage.from_buffer(buffer) 83 | tensor = torch.ByteTensor(storage).to("cuda") 84 | 85 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 86 | size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)] 87 | dist.all_gather(size_list, local_size) 88 | size_list = [int(size.item()) for size in size_list] 89 | max_size = max(size_list) 90 | 91 | tensor_list = [] 92 | for _ in size_list: 93 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 94 | 95 | if local_size != max_size: 96 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 97 | tensor = torch.cat((tensor, padding), 0) 98 | 99 | dist.all_gather(tensor_list, tensor) 100 | 101 | data_list = [] 102 | 103 | for size, tensor in zip(size_list, tensor_list): 104 | buffer = tensor.cpu().numpy().tobytes()[:size] 105 | data_list.append(pickle.loads(buffer)) 106 | 107 | return data_list 108 | 109 | 110 | def reduce_dict(input_dict, average=True): 111 | world_size = get_world_size() 112 | 113 | if world_size < 2: 114 | return input_dict 115 | 116 | with torch.no_grad(): 117 | keys = [] 118 | values = [] 119 | 120 | for k in sorted(input_dict.keys()): 121 | keys.append(k) 122 | values.append(input_dict[k]) 123 | 124 | values = torch.stack(values, 0) 125 | dist.reduce(values, dst=0) 126 | 127 | if dist.get_rank() == 0 and average: 128 | values /= world_size 129 | 130 | reduced_dict = {k: v for k, v in zip(keys, values)} 131 | 132 | return reduced_dict 133 | 134 | 135 | def data_sampler(dataset, shuffle, distributed): 136 | if distributed: 137 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 138 | 139 | if shuffle: 140 | return data.RandomSampler(dataset) 141 | 142 | else: 143 | return data.SequentialSampler(dataset) 144 | -------------------------------------------------------------------------------- /vqvae/distributed/launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import distributed as dist 5 | from torch import multiprocessing as mp 6 | 7 | import distributed as dist_fn 8 | 9 | 10 | def find_free_port(): 11 | import socket 12 | 13 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 14 | 15 | sock.bind(("", 0)) 16 | port = sock.getsockname()[1] 17 | sock.close() 18 | 19 | return port 20 | 21 | 22 | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()): 23 | world_size = n_machine * n_gpu_per_machine 24 | 25 | if world_size > 1: 26 | if "OMP_NUM_THREADS" not in os.environ: 27 | os.environ["OMP_NUM_THREADS"] = "1" 28 | 29 | if dist_url == "auto": 30 | if n_machine != 1: 31 | raise ValueError('dist_url="auto" not supported in multi-machine jobs') 32 | 33 | port = find_free_port() 34 | dist_url = f"tcp://127.0.0.1:{port}" 35 | 36 | if n_machine > 1 and dist_url.startswith("file://"): 37 | raise ValueError( 38 | "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://" 39 | ) 40 | 41 | mp.spawn( 42 | distributed_worker, 43 | nprocs=n_gpu_per_machine, 44 | args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args), 45 | daemon=False, 46 | ) 47 | 48 | else: 49 | fn(*args) 50 | 51 | 52 | def distributed_worker( 53 | local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args 54 | ): 55 | if not torch.cuda.is_available(): 56 | raise OSError("CUDA is not available. Please check your environments") 57 | 58 | global_rank = machine_rank * n_gpu_per_machine + local_rank 59 | 60 | try: 61 | dist.init_process_group( 62 | backend="NCCL", 63 | init_method=dist_url, 64 | world_size=world_size, 65 | rank=global_rank, 66 | ) 67 | 68 | except Exception: 69 | raise OSError("failed to initialize NCCL groups") 70 | 71 | dist_fn.synchronize() 72 | 73 | if n_gpu_per_machine > torch.cuda.device_count(): 74 | raise ValueError( 75 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" 76 | ) 77 | 78 | torch.cuda.set_device(local_rank) 79 | 80 | if dist_fn.LOCAL_PROCESS_GROUP is not None: 81 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") 82 | 83 | n_machine = world_size // n_gpu_per_machine 84 | 85 | for i in range(n_machine): 86 | ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) 87 | pg = dist.new_group(ranks_on_i) 88 | 89 | if i == machine_rank: 90 | dist_fn.distributed.LOCAL_PROCESS_GROUP = pg 91 | 92 | fn(*args) 93 | --------------------------------------------------------------------------------