├── .flake8 ├── .github └── workflows │ └── pythonpublish.yml ├── LICENSE ├── README.md ├── examples ├── __init__.py ├── enwik8 │ ├── __init__.py │ ├── eval_enwik8.py │ ├── train_enwik8.py │ ├── train_enwik8_agp_struct.py │ └── train_enwik8_agp_unstruct.py ├── enwik8_tf │ ├── __init__.py │ ├── data_utils.py │ ├── eval.py │ ├── mem_transformer.py │ ├── train.py │ ├── train_agp_struct.py │ ├── train_agp_unstruct.py │ └── utils │ │ ├── __init__.py │ │ ├── adaptive_softmax.py │ │ ├── data_parallel.py │ │ ├── exp_utils.py │ │ ├── log_uniform_sampler.py │ │ ├── proj_adaptive_softmax.py │ │ └── vocabulary.py └── wt103 │ ├── __init__.py │ ├── eval.py │ ├── train.py │ ├── train_agp_struct.py │ ├── train_agp_unstruct.py │ ├── train_distributed.py │ ├── train_distributed.wgx.py │ └── utils │ ├── __init__.py │ ├── data_parallel.py │ ├── data_utils.py │ └── log_uniform_sampler.py ├── flop ├── __init__.py ├── agp.py ├── embedding.py ├── hardconcrete.py ├── linear.py ├── utils.py └── version.py ├── requirements.txt └── setup.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | select = C,E,F,W,B,B950 4 | ignore = E203, E501, W503 5 | -------------------------------------------------------------------------------- /.github/workflows/pythonpublish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v1 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ASAPP Inc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FLOP 2 | 3 | Pytorch Library for L0 based pruning, as proposed in the paper: 4 | [Structured Pruning of Large Language Models](https://arxiv.org/abs/1910.04732)(EMNLP 2020) 5 | 6 | ## Install 7 | 8 | `pip install -U flop` 9 | 10 | ## Usage 11 | 12 | Create a hard concrete mask of size N: 13 | 14 | ```python 15 | from flop import HardConrete 16 | 17 | N = 100 18 | hardconcrete = HardConcrete(n_in=N) 19 | ``` 20 | 21 | You can then sample masks on the fly with: 22 | 23 | ```python 24 | mask = hardconcrete() 25 | ``` 26 | 27 | Note that during evaluation, a mask is compiled and fixed. 28 | 29 | You may also find these other objects useful: 30 | 31 | - ``ProjectedLinear``: replaces a linear layer to include an intermediate projection. 32 | - ``HardConreteProjectedLinear``: the hard conrete version of the ``ProjectedLinear`` module. 33 | 34 | You may instantiate the HardConcrete objects directly, or you can choose to first train with 35 | a ``ProjectedLinear`` module, and introduce the hardconcrete mask with: 36 | 37 | ```python 38 | module = ProjectedLinear(...) 39 | # Perform training 40 | 41 | # ... 42 | 43 | # Start pruning 44 | pruning_module = HardConcreteProjectedLinear.from_module(module) 45 | ``` 46 | 47 | We also provide some utily functions to replace all ProjectedLinear modules in a model: 48 | 49 | ```python 50 | from flop import make_hard_concrete 51 | 52 | model = make_hard_concrete(model) 53 | ``` 54 | 55 | ## Replicate results from the paper 56 | 57 | To replicate the SRU numbers, please look at the script ``examples/train_enwik8.py``. 58 | 59 | ## Cite 60 | 61 | ```sh 62 | @inproceedings{wang-etal-2020-structured, 63 | title = "Structured Pruning of Large Language Models", 64 | author = "Wang, Ziheng and 65 | Wohlwend, Jeremy and 66 | Lei, Tao", 67 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)", 68 | month = nov, 69 | year = "2020", 70 | address = "Online", 71 | publisher = "Association for Computational Linguistics", 72 | url = "https://www.aclweb.org/anthology/2020.emnlp-main.496", 73 | doi = "10.18653/v1/2020.emnlp-main.496", 74 | pages = "6151--6162" 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/flop/bdfc1845dbdddde70e65ce5a98ef7d0070833541/examples/__init__.py -------------------------------------------------------------------------------- /examples/enwik8/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/flop/bdfc1845dbdddde70e65ce5a98ef7d0070833541/examples/enwik8/__init__.py -------------------------------------------------------------------------------- /examples/enwik8/eval_enwik8.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | import sru 9 | import flop 10 | 11 | 12 | def read_corpus(path, num_test_symbols=5000000): 13 | raw_data = open(path).read() 14 | raw_data = np.fromstring(raw_data, dtype=np.uint8) 15 | unique, data = np.unique(raw_data, return_inverse=True) 16 | train_data = data[: -2 * num_test_symbols] 17 | valid_data = data[-2 * num_test_symbols : -num_test_symbols] 18 | test_data = data[-num_test_symbols:] 19 | return train_data, valid_data, test_data, unique 20 | 21 | 22 | def create_batches(data_ids, batch_size): 23 | N = len(data_ids) 24 | L = ((N - 1) // batch_size) * batch_size 25 | x = np.copy(data_ids[:L].reshape(batch_size, -1).T) 26 | y = np.copy(data_ids[1 : L + 1].reshape(batch_size, -1).T) 27 | x, y = torch.from_numpy(x), torch.from_numpy(y) 28 | x, y = x.contiguous(), y.contiguous() 29 | x, y = x.cuda(), y.cuda() 30 | return x, y 31 | 32 | 33 | class CustomLinear(nn.Linear): 34 | def __init__(self, in_features, out_features, bias=False): 35 | super(CustomLinear, self).__init__(in_features, out_features, bias=bias) 36 | 37 | def forward(self, data, **kwargs): 38 | return super().forward(data) 39 | 40 | 41 | class Model(nn.Module): 42 | def __init__(self, words, args): 43 | super(Model, self).__init__() 44 | self.args = args 45 | if args.n_e: 46 | self.n_e = args.n_e 47 | else: 48 | self.n_e = len(words) if len(words) < args.n_d else args.n_d 49 | self.n_d = args.n_d 50 | self.depth = args.depth 51 | self.drop = nn.Dropout(args.dropout) 52 | self.embedding_layer = nn.Embedding(len(words), self.n_e) 53 | self.n_V = len(words) 54 | custom_m_list = [CustomLinear(self.n_e, self.n_d * 4, bias=False)] 55 | for i in range(self.depth - 1): 56 | custom_m_list.append( 57 | flop.ProjectedLinear( 58 | self.n_d, self.n_d * 3, proj_features=args.n_proj, bias=False 59 | ) 60 | ) 61 | self.rnn = sru.SRU( 62 | self.n_e, 63 | self.n_d, 64 | self.depth, 65 | dropout=args.dropout, 66 | highway_bias=args.bias, 67 | layer_norm=args.layer_norm, 68 | rescale=args.rescale, 69 | custom_m=custom_m_list, 70 | ) 71 | self.output_layer = nn.Linear(self.n_d, self.n_V) 72 | self.init_weights() 73 | 74 | def init_weights(self, reinit_rnn=False): 75 | params = list(self.embedding_layer.parameters()) + list( 76 | self.output_layer.parameters() 77 | ) 78 | for p in params: 79 | if p.dim() > 1: # matrix 80 | val = (3.0 / p.size(0)) ** 0.5 81 | p.data.uniform_(-val, val) 82 | else: 83 | p.data.zero_() 84 | if reinit_rnn: 85 | for p in self.rnn.parameters(): 86 | if p.dim() > 1: # matrix 87 | val = (3.0 / p.size(0)) ** 0.5 88 | p.data.uniform_(-val, val) 89 | 90 | def forward(self, x, hidden): 91 | emb = self.drop(self.embedding_layer(x)) 92 | output, hidden = self.rnn(emb, hidden) 93 | output = self.drop(output) 94 | output = output.view(-1, output.size(2)) 95 | output = self.output_layer(output) 96 | return output, hidden 97 | 98 | def init_hidden(self, batch_size): 99 | weight = next(self.parameters()).data 100 | zeros = weight.new(self.depth, batch_size, self.n_d).zero_() 101 | return zeros 102 | 103 | 104 | def calc_norm(lis): 105 | l2_sum = sum(x.norm() ** 2 for x in lis) 106 | return l2_sum ** 0.5 107 | 108 | 109 | def eval_model(model, valid): 110 | with torch.no_grad(): 111 | model.eval() 112 | args = model.args 113 | batch_size = valid[0].size(1) 114 | total_loss = 0.0 115 | unroll_size = args.unroll_size 116 | criterion = nn.CrossEntropyLoss(size_average=False) 117 | hidden = model.init_hidden(batch_size) 118 | N = (len(valid[0]) - 1) // unroll_size + 1 119 | for i in range(N): 120 | x = valid[0][i * unroll_size : (i + 1) * unroll_size] 121 | y = valid[1][i * unroll_size : (i + 1) * unroll_size].view(-1) 122 | hidden.detach_() 123 | output, hidden = model(x, hidden) 124 | loss = criterion(output, y) 125 | total_loss += loss.item() 126 | avg_loss = total_loss / valid[1].numel() 127 | ppl = np.exp(avg_loss) 128 | model.train() 129 | return ppl, avg_loss 130 | 131 | 132 | def copy_model(model): 133 | states = model.state_dict() 134 | for k in states: 135 | v = states[k] 136 | states[k] = v.clone().cpu() 137 | return states 138 | 139 | 140 | def main(args): 141 | train, dev, test, words = read_corpus(args.data) 142 | dev_, test_ = dev, test 143 | # train = create_batches(train, args.batch_size) 144 | dev = create_batches(dev, args.batch_size) 145 | test = create_batches(test, args.batch_size) 146 | 147 | model = Model(words, args) 148 | model.cuda() 149 | flop.make_projected_linear_with_mask(model.rnn, in_place=True) 150 | if args.load: 151 | model.load_state_dict(torch.load(args.load)) 152 | 153 | model.cuda() 154 | dev = create_batches(dev_, 1) 155 | test = create_batches(test_, 1) 156 | dev_ppl, dev_loss = eval_model(model, dev) 157 | test_ppl, test_loss = eval_model(model, test) 158 | sys.stdout.write( 159 | "dev_bpc={:.3f} test_bpc={:.3f}\n".format(np.log2(dev_ppl), np.log2(test_ppl)) 160 | ) 161 | 162 | 163 | if __name__ == "__main__": 164 | argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler="resolve") 165 | argparser.add_argument("--log", type=str, default="") 166 | argparser.add_argument("--noam", type=bool, default=True) 167 | argparser.add_argument("--warmup_steps", type=int, default=16000) 168 | argparser.add_argument("--layer_norm", action="store_true") 169 | argparser.add_argument("--rescale", action="store_true") 170 | argparser.add_argument("--data", type=str, required=True, help="training file") 171 | argparser.add_argument("--batch_size", "--batch", type=int, default=64) 172 | argparser.add_argument("--update_param_freq", type=int, default=1) 173 | argparser.add_argument("--unroll_size", type=int, default=256) 174 | argparser.add_argument("--max_epoch", type=int, default=100) 175 | argparser.add_argument("--n_e", type=int, default=0) 176 | argparser.add_argument("--n_d", "--d", type=int, default=3056) 177 | argparser.add_argument("--n_proj", type=int, default=512) 178 | argparser.add_argument( 179 | "--dropout", type=float, default=0.1, help="dropout probability" 180 | ) 181 | argparser.add_argument( 182 | "--bias", type=float, default=-3, help="intial bias of highway gates", 183 | ) 184 | argparser.add_argument("--depth", type=int, default=6) 185 | argparser.add_argument("--lr", type=float, default=2) 186 | argparser.add_argument("--weight_decay", type=float, default=1e-7) 187 | argparser.add_argument("--clip_grad", type=float, default=0.3) 188 | argparser.add_argument("--log_period", type=int, default=1000000) 189 | argparser.add_argument("--save", type=str, default="") 190 | argparser.add_argument("--load", type=str, default="") 191 | 192 | argparser.add_argument("--prune", type=bool, default=True) 193 | argparser.add_argument("--prune_lr", type=float, default=2) 194 | argparser.add_argument("--prune_warmup", type=int, default=0) 195 | argparser.add_argument("--prune_start_epoch", type=int, default=0) 196 | argparser.add_argument("--prune_sparsity", type=float, default=0.9) 197 | argparser.add_argument("--prune_end_epoch", type=int, default=30) 198 | argparser.add_argument("--l1_lambda", type=float, default=0) 199 | 200 | args = argparser.parse_args() 201 | main(args) 202 | -------------------------------------------------------------------------------- /examples/enwik8/train_enwik8.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import time 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.optim import Adam 10 | from tensorboardX import SummaryWriter 11 | 12 | import sru 13 | import flop 14 | 15 | 16 | def read_corpus(path, num_test_symbols=5000000): 17 | raw_data = open(path).read() 18 | raw_data = np.fromstring(raw_data, dtype=np.uint8) 19 | unique, data = np.unique(raw_data, return_inverse=True) 20 | train_data = data[: -2 * num_test_symbols] 21 | valid_data = data[-2 * num_test_symbols : -num_test_symbols] 22 | test_data = data[-num_test_symbols:] 23 | return train_data, valid_data, test_data, unique 24 | 25 | 26 | def create_batches(data_ids, batch_size): 27 | N = len(data_ids) 28 | L = ((N - 1) // batch_size) * batch_size 29 | x = np.copy(data_ids[:L].reshape(batch_size, -1).T) 30 | y = np.copy(data_ids[1 : L + 1].reshape(batch_size, -1).T) 31 | x, y = torch.from_numpy(x), torch.from_numpy(y) 32 | x, y = x.contiguous(), y.contiguous() 33 | x, y = x.cuda(), y.cuda() 34 | return x, y 35 | 36 | 37 | class CustomLinear(nn.Linear): 38 | def __init__(self, in_features, out_features, bias=False): 39 | super(CustomLinear, self).__init__(in_features, out_features, bias=bias) 40 | 41 | def forward(self, data, **kwargs): 42 | return super().forward(data) 43 | 44 | 45 | class Model(nn.Module): 46 | def __init__(self, words, args): 47 | super(Model, self).__init__() 48 | self.args = args 49 | if args.n_e: 50 | self.n_e = args.n_e 51 | else: 52 | self.n_e = len(words) if len(words) < args.n_d else args.n_d 53 | self.n_d = args.n_d 54 | self.depth = args.depth 55 | self.drop = nn.Dropout(args.dropout) 56 | self.embedding_layer = nn.Embedding(len(words), self.n_e) 57 | self.n_V = len(words) 58 | custom_m_list = [CustomLinear(self.n_e, self.n_d * 4, bias=False)] 59 | for i in range(self.depth - 1): 60 | custom_m_list.append( 61 | flop.ProjectedLinear( 62 | self.n_d, self.n_d * 3, proj_features=args.n_proj, bias=False 63 | ) 64 | ) 65 | self.rnn = sru.SRU( 66 | self.n_e, 67 | self.n_d, 68 | self.depth, 69 | dropout=args.dropout, 70 | highway_bias=args.bias, 71 | layer_norm=args.layer_norm, 72 | rescale=args.rescale, 73 | custom_m=custom_m_list, 74 | ) 75 | self.output_layer = nn.Linear(self.n_d, self.n_V) 76 | self.init_weights() 77 | 78 | def init_weights(self, reinit_rnn=False): 79 | params = list(self.embedding_layer.parameters()) + list( 80 | self.output_layer.parameters() 81 | ) 82 | for p in params: 83 | if p.dim() > 1: # matrix 84 | val = (3.0 / p.size(0)) ** 0.5 85 | p.data.uniform_(-val, val) 86 | else: 87 | p.data.zero_() 88 | if reinit_rnn: 89 | for p in self.rnn.parameters(): 90 | if p.dim() > 1: # matrix 91 | val = (3.0 / p.size(0)) ** 0.5 92 | p.data.uniform_(-val, val) 93 | 94 | def forward(self, x, hidden): 95 | emb = self.drop(self.embedding_layer(x)) 96 | output, hidden = self.rnn(emb, hidden) 97 | output = self.drop(output) 98 | output = output.view(-1, output.size(2)) 99 | output = self.output_layer(output) 100 | return output, hidden 101 | 102 | def init_hidden(self, batch_size): 103 | weight = next(self.parameters()).data 104 | zeros = weight.new(self.depth, batch_size, self.n_d).zero_() 105 | return zeros 106 | 107 | 108 | def calc_norm(lis): 109 | l2_sum = sum(x.norm() ** 2 for x in lis) 110 | return l2_sum ** 0.5 111 | 112 | 113 | def eval_model(model, valid): 114 | with torch.no_grad(): 115 | model.eval() 116 | args = model.args 117 | batch_size = valid[0].size(1) 118 | total_loss = 0.0 119 | unroll_size = args.unroll_size 120 | criterion = nn.CrossEntropyLoss(size_average=False) 121 | hidden = model.init_hidden(batch_size) 122 | N = (len(valid[0]) - 1) // unroll_size + 1 123 | for i in range(N): 124 | x = valid[0][i * unroll_size : (i + 1) * unroll_size] 125 | y = valid[1][i * unroll_size : (i + 1) * unroll_size].view(-1) 126 | hidden.detach_() 127 | output, hidden = model(x, hidden) 128 | loss = criterion(output, y) 129 | total_loss += loss.item() 130 | avg_loss = total_loss / valid[1].numel() 131 | ppl = np.exp(avg_loss) 132 | model.train() 133 | return ppl, avg_loss 134 | 135 | 136 | def copy_model(model): 137 | states = model.state_dict() 138 | for k in states: 139 | v = states[k] 140 | states[k] = v.clone().cpu() 141 | return states 142 | 143 | 144 | def main(args): 145 | log_path = "{}_{}".format(args.log, random.randint(1, 100)) 146 | train_writer = SummaryWriter(log_dir=log_path + "/train") 147 | dev_writer = SummaryWriter(log_dir=log_path + "/dev") 148 | 149 | train, dev, test, words = read_corpus(args.data) 150 | dev_, test_ = dev, test 151 | train = create_batches(train, args.batch_size) 152 | dev = create_batches(dev, args.batch_size) 153 | test = create_batches(test, args.batch_size) 154 | 155 | model = Model(words, args) 156 | if args.load: 157 | model.load_state_dict(torch.load(args.load)) 158 | model.cuda() 159 | print(model) 160 | print("vocab size: {}".format(model.n_V)) 161 | 162 | lr = 1.0 if not args.noam else 1.0 / (args.n_d ** 0.5) / (args.warmup_steps ** 1.5) 163 | if args.prune: 164 | # in place substituion of linear ops in SRU 165 | flop.make_hard_concrete(model.rnn, in_place=True) 166 | model.cuda() 167 | print("model after inserting hardconcrete:") 168 | print(model) 169 | hc_modules = flop.get_hardconcrete_modules(model) 170 | hc_parameters = [ 171 | p for m in hc_modules for p in m.parameters() if p.requires_grad 172 | ] 173 | optimizer_hc = Adam(hc_parameters, lr=lr * args.prune_lr, weight_decay=0) 174 | num_hardconcrete_params = sum(x.numel() for x in hc_parameters) 175 | print("num of hardconcrete paramters: {}".format(num_hardconcrete_params)) 176 | lambda_1 = nn.Parameter(torch.tensor(0.0).cuda()) 177 | lambda_2 = nn.Parameter(torch.tensor(0.0).cuda()) 178 | optimizer_max = Adam([lambda_1, lambda_2], lr=lr, weight_decay=0) 179 | optimizer_max.param_groups[0]["lr"] = -lr * args.prune_lr 180 | hc_linear_modules = flop.get_hardconcrete_linear_modules(model) 181 | num_prunable_params = sum( 182 | m.num_prunable_parameters() for m in hc_linear_modules 183 | ) 184 | print("num of prunable paramters: {}".format(num_prunable_params)) 185 | else: 186 | args.prune_start_epoch = args.max_epoch 187 | 188 | m_parameters = [ 189 | i[1] 190 | for i in model.named_parameters() 191 | if i[1].requires_grad and "log_alpha" not in i[0] 192 | ] 193 | optimizer = Adam(m_parameters, lr=lr * args.lr, weight_decay=args.weight_decay) 194 | num_params = sum(x.numel() for x in m_parameters if x.requires_grad) 195 | print("num of parameters: {}".format(num_params)) 196 | 197 | nbatch = 1 198 | niter = 1 199 | best_dev = 1e8 200 | unroll_size = args.unroll_size 201 | batch_size = args.batch_size 202 | N = (len(train[0]) - 1) // unroll_size + 1 203 | criterion = nn.CrossEntropyLoss() 204 | 205 | model.zero_grad() 206 | if args.prune: 207 | optimizer_max.zero_grad() 208 | optimizer_hc.zero_grad() 209 | 210 | for epoch in range(args.max_epoch): 211 | start_time = time.time() 212 | model.train() 213 | hidden = model.init_hidden(batch_size) 214 | start_prune = epoch >= args.prune_start_epoch 215 | 216 | for i in range(N): 217 | x = train[0][i * unroll_size : (i + 1) * unroll_size] 218 | y = train[1][i * unroll_size : (i + 1) * unroll_size].view(-1) 219 | hidden.detach_() 220 | 221 | # language model forward and backward 222 | output, hidden = model(x, hidden) 223 | loss = criterion(output, y) 224 | (loss / args.update_param_freq).backward() 225 | loss = loss.item() 226 | lagrangian_loss = 0 227 | target_sparsity = 0 228 | expected_sparsity = 0 229 | 230 | # add lagrangian loss (regularization) when pruning 231 | if start_prune: 232 | # compute target sparsity with (optionally) linear warmup 233 | target_sparsity = args.prune_sparsity 234 | if args.prune_warmup > 0: 235 | niter_ = niter - args.prune_start_epoch * N 236 | target_sparsity *= min(1.0, niter_ / args.prune_warmup) 237 | 238 | # compute expected model size and sparsity 239 | expected_size = sum( 240 | m.num_parameters(train=True) for m in hc_linear_modules 241 | ) 242 | expected_sparsity = 1.0 - expected_size / num_prunable_params 243 | 244 | # compute lagrangian loss 245 | lagrangian_loss = ( 246 | lambda_1 * (expected_sparsity - target_sparsity) 247 | + lambda_2 * (expected_sparsity - target_sparsity) ** 2 248 | ) 249 | (lagrangian_loss / args.update_param_freq).backward() 250 | expected_sparsity = expected_sparsity.item() 251 | lagrangian_loss = lagrangian_loss.item() 252 | 253 | # log training stats 254 | if (niter - 1) % 100 == 0 and nbatch % args.update_param_freq == 0: 255 | if args.prune: 256 | train_writer.add_scalar( 257 | "sparsity/expected_sparsity", expected_sparsity, niter 258 | ) 259 | train_writer.add_scalar( 260 | "sparsity/target_sparsity", target_sparsity, niter 261 | ) 262 | train_writer.add_scalar( 263 | "loss/lagrangian_loss", lagrangian_loss, niter 264 | ) 265 | train_writer.add_scalar("lambda/1", lambda_1.item(), niter) 266 | train_writer.add_scalar("lambda/2", lambda_2.item(), niter) 267 | if (niter - 1) % 3000 == 0: 268 | for index, layer in enumerate(hc_modules): 269 | train_writer.add_histogram( 270 | "log_alpha/{}".format(index), 271 | layer.log_alpha, 272 | niter, 273 | bins="sqrt", 274 | ) 275 | sys.stderr.write( 276 | "\r{:.4f} {:.2f} {:.2f}".format( 277 | loss, lagrangian_loss, expected_sparsity, 278 | ) 279 | ) 280 | train_writer.add_scalar("loss/lm_loss", loss, niter) 281 | train_writer.add_scalar( 282 | "loss/total_loss", loss + lagrangian_loss, niter 283 | ) 284 | train_writer.add_scalar( 285 | "parameter_norm", calc_norm([x.data for x in m_parameters]), niter 286 | ) 287 | train_writer.add_scalar( 288 | "gradient_norm", 289 | calc_norm([x.grad for x in m_parameters if x.grad is not None]), 290 | niter, 291 | ) 292 | 293 | # perform gradient decent every few number of backward() 294 | if nbatch % args.update_param_freq == 0: 295 | if args.clip_grad > 0: 296 | torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad) 297 | optimizer.step() 298 | if start_prune: 299 | optimizer_max.step() 300 | optimizer_hc.step() 301 | # clear gradient 302 | model.zero_grad() 303 | if args.prune: 304 | optimizer_max.zero_grad() 305 | optimizer_hc.zero_grad() 306 | niter += 1 307 | 308 | if nbatch % args.log_period == 0 or i == N - 1: 309 | elapsed_time = (time.time() - start_time) / 60.0 310 | dev_ppl, dev_loss = eval_model(model, dev) 311 | dev_writer.add_scalar("loss/lm_loss", dev_loss, niter) 312 | dev_writer.add_scalar("bpc", np.log2(dev_ppl), niter) 313 | sparsity = 0 314 | if args.prune: 315 | pruned_size = sum( 316 | m.num_parameters(train=False) for m in hc_linear_modules 317 | ) 318 | sparsity = 1.0 - pruned_size / num_prunable_params 319 | dev_writer.add_scalar("sparsity/hard_sparsity", sparsity, niter) 320 | dev_writer.add_scalar( 321 | "model_size/total_prunable", num_prunable_params, niter 322 | ) 323 | dev_writer.add_scalar( 324 | "model_size/current_prunable", pruned_size, niter 325 | ) 326 | dev_writer.add_scalar("model_size/total", num_params, niter) 327 | dev_writer.add_scalar( 328 | "model_size/current", 329 | num_params - num_prunable_params + pruned_size, 330 | niter, 331 | ) 332 | sys.stdout.write( 333 | "\rIter={} lr={:.5f} train_loss={:.4f} dev_loss={:.4f}" 334 | " dev_bpc={:.2f} sparsity={:.2f}\teta={:.1f}m\t[{:.1f}m]\n".format( 335 | niter, 336 | optimizer.param_groups[0]["lr"], 337 | loss, 338 | dev_loss, 339 | np.log2(dev_ppl), 340 | sparsity, 341 | elapsed_time * N / (i + 1), 342 | elapsed_time, 343 | ) 344 | ) 345 | if dev_ppl < best_dev: 346 | if (not args.prune) or sparsity > args.prune_sparsity - 0.02: 347 | best_dev = dev_ppl 348 | checkpoint = copy_model(model) 349 | sys.stdout.write("\n") 350 | sys.stdout.flush() 351 | 352 | nbatch += 1 353 | if args.noam: 354 | lr = min(1.0 / (niter ** 0.5), niter / (args.warmup_steps ** 1.5)) 355 | optimizer.param_groups[0]["lr"] = lr * args.lr / (args.n_d ** 0.5) 356 | if args.noam and start_prune: 357 | niter_ = niter - args.prune_start_epoch * N 358 | lr = min(1.0 / (niter_ ** 0.5), niter_ / (args.warmup_steps ** 1.5)) 359 | optimizer_max.param_groups[0]["lr"] = ( 360 | -lr * args.prune_lr / (args.n_d ** 0.5) 361 | ) 362 | optimizer_hc.param_groups[0]["lr"] = lr * args.lr / (args.n_d ** 0.5) 363 | 364 | if args.save and (epoch + 1) % 10 == 0: 365 | torch.save(copy_model(model), "{}.{}.{:.3f}.pt".format( 366 | args.save, 367 | epoch + 1, 368 | sparsity 369 | )) 370 | 371 | train_writer.close() 372 | dev_writer.close() 373 | 374 | model.load_state_dict(checkpoint) 375 | model.cuda() 376 | dev = create_batches(dev_, 1) 377 | test = create_batches(test_, 1) 378 | dev_ppl, dev_loss = eval_model(model, dev) 379 | test_ppl, test_loss = eval_model(model, test) 380 | sys.stdout.write( 381 | "dev_bpc={:.3f} test_bpc={:.3f}\n".format(np.log2(dev_ppl), np.log2(test_ppl)) 382 | ) 383 | 384 | 385 | if __name__ == "__main__": 386 | argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler="resolve") 387 | argparser.add_argument("--log", type=str, required=True) 388 | argparser.add_argument("--noam", action="store_true") 389 | argparser.add_argument("--warmup_steps", type=int, default=16000) 390 | argparser.add_argument("--layer_norm", action="store_true") 391 | argparser.add_argument("--rescale", action="store_true") 392 | argparser.add_argument("--data", type=str, required=True, help="training file") 393 | argparser.add_argument("--batch_size", "--batch", type=int, default=64) 394 | argparser.add_argument("--update_param_freq", type=int, default=1) 395 | argparser.add_argument("--unroll_size", type=int, default=256) 396 | argparser.add_argument("--max_epoch", type=int, default=100) 397 | argparser.add_argument("--n_e", type=int, default=0) 398 | argparser.add_argument("--n_d", "--d", type=int, default=3056) 399 | argparser.add_argument("--n_proj", type=int, default=512) 400 | argparser.add_argument( 401 | "--dropout", type=float, default=0.1, help="dropout probability" 402 | ) 403 | argparser.add_argument( 404 | "--bias", type=float, default=-3, help="intial bias of highway gates", 405 | ) 406 | argparser.add_argument("--depth", type=int, default=6) 407 | argparser.add_argument("--lr", type=float, default=2) 408 | argparser.add_argument("--weight_decay", type=float, default=1e-7) 409 | argparser.add_argument("--clip_grad", type=float, default=0.3) 410 | argparser.add_argument("--log_period", type=int, default=1000000) 411 | argparser.add_argument("--save", type=str, default="") 412 | argparser.add_argument("--load", type=str, default="") 413 | 414 | argparser.add_argument("--prune", action="store_true") 415 | argparser.add_argument("--prune_lr", type=float, default=3) 416 | argparser.add_argument("--prune_warmup", type=int, default=0) 417 | argparser.add_argument("--prune_sparsity", type=float, default=0.0) 418 | argparser.add_argument("--prune_start_epoch", type=int, default=0) 419 | 420 | args = argparser.parse_args() 421 | print(args) 422 | main(args) 423 | -------------------------------------------------------------------------------- /examples/enwik8/train_enwik8_agp_struct.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import time 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.optim import Adam 10 | from tensorboardX import SummaryWriter 11 | 12 | import sru 13 | import flop 14 | 15 | 16 | def read_corpus(path, num_test_symbols=5000000): 17 | raw_data = open(path).read() 18 | raw_data = np.fromstring(raw_data, dtype=np.uint8) 19 | unique, data = np.unique(raw_data, return_inverse=True) 20 | train_data = data[: -2 * num_test_symbols] 21 | valid_data = data[-2 * num_test_symbols : -num_test_symbols] 22 | test_data = data[-num_test_symbols:] 23 | return train_data, valid_data, test_data, unique 24 | 25 | 26 | def create_batches(data_ids, batch_size): 27 | N = len(data_ids) 28 | L = ((N - 1) // batch_size) * batch_size 29 | x = np.copy(data_ids[:L].reshape(batch_size, -1).T) 30 | y = np.copy(data_ids[1 : L + 1].reshape(batch_size, -1).T) 31 | x, y = torch.from_numpy(x), torch.from_numpy(y) 32 | x, y = x.contiguous(), y.contiguous() 33 | x, y = x.cuda(), y.cuda() 34 | return x, y 35 | 36 | 37 | class CustomLinear(nn.Linear): 38 | def __init__(self, in_features, out_features, bias=False): 39 | super(CustomLinear, self).__init__(in_features, out_features, bias=bias) 40 | 41 | def forward(self, data, **kwargs): 42 | return super().forward(data) 43 | 44 | 45 | class Model(nn.Module): 46 | def __init__(self, words, args): 47 | super(Model, self).__init__() 48 | self.args = args 49 | if args.n_e: 50 | self.n_e = args.n_e 51 | else: 52 | self.n_e = len(words) if len(words) < args.n_d else args.n_d 53 | self.n_d = args.n_d 54 | self.depth = args.depth 55 | self.drop = nn.Dropout(args.dropout) 56 | self.embedding_layer = nn.Embedding(len(words), self.n_e) 57 | self.n_V = len(words) 58 | custom_m_list = [CustomLinear(self.n_e, self.n_d * 4, bias=False)] 59 | for i in range(self.depth - 1): 60 | custom_m_list.append( 61 | flop.ProjectedLinear( 62 | self.n_d, self.n_d * 3, proj_features=args.n_proj, bias=False 63 | ) 64 | ) 65 | self.rnn = sru.SRU( 66 | self.n_e, 67 | self.n_d, 68 | self.depth, 69 | dropout=args.dropout, 70 | highway_bias=args.bias, 71 | layer_norm=args.layer_norm, 72 | rescale=args.rescale, 73 | custom_m=custom_m_list, 74 | ) 75 | self.output_layer = nn.Linear(self.n_d, self.n_V) 76 | self.init_weights() 77 | 78 | def init_weights(self, reinit_rnn=False): 79 | params = list(self.embedding_layer.parameters()) + list( 80 | self.output_layer.parameters() 81 | ) 82 | for p in params: 83 | if p.dim() > 1: # matrix 84 | val = (3.0 / p.size(0)) ** 0.5 85 | p.data.uniform_(-val, val) 86 | else: 87 | p.data.zero_() 88 | if reinit_rnn: 89 | for p in self.rnn.parameters(): 90 | if p.dim() > 1: # matrix 91 | val = (3.0 / p.size(0)) ** 0.5 92 | p.data.uniform_(-val, val) 93 | 94 | def forward(self, x, hidden): 95 | emb = self.drop(self.embedding_layer(x)) 96 | output, hidden = self.rnn(emb, hidden) 97 | output = self.drop(output) 98 | output = output.view(-1, output.size(2)) 99 | output = self.output_layer(output) 100 | return output, hidden 101 | 102 | def init_hidden(self, batch_size): 103 | weight = next(self.parameters()).data 104 | zeros = weight.new(self.depth, batch_size, self.n_d).zero_() 105 | return zeros 106 | 107 | 108 | def calc_norm(lis): 109 | l2_sum = sum(x.norm() ** 2 for x in lis) 110 | return l2_sum ** 0.5 111 | 112 | 113 | def eval_model(model, valid): 114 | with torch.no_grad(): 115 | model.eval() 116 | args = model.args 117 | batch_size = valid[0].size(1) 118 | total_loss = 0.0 119 | unroll_size = args.unroll_size 120 | criterion = nn.CrossEntropyLoss(size_average=False) 121 | hidden = model.init_hidden(batch_size) 122 | N = (len(valid[0]) - 1) // unroll_size + 1 123 | for i in range(N): 124 | x = valid[0][i * unroll_size : (i + 1) * unroll_size] 125 | y = valid[1][i * unroll_size : (i + 1) * unroll_size].view(-1) 126 | hidden.detach_() 127 | output, hidden = model(x, hidden) 128 | loss = criterion(output, y) 129 | total_loss += loss.item() 130 | avg_loss = total_loss / valid[1].numel() 131 | ppl = np.exp(avg_loss) 132 | model.train() 133 | return ppl, avg_loss 134 | 135 | 136 | def copy_model(model): 137 | states = model.state_dict() 138 | for k in states: 139 | v = states[k] 140 | states[k] = v.clone().cpu() 141 | return states 142 | 143 | 144 | def main(args): 145 | log_path = "{}_{}".format(args.log, random.randint(1, 100)) 146 | train_writer = SummaryWriter(log_dir=log_path + "/train") 147 | dev_writer = SummaryWriter(log_dir=log_path + "/dev") 148 | 149 | train, dev, test, words = read_corpus(args.data) 150 | dev_, test_ = dev, test 151 | train = create_batches(train, args.batch_size) 152 | dev = create_batches(dev, args.batch_size) 153 | test = create_batches(test, args.batch_size) 154 | 155 | model = Model(words, args) 156 | if args.load: 157 | model.load_state_dict(torch.load(args.load)) 158 | model.cuda() 159 | print(model) 160 | print("vocab size: {}".format(model.n_V)) 161 | 162 | lr = 1.0 if not args.noam else 1.0 / (args.n_d ** 0.5) / (args.warmup_steps ** 1.5) 163 | if args.prune: 164 | # in place substituion of linear ops in SRU 165 | flop.make_projected_linear_with_mask(model.rnn, in_place=True, init_zero=True) 166 | model.cuda() 167 | print("model after inserting masks:") 168 | print(model) 169 | mask_params = list(flop.get_projected_linear_masks(model)) 170 | optimizer_pm = Adam(mask_params, lr=0.001, weight_decay=0) 171 | num_masks_params = sum(x.numel() for x in mask_params) 172 | print("num of mask paramters: {}".format(num_masks_params)) 173 | pm_linear_modules = flop.get_projected_linear_with_mask_modules(model) 174 | num_prunable_params = sum( 175 | m.num_prunable_parameters() for m in pm_linear_modules 176 | ) 177 | print("num of prunable paramters: {}".format(num_prunable_params)) 178 | mask_param_names = [ 179 | i[0] 180 | for i in model.named_parameters() 181 | if i[1].requires_grad and "mask" in i[0] 182 | ] 183 | pruner = flop.NervanaPruner( 184 | model, 185 | subpruners={ 186 | "agppruner": { 187 | "class": "AutomatedGradualPruner", 188 | "initial_sparsity": 0.05, 189 | "weights": mask_param_names, 190 | "final_sparsity": args.prune_sparsity, 191 | "starting_step": args.prune_start_epoch, 192 | "ending_step": args.prune_end_epoch, 193 | "frequency": 1, 194 | } 195 | }, 196 | ) 197 | else: 198 | args.prune_start_epoch = args.max_epoch 199 | 200 | all_non_mask_params = [ 201 | i[1] 202 | for i in model.named_parameters() 203 | if i[1].requires_grad and "mask" not in i[0] 204 | ] 205 | num_params = sum(x.numel() for x in all_non_mask_params if x.requires_grad) 206 | print("num of parameters: {}".format(num_params)) 207 | 208 | nbatch = 1 209 | niter = 1 210 | best_dev = 1e8 211 | unroll_size = args.unroll_size 212 | batch_size = args.batch_size 213 | N = (len(train[0]) - 1) // unroll_size + 1 214 | criterion = nn.CrossEntropyLoss() 215 | 216 | model.zero_grad() 217 | if args.prune: 218 | optimizer_pm.zero_grad() 219 | 220 | emb_parameters = list(model.embedding_layer.parameters()) + list(model.output_layer.parameters()) 221 | emb_optimizer = Adam(emb_parameters, lr=lr * args.lr, weight_decay=args.weight_decay) 222 | emb_optimizer.zero_grad() 223 | # Deactivate all parameters in the RNN 224 | m_parameters = [ 225 | i[1] 226 | for i in model.named_parameters() 227 | if i[1].requires_grad and "mask" not in i[0] 228 | ] 229 | optimizer = None 230 | if args.freeze_period: 231 | for p in m_parameters: 232 | p.requires_grad = False 233 | else: 234 | optimizer = Adam(m_parameters, lr=lr * args.lr, weight_decay=args.weight_decay) 235 | 236 | for epoch in range(args.max_epoch): 237 | start_prune = epoch >= args.prune_start_epoch 238 | if args.freeze_period and optimizer is None and start_prune: 239 | for p in mask_params: 240 | p.requires_grad = False 241 | for p in m_parameters: 242 | p.requires_grad = True 243 | optimizer = Adam(m_parameters, lr=lr * args.lr, weight_decay=args.weight_decay) 244 | 245 | start_time = time.time() 246 | model.train() 247 | hidden = model.init_hidden(batch_size) 248 | pruner.begin_step(epoch) 249 | 250 | for i in range(N): 251 | # start iter on the first batch 252 | if nbatch % args.update_param_freq == 1: 253 | pruner.begin_iter(epoch, niter, N // args.update_param_freq) 254 | 255 | x = train[0][i * unroll_size : (i + 1) * unroll_size] 256 | y = train[1][i * unroll_size : (i + 1) * unroll_size].view(-1) 257 | hidden.detach_() 258 | 259 | # language model forward and backward 260 | output, hidden = model(x, hidden) 261 | model_loss = criterion(output, y) 262 | expected_sparsity = 0 263 | l1_loss = 0 264 | 265 | # add lagrangian loss (regularization) when pruning 266 | if start_prune: 267 | # compute expected model size and sparsity 268 | expected_size = sum( 269 | m.num_parameters(train=True) for m in pm_linear_modules 270 | ) 271 | expected_sparsity = 1.0 - expected_size / num_prunable_params 272 | expected_sparsity = expected_sparsity.item() 273 | 274 | l1_loss_aggr = 0 275 | if args.l1_lambda > 0 and expected_sparsity < args.prune_sparsity: 276 | for p in mask_params: 277 | l1_loss_aggr += torch.sum(torch.abs(p)) 278 | 279 | l1_loss = args.l1_lambda * l1_loss_aggr 280 | 281 | if args.l1_lambda > 0: 282 | loss = model_loss + l1_loss 283 | else: 284 | loss = model_loss 285 | 286 | (loss / args.update_param_freq).backward() 287 | model_loss = model_loss.item() 288 | l1_loss = l1_loss.item() if isinstance(l1_loss, torch.Tensor) else l1_loss 289 | 290 | # log training stats 291 | if (niter - 1) % 100 == 0 and nbatch % args.update_param_freq == 0: 292 | if args.prune: 293 | train_writer.add_scalar( 294 | "sparsity/expected_sparsity", expected_sparsity, niter 295 | ) 296 | if (niter - 1) % 3000 == 0: 297 | for index, layer in enumerate(mask_params): 298 | train_writer.add_histogram( 299 | "log_alpha/{}".format(index), layer, niter, bins="sqrt", 300 | ) 301 | # sys.stderr.write( 302 | # "\r{:.4f} {:.2f}".format( 303 | # model_loss, expected_sparsity, 304 | # ) 305 | # ) 306 | train_writer.add_scalar("loss/lm_loss", model_loss, niter) 307 | train_writer.add_scalar("loss/l1_loss", l1_loss, niter) 308 | train_writer.add_scalar("loss/total_loss", model_loss + l1_loss, niter) 309 | train_writer.add_scalar( 310 | "parameter_norm", calc_norm([x.data for x in m_parameters]), niter 311 | ) 312 | train_writer.add_scalar( 313 | "gradient_norm", 314 | calc_norm([x.grad for x in m_parameters if x.grad is not None]), 315 | niter, 316 | ) 317 | 318 | # perform gradient decent every few number of backward() 319 | if nbatch % args.update_param_freq == 0: 320 | if args.clip_grad > 0: 321 | torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad) 322 | if emb_optimizer is not None: 323 | emb_optimizer.step() 324 | if optimizer is not None: 325 | optimizer.step() 326 | if start_prune or args.freeze_period: 327 | optimizer_pm.step() 328 | # clear gradient 329 | model.zero_grad() 330 | if args.prune: 331 | optimizer_pm.zero_grad() 332 | 333 | # End iter on the last batch 334 | pruner.end_iter(epoch, niter, N // args.update_param_freq) 335 | niter += 1 336 | 337 | if nbatch % args.log_period == 0 or i == N - 1: 338 | elapsed_time = (time.time() - start_time) / 60.0 339 | dev_ppl, dev_loss = eval_model(model, dev) 340 | dev_writer.add_scalar("loss/lm_loss", dev_loss, niter) 341 | dev_writer.add_scalar("bpc", np.log2(dev_ppl), niter) 342 | sparsity = 0 343 | if args.prune: 344 | pruned_size = sum( 345 | m.num_parameters(train=False) for m in pm_linear_modules 346 | ) 347 | sparsity = 1.0 - pruned_size / num_prunable_params 348 | # agp_sparsity = pruner.get_step_logs() 349 | dev_writer.add_scalar("sparsity/hard_sparsity", sparsity, niter) 350 | # dev_writer.add_scalar("sparsity/agp_sparsity", agp_sparsity, niter) 351 | dev_writer.add_scalar( 352 | "model_size/total_prunable", num_prunable_params, niter 353 | ) 354 | dev_writer.add_scalar( 355 | "model_size/current_prunable", pruned_size, niter 356 | ) 357 | dev_writer.add_scalar("model_size/total", num_params, niter) 358 | dev_writer.add_scalar( 359 | "model_size/current", 360 | num_params - num_prunable_params + pruned_size, 361 | niter, 362 | ) 363 | sys.stdout.write( 364 | "\rIter={} train_loss={:.4f} dev_loss={:.4f}" 365 | " dev_bpc={:.2f} sparsity={:.2f}\teta={:.1f}m\t[{:.1f}m]\n".format( 366 | niter, 367 | loss, 368 | dev_loss, 369 | np.log2(dev_ppl), 370 | sparsity, 371 | elapsed_time * N / (i + 1), 372 | elapsed_time, 373 | ) 374 | ) 375 | checkpoint = copy_model(model) 376 | sys.stdout.write("\n") 377 | sys.stdout.flush() 378 | 379 | nbatch += 1 380 | if args.noam: 381 | niter_ = niter 382 | lr = min(1.0 / (niter_ ** 0.5), niter_ / (args.warmup_steps ** 1.5)) 383 | emb_optimizer.param_groups[0]["lr"] = lr * args.lr / (args.n_d ** 0.5) 384 | if args.noam and optimizer is not None: 385 | niter_ = niter - args.prune_start_epoch * N if args.freeze_period else niter 386 | lr = min(1.0 / (niter_ ** 0.5), niter_ / (args.warmup_steps ** 1.5)) 387 | optimizer.param_groups[0]["lr"] = lr * args.lr / (args.n_d ** 0.5) 388 | # if args.noam and (start_prune or args.freeze_period): 389 | # niter_ = niter if args.freeze_period else niter - args.prune_start_epoch * N 390 | # lr = min(1.0 / (niter_ ** 0.5), niter_ / (args.warmup_steps ** 1.5)) 391 | # optimizer_pm.param_groups[0]["lr"] = lr * args.lr / (args.n_d ** 0.5) 392 | 393 | pruner.end_step(epoch) 394 | if args.save and (epoch + 1) % 10 == 0: 395 | torch.save(copy_model(model), "{}.{}.{:.3f}.pt".format( 396 | args.save, 397 | epoch + 1, 398 | sparsity 399 | )) 400 | 401 | train_writer.close() 402 | dev_writer.close() 403 | 404 | model.load_state_dict(checkpoint) 405 | model.cuda() 406 | dev = create_batches(dev_, 1) 407 | test = create_batches(test_, 1) 408 | dev_ppl, dev_loss = eval_model(model, dev) 409 | test_ppl, test_loss = eval_model(model, test) 410 | sys.stdout.write( 411 | "dev_bpc={:.3f} test_bpc={:.3f}\n".format(np.log2(dev_ppl), np.log2(test_ppl)) 412 | ) 413 | 414 | 415 | if __name__ == "__main__": 416 | argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler="resolve") 417 | argparser.add_argument("--log", type=str, default="") 418 | argparser.add_argument("--noam", type=bool, default=True) 419 | argparser.add_argument("--warmup_steps", type=int, default=16000) 420 | argparser.add_argument("--layer_norm", action="store_true") 421 | argparser.add_argument("--rescale", action="store_true") 422 | argparser.add_argument("--data", type=str, required=True, help="training file") 423 | argparser.add_argument("--batch_size", "--batch", type=int, default=64) 424 | argparser.add_argument("--update_param_freq", type=int, default=1) 425 | argparser.add_argument("--unroll_size", type=int, default=256) 426 | argparser.add_argument("--max_epoch", type=int, default=100) 427 | argparser.add_argument("--n_e", type=int, default=0) 428 | argparser.add_argument("--n_d", "--d", type=int, default=3056) 429 | argparser.add_argument("--n_proj", type=int, default=512) 430 | argparser.add_argument( 431 | "--dropout", type=float, default=0.1, help="dropout probability" 432 | ) 433 | argparser.add_argument( 434 | "--bias", type=float, default=-3, help="intial bias of highway gates", 435 | ) 436 | argparser.add_argument("--depth", type=int, default=6) 437 | argparser.add_argument("--lr", type=float, default=2) 438 | argparser.add_argument("--weight_decay", type=float, default=1e-7) 439 | argparser.add_argument("--clip_grad", type=float, default=0.3) 440 | argparser.add_argument("--log_period", type=int, default=1000000) 441 | argparser.add_argument("--save", type=str, default="") 442 | argparser.add_argument("--load", type=str, default="") 443 | 444 | argparser.add_argument("--prune", type=bool, default=True) 445 | argparser.add_argument("--prune_lr", type=float, default=2) 446 | argparser.add_argument("--prune_warmup", type=int, default=0) 447 | argparser.add_argument("--prune_start_epoch", type=int, default=0) 448 | argparser.add_argument("--prune_sparsity", type=float, default=0.9) 449 | argparser.add_argument("--prune_end_epoch", type=int, default=30) 450 | argparser.add_argument("--l1_lambda", type=float, default=0) 451 | argparser.add_argument("--freeze_period", type=bool, default=False) 452 | 453 | args = argparser.parse_args() 454 | main(args) 455 | -------------------------------------------------------------------------------- /examples/enwik8/train_enwik8_agp_unstruct.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import time 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.optim import Adam 11 | from tensorboardX import SummaryWriter 12 | 13 | import sru 14 | import flop 15 | 16 | 17 | def read_corpus(path, num_test_symbols=5000000): 18 | raw_data = open(path).read() 19 | raw_data = np.fromstring(raw_data, dtype=np.uint8) 20 | unique, data = np.unique(raw_data, return_inverse=True) 21 | train_data = data[: -2 * num_test_symbols] 22 | valid_data = data[-2 * num_test_symbols : -num_test_symbols] 23 | test_data = data[-num_test_symbols:] 24 | return train_data, valid_data, test_data, unique 25 | 26 | 27 | def create_batches(data_ids, batch_size): 28 | N = len(data_ids) 29 | L = ((N - 1) // batch_size) * batch_size 30 | x = np.copy(data_ids[:L].reshape(batch_size, -1).T) 31 | y = np.copy(data_ids[1 : L + 1].reshape(batch_size, -1).T) 32 | x, y = torch.from_numpy(x), torch.from_numpy(y) 33 | x, y = x.contiguous(), y.contiguous() 34 | x, y = x.cuda(), y.cuda() 35 | return x, y 36 | 37 | 38 | class CustomLinear(nn.Linear): 39 | def __init__(self, in_features, out_features, bias=False): 40 | super(CustomLinear, self).__init__(in_features, out_features, bias=bias) 41 | 42 | def forward(self, data, **kwargs): 43 | return super().forward(data) 44 | 45 | 46 | class Model(nn.Module): 47 | def __init__(self, words, args): 48 | super(Model, self).__init__() 49 | self.args = args 50 | if args.n_e: 51 | self.n_e = args.n_e 52 | else: 53 | self.n_e = len(words) if len(words) < args.n_d else args.n_d 54 | self.n_d = args.n_d 55 | self.depth = args.depth 56 | self.drop = nn.Dropout(args.dropout) 57 | self.embedding_layer = nn.Embedding(len(words), self.n_e) 58 | self.n_V = len(words) 59 | custom_m_list = [CustomLinear(self.n_e, self.n_d * 4, bias=False)] 60 | for i in range(self.depth - 1): 61 | custom_m_list.append( 62 | flop.ProjectedLinear( 63 | self.n_d, self.n_d * 3, proj_features=args.n_proj, bias=False 64 | ) 65 | ) 66 | self.rnn = sru.SRU( 67 | self.n_e, 68 | self.n_d, 69 | self.depth, 70 | dropout=args.dropout, 71 | highway_bias=args.bias, 72 | layer_norm=args.layer_norm, 73 | rescale=args.rescale, 74 | custom_m=custom_m_list, 75 | ) 76 | self.output_layer = nn.Linear(self.n_d, self.n_V) 77 | self.init_weights() 78 | 79 | def init_weights(self, reinit_rnn=False): 80 | params = list(self.embedding_layer.parameters()) + list( 81 | self.output_layer.parameters() 82 | ) 83 | for p in params: 84 | if p.dim() > 1: # matrix 85 | val = (3.0 / p.size(0)) ** 0.5 86 | p.data.uniform_(-val, val) 87 | else: 88 | p.data.zero_() 89 | if reinit_rnn: 90 | for p in self.rnn.parameters(): 91 | if p.dim() > 1: # matrix 92 | val = (3.0 / p.size(0)) ** 0.5 93 | p.data.uniform_(-val, val) 94 | 95 | def forward(self, x, hidden): 96 | emb = self.drop(self.embedding_layer(x)) 97 | output, hidden = self.rnn(emb, hidden) 98 | output = self.drop(output) 99 | output = output.view(-1, output.size(2)) 100 | output = self.output_layer(output) 101 | return output, hidden 102 | 103 | def init_hidden(self, batch_size): 104 | weight = next(self.parameters()).data 105 | zeros = weight.new(self.depth, batch_size, self.n_d).zero_() 106 | return zeros 107 | 108 | 109 | def calc_norm(lis): 110 | l2_sum = sum(x.norm() ** 2 for x in lis) 111 | return l2_sum ** 0.5 112 | 113 | 114 | def eval_model(model, valid): 115 | with torch.no_grad(): 116 | model.eval() 117 | args = model.args 118 | batch_size = valid[0].size(1) 119 | total_loss = 0.0 120 | unroll_size = args.unroll_size 121 | criterion = nn.CrossEntropyLoss(size_average=False) 122 | hidden = model.init_hidden(batch_size) 123 | N = (len(valid[0]) - 1) // unroll_size + 1 124 | for i in range(N): 125 | x = valid[0][i * unroll_size : (i + 1) * unroll_size] 126 | y = valid[1][i * unroll_size : (i + 1) * unroll_size].view(-1) 127 | hidden.detach_() 128 | output, hidden = model(x, hidden) 129 | loss = criterion(output, y) 130 | total_loss += loss.item() 131 | avg_loss = total_loss / valid[1].numel() 132 | ppl = np.exp(avg_loss) 133 | model.train() 134 | return ppl, avg_loss 135 | 136 | 137 | def copy_model(model): 138 | states = model.state_dict() 139 | for k in states: 140 | v = states[k] 141 | states[k] = v.clone().cpu() 142 | return states 143 | 144 | 145 | def main(args): 146 | log_path = "{}_{}".format(args.log, random.randint(1, 100)) 147 | train_writer = SummaryWriter(log_dir=log_path + "/train") 148 | dev_writer = SummaryWriter(log_dir=log_path + "/dev") 149 | 150 | train, dev, test, words = read_corpus(args.data) 151 | dev_, test_ = dev, test 152 | train = create_batches(train, args.batch_size) 153 | dev = create_batches(test, args.batch_size) 154 | test = create_batches(test, args.batch_size) 155 | 156 | model = Model(words, args) 157 | if args.load: 158 | model.load_state_dict(torch.load(args.load)) 159 | model.cuda() 160 | print(model) 161 | print("vocab size: {}".format(model.n_V)) 162 | 163 | lr = 1.0 if not args.noam else 1.0 / (args.n_d ** 0.5) / (args.warmup_steps ** 1.5) 164 | if args.prune: 165 | model.cuda() 166 | print(model) 167 | num_mask_params = sum(x.numel() for x in model.rnn.parameters()) 168 | num_prunable_params = num_mask_params 169 | print("num of mask parameters: {}".format(num_mask_params)) 170 | print("num of prunable parameters: {}".format(num_prunable_params)) 171 | param_names = [ 172 | i[0] 173 | for i in model.rnn.named_parameters() 174 | if i[1].requires_grad 175 | ] 176 | pruner = flop.NervanaPruner( 177 | model.rnn, 178 | subpruners={ 179 | "agppruner": { 180 | "class": "AutomatedGradualPruner", 181 | "initial_sparsity": 0.05, 182 | "weights": param_names, 183 | "final_sparsity": args.prune_sparsity, 184 | "starting_step": args.prune_start_epoch, 185 | "ending_step": args.prune_end_epoch, 186 | "frequency": 1, 187 | } 188 | }, 189 | ) 190 | else: 191 | args.prune_start_epoch = args.max_epoch 192 | 193 | m_parameters = [ 194 | i[1] 195 | for i in model.named_parameters() 196 | if i[1].requires_grad 197 | ] 198 | optimizer = Adam(m_parameters, lr=lr * args.lr, weight_decay=args.weight_decay) 199 | num_params = sum(x.numel() for x in m_parameters if x.requires_grad) 200 | print("num of parameters: {}".format(num_params)) 201 | 202 | nbatch = 1 203 | niter = 1 204 | best_dev = 1e8 205 | unroll_size = args.unroll_size 206 | batch_size = args.batch_size 207 | N = (len(train[0]) - 1) // unroll_size + 1 208 | criterion = nn.CrossEntropyLoss() 209 | 210 | model.zero_grad() 211 | 212 | for epoch in range(args.max_epoch): 213 | start_time = time.time() 214 | model.train() 215 | hidden = model.init_hidden(batch_size) 216 | 217 | pruner.begin_step(epoch) 218 | for i in range(N): 219 | # start iter on the first batch 220 | if nbatch % args.update_param_freq == 1: 221 | pruner.begin_iter(epoch, niter, N // args.update_param_freq) 222 | 223 | x = train[0][i * unroll_size : (i + 1) * unroll_size] 224 | y = train[1][i * unroll_size : (i + 1) * unroll_size].view(-1) 225 | hidden.detach_() 226 | 227 | # language model forward and backward 228 | output, hidden = model(x, hidden) 229 | model_loss = criterion(output, y) 230 | 231 | loss = model_loss 232 | (loss / args.update_param_freq).backward() 233 | model_loss = model_loss.item() 234 | 235 | # log training stats 236 | if (niter - 1) % 100 == 0 and nbatch % args.update_param_freq == 0: 237 | sys.stderr.write( 238 | "\r{:.4f}".format( 239 | model_loss, 240 | ) 241 | ) 242 | train_writer.add_scalar("loss/lm_loss", model_loss, niter) 243 | train_writer.add_scalar( 244 | "parameter_norm", calc_norm([x.data for x in m_parameters]), niter 245 | ) 246 | train_writer.add_scalar( 247 | "gradient_norm", 248 | calc_norm([x.grad for x in m_parameters if x.grad is not None]), 249 | niter, 250 | ) 251 | 252 | # perform gradient decent every few number of backward() 253 | if nbatch % args.update_param_freq == 0: 254 | if args.clip_grad > 0: 255 | torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad) 256 | optimizer.step() 257 | # clear gradient 258 | model.zero_grad() 259 | 260 | # End iter on the last batch 261 | pruner.end_iter(epoch, niter, N // args.update_param_freq) 262 | niter += 1 263 | 264 | if nbatch % args.log_period == 0 or i == N - 1: 265 | elapsed_time = (time.time() - start_time) / 60.0 266 | dev_ppl, dev_loss = eval_model(model, dev) 267 | dev_writer.add_scalar("loss/lm_loss", dev_loss, niter) 268 | dev_writer.add_scalar("bpc", np.log2(dev_ppl), niter) 269 | sparsity = 0 270 | if args.prune: 271 | agp_sparsity = pruner.get_step_logs()['sparsity'] 272 | dev_writer.add_scalar("sparsity/hard_sparsity", agp_sparsity, niter) 273 | # dev_writer.add_scalar("sparsity/agp_sparsity", agp_sparsity, niter) 274 | dev_writer.add_scalar( 275 | "model_size/total_prunable", num_prunable_params, niter 276 | ) 277 | dev_writer.add_scalar("model_size/total", num_params, niter) 278 | sys.stdout.write( 279 | "\rIter={} lr={:.5f} train_loss={:.4f} dev_loss={:.4f}" 280 | " dev_bpc={:.2f} sparsity={:.2f}\teta={:.1f}m\t[{:.1f}m]\n".format( 281 | niter, 282 | optimizer.param_groups[0]["lr"], 283 | loss, 284 | dev_loss, 285 | np.log2(dev_ppl), 286 | sparsity, 287 | elapsed_time * N / (i + 1), 288 | elapsed_time, 289 | ) 290 | ) 291 | 292 | checkpoint = copy_model(model) 293 | sys.stdout.write("\n") 294 | sys.stdout.flush() 295 | 296 | nbatch += 1 297 | if args.noam: 298 | lr = min(1.0 / (niter ** 0.5), niter / (args.warmup_steps ** 1.5)) 299 | optimizer.param_groups[0]["lr"] = lr * args.lr / (args.n_d ** 0.5) 300 | 301 | pruner.end_step(epoch) 302 | if args.save and (epoch + 1) % 10 == 0: 303 | torch.save(copy_model(model), "{}.{}.{:.3f}.pt".format( 304 | args.save, 305 | epoch + 1, 306 | sparsity 307 | )) 308 | 309 | train_writer.close() 310 | dev_writer.close() 311 | 312 | model.load_state_dict(checkpoint) 313 | model.cuda() 314 | dev = create_batches(dev_, 1) 315 | test = create_batches(test_, 1) 316 | dev_ppl, dev_loss = eval_model(model, dev) 317 | test_ppl, test_loss = eval_model(model, test) 318 | sys.stdout.write( 319 | "dev_bpc={:.3f} test_bpc={:.3f}\n".format(np.log2(dev_ppl), np.log2(test_ppl)) 320 | ) 321 | 322 | 323 | if __name__ == "__main__": 324 | argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler="resolve") 325 | argparser.add_argument("--log", type=str, default="") 326 | argparser.add_argument("--noam", type=bool, default=True) 327 | argparser.add_argument("--warmup_steps", type=int, default=16000) 328 | argparser.add_argument("--layer_norm", action="store_true") 329 | argparser.add_argument("--rescale", action="store_true") 330 | argparser.add_argument("--data", type=str, required=True, help="training file") 331 | argparser.add_argument("--batch_size", "--batch", type=int, default=64) 332 | argparser.add_argument("--update_param_freq", type=int, default=1) 333 | argparser.add_argument("--unroll_size", type=int, default=256) 334 | argparser.add_argument("--max_epoch", type=int, default=100) 335 | argparser.add_argument("--n_e", type=int, default=0) 336 | argparser.add_argument("--n_d", "--d", type=int, default=3056) 337 | argparser.add_argument("--n_proj", type=int, default=512) 338 | argparser.add_argument( 339 | "--dropout", type=float, default=0.1, help="dropout probability" 340 | ) 341 | argparser.add_argument( 342 | "--bias", type=float, default=-3, help="intial bias of highway gates", 343 | ) 344 | argparser.add_argument("--depth", type=int, default=6) 345 | argparser.add_argument("--lr", type=float, default=2) 346 | argparser.add_argument("--weight_decay", type=float, default=1e-7) 347 | argparser.add_argument("--clip_grad", type=float, default=0.3) 348 | argparser.add_argument("--log_period", type=int, default=1000000) 349 | argparser.add_argument("--save", type=str, default="") 350 | argparser.add_argument("--load", type=str, default="") 351 | 352 | argparser.add_argument("--prune", type=bool, default=True) 353 | argparser.add_argument("--prune_lr", type=float, default=2) 354 | argparser.add_argument("--prune_warmup", type=int, default=0) 355 | argparser.add_argument("--prune_start_epoch", type=int, default=0) 356 | argparser.add_argument("--prune_sparsity", type=float, default=0.9) 357 | argparser.add_argument("--prune_end_epoch", type=int, default=30) 358 | argparser.add_argument("--l1_lambda", type=float, default=0) 359 | 360 | args = argparser.parse_args() 361 | main(args) 362 | -------------------------------------------------------------------------------- /examples/enwik8_tf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/flop/bdfc1845dbdddde70e65ce5a98ef7d0070833541/examples/enwik8_tf/__init__.py -------------------------------------------------------------------------------- /examples/enwik8_tf/data_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import glob 3 | 4 | from collections import Counter, OrderedDict 5 | import numpy as np 6 | import torch 7 | 8 | from flop.scripts.enwik8_tf.utils.vocabulary import Vocab 9 | 10 | class LMOrderedIterator(object): 11 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): 12 | """ 13 | data -- LongTensor -- the LongTensor is strictly ordered 14 | """ 15 | self.bsz = bsz 16 | self.bptt = bptt 17 | self.ext_len = ext_len if ext_len is not None else 0 18 | 19 | self.device = device 20 | 21 | # Work out how cleanly we can divide the dataset into bsz parts. 22 | self.n_step = data.size(0) // bsz 23 | 24 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 25 | data = data.narrow(0, 0, self.n_step * bsz) 26 | 27 | # Evenly divide the data across the bsz batches. 28 | self.data = data.view(bsz, -1).t().contiguous().to(device) 29 | 30 | # Number of mini-batches 31 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 32 | 33 | def get_batch(self, i, bptt=None): 34 | if bptt is None: bptt = self.bptt 35 | seq_len = min(bptt, self.data.size(0) - 1 - i) 36 | 37 | end_idx = i + seq_len 38 | beg_idx = max(0, i - self.ext_len) 39 | 40 | data = self.data[beg_idx:end_idx] 41 | target = self.data[i+1:i+1+seq_len] 42 | 43 | return data, target, seq_len 44 | 45 | def get_fixlen_iter(self, start=0): 46 | for i in range(start, self.data.size(0) - 1, self.bptt): 47 | yield self.get_batch(i) 48 | 49 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 50 | max_len = self.bptt + max_deviation * std 51 | i = start 52 | while True: 53 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 54 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 55 | data, target, seq_len = self.get_batch(i, bptt) 56 | i += seq_len 57 | yield data, target, seq_len 58 | if i >= self.data.size(0) - 2: 59 | break 60 | 61 | def __iter__(self): 62 | return self.get_fixlen_iter() 63 | 64 | 65 | class LMShuffledIterator(object): 66 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): 67 | """ 68 | data -- list[LongTensor] -- there is no order among the LongTensors 69 | """ 70 | self.data = data 71 | 72 | self.bsz = bsz 73 | self.bptt = bptt 74 | self.ext_len = ext_len if ext_len is not None else 0 75 | 76 | self.device = device 77 | self.shuffle = shuffle 78 | 79 | def get_sent_stream(self): 80 | # index iterator 81 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ 82 | else np.array(range(len(self.data))) 83 | 84 | # sentence iterator 85 | for idx in epoch_indices: 86 | yield self.data[idx] 87 | 88 | def stream_iterator(self, sent_stream): 89 | # streams for each data in the batch 90 | streams = [None] * self.bsz 91 | 92 | data = torch.LongTensor(self.bptt, self.bsz) 93 | target = torch.LongTensor(self.bptt, self.bsz) 94 | 95 | n_retain = 0 96 | 97 | while True: 98 | # data : [n_retain+bptt x bsz] 99 | # target : [bptt x bsz] 100 | data[n_retain:].fill_(-1) 101 | target.fill_(-1) 102 | 103 | valid_batch = True 104 | 105 | for i in range(self.bsz): 106 | n_filled = 0 107 | try: 108 | while n_filled < self.bptt: 109 | if streams[i] is None or len(streams[i]) <= 1: 110 | streams[i] = next(sent_stream) 111 | # number of new tokens to fill in 112 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled) 113 | # first n_retain tokens are retained from last batch 114 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ 115 | streams[i][:n_new] 116 | target[n_filled:n_filled+n_new, i] = \ 117 | streams[i][1:n_new+1] 118 | streams[i] = streams[i][n_new:] 119 | n_filled += n_new 120 | except StopIteration: 121 | valid_batch = False 122 | break 123 | 124 | if not valid_batch: 125 | return 126 | 127 | data = data.to(self.device) 128 | target = target.to(self.device) 129 | 130 | yield data, target, self.bptt 131 | 132 | n_retain = min(data.size(0), self.ext_len) 133 | if n_retain > 0: 134 | data[:n_retain] = data[-n_retain:] 135 | data.resize_(n_retain + self.bptt, data.size(1)) 136 | 137 | def __iter__(self): 138 | # sent_stream is an iterator 139 | sent_stream = self.get_sent_stream() 140 | 141 | for batch in self.stream_iterator(sent_stream): 142 | yield batch 143 | 144 | 145 | class LMMultiFileIterator(LMShuffledIterator): 146 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, 147 | shuffle=False): 148 | 149 | self.paths = paths 150 | self.vocab = vocab 151 | 152 | self.bsz = bsz 153 | self.bptt = bptt 154 | self.ext_len = ext_len if ext_len is not None else 0 155 | 156 | self.device = device 157 | self.shuffle = shuffle 158 | 159 | def get_sent_stream(self, path): 160 | sents = self.vocab.encode_file(path, add_double_eos=True) 161 | if self.shuffle: 162 | np.random.shuffle(sents) 163 | sent_stream = iter(sents) 164 | 165 | return sent_stream 166 | 167 | def __iter__(self): 168 | if self.shuffle: 169 | np.random.shuffle(self.paths) 170 | 171 | for path in self.paths: 172 | # sent_stream is an iterator 173 | sent_stream = self.get_sent_stream(path) 174 | for batch in self.stream_iterator(sent_stream): 175 | yield batch 176 | 177 | 178 | class Corpus(object): 179 | def __init__(self, path, dataset, *args, **kwargs): 180 | self.dataset = dataset 181 | self.vocab = Vocab(*args, **kwargs) 182 | 183 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: 184 | self.vocab.count_file(os.path.join(path, 'train.txt')) 185 | self.vocab.count_file(os.path.join(path, 'valid.txt')) 186 | self.vocab.count_file(os.path.join(path, 'test.txt')) 187 | elif self.dataset == 'wt103': 188 | self.vocab.count_file(os.path.join(path, 'train.txt')) 189 | elif self.dataset == 'lm1b': 190 | train_path_pattern = os.path.join( 191 | path, '1-billion-word-language-modeling-benchmark-r13output', 192 | 'training-monolingual.tokenized.shuffled', 'news.en-*') 193 | train_paths = glob.glob(train_path_pattern) 194 | # the vocab will load from file when build_vocab() is called 195 | 196 | self.vocab.build_vocab() 197 | 198 | if self.dataset in ['ptb', 'wt2', 'wt103']: 199 | self.train = self.vocab.encode_file( 200 | os.path.join(path, 'train.txt'), ordered=True) 201 | self.valid = self.vocab.encode_file( 202 | os.path.join(path, 'valid.txt'), ordered=True) 203 | self.test = self.vocab.encode_file( 204 | os.path.join(path, 'test.txt'), ordered=True) 205 | elif self.dataset in ['enwik8', 'text8']: 206 | self.train = self.vocab.encode_file( 207 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False) 208 | self.valid = self.vocab.encode_file( 209 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) 210 | self.test = self.vocab.encode_file( 211 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False) 212 | elif self.dataset == 'lm1b': 213 | self.train = train_paths 214 | self.valid = self.vocab.encode_file( 215 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) 216 | self.test = self.vocab.encode_file( 217 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) 218 | 219 | def get_iterator(self, split, *args, **kwargs): 220 | if split == 'train': 221 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 222 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 223 | elif self.dataset == 'lm1b': 224 | kwargs['shuffle'] = True 225 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 226 | elif split in ['valid', 'test']: 227 | data = self.valid if split == 'valid' else self.test 228 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 229 | data_iter = LMOrderedIterator(data, *args, **kwargs) 230 | elif self.dataset == 'lm1b': 231 | data_iter = LMShuffledIterator(data, *args, **kwargs) 232 | 233 | return data_iter 234 | 235 | 236 | def get_lm_corpus(datadir, dataset): 237 | fn = os.path.join(datadir, 'cache.pt') 238 | if os.path.exists(fn): 239 | print('Loading cached dataset...') 240 | corpus = torch.load(fn) 241 | else: 242 | print('Producing dataset {}...'.format(dataset)) 243 | kwargs = {} 244 | if dataset in ['wt103', 'wt2']: 245 | kwargs['special'] = [''] 246 | kwargs['lower_case'] = False 247 | elif dataset == 'ptb': 248 | kwargs['special'] = [''] 249 | kwargs['lower_case'] = True 250 | elif dataset == 'lm1b': 251 | kwargs['special'] = [] 252 | kwargs['lower_case'] = False 253 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') 254 | elif dataset in ['enwik8', 'text8']: 255 | pass 256 | 257 | corpus = Corpus(datadir, dataset, **kwargs) 258 | torch.save(corpus, fn) 259 | 260 | return corpus 261 | 262 | if __name__ == '__main__': 263 | import argparse 264 | parser = argparse.ArgumentParser(description='unit test') 265 | parser.add_argument('--datadir', type=str, default='../data/text8', 266 | help='location of the data corpus') 267 | parser.add_argument('--dataset', type=str, default='text8', 268 | choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'], 269 | help='dataset name') 270 | args = parser.parse_args() 271 | 272 | corpus = get_lm_corpus(args.datadir, args.dataset) 273 | print('Vocab size : {}'.format(len(corpus.vocab.idx2sym))) 274 | -------------------------------------------------------------------------------- /examples/enwik8_tf/eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import time 4 | import math 5 | import os, sys 6 | 7 | import torch 8 | 9 | from flop.scripts.enwik8_tf.data_utils import get_lm_corpus 10 | from flop.scripts.enwik8_tf.mem_transformer import MemTransformerLM 11 | from flop.scripts.enwik8_tf.utils.exp_utils import get_logger 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 14 | parser.add_argument('--data', type=str, default='../data/wikitext-103', 15 | help='location of the data corpus') 16 | parser.add_argument('--dataset', type=str, default='wt103', 17 | choices=['wt103', 'lm1b', 'enwik8', 'text8'], 18 | help='dataset name') 19 | parser.add_argument('--split', type=str, default='all', 20 | choices=['all', 'valid', 'test'], 21 | help='which split to evaluate') 22 | parser.add_argument('--batch_size', type=int, default=10, 23 | help='batch size') 24 | parser.add_argument('--tgt_len', type=int, default=5, 25 | help='number of tokens to predict') 26 | parser.add_argument('--ext_len', type=int, default=0, 27 | help='length of the extended context') 28 | parser.add_argument('--mem_len', type=int, default=0, 29 | help='length of the retained previous heads') 30 | parser.add_argument('--clamp_len', type=int, default=-1, 31 | help='max positional embedding index') 32 | parser.add_argument('--cuda', action='store_true', 33 | help='use CUDA') 34 | parser.add_argument('--work_dir', type=str, required=True, 35 | help='path to the work_dir') 36 | parser.add_argument('--no_log', action='store_true', 37 | help='do not log the eval result') 38 | parser.add_argument('--same_length', action='store_true', 39 | help='set same length attention with masking') 40 | args = parser.parse_args() 41 | assert args.ext_len >= 0, 'extended context length must be non-negative' 42 | 43 | device = torch.device("cuda" if args.cuda else "cpu") 44 | 45 | # Get logger 46 | logging = get_logger(os.path.join(args.work_dir, 'log.txt'), 47 | log_=not args.no_log) 48 | 49 | # Load dataset 50 | corpus = get_lm_corpus(args.data, args.dataset) 51 | ntokens = len(corpus.vocab) 52 | 53 | va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, 54 | device=device, ext_len=args.ext_len) 55 | te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, 56 | device=device, ext_len=args.ext_len) 57 | 58 | # Load the best saved model. 59 | with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: 60 | model = torch.load(f) 61 | model.backward_compatible() 62 | model = model.to(device) 63 | 64 | logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( 65 | args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) 66 | 67 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 68 | if args.clamp_len > 0: 69 | model.clamp_len = args.clamp_len 70 | if args.same_length: 71 | model.same_length = True 72 | 73 | ############################################################################### 74 | # Evaluation code 75 | ############################################################################### 76 | def evaluate(eval_iter): 77 | # Reset compiled masks and weights 78 | with torch.no_grad(): 79 | model.train() 80 | mems = tuple() 81 | for i, (data, target, seq_len) in enumerate(eval_iter): 82 | ret = model(data, target, *mems) 83 | break 84 | 85 | # Turn on evaluation mode which disables dropout. 86 | with torch.no_grad(): 87 | model.train() 88 | mems = tuple() 89 | for i, (data, target, seq_len) in enumerate(eval_iter): 90 | ret = model(data, target, *mems) 91 | break 92 | model.eval() 93 | total_len, total_loss = 0, 0. 94 | start_time = time.time() 95 | with torch.no_grad(): 96 | mems = tuple() 97 | for idx, (data, target, seq_len) in enumerate(eval_iter): 98 | ret = model(data, target, *mems) 99 | loss, mems = ret[0], ret[1:] 100 | loss = loss.mean() 101 | total_loss += seq_len * loss.item() 102 | total_len += seq_len 103 | total_time = time.time() - start_time 104 | logging('Time : {:.2f}s, {:.2f}ms/segment'.format( 105 | total_time, 1000 * total_time / (idx+1))) 106 | return total_loss / total_len 107 | 108 | # Run on test data. 109 | if args.split == 'all': 110 | test_loss = evaluate(te_iter) 111 | valid_loss = evaluate(va_iter) 112 | elif args.split == 'valid': 113 | valid_loss = evaluate(va_iter) 114 | test_loss = None 115 | elif args.split == 'test': 116 | test_loss = evaluate(te_iter) 117 | valid_loss = None 118 | 119 | def format_log(loss, split): 120 | if args.dataset in ['enwik8', 'text8']: 121 | log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format( 122 | split, loss, loss / math.log(2)) 123 | else: 124 | log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( 125 | split, loss, math.exp(loss)) 126 | return log_str 127 | 128 | log_str = '' 129 | if valid_loss is not None: 130 | log_str += format_log(valid_loss, 'valid') 131 | if test_loss is not None: 132 | log_str += format_log(test_loss, 'test') 133 | 134 | logging('=' * 100) 135 | logging(log_str) 136 | logging('=' * 100) 137 | -------------------------------------------------------------------------------- /examples/enwik8_tf/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/flop/bdfc1845dbdddde70e65ce5a98ef7d0070833541/examples/enwik8_tf/utils/__init__.py -------------------------------------------------------------------------------- /examples/enwik8_tf/utils/adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class AdaptiveLogSoftmax(nn.Module): 10 | def __init__(self, in_features, n_classes, cutoffs, keep_order=False): 11 | super(AdaptiveLogSoftmax, self).__init__() 12 | 13 | cutoffs = list(cutoffs) 14 | 15 | if (cutoffs != sorted(cutoffs)) \ 16 | or (min(cutoffs) <= 0) \ 17 | or (max(cutoffs) >= (n_classes - 1)) \ 18 | or (len(set(cutoffs)) != len(cutoffs)) \ 19 | or any([int(c) != c for c in cutoffs]): 20 | 21 | raise ValueError("cutoffs should be a sequence of unique, positive " 22 | "integers sorted in an increasing order, where " 23 | "each value is between 1 and n_classes-1") 24 | 25 | self.in_features = in_features 26 | self.n_classes = n_classes 27 | self.cutoffs = cutoffs + [n_classes] 28 | 29 | self.shortlist_size = self.cutoffs[0] 30 | self.n_clusters = len(self.cutoffs) - 1 31 | self.head_size = self.shortlist_size + self.n_clusters 32 | 33 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features)) 34 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 35 | 36 | self.keep_order = keep_order 37 | 38 | 39 | def forward(self, hidden, target, weight, bias, keep_order=False): 40 | if hidden.size(0) != target.size(0): 41 | raise RuntimeError('Input and target should have the same size ' 42 | 'in the batch dimension.') 43 | 44 | head_weight = torch.cat( 45 | [weight[:self.shortlist_size], self.cluster_weight], dim=0) 46 | head_bias = torch.cat( 47 | [bias[:self.shortlist_size], self.cluster_bias], dim=0) 48 | 49 | head_logit = F.linear(hidden, head_weight, bias=head_bias) 50 | head_logprob = F.log_softmax(head_logit, dim=1) 51 | 52 | nll = torch.zeros_like(target, 53 | dtype=hidden.dtype, device=hidden.device) 54 | 55 | offset = 0 56 | cutoff_values = [0] + self.cutoffs 57 | for i in range(len(cutoff_values) - 1): 58 | l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1] 59 | 60 | mask_i = (target >= l_idx) & (target < h_idx) 61 | indices_i = mask_i.nonzero().squeeze() 62 | 63 | if indices_i.numel() == 0: 64 | continue 65 | 66 | target_i = target.index_select(0, indices_i) - l_idx 67 | head_logprob_i = head_logprob.index_select(0, indices_i) 68 | 69 | if i == 0: 70 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 71 | else: 72 | weight_i = weight[l_idx:h_idx] 73 | bias_i = bias[l_idx:h_idx] 74 | 75 | hidden_i = hidden.index_select(0, indices_i) 76 | 77 | tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i) 78 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 79 | 80 | logprob_i = head_logprob_i[:, -i] \ 81 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 82 | 83 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 84 | nll.index_copy_(0, indices_i, -logprob_i) 85 | else: 86 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 87 | 88 | offset += logprob_i.size(0) 89 | 90 | return nll 91 | -------------------------------------------------------------------------------- /examples/enwik8_tf/utils/data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | if len(self.device_ids) == 1: 66 | return self.module(*inputs[0], **kwargs[0]) 67 | replicas = self.replicate(self.module, self.device_ids) 68 | if self.gpu0_bsz == 0: 69 | replicas = replicas[1:] 70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 74 | return parallel_apply(replicas, inputs, kwargs, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids): 77 | bsz = inputs[0].size(self.dim) 78 | num_dev = len(self.device_ids) 79 | gpu0_bsz = self.gpu0_bsz 80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 81 | if gpu0_bsz < bsz_unit: 82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 83 | delta = bsz - sum(chunk_sizes) 84 | for i in range(delta): 85 | chunk_sizes[i + 1] += 1 86 | if gpu0_bsz == 0: 87 | chunk_sizes = chunk_sizes[1:] 88 | else: 89 | return super().scatter(inputs, kwargs, device_ids) 90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 91 | 92 | -------------------------------------------------------------------------------- /examples/enwik8_tf/utils/exp_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os, shutil 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | def logging(s, log_path, print_=True, log_=True): 10 | if print_: 11 | print(s) 12 | if log_: 13 | with open(log_path, 'a+') as f_log: 14 | f_log.write(s + '\n') 15 | 16 | def get_logger(log_path, **kwargs): 17 | return functools.partial(logging, log_path=log_path, **kwargs) 18 | 19 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False): 20 | if debug: 21 | print('Debug Mode : no experiment dir created') 22 | return functools.partial(logging, log_path=None, log_=False) 23 | 24 | if not os.path.exists(dir_path): 25 | os.makedirs(dir_path) 26 | 27 | print('Experiment dir : {}'.format(dir_path)) 28 | if scripts_to_save is not None: 29 | script_path = os.path.join(dir_path, 'scripts') 30 | if not os.path.exists(script_path): 31 | os.makedirs(script_path) 32 | for script in scripts_to_save: 33 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) 34 | shutil.copyfile(script, dst_file) 35 | 36 | return get_logger(log_path=os.path.join(dir_path, 'log.txt')) 37 | 38 | def save_checkpoint(model, optimizer, path, epoch): 39 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) 40 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch))) 41 | -------------------------------------------------------------------------------- /examples/enwik8_tf/utils/log_uniform_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | class LogUniformSampler(object): 6 | def __init__(self, range_max, n_sample): 7 | """ 8 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 9 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 10 | 11 | expected count can be approximated by 1 - (1 - p)^n 12 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 13 | 14 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 15 | """ 16 | with torch.no_grad(): 17 | self.range_max = range_max 18 | log_indices = torch.arange(1., range_max+2., 1.).log_() 19 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 20 | # print('P', self.dist.numpy().tolist()[-30:]) 21 | 22 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 23 | 24 | self.n_sample = n_sample 25 | 26 | def sample(self, labels): 27 | """ 28 | labels: [b1, b2] 29 | Return 30 | true_log_probs: [b1, b2] 31 | samp_log_probs: [n_sample] 32 | neg_samples: [n_sample] 33 | """ 34 | 35 | # neg_samples = torch.empty(0).long() 36 | n_sample = self.n_sample 37 | n_tries = 2 * n_sample 38 | 39 | with torch.no_grad(): 40 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 41 | device = labels.device 42 | neg_samples = neg_samples.to(device) 43 | true_log_probs = self.log_q[labels].to(device) 44 | samp_log_probs = self.log_q[neg_samples].to(device) 45 | return true_log_probs, samp_log_probs, neg_samples 46 | 47 | def sample_logits(embedding, bias, labels, inputs, sampler): 48 | """ 49 | embedding: an nn.Embedding layer 50 | bias: [n_vocab] 51 | labels: [b1, b2] 52 | inputs: [b1, b2, n_emb] 53 | sampler: you may use a LogUniformSampler 54 | Return 55 | logits: [b1, b2, 1 + n_sample] 56 | """ 57 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 58 | n_sample = neg_samples.size(0) 59 | b1, b2 = labels.size(0), labels.size(1) 60 | all_ids = torch.cat([labels.view(-1), neg_samples]) 61 | all_w = embedding(all_ids) 62 | true_w = all_w[: -n_sample].view(b1, b2, -1) 63 | sample_w = all_w[- n_sample:].view(n_sample, -1) 64 | 65 | all_b = bias[all_ids] 66 | true_b = all_b[: -n_sample].view(b1, b2) 67 | sample_b = all_b[- n_sample:] 68 | 69 | hit = (labels[:, :, None] == neg_samples).detach() 70 | 71 | true_logits = torch.einsum('ijk,ijk->ij', 72 | [true_w, inputs]) + true_b - true_log_probs 73 | sample_logits = torch.einsum('lk,ijk->ijl', 74 | [sample_w, inputs]) + sample_b - samp_log_probs 75 | sample_logits.masked_fill_(hit, -1e30) 76 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 77 | 78 | return logits 79 | 80 | 81 | # class LogUniformSampler(object): 82 | # def __init__(self, range_max, unique=False): 83 | # """ 84 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 85 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 86 | # """ 87 | # self.range_max = range_max 88 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 89 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 90 | 91 | # self.unique = unique 92 | 93 | # if self.unique: 94 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 95 | 96 | # def sample(self, n_sample, labels): 97 | # pos_sample, new_labels = labels.unique(return_inverse=True) 98 | # n_pos_sample = pos_sample.size(0) 99 | # n_neg_sample = n_sample - n_pos_sample 100 | 101 | # if self.unique: 102 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 103 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 104 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 105 | # else: 106 | # sample_dist = self.dist 107 | 108 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 109 | 110 | # sample = torch.cat([pos_sample, neg_sample]) 111 | # sample_prob = self.dist[sample] 112 | 113 | # return new_labels, sample, sample_prob 114 | 115 | 116 | if __name__ == '__main__': 117 | S, B = 3, 4 118 | n_vocab = 10000 119 | n_sample = 5 120 | H = 32 121 | 122 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 123 | 124 | # sampler = LogUniformSampler(n_vocab, unique=False) 125 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 126 | 127 | sampler = LogUniformSampler(n_vocab, unique=True) 128 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 129 | 130 | # print('true_probs', true_probs.numpy().tolist()) 131 | # print('samp_probs', samp_probs.numpy().tolist()) 132 | # print('neg_samples', neg_samples.numpy().tolist()) 133 | 134 | # print('sum', torch.sum(sampler.dist).item()) 135 | 136 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 137 | 138 | embedding = nn.Embedding(n_vocab, H) 139 | bias = torch.zeros(n_vocab) 140 | inputs = torch.Tensor(S, B, H).normal_() 141 | 142 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 143 | print('logits', logits.detach().numpy().tolist()) 144 | print('logits shape', logits.size()) 145 | print('out_labels', out_labels.detach().numpy().tolist()) 146 | print('out_labels shape', out_labels.size()) 147 | 148 | -------------------------------------------------------------------------------- /examples/enwik8_tf/utils/proj_adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 10 | CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 11 | 12 | class ProjectedAdaptiveLogSoftmax(nn.Module): 13 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 14 | keep_order=False): 15 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 16 | 17 | self.n_token = n_token 18 | self.d_embed = d_embed 19 | self.d_proj = d_proj 20 | 21 | self.cutoffs = cutoffs + [n_token] 22 | self.cutoff_ends = [0] + self.cutoffs 23 | self.div_val = div_val 24 | 25 | self.shortlist_size = self.cutoffs[0] 26 | self.n_clusters = len(self.cutoffs) - 1 27 | self.head_size = self.shortlist_size + self.n_clusters 28 | 29 | if self.n_clusters > 0: 30 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 31 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 32 | 33 | self.out_layers = nn.ModuleList() 34 | self.out_projs = nn.ParameterList() 35 | 36 | if div_val == 1: 37 | for i in range(len(self.cutoffs)): 38 | if d_proj != d_embed: 39 | self.out_projs.append( 40 | nn.Parameter(torch.Tensor(d_proj, d_embed)) 41 | ) 42 | else: 43 | self.out_projs.append(None) 44 | 45 | self.out_layers.append(nn.Linear(d_embed, n_token)) 46 | else: 47 | for i in range(len(self.cutoffs)): 48 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 49 | d_emb_i = d_embed // (div_val ** i) 50 | 51 | self.out_projs.append( 52 | nn.Parameter(torch.Tensor(d_proj, d_emb_i)) 53 | ) 54 | 55 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 56 | 57 | self.keep_order = keep_order 58 | 59 | def _compute_logit(self, hidden, weight, bias, proj): 60 | if proj is None: 61 | logit = F.linear(hidden, weight, bias=bias) 62 | else: 63 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 64 | proj_hid = F.linear(hidden, proj.t().contiguous()) 65 | logit = F.linear(proj_hid, weight, bias=bias) 66 | # else: 67 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 68 | # if bias is not None: 69 | # logit = logit + bias 70 | 71 | return logit 72 | 73 | def forward(self, hidden, target, keep_order=False): 74 | ''' 75 | hidden :: [len*bsz x d_proj] 76 | target :: [len*bsz] 77 | ''' 78 | 79 | if hidden.size(0) != target.size(0): 80 | raise RuntimeError('Input and target should have the same size ' 81 | 'in the batch dimension.') 82 | 83 | if self.n_clusters == 0: 84 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 85 | self.out_layers[0].bias, self.out_projs[0]) 86 | nll = -F.log_softmax(logit, dim=-1) \ 87 | .gather(1, target.unsqueeze(1)).squeeze(1) 88 | else: 89 | # construct weights and biases 90 | weights, biases = [], [] 91 | for i in range(len(self.cutoffs)): 92 | if self.div_val == 1: 93 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 94 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 95 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 96 | else: 97 | weight_i = self.out_layers[i].weight 98 | bias_i = self.out_layers[i].bias 99 | 100 | if i == 0: 101 | weight_i = torch.cat( 102 | [weight_i, self.cluster_weight], dim=0) 103 | bias_i = torch.cat( 104 | [bias_i, self.cluster_bias], dim=0) 105 | 106 | weights.append(weight_i) 107 | biases.append(bias_i) 108 | 109 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 110 | 111 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 112 | head_logprob = F.log_softmax(head_logit, dim=1) 113 | 114 | nll = torch.zeros_like(target, 115 | dtype=hidden.dtype, device=hidden.device) 116 | 117 | offset = 0 118 | cutoff_values = [0] + self.cutoffs 119 | for i in range(len(cutoff_values) - 1): 120 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 121 | 122 | mask_i = (target >= l_idx) & (target < r_idx) 123 | indices_i = mask_i.nonzero().squeeze() 124 | 125 | if indices_i.numel() == 0: 126 | continue 127 | 128 | target_i = target.index_select(0, indices_i) - l_idx 129 | head_logprob_i = head_logprob.index_select(0, indices_i) 130 | 131 | if i == 0: 132 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 133 | else: 134 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 135 | 136 | hidden_i = hidden.index_select(0, indices_i) 137 | 138 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 139 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 140 | 141 | logprob_i = head_logprob_i[:, -i] \ 142 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 143 | 144 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 145 | nll.index_copy_(0, indices_i, -logprob_i) 146 | else: 147 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 148 | 149 | offset += logprob_i.size(0) 150 | 151 | return nll 152 | -------------------------------------------------------------------------------- /examples/enwik8_tf/utils/vocabulary.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter, OrderedDict 3 | 4 | import torch 5 | 6 | class Vocab(object): 7 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, 8 | delimiter=None, vocab_file=None): 9 | self.counter = Counter() 10 | self.special = special 11 | self.min_freq = min_freq 12 | self.max_size = max_size 13 | self.lower_case = lower_case 14 | self.delimiter = delimiter 15 | self.vocab_file = vocab_file 16 | 17 | def tokenize(self, line, add_eos=False, add_double_eos=False): 18 | line = line.strip() 19 | # convert to lower case 20 | if self.lower_case: 21 | line = line.lower() 22 | 23 | # empty delimiter '' will evaluate False 24 | if self.delimiter == '': 25 | symbols = line 26 | else: 27 | symbols = line.split(self.delimiter) 28 | 29 | if add_double_eos: # lm1b 30 | return [''] + symbols + [''] 31 | elif add_eos: 32 | return symbols + [''] 33 | else: 34 | return symbols 35 | 36 | def count_file(self, path, verbose=False, add_eos=False): 37 | if verbose: print('counting file {} ...'.format(path)) 38 | assert os.path.exists(path) 39 | 40 | sents = [] 41 | with open(path, 'r', encoding='utf-8') as f: 42 | for idx, line in enumerate(f): 43 | if verbose and idx > 0 and idx % 500000 == 0: 44 | print(' line {}'.format(idx)) 45 | symbols = self.tokenize(line, add_eos=add_eos) 46 | self.counter.update(symbols) 47 | sents.append(symbols) 48 | 49 | return sents 50 | 51 | def count_sents(self, sents, verbose=False): 52 | """ 53 | sents : a list of sentences, each a list of tokenized symbols 54 | """ 55 | if verbose: print('counting {} sents ...'.format(len(sents))) 56 | for idx, symbols in enumerate(sents): 57 | if verbose and idx > 0 and idx % 500000 == 0: 58 | print(' line {}'.format(idx)) 59 | self.counter.update(symbols) 60 | 61 | def _build_from_file(self, vocab_file): 62 | self.idx2sym = [] 63 | self.sym2idx = OrderedDict() 64 | 65 | with open(vocab_file, 'r', encoding='utf-8') as f: 66 | for line in f: 67 | symb = line.strip().split()[0] 68 | self.add_symbol(symb) 69 | self.unk_idx = self.sym2idx[''] 70 | 71 | def build_vocab(self): 72 | if self.vocab_file: 73 | print('building vocab from {}'.format(self.vocab_file)) 74 | self._build_from_file(self.vocab_file) 75 | print('final vocab size {}'.format(len(self))) 76 | else: 77 | print('building vocab with min_freq={}, max_size={}'.format( 78 | self.min_freq, self.max_size)) 79 | self.idx2sym = [] 80 | self.sym2idx = OrderedDict() 81 | 82 | for sym in self.special: 83 | self.add_special(sym) 84 | 85 | for sym, cnt in self.counter.most_common(self.max_size): 86 | if cnt < self.min_freq: break 87 | self.add_symbol(sym) 88 | 89 | print('final vocab size {} from {} unique tokens'.format( 90 | len(self), len(self.counter))) 91 | 92 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 93 | add_double_eos=False): 94 | if verbose: print('encoding file {} ...'.format(path)) 95 | assert os.path.exists(path) 96 | encoded = [] 97 | with open(path, 'r', encoding='utf-8') as f: 98 | for idx, line in enumerate(f): 99 | if verbose and idx > 0 and idx % 500000 == 0: 100 | print(' line {}'.format(idx)) 101 | symbols = self.tokenize(line, add_eos=add_eos, 102 | add_double_eos=add_double_eos) 103 | encoded.append(self.convert_to_tensor(symbols)) 104 | 105 | if ordered: 106 | encoded = torch.cat(encoded) 107 | 108 | return encoded 109 | 110 | def encode_sents(self, sents, ordered=False, verbose=False): 111 | if verbose: print('encoding {} sents ...'.format(len(sents))) 112 | encoded = [] 113 | for idx, symbols in enumerate(sents): 114 | if verbose and idx > 0 and idx % 500000 == 0: 115 | print(' line {}'.format(idx)) 116 | encoded.append(self.convert_to_tensor(symbols)) 117 | 118 | if ordered: 119 | encoded = torch.cat(encoded) 120 | 121 | return encoded 122 | 123 | def add_special(self, sym): 124 | if sym not in self.sym2idx: 125 | self.idx2sym.append(sym) 126 | self.sym2idx[sym] = len(self.idx2sym) - 1 127 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 128 | 129 | def add_symbol(self, sym): 130 | if sym not in self.sym2idx: 131 | self.idx2sym.append(sym) 132 | self.sym2idx[sym] = len(self.idx2sym) - 1 133 | 134 | def get_sym(self, idx): 135 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) 136 | return self.idx2sym[idx] 137 | 138 | def get_idx(self, sym): 139 | if sym in self.sym2idx: 140 | return self.sym2idx[sym] 141 | else: 142 | # print('encounter unk {}'.format(sym)) 143 | assert '' not in sym 144 | assert hasattr(self, 'unk_idx') 145 | return self.sym2idx.get(sym, self.unk_idx) 146 | 147 | def get_symbols(self, indices): 148 | return [self.get_sym(idx) for idx in indices] 149 | 150 | def get_indices(self, symbols): 151 | return [self.get_idx(sym) for sym in symbols] 152 | 153 | def convert_to_tensor(self, symbols): 154 | return torch.LongTensor(self.get_indices(symbols)) 155 | 156 | def convert_to_sent(self, indices, exclude=None): 157 | if exclude is None: 158 | return ' '.join([self.get_sym(idx) for idx in indices]) 159 | else: 160 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 161 | 162 | def __len__(self): 163 | return len(self.idx2sym) 164 | -------------------------------------------------------------------------------- /examples/wt103/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/flop/bdfc1845dbdddde70e65ce5a98ef7d0070833541/examples/wt103/__init__.py -------------------------------------------------------------------------------- /examples/wt103/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import random 4 | import os 5 | import zipfile 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | import sru 12 | import flop 13 | from flop.embedding import AdaptiveEmbedding, AdaptiveLogSoftmax 14 | from flop.embedding import AdaptiveEmbeddingWithMask, AdaptiveLogSoftmaxWithMask 15 | from flop.scripts.wt103.utils.data_utils import get_lm_corpus 16 | 17 | 18 | class Model(nn.Module): 19 | def __init__(self, args): 20 | super(Model, self).__init__() 21 | self.args = args 22 | # self.cutoffs = [20000, 60000] 23 | self.cutoffs = [10000, 20000, 40000, 60000, 100000] 24 | self.n_V = args.n_token 25 | self.n_e = args.n_e or args.n_proj 26 | self.n_d = args.n_d 27 | self.depth = args.depth 28 | self.drop = nn.Dropout(args.dropout) 29 | self.embedding_layer = AdaptiveEmbedding( 30 | self.n_V, 31 | self.n_e, 32 | self.n_d, 33 | self.cutoffs, 34 | div_val=args.div_val, 35 | div_freq=2, 36 | dropout=args.dropout_e, 37 | ) 38 | self.rnn = sru.SRU( 39 | self.n_d, 40 | self.n_d, 41 | self.depth, 42 | projection_size=args.n_proj, 43 | dropout=args.dropout, 44 | highway_bias=args.bias, 45 | layer_norm=args.layer_norm, 46 | rescale=args.rescale, 47 | custom_m=flop.ProjectedLinear( 48 | self.n_d, self.n_d * 3, proj_features=args.n_proj, bias=False 49 | ), 50 | ) 51 | self.output_layer = AdaptiveLogSoftmax( 52 | self.n_V, 53 | self.n_e, 54 | self.n_d, 55 | self.cutoffs, 56 | div_val=args.div_val, 57 | div_freq=2, 58 | dropout=args.dropout_e, 59 | keep_order=False, 60 | ) 61 | self.init_weights() 62 | if not args.not_tie: 63 | self.tie_weights() 64 | 65 | def tie_weights(self): 66 | for i in range(len(self.output_layer.out_layers)): 67 | self.embedding_layer.emb_layers[i].weight = self.output_layer.out_layers[ 68 | i 69 | ].weight 70 | 71 | for i in range(len(self.output_layer.out_projs)): 72 | self.embedding_layer.emb_projs[i] = self.output_layer.out_projs[i] 73 | 74 | if hasattr(self.embedding_layer, "masks") and hasattr( 75 | self.output_layer, "masks" 76 | ): 77 | delattr(self.output_layer, "masks") 78 | setattr(self.output_layer, "masks", self.embedding_layer.masks) 79 | 80 | def init_weights(self, init_range=0.03, reinit_rnn=False): 81 | params = list(self.embedding_layer.parameters()) + list( 82 | self.output_layer.parameters() 83 | ) 84 | for p in params: 85 | if p.dim() > 1: # matrix 86 | p.data.uniform_(-init_range, init_range) 87 | else: 88 | p.data.zero_() 89 | if reinit_rnn: 90 | for p in self.rnn.parameters(): 91 | if p.dim() > 1: # matrix 92 | p.data.uniform_(-init_range, init_range) 93 | 94 | def forward(self, x, y, hidden): 95 | emb = self.drop(self.embedding_layer(x)) 96 | output, hidden = self.rnn(emb, hidden) 97 | output = self.drop(output) 98 | output = output.view(-1, output.size(2)) 99 | loss = self.output_layer(output, y.view(-1)) 100 | loss = loss.view(y.size(0), -1) 101 | return loss, hidden 102 | 103 | def init_hidden(self, batch_size): 104 | weight = next(self.parameters()).data 105 | zeros = weight.new(self.depth, batch_size, self.n_d).zero_() 106 | return zeros 107 | 108 | 109 | def calc_norm(lis): 110 | l2_sum = sum(x.norm() ** 2 for x in lis) 111 | return l2_sum ** 0.5 112 | 113 | 114 | def eval_model(model, valid): 115 | with torch.no_grad(): 116 | # Important: reset compiled masks. When multiple GPUs are used, model() and DDP model() 117 | # are not the same instance although they share the same parameters. 118 | # Calling model(..) in training mode will reset all compiled weights and cached masks 119 | for x, y, seq_len in valid: 120 | model(x, y, hidden=None) 121 | break 122 | model.eval() 123 | args = model.args 124 | batch_size = args.eval_batch_size or args.batch_size 125 | total_loss = 0.0 126 | total_tok = 0.0 127 | hidden = model.init_hidden(batch_size) 128 | for x, y, seq_len in valid: 129 | loss, hidden = model(x, y, hidden) 130 | total_loss += loss.sum().item() 131 | total_tok += y.numel() 132 | avg_loss = total_loss / total_tok 133 | ppl = np.exp(avg_loss) 134 | model.train() 135 | return ppl, avg_loss 136 | 137 | 138 | def copy_model(model): 139 | states = model.state_dict() 140 | for k in states: 141 | v = states[k] 142 | states[k] = v.clone().cpu() 143 | return states 144 | 145 | 146 | def set_seed(seed): 147 | random.seed(seed) 148 | np.random.seed(seed) 149 | torch.manual_seed(seed) 150 | torch.cuda.manual_seed_all(seed) 151 | 152 | 153 | def main(args): 154 | 155 | # set up distributed training 156 | torch.cuda.set_device(args.local_rank) 157 | device = torch.device("cuda", args.local_rank) 158 | # torch.distributed.init_process_group(backend="nccl") 159 | set_seed(1234) 160 | args.n_gpu = 1 161 | args.device = device 162 | local_rank = args.local_rank 163 | 164 | corpus = get_lm_corpus(args.data, "wt103") 165 | n_token = args.n_token = len(corpus.vocab) 166 | args.eval_batch_size = args.eval_batch_size or args.batch_size 167 | args.eval_unroll_size = args.eval_unroll_size or args.unroll_size 168 | eval_unroll_size = args.eval_unroll_size 169 | eval_batch_size = args.eval_batch_size 170 | dev = corpus.get_iterator("valid", eval_batch_size, eval_unroll_size, device=device) 171 | if local_rank == 0: 172 | print("vocab size: {}".format(n_token)) 173 | 174 | model = Model(args) 175 | 176 | # in place substituion of linear ops in SRU 177 | flop.make_projected_linear_with_mask( 178 | model.rnn, in_place=True 179 | ) 180 | model.embedding_layer = AdaptiveEmbeddingWithMask.from_module( 181 | model.embedding_layer 182 | ) 183 | model.output_layer = AdaptiveLogSoftmaxWithMask.from_module( 184 | model.output_layer 185 | ) 186 | 187 | if args.load: 188 | model.load_state_dict(torch.load(args.load, map_location='cpu')) 189 | 190 | # tie weights again 191 | model.tie_weights() 192 | model.to(device) 193 | 194 | model_ = model 195 | if local_rank == 0: 196 | 197 | # dev = create_batches(dev_, 1) 198 | # test = create_batches(test_, 1) 199 | test = corpus.get_iterator( 200 | "test", eval_batch_size, eval_unroll_size, device=device 201 | ) 202 | dev_ppl, dev_loss = eval_model(model_, dev) 203 | test_ppl, test_loss = eval_model(model_, test) 204 | sys.stdout.write("dev_ppl={:.3f} test_ppl={:.3f}\n".format(dev_ppl, test_ppl)) 205 | 206 | 207 | if __name__ == "__main__": 208 | argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler="resolve") 209 | argparser.add_argument("--log", type=str, default="") 210 | argparser.add_argument("--noam", type=bool, default=True) 211 | argparser.add_argument("--warmup_steps", type=int, default=4000) 212 | argparser.add_argument("--layer_norm", type=bool, default=True) 213 | argparser.add_argument("--rescale", action="store_true") 214 | argparser.add_argument("--not_tie", action="store_true") 215 | argparser.add_argument("--data", type=str, required=True, help="training file") 216 | argparser.add_argument("--update_param_freq", type=int, default=2) 217 | argparser.add_argument("--batch_size", "--batch", type=int, default=32) 218 | argparser.add_argument("--eval_batch_size", type=int, default=10) 219 | argparser.add_argument("--unroll_size", type=int, default=256) 220 | argparser.add_argument("--eval_unroll_size", type=int, default=0) 221 | argparser.add_argument("--max_epoch", type=int, default=100) 222 | argparser.add_argument("--n_e", type=int, default=1024) 223 | argparser.add_argument("--n_d", "--d", type=int, default=2048) 224 | argparser.add_argument("--n_proj", type=int, default=512) 225 | argparser.add_argument("--div_val", type=float, default=4) 226 | argparser.add_argument( 227 | "--dropout", type=float, default=0.1, help="dropout probability" 228 | ) 229 | argparser.add_argument("--dropout_e", type=float, default=0.1) 230 | argparser.add_argument( 231 | "--bias", type=float, default=-3, help="intial bias of highway gates", 232 | ) 233 | argparser.add_argument("--depth", type=int, default=12) 234 | argparser.add_argument("--lr", type=float, default=2) 235 | argparser.add_argument("--weight_decay", type=float, default=0.01) 236 | argparser.add_argument("--clip_grad", type=float, default=0.3) 237 | argparser.add_argument("--log_period", type=int, default=1000000) 238 | argparser.add_argument("--save", type=str, default="") 239 | argparser.add_argument("--load", type=str, default="") 240 | 241 | argparser.add_argument("--prune", type=bool, default=True) 242 | argparser.add_argument("--prune_lr", type=float, default=2) 243 | argparser.add_argument("--prune_warmup", type=int, default=0) 244 | argparser.add_argument("--prune_start_epoch", type=int, default=0) 245 | argparser.add_argument("--prune_sparsity", type=float, default=0.8) 246 | argparser.add_argument("--prune_end_epoch", type=int, default=30) 247 | argparser.add_argument("--l1_lambda", type=float, default=0) 248 | 249 | argparser.add_argument("--local_rank", type=int, default=0) 250 | args = argparser.parse_args() 251 | 252 | dirname = os.path.dirname(args.data) 253 | with zipfile.ZipFile(args.data, 'r') as f: 254 | f.extractall(dirname) 255 | args.data = os.path.join(dirname, 'wikitext-103') 256 | os.makedirs(args.log, exist_ok=True) 257 | os.makedirs(args.save, exist_ok=True) 258 | print(args) 259 | main(args) 260 | -------------------------------------------------------------------------------- /examples/wt103/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import time 5 | import random 6 | import math 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn.init import xavier_uniform_ 13 | from torch.optim import Adam 14 | from tensorboardX import SummaryWriter 15 | 16 | import sru 17 | import flop 18 | from flop.embedding import AdaptiveEmbedding, AdaptiveLogSoftmax 19 | from flop.embedding import HardConcreteAdaptiveEmbedding, HardConcreteAdaptiveLogSoftmax 20 | from utils.data_utils import get_lm_corpus 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, args): 25 | super(Model, self).__init__() 26 | self.args = args 27 | #self.cutoffs = [20000, 60000] 28 | self.cutoffs = [10000, 20000, 40000, 60000, 100000] 29 | self.n_V = args.n_token 30 | self.n_e = args.n_e or args.n_proj 31 | self.n_d = args.n_d 32 | self.depth = args.depth 33 | self.drop = nn.Dropout(args.dropout) 34 | self.embedding_layer = AdaptiveEmbedding(self.n_V, 35 | self.n_e, 36 | self.n_d, 37 | self.cutoffs, 38 | div_val=args.div_val, 39 | div_freq=2, 40 | dropout=0.1 41 | ) 42 | self.rnn = sru.SRU(self.n_d, self.n_d, self.depth, 43 | projection_size=args.n_proj, 44 | dropout=args.dropout, 45 | highway_bias=args.bias, 46 | layer_norm=args.layer_norm, 47 | rescale=args.rescale, 48 | custom_m=flop.ProjectedLinear( 49 | self.n_d, self.n_d * 3, 50 | proj_features=args.n_proj, 51 | bias=False 52 | ) 53 | ) 54 | self.output_layer = AdaptiveLogSoftmax(self.n_V, 55 | self.n_e, 56 | self.n_d, 57 | self.cutoffs, 58 | div_val=args.div_val, 59 | div_freq=2, 60 | dropout=0.1, 61 | keep_order=False 62 | ) 63 | self.init_weights() 64 | if not args.not_tie: 65 | self.tie_weights() 66 | 67 | def tie_weights(self): 68 | for i in range(len(self.output_layer.out_layers)): 69 | self.embedding_layer.emb_layers[i].weight = self.output_layer.out_layers[i].weight 70 | 71 | for i in range(len(self.output_layer.out_projs)): 72 | self.embedding_layer.emb_projs[i] = self.output_layer.out_projs[i] 73 | 74 | if hasattr(self.embedding_layer, 'masks') and hasattr(self.output_layer, 'masks'): 75 | delattr(self.output_layer, 'masks') 76 | setattr(self.output_layer, 'masks', self.embedding_layer.masks) 77 | 78 | def init_weights(self, init_range=0.03, reinit_rnn=False): 79 | params = list(self.embedding_layer.parameters()) + list(self.output_layer.parameters()) 80 | for p in params: 81 | if p.dim() > 1: # matrix 82 | p.data.uniform_(-init_range, init_range) 83 | else: 84 | p.data.zero_() 85 | if reinit_rnn: 86 | for p in self.rnn.parameters(): 87 | if p.dim() > 1: # matrix 88 | p.data.uniform_(-init_range, init_range) 89 | 90 | def forward(self, x, y, hidden): 91 | emb = self.drop(self.embedding_layer(x)) 92 | output, hidden = self.rnn(emb, hidden) 93 | output = self.drop(output) 94 | output = output.view(-1, output.size(2)) 95 | loss = self.output_layer(output, y.view(-1)) 96 | loss = loss.view(y.size(0), -1) 97 | return loss, hidden 98 | 99 | def init_hidden(self, batch_size): 100 | weight = next(self.parameters()).data 101 | zeros = weight.new(self.depth, batch_size, self.n_d).zero_() 102 | return zeros 103 | 104 | def calc_norm(lis): 105 | l2_sum = sum(x.norm()**2 for x in lis) 106 | return l2_sum**0.5 107 | 108 | def eval_model(model, valid): 109 | with torch.no_grad(): 110 | # Important: reset compiled masks. When multiple GPUs are used, model() and para_model() 111 | # are not the same instance although they share the same parameters. 112 | # Calling model(..) in training mode will reset all compiled weights and cached masks 113 | for x, y, seq_len in valid: 114 | model(x, y, hidden=None) 115 | break 116 | model.eval() 117 | args = model.args 118 | batch_size = args.eval_batch_size or args.batch_size 119 | total_loss = 0.0 120 | total_tok = 0.0 121 | hidden = model.init_hidden(batch_size) 122 | for x, y, seq_len in valid: 123 | loss, hidden = model(x, y, hidden) 124 | total_loss += loss.sum().item() 125 | total_tok += y.numel() 126 | avg_loss = total_loss / total_tok 127 | ppl = np.exp(avg_loss) 128 | model.train() 129 | return ppl, avg_loss 130 | 131 | def copy_model(model): 132 | states = model.state_dict() 133 | for k in states: 134 | v = states[k] 135 | states[k] = v.clone().cpu() 136 | return states 137 | 138 | def main(args): 139 | log_path = "{}_{}".format(args.log, random.randint(1,100)) 140 | train_writer = SummaryWriter(log_dir=log_path+"/train") 141 | dev_writer = SummaryWriter(log_dir=log_path+"/dev") 142 | 143 | device = torch.device('cuda') 144 | corpus = get_lm_corpus(args.data, 'wt103') 145 | n_token = args.n_token = len(corpus.vocab) 146 | args.eval_batch_size = args.eval_batch_size or args.batch_size 147 | args.eval_unroll_size = args.eval_unroll_size or args.unroll_size 148 | train = corpus.get_iterator('train', args.batch_size, args.unroll_size, device=device) 149 | dev = corpus.get_iterator('valid', args.eval_batch_size, args.eval_unroll_size, device=device) 150 | test = corpus.get_iterator('test', args.eval_batch_size, args.eval_unroll_size, device=device) 151 | print("vocab size: {}".format(n_token)) 152 | 153 | model = Model(args) 154 | if args.load: 155 | model.load_state_dict(torch.load(args.load)) 156 | model.cuda() 157 | print(model) 158 | if torch.cuda.device_count() > 1: 159 | para_model = torch.nn.DataParallel(model, dim=1)#, output_device=1) 160 | else: 161 | para_model = model 162 | lr = 1.0 if not args.noam else 1.0/(args.n_d**0.5)/(args.warmup_steps**1.5) 163 | if args.prune: 164 | # in place substituion of linear ops in SRU 165 | flop.make_hard_concrete(model.rnn, in_place=True, init_mean=args.prune_init_mean) 166 | model.embedding_layer = HardConcreteAdaptiveEmbedding.from_module( 167 | model.embedding_layer, 168 | init_mean=args.prune_init_mean 169 | ) 170 | model.output_layer = HardConcreteAdaptiveLogSoftmax.from_module( 171 | model.output_layer, 172 | init_mean=args.prune_init_mean 173 | ) 174 | # tie weights again 175 | model.tie_weights() 176 | model.cuda() 177 | print("model after inserting hardconcrete:") 178 | print(model) 179 | hc_modules = flop.get_hardconcrete_modules(model.rnn) + flop.get_hardconcrete_modules(model.embedding_layer) 180 | print(len(flop.get_hardconcrete_modules(model))) 181 | print(len(hc_modules)) 182 | hc_parameters = [p for m in hc_modules for p in m.parameters() if p.requires_grad] 183 | optimizer_hc = Adam( 184 | hc_parameters, 185 | lr = lr * args.prune_lr, 186 | weight_decay = 0 187 | ) 188 | num_hardconcrete_params = sum(x.numel() for x in hc_parameters) 189 | print("num of hardconcrete paramters: {}".format(num_hardconcrete_params)) 190 | lambda_1 = nn.Parameter(torch.tensor(0.).cuda()) 191 | lambda_2 = nn.Parameter(torch.tensor(0.).cuda()) 192 | optimizer_max = Adam( 193 | [lambda_1, lambda_2], 194 | lr = lr, 195 | weight_decay = 0 196 | ) 197 | optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr 198 | hc_linear_modules = flop.get_hardconcrete_linear_modules(model) + \ 199 | [model.embedding_layer] 200 | num_prunable_params = sum(m.num_prunable_parameters() for m in hc_linear_modules) 201 | print("num of prunable paramters: {}".format(num_prunable_params)) 202 | else: 203 | args.prune_start_epoch = args.max_epoch 204 | 205 | m_parameters = [i[1] for i in model.named_parameters() if i[1].requires_grad and 'log_alpha' not in i[0]] 206 | optimizer = Adam( 207 | m_parameters, 208 | lr = lr * args.lr, 209 | weight_decay = args.weight_decay 210 | ) 211 | num_params = sum(x.numel() for x in m_parameters if x.requires_grad) 212 | print("num of parameters: {}".format(num_params)) 213 | 214 | nbatch = 1 215 | niter = 1 216 | best_dev = 1e+8 217 | unroll_size = args.unroll_size 218 | batch_size = args.batch_size 219 | N = train.n_batch 220 | checkpoint = None 221 | print("num of mini-batches: {}".format(N)) 222 | 223 | model.zero_grad() 224 | if args.prune: 225 | optimizer_max.zero_grad() 226 | optimizer_hc.zero_grad() 227 | 228 | for epoch in range(args.max_epoch): 229 | start_time = time.time() 230 | model.train() 231 | total_loss = 0.0 232 | hidden = model.init_hidden(batch_size) 233 | start_prune = epoch >= args.prune_start_epoch 234 | i = 0 235 | 236 | for x, y, seq_len in train: 237 | i += 1 238 | hidden.detach_() 239 | 240 | # language model forward and backward 241 | loss, hidden = para_model(x, y, hidden) 242 | loss = loss.mean() 243 | (loss / args.update_param_freq).backward() 244 | loss = loss.item() 245 | lagrangian_loss = 0 246 | target_sparsity = 0 247 | expected_sparsity = 0 248 | 249 | # add lagrangian loss (regularization) when pruning 250 | if start_prune: 251 | # compute target sparsity with (optionally) linear warmup 252 | target_sparsity = args.prune_sparsity 253 | if args.prune_warmup > 0: 254 | niter_ = niter - args.prune_start_epoch * N 255 | target_sparsity *= min(1.0, niter_ / args.prune_warmup) 256 | 257 | # compute expected model size and sparsity 258 | expected_size = sum(m.num_parameters(train=True) for m in hc_linear_modules) 259 | expected_sparsity = 1.0 - expected_size / num_prunable_params 260 | 261 | # compute lagrangian loss 262 | lagrangian_loss = lambda_1 * (expected_sparsity - target_sparsity) + \ 263 | lambda_2 * (expected_sparsity - target_sparsity)**2 264 | (lagrangian_loss / args.update_param_freq).backward() 265 | expected_sparsity = expected_sparsity.item() 266 | lagrangian_loss = lagrangian_loss.item() 267 | 268 | # log training stats 269 | if (niter - 1) % 100 == 0 and nbatch % args.update_param_freq == 0: 270 | if args.prune: 271 | train_writer.add_scalar('sparsity/expected_sparsity', expected_sparsity, niter) 272 | train_writer.add_scalar('sparsity/target_sparsity', target_sparsity, niter) 273 | train_writer.add_scalar('loss/lagrangian_loss', lagrangian_loss, niter) 274 | train_writer.add_scalar('lambda/1', lambda_1.item(), niter) 275 | train_writer.add_scalar('lambda/2', lambda_2.item(), niter) 276 | if (nbatch - 1) % 3000 == 0: 277 | for index, layer in enumerate(hc_modules): 278 | train_writer.add_histogram( 279 | 'log_alpha/{}'.format(index), 280 | layer.log_alpha, 281 | niter, 282 | bins='sqrt', 283 | ) 284 | sys.stderr.write("\r{:.4f} {:.2f} {:.2f} eta={:.1f}m".format( 285 | math.exp(loss), 286 | lagrangian_loss, 287 | expected_sparsity, 288 | (time.time()-start_time)/60.0/(i+1)*(N-i-1), 289 | )) 290 | train_writer.add_scalar('loss/ppl', math.exp(loss), niter) 291 | train_writer.add_scalar('loss/lm_loss', loss, niter) 292 | train_writer.add_scalar('loss/total_loss', loss + lagrangian_loss, niter) 293 | train_writer.add_scalar('parameter_norm', 294 | calc_norm([ x.data for x in m_parameters ]), 295 | niter 296 | ) 297 | train_writer.add_scalar('gradient_norm', 298 | calc_norm([ x.grad for x in m_parameters if x.grad is not None]), 299 | niter 300 | ) 301 | 302 | # perform gradient decent every few number of backward() 303 | if nbatch % args.update_param_freq == 0: 304 | if args.clip_grad > 0: 305 | torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad) 306 | optimizer.step() 307 | if start_prune: 308 | optimizer_max.step() 309 | optimizer_hc.step() 310 | # clear gradient 311 | model.zero_grad() 312 | if args.prune: 313 | optimizer_max.zero_grad() 314 | optimizer_hc.zero_grad() 315 | niter += 1 316 | 317 | if nbatch % args.log_period == 0 or i == N: 318 | elapsed_time = (time.time()-start_time)/60.0 319 | dev_ppl, dev_loss = eval_model(model, dev) 320 | dev_writer.add_scalar('loss/lm_loss', dev_loss, niter) 321 | dev_writer.add_scalar('loss/ppl', dev_ppl, niter) 322 | dev_writer.add_scalar('ppl', dev_ppl, niter) 323 | sparsity = 0 324 | if args.prune: 325 | pruned_size = sum(m.num_parameters(train=False) for m in hc_linear_modules) 326 | sparsity = 1.0 - pruned_size / num_prunable_params 327 | dev_writer.add_scalar('sparsity/hard_sparsity', sparsity, niter) 328 | dev_writer.add_scalar('model_size/total_prunable', num_prunable_params, niter) 329 | dev_writer.add_scalar('model_size/current_prunable', pruned_size, niter) 330 | dev_writer.add_scalar('model_size/total', num_params, niter) 331 | dev_writer.add_scalar('model_size/current', 332 | num_params - num_prunable_params + pruned_size, 333 | niter 334 | ) 335 | dev_writer.add_scalar('model_size/current_embedding', 336 | model.embedding_layer.num_parameters(train=False), 337 | niter 338 | ) 339 | dev_writer.add_scalar('model_size/current_output_layer', 340 | model.output_layer.num_parameters(train=False), 341 | niter 342 | ) 343 | sys.stdout.write("\rnum_batches={} lr={:.5f} train_loss={:.4f} dev_loss={:.4f}" 344 | " dev_bpc={:.2f} sparsity={:.2f}\t[{:.1f}m]\n".format( 345 | nbatch, 346 | optimizer.param_groups[0]['lr'], 347 | loss, 348 | dev_loss, 349 | dev_ppl, 350 | sparsity, 351 | elapsed_time 352 | )) 353 | if dev_ppl < best_dev: 354 | best_dev = dev_ppl 355 | checkpoint = copy_model(model) 356 | sys.stdout.write("\n") 357 | sys.stdout.flush() 358 | 359 | nbatch += 1 360 | if args.noam: 361 | lr = min(1.0 / (niter**0.5), niter / (args.warmup_steps**1.5)) 362 | optimizer.param_groups[0]['lr'] = lr * args.lr / (args.n_d**0.5) 363 | if args.noam and start_prune: 364 | niter_ = niter - args.prune_start_epoch * N 365 | lr = min(1.0 / (niter_**0.5), niter_ / (args.warmup_steps**1.5)) 366 | optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr / (args.n_d**0.5) 367 | optimizer_hc.param_groups[0]['lr'] = lr * args.lr / (args.n_d**0.5) 368 | 369 | if args.save and (epoch + 1) % 5 == 0: 370 | torch.save(checkpoint, "{}.{}.{}.pt".format( 371 | args.save, 372 | epoch + 1, 373 | int(dev_ppl) 374 | #sparsity 375 | )) 376 | 377 | train_writer.close() 378 | dev_writer.close() 379 | 380 | if checkpoint is not None: 381 | model.load_state_dict(checkpoint) 382 | model.cuda() 383 | #dev = create_batches(dev_, 1) 384 | #test = create_batches(test_, 1) 385 | dev_ppl, dev_loss = eval_model(model, dev) 386 | test_ppl, test_loss = eval_model(model, test) 387 | sys.stdout.write("dev_ppl={:.3f} test_ppl={:.3f}\n".format( 388 | dev_ppl, test_ppl 389 | )) 390 | 391 | if __name__ == "__main__": 392 | argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler='resolve') 393 | argparser.add_argument("--log", type=str, required=True) 394 | argparser.add_argument("--noam", action="store_true") 395 | argparser.add_argument("--warmup_steps", type=int, default=32000) 396 | argparser.add_argument("--layer_norm", action="store_true") 397 | argparser.add_argument("--rescale", action="store_true") 398 | argparser.add_argument("--not_tie", action="store_true") 399 | argparser.add_argument("--data", type=str, required=True, help="training file") 400 | argparser.add_argument("--update_param_freq", type=int, default=1) 401 | argparser.add_argument("--batch_size", "--batch", type=int, default=64) 402 | argparser.add_argument("--eval_batch_size", type=int, default=10) 403 | argparser.add_argument("--unroll_size", type=int, default=128) 404 | argparser.add_argument("--eval_unroll_size", type=int, default=0) 405 | argparser.add_argument("--max_epoch", type=int, default=100) 406 | argparser.add_argument("--n_e", type=int, default=0) 407 | argparser.add_argument("--n_d", "--d", type=int, default=2048) 408 | argparser.add_argument("--n_proj", type=int, default=512) 409 | argparser.add_argument("--div_val", type=float, default=2) 410 | argparser.add_argument("--dropout", type=float, default=0.1, 411 | help="dropout probability" 412 | ) 413 | argparser.add_argument("--bias", type=float, default=-3, 414 | help="intial bias of highway gates", 415 | ) 416 | argparser.add_argument("--depth", type=int, default=6) 417 | argparser.add_argument("--lr", type=float, default=0.00025) 418 | argparser.add_argument("--weight_decay", type=float, default=0) 419 | argparser.add_argument("--clip_grad", type=float, default=0.3) 420 | argparser.add_argument("--log_period", type=int, default=1000000) 421 | argparser.add_argument("--save", type=str, default="") 422 | argparser.add_argument("--load", type=str, default="") 423 | 424 | argparser.add_argument("--prune", action="store_true") 425 | argparser.add_argument("--prune_lr", type=float, default=3) 426 | argparser.add_argument("--prune_warmup", type=int, default=0) 427 | argparser.add_argument("--prune_sparsity", type=float, default=0.) 428 | argparser.add_argument("--prune_init_mean", type=float, default=0.1) 429 | argparser.add_argument("--prune_start_epoch", type=int, default=0) 430 | 431 | args = argparser.parse_args() 432 | print (args) 433 | main(args) 434 | -------------------------------------------------------------------------------- /examples/wt103/train_agp_unstruct.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import time 4 | import random 5 | import math 6 | import os 7 | import zipfile 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from tensorboardX import SummaryWriter 13 | 14 | import sru 15 | import flop 16 | from flop.embedding import AdaptiveEmbedding, AdaptiveLogSoftmax 17 | from flop.scripts.wt103.utils.data_utils import get_lm_corpus 18 | 19 | 20 | class Model(nn.Module): 21 | def __init__(self, args): 22 | super(Model, self).__init__() 23 | self.args = args 24 | # self.cutoffs = [20000, 60000] 25 | self.cutoffs = [10000, 20000, 40000, 60000, 100000] 26 | self.n_V = args.n_token 27 | self.n_e = args.n_e or args.n_proj 28 | self.n_d = args.n_d 29 | self.depth = args.depth 30 | self.drop = nn.Dropout(args.dropout) 31 | self.embedding_layer = AdaptiveEmbedding( 32 | self.n_V, 33 | self.n_e, 34 | self.n_d, 35 | self.cutoffs, 36 | div_val=args.div_val, 37 | div_freq=2, 38 | dropout=args.dropout_e, 39 | ) 40 | self.rnn = sru.SRU( 41 | self.n_d, 42 | self.n_d, 43 | self.depth, 44 | projection_size=args.n_proj, 45 | dropout=args.dropout, 46 | highway_bias=args.bias, 47 | layer_norm=args.layer_norm, 48 | rescale=args.rescale, 49 | custom_m=flop.ProjectedLinear( 50 | self.n_d, self.n_d * 3, proj_features=args.n_proj, bias=False 51 | ), 52 | ) 53 | self.output_layer = AdaptiveLogSoftmax( 54 | self.n_V, 55 | self.n_e, 56 | self.n_d, 57 | self.cutoffs, 58 | div_val=args.div_val, 59 | div_freq=2, 60 | dropout=args.dropout_e, 61 | keep_order=False, 62 | ) 63 | self.init_weights() 64 | if not args.not_tie: 65 | self.tie_weights() 66 | 67 | def tie_weights(self): 68 | for i in range(len(self.output_layer.out_layers)): 69 | self.embedding_layer.emb_layers[i].weight = self.output_layer.out_layers[ 70 | i 71 | ].weight 72 | 73 | for i in range(len(self.output_layer.out_projs)): 74 | self.embedding_layer.emb_projs[i] = self.output_layer.out_projs[i] 75 | 76 | if hasattr(self.embedding_layer, "masks") and hasattr( 77 | self.output_layer, "masks" 78 | ): 79 | delattr(self.output_layer, "masks") 80 | setattr(self.output_layer, "masks", self.embedding_layer.masks) 81 | 82 | def init_weights(self, init_range=0.03, reinit_rnn=False): 83 | params = list(self.embedding_layer.parameters()) + list( 84 | self.output_layer.parameters() 85 | ) 86 | for p in params: 87 | if p.dim() > 1: # matrix 88 | p.data.uniform_(-init_range, init_range) 89 | else: 90 | p.data.zero_() 91 | if reinit_rnn: 92 | for p in self.rnn.parameters(): 93 | if p.dim() > 1: # matrix 94 | p.data.uniform_(-init_range, init_range) 95 | 96 | def forward(self, x, y, hidden): 97 | emb = self.drop(self.embedding_layer(x)) 98 | output, hidden = self.rnn(emb, hidden) 99 | output = self.drop(output) 100 | output = output.view(-1, output.size(2)) 101 | loss = self.output_layer(output, y.view(-1)) 102 | loss = loss.view(y.size(0), -1) 103 | return loss, hidden 104 | 105 | def init_hidden(self, batch_size): 106 | weight = next(self.parameters()).data 107 | zeros = weight.new(self.depth, batch_size, self.n_d).zero_() 108 | return zeros 109 | 110 | 111 | def calc_norm(lis): 112 | l2_sum = sum(x.norm() ** 2 for x in lis) 113 | return l2_sum ** 0.5 114 | 115 | 116 | def eval_model(model, valid): 117 | with torch.no_grad(): 118 | # Important: reset compiled masks. When multiple GPUs are used, model() and DDP model() 119 | # are not the same instance although they share the same parameters. 120 | # Calling model(..) in training mode will reset all compiled weights and cached masks 121 | for x, y, seq_len in valid: 122 | model(x, y, hidden=None) 123 | break 124 | model.eval() 125 | args = model.args 126 | batch_size = args.eval_batch_size or args.batch_size 127 | total_loss = 0.0 128 | total_tok = 0.0 129 | hidden = model.init_hidden(batch_size) 130 | for x, y, seq_len in valid: 131 | loss, hidden = model(x, y, hidden) 132 | total_loss += loss.sum().item() 133 | total_tok += y.numel() 134 | avg_loss = total_loss / total_tok 135 | ppl = np.exp(avg_loss) 136 | model.train() 137 | return ppl, avg_loss 138 | 139 | 140 | def copy_model(model): 141 | states = model.state_dict() 142 | for k in states: 143 | v = states[k] 144 | states[k] = v.clone().cpu() 145 | return states 146 | 147 | 148 | def set_seed(seed): 149 | random.seed(seed) 150 | np.random.seed(seed) 151 | torch.manual_seed(seed) 152 | torch.cuda.manual_seed_all(seed) 153 | 154 | 155 | def main(args): 156 | 157 | if args.local_rank == 0: 158 | log_path = "{}_{}".format(args.log, random.randint(1, 100)) 159 | train_writer = SummaryWriter(log_dir=log_path + "/train") 160 | dev_writer = SummaryWriter(log_dir=log_path + "/dev") 161 | 162 | # set up distributed training 163 | # torch.cuda.set_device(args.local_rank) 164 | device = 'cuda' 165 | # torch.distributed.init_process_group(backend="nccl") 166 | set_seed(1234) 167 | args.n_gpu = 1 168 | args.device = device 169 | local_rank = args.local_rank 170 | 171 | corpus = get_lm_corpus(args.data, "wt103") 172 | n_token = args.n_token = len(corpus.vocab) 173 | args.eval_batch_size = args.eval_batch_size or args.batch_size 174 | args.eval_unroll_size = args.eval_unroll_size or args.unroll_size 175 | unroll_size = args.unroll_size 176 | eval_unroll_size = args.eval_unroll_size 177 | batch_size = args.batch_size 178 | eval_batch_size = args.eval_batch_size 179 | # n_nodes = torch.cuda.device_count() 180 | # train = corpus.get_distributed_iterator('train', batch_size, 181 | # unroll_size, n_nodes=n_nodes, 182 | # rank=local_rank, device=device) 183 | train = corpus.get_iterator("train", batch_size, unroll_size, device=device) 184 | dev = corpus.get_iterator("test", eval_batch_size, eval_unroll_size, device=device) 185 | if local_rank == 0: 186 | print("vocab size: {}".format(n_token)) 187 | 188 | model = Model(args) 189 | if args.load: 190 | model.load_state_dict(torch.load(args.load)) 191 | lr = 1.0 if not args.noam else 1.0 / (args.n_d ** 0.5) / (args.warmup_steps ** 1.5) 192 | if args.prune: 193 | # tie weights again 194 | model.tie_weights() 195 | model.to(device) 196 | num_mask_params = sum(x.numel() for x in model.parameters()) 197 | num_prunable_params = num_mask_params 198 | if local_rank == 0: 199 | print("num of mask parameters: {}".format(num_mask_params)) 200 | print("num of prunable parameters: {}".format(num_prunable_params)) 201 | 202 | mask_param_names = [ 203 | i[0] 204 | for i in model.named_parameters() 205 | if i[1].requires_grad 206 | ] 207 | pruner = flop.NervanaPruner( 208 | model, 209 | subpruners={ 210 | "agppruner": { 211 | "class": "AutomatedGradualPruner", 212 | "initial_sparsity": 0.05, 213 | "weights": mask_param_names, 214 | "final_sparsity": args.prune_sparsity, 215 | "starting_step": args.prune_start_epoch, 216 | "ending_step": args.prune_end_epoch, 217 | "frequency": 1, 218 | } 219 | }, 220 | ) 221 | else: 222 | model.to(device) 223 | args.prune_start_epoch = args.max_epoch 224 | 225 | m_parameters = [ 226 | i[1] 227 | for i in model.named_parameters() 228 | if i[1].requires_grad 229 | ] 230 | optimizer = torch.optim.Adam(m_parameters, lr=lr * args.lr, weight_decay=args.weight_decay) 231 | num_params = sum(x.numel() for x in m_parameters if x.requires_grad) 232 | 233 | model_ = model 234 | # model = torch.nn.parallel.DistributedDataParallel( 235 | # model, dim=1, device_ids=[local_rank], output_device=local_rank, 236 | # ) 237 | model = nn.DataParallel(model, dim=1).to('cuda') 238 | nbatch = 1 239 | niter = 1 240 | best_dev = 1e8 241 | unroll_size = args.unroll_size 242 | batch_size = args.batch_size 243 | N = train.n_batch 244 | checkpoint = None 245 | if local_rank == 0: 246 | print(model) 247 | print("num of parameters: {}".format(num_params)) 248 | print("num of mini-batches: {}".format(N)) 249 | 250 | model.zero_grad() 251 | 252 | for epoch in range(args.max_epoch): 253 | start_time = time.time() 254 | model.train() 255 | hidden = model_.init_hidden(batch_size) 256 | i = 0 257 | 258 | pruner.begin_step(epoch) 259 | 260 | for x, y, seq_len in train: 261 | # start iter on the first batch 262 | if nbatch % args.update_param_freq == 1: 263 | pruner.begin_iter(epoch, niter, N // args.update_param_freq) 264 | 265 | i += 1 266 | hidden.detach_() 267 | 268 | # language model forward and backward 269 | model_loss, hidden = model(x, y, hidden) 270 | 271 | model_loss = model_loss.mean() 272 | loss = model_loss 273 | 274 | (loss / args.update_param_freq).backward() 275 | model_loss = model_loss.item() 276 | 277 | # log training stats 278 | if ( 279 | local_rank == 0 280 | and (niter - 1) % 100 == 0 281 | and nbatch % args.update_param_freq == 0 282 | ): 283 | sys.stderr.write( 284 | "\r{:.4f} eta={:.1f}m".format( 285 | math.exp(model_loss), 286 | (time.time() - start_time) / 60.0 / (i + 1) * (N - i - 1), 287 | ) 288 | ) 289 | train_writer.add_scalar("loss/ppl", math.exp(model_loss), niter) 290 | train_writer.add_scalar( 291 | "loss/total_loss", model_loss, niter 292 | ) 293 | train_writer.add_scalar( 294 | "parameter_norm", calc_norm([x.data for x in m_parameters]), niter 295 | ) 296 | train_writer.add_scalar( 297 | "gradieånt_norm", 298 | calc_norm([x.grad for x in m_parameters if x.grad is not None]), 299 | niter, 300 | ) 301 | 302 | # perform gradient decent every few number of backward() 303 | if nbatch % args.update_param_freq == 0: 304 | if args.clip_grad > 0: 305 | torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad) 306 | optimizer.step() 307 | # clear gradient 308 | model.zero_grad() 309 | 310 | # End iter on the last batch 311 | pruner.end_iter(epoch, niter, N // args.update_param_freq) 312 | niter += 1 313 | 314 | if local_rank == 0 and (nbatch % args.log_period == 0 or i == N): 315 | elapsed_time = (time.time() - start_time) / 60.0 316 | dev_ppl, dev_loss = eval_model(model_, dev) 317 | dev_writer.add_scalar("loss/lm_loss", dev_loss, niter) 318 | dev_writer.add_scalar("loss/ppl", dev_ppl, niter) 319 | dev_writer.add_scalar("ppl", dev_ppl, niter) 320 | sparsity = 0 321 | if args.prune: 322 | agp_sparsity = pruner.get_step_logs()['sparsity'] 323 | dev_writer.add_scalar("sparsity/hard_sparsity", agp_sparsity, niter) 324 | dev_writer.add_scalar( 325 | "model_size/total_prunable", num_prunable_params, niter 326 | ) 327 | dev_writer.add_scalar("model_size/total", num_params, niter) 328 | sys.stdout.write( 329 | "\rnum_batches={} lr={:.5f} train_loss={:.4f} dev_loss={:.4f}" 330 | " dev_bpc={:.2f} sparsity={:.2f}\t[{:.1f}m]\n".format( 331 | nbatch, 332 | optimizer.param_groups[0]["lr"], 333 | loss, 334 | dev_loss, 335 | dev_ppl, 336 | sparsity, 337 | elapsed_time, 338 | ) 339 | ) 340 | checkpoint = copy_model(model_) 341 | sys.stdout.write("\n") 342 | sys.stdout.flush() 343 | 344 | nbatch += 1 345 | if args.noam: 346 | lr = min(1.0 / (niter ** 0.5), niter / (args.warmup_steps ** 1.5)) 347 | optimizer.param_groups[0]["lr"] = lr * args.lr / (args.n_d ** 0.5) 348 | 349 | pruner.end_step(epoch) 350 | if local_rank == 0 and args.save and (epoch + 1) % 10 == 0: 351 | torch.save( 352 | checkpoint, 353 | "{}.{}.{}.pt".format( 354 | args.save, 355 | epoch + 1, 356 | int(dev_ppl) 357 | # sparsity 358 | ), 359 | ) 360 | 361 | if local_rank == 0: 362 | train_writer.close() 363 | dev_writer.close() 364 | 365 | if checkpoint is not None: 366 | model_.load_state_dict(checkpoint) 367 | model_.to(device) 368 | # dev = create_batches(dev_, 1) 369 | # test = create_batches(test_, 1) 370 | test = corpus.get_iterator( 371 | "test", eval_batch_size, eval_unroll_size, device=device 372 | ) 373 | dev_ppl, dev_loss = eval_model(model_, dev) 374 | test_ppl, test_loss = eval_model(model_, test) 375 | sys.stdout.write("dev_ppl={:.3f} test_ppl={:.3f}\n".format(dev_ppl, test_ppl)) 376 | 377 | 378 | if __name__ == "__main__": 379 | argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler="resolve") 380 | argparser.add_argument("--log", type=str, default="") 381 | argparser.add_argument("--noam", type=bool, default=True) 382 | argparser.add_argument("--warmup_steps", type=int, default=4000) 383 | argparser.add_argument("--layer_norm", type=bool, default=True) 384 | argparser.add_argument("--rescale", action="store_true") 385 | argparser.add_argument("--not_tie", action="store_true") 386 | argparser.add_argument("--data", type=str, required=True, help="training file") 387 | argparser.add_argument("--update_param_freq", type=int, default=2) 388 | argparser.add_argument("--batch_size", "--batch", type=int, default=32) 389 | argparser.add_argument("--eval_batch_size", type=int, default=10) 390 | argparser.add_argument("--unroll_size", type=int, default=256) 391 | argparser.add_argument("--eval_unroll_size", type=int, default=0) 392 | argparser.add_argument("--max_epoch", type=int, default=100) 393 | argparser.add_argument("--n_e", type=int, default=1024) 394 | argparser.add_argument("--n_d", "--d", type=int, default=2048) 395 | argparser.add_argument("--n_proj", type=int, default=512) 396 | argparser.add_argument("--div_val", type=float, default=4) 397 | argparser.add_argument( 398 | "--dropout", type=float, default=0.1, help="dropout probability" 399 | ) 400 | argparser.add_argument("--dropout_e", type=float, default=0.1) 401 | argparser.add_argument( 402 | "--bias", type=float, default=-3, help="intial bias of highway gates", 403 | ) 404 | argparser.add_argument("--depth", type=int, default=12) 405 | argparser.add_argument("--lr", type=float, default=2) 406 | argparser.add_argument("--weight_decay", type=float, default=0.01) 407 | argparser.add_argument("--clip_grad", type=float, default=0.3) 408 | argparser.add_argument("--log_period", type=int, default=1000000) 409 | argparser.add_argument("--save", type=str, default="") 410 | argparser.add_argument("--load", type=str, default="") 411 | 412 | argparser.add_argument("--prune", type=bool, default=True) 413 | argparser.add_argument("--prune_lr", type=float, default=2) 414 | argparser.add_argument("--prune_warmup", type=int, default=0) 415 | argparser.add_argument("--prune_start_epoch", type=int, default=0) 416 | argparser.add_argument("--prune_sparsity", type=float, default=0.8) 417 | argparser.add_argument("--prune_end_epoch", type=int, default=30) 418 | argparser.add_argument("--l1_lambda", type=float, default=0) 419 | 420 | argparser.add_argument("--local_rank", type=int, default=0) 421 | args = argparser.parse_args() 422 | 423 | dirname = os.path.dirname(args.data) 424 | with zipfile.ZipFile(args.data, 'r') as f: 425 | f.extractall(dirname) 426 | args.data = os.path.join(dirname, 'wikitext-103') 427 | os.makedirs(args.log, exist_ok=True) 428 | os.makedirs(args.save, exist_ok=True) 429 | print(args) 430 | main(args) 431 | -------------------------------------------------------------------------------- /examples/wt103/train_distributed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import time 5 | import random 6 | import math 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn.init import xavier_uniform_ 13 | from torch.optim import Adam 14 | from tensorboardX import SummaryWriter 15 | 16 | import sru 17 | import flop 18 | from flop.embedding import AdaptiveEmbedding, AdaptiveLogSoftmax 19 | from flop.embedding import HardConcreteAdaptiveEmbedding, HardConcreteAdaptiveLogSoftmax 20 | from utils.data_utils import get_lm_corpus 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, args): 25 | super(Model, self).__init__() 26 | self.args = args 27 | #self.cutoffs = [20000, 60000] 28 | self.cutoffs = [10000, 20000, 40000, 60000, 100000] 29 | self.n_V = args.n_token 30 | self.n_e = args.n_e or args.n_proj 31 | self.n_d = args.n_d 32 | self.depth = args.depth 33 | self.drop = nn.Dropout(args.dropout) 34 | self.embedding_layer = AdaptiveEmbedding(self.n_V, 35 | self.n_e, 36 | self.n_d, 37 | self.cutoffs, 38 | div_val=args.div_val, 39 | div_freq=2, 40 | dropout=args.dropout_e 41 | ) 42 | self.rnn = sru.SRU(self.n_d, self.n_d, self.depth, 43 | projection_size=args.n_proj, 44 | dropout=args.dropout, 45 | highway_bias=args.bias, 46 | layer_norm=args.layer_norm, 47 | rescale=args.rescale, 48 | custom_m=flop.ProjectedLinear( 49 | self.n_d, self.n_d * 3, 50 | proj_features=args.n_proj, 51 | bias=False 52 | ) 53 | ) 54 | self.output_layer = AdaptiveLogSoftmax(self.n_V, 55 | self.n_e, 56 | self.n_d, 57 | self.cutoffs, 58 | div_val=args.div_val, 59 | div_freq=2, 60 | dropout=args.dropout_e, 61 | keep_order=False 62 | ) 63 | self.init_weights() 64 | if not args.not_tie: 65 | self.tie_weights() 66 | 67 | def tie_weights(self): 68 | for i in range(len(self.output_layer.out_layers)): 69 | self.embedding_layer.emb_layers[i].weight = self.output_layer.out_layers[i].weight 70 | 71 | for i in range(len(self.output_layer.out_projs)): 72 | self.embedding_layer.emb_projs[i] = self.output_layer.out_projs[i] 73 | 74 | if hasattr(self.embedding_layer, 'masks') and hasattr(self.output_layer, 'masks'): 75 | delattr(self.output_layer, 'masks') 76 | setattr(self.output_layer, 'masks', self.embedding_layer.masks) 77 | 78 | def init_weights(self, init_range=0.03, reinit_rnn=False): 79 | params = list(self.embedding_layer.parameters()) + list(self.output_layer.parameters()) 80 | for p in params: 81 | if p.dim() > 1: # matrix 82 | p.data.uniform_(-init_range, init_range) 83 | else: 84 | p.data.zero_() 85 | if reinit_rnn: 86 | for p in self.rnn.parameters(): 87 | if p.dim() > 1: # matrix 88 | p.data.uniform_(-init_range, init_range) 89 | 90 | def forward(self, x, y, hidden): 91 | emb = self.drop(self.embedding_layer(x)) 92 | output, hidden = self.rnn(emb, hidden) 93 | output = self.drop(output) 94 | output = output.view(-1, output.size(2)) 95 | loss = self.output_layer(output, y.view(-1)) 96 | loss = loss.view(y.size(0), -1) 97 | return loss, hidden 98 | 99 | def init_hidden(self, batch_size): 100 | weight = next(self.parameters()).data 101 | zeros = weight.new(self.depth, batch_size, self.n_d).zero_() 102 | return zeros 103 | 104 | def calc_norm(lis): 105 | l2_sum = sum(x.norm()**2 for x in lis) 106 | return l2_sum**0.5 107 | 108 | def eval_model(model, valid): 109 | with torch.no_grad(): 110 | # Important: reset compiled masks. When multiple GPUs are used, model() and DDP model() 111 | # are not the same instance although they share the same parameters. 112 | # Calling model(..) in training mode will reset all compiled weights and cached masks 113 | for x, y, seq_len in valid: 114 | model(x, y, hidden=None) 115 | break 116 | model.eval() 117 | args = model.args 118 | batch_size = args.eval_batch_size or args.batch_size 119 | total_loss = 0.0 120 | total_tok = 0.0 121 | hidden = model.init_hidden(batch_size) 122 | for x, y, seq_len in valid: 123 | loss, hidden = model(x, y, hidden) 124 | total_loss += loss.sum().item() 125 | total_tok += y.numel() 126 | avg_loss = total_loss / total_tok 127 | ppl = np.exp(avg_loss) 128 | model.train() 129 | return ppl, avg_loss 130 | 131 | def copy_model(model): 132 | states = model.state_dict() 133 | for k in states: 134 | v = states[k] 135 | states[k] = v.clone().cpu() 136 | return states 137 | 138 | def set_seed(seed): 139 | random.seed(seed) 140 | np.random.seed(seed) 141 | torch.manual_seed(seed) 142 | torch.cuda.manual_seed_all(seed) 143 | 144 | def main(args): 145 | 146 | if args.local_rank == 0: 147 | log_path = "{}_{}".format(args.log, random.randint(1,100)) 148 | train_writer = SummaryWriter(log_dir=log_path+"/train") 149 | dev_writer = SummaryWriter(log_dir=log_path+"/dev") 150 | 151 | # set up distributed training 152 | torch.cuda.set_device(args.local_rank) 153 | device = torch.device("cuda", args.local_rank) 154 | torch.distributed.init_process_group(backend="nccl") 155 | set_seed(1234) 156 | args.n_gpu = 1 157 | args.device = device 158 | local_rank = args.local_rank 159 | 160 | corpus = get_lm_corpus(args.data, 'wt103') 161 | n_token = args.n_token = len(corpus.vocab) 162 | args.eval_batch_size = args.eval_batch_size or args.batch_size 163 | args.eval_unroll_size = args.eval_unroll_size or args.unroll_size 164 | unroll_size = args.unroll_size 165 | eval_unroll_size = args.eval_unroll_size 166 | batch_size = args.batch_size 167 | eval_batch_size = args.eval_batch_size 168 | n_nodes = torch.cuda.device_count() 169 | train = corpus.get_distributed_iterator('train', batch_size, 170 | unroll_size, n_nodes=n_nodes, 171 | rank=local_rank, device=device) 172 | dev = corpus.get_iterator('valid', eval_batch_size, eval_unroll_size, device=device) 173 | if local_rank == 0: 174 | print("vocab size: {}".format(n_token)) 175 | 176 | model = Model(args) 177 | if args.load: 178 | model.load_state_dict(torch.load(args.load)) 179 | lr = 1.0 if not args.noam else 1.0/(args.n_d**0.5)/(args.warmup_steps**1.5) 180 | if args.prune: 181 | # in place substituion of linear ops in SRU 182 | flop.make_hard_concrete(model.rnn, in_place=True, init_mean=args.prune_init_mean) 183 | model.embedding_layer = HardConcreteAdaptiveEmbedding.from_module( 184 | model.embedding_layer, 185 | init_mean=args.prune_init_mean 186 | ) 187 | model.output_layer = HardConcreteAdaptiveLogSoftmax.from_module( 188 | model.output_layer, 189 | init_mean=args.prune_init_mean 190 | ) 191 | # tie weights again 192 | model.tie_weights() 193 | model.to(device) 194 | hc_modules = flop.get_hardconcrete_modules(model.rnn) + flop.get_hardconcrete_modules(model.embedding_layer) 195 | #print(len(flop.get_hardconcrete_modules(model))) 196 | #print(len(hc_modules)) 197 | hc_parameters = [p for m in hc_modules for p in m.parameters() if p.requires_grad] 198 | optimizer_hc = torch.optim.Adam( 199 | hc_parameters, 200 | lr = lr * args.prune_lr, 201 | weight_decay = 0 202 | ) 203 | 204 | lambda_1 = nn.Parameter(torch.tensor(0.).cuda()) 205 | lambda_2 = nn.Parameter(torch.tensor(0.).cuda()) 206 | optimizer_max = torch.optim.Adam( 207 | [lambda_1, lambda_2], 208 | lr = lr, 209 | weight_decay = 0 210 | ) 211 | optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr 212 | hc_linear_modules = flop.get_hardconcrete_linear_modules(model) + \ 213 | [model.embedding_layer] 214 | 215 | num_hardconcrete_params = sum(x.numel() for x in hc_parameters) 216 | num_prunable_params = sum(m.num_prunable_parameters() for m in hc_linear_modules) 217 | if local_rank == 0: 218 | print("num of hardconcrete paramters: {}".format(num_hardconcrete_params)) 219 | print("num of prunable paramters: {}".format(num_prunable_params)) 220 | else: 221 | model.to(device) 222 | args.prune_start_epoch = args.max_epoch 223 | 224 | m_parameters = [i[1] for i in model.named_parameters() if i[1].requires_grad and 'log_alpha' not in i[0]] 225 | optimizer = torch.optim.Adam( 226 | m_parameters, 227 | lr = lr * args.lr, 228 | weight_decay = args.weight_decay 229 | ) 230 | num_params = sum(x.numel() for x in m_parameters if x.requires_grad) 231 | 232 | model_ = model 233 | model = torch.nn.parallel.DistributedDataParallel( 234 | model, 235 | dim=1, 236 | device_ids=[local_rank], 237 | output_device=local_rank, 238 | ) 239 | 240 | nbatch = 1 241 | niter = 1 242 | best_dev = 1e+8 243 | unroll_size = args.unroll_size 244 | batch_size = args.batch_size 245 | N = train.n_batch 246 | checkpoint = None 247 | if local_rank == 0: 248 | print(model) 249 | print("num of parameters: {}".format(num_params)) 250 | print("num of mini-batches: {}".format(N)) 251 | 252 | model.zero_grad() 253 | if args.prune: 254 | optimizer_max.zero_grad() 255 | optimizer_hc.zero_grad() 256 | 257 | for epoch in range(args.max_epoch): 258 | start_time = time.time() 259 | model.train() 260 | total_loss = 0.0 261 | hidden = model_.init_hidden(batch_size) 262 | start_prune = epoch >= args.prune_start_epoch 263 | i = 0 264 | 265 | for x, y, seq_len in train: 266 | i += 1 267 | hidden.detach_() 268 | 269 | # language model forward and backward 270 | loss, hidden = model(x, y, hidden) 271 | loss = loss.mean() 272 | (loss / args.update_param_freq).backward() 273 | loss = loss.item() 274 | lagrangian_loss = 0 275 | target_sparsity = 0 276 | expected_sparsity = 0 277 | 278 | # add lagrangian loss (regularization) when pruning 279 | if start_prune: 280 | # compute target sparsity with (optionally) linear warmup 281 | target_sparsity = args.prune_sparsity 282 | if args.prune_warmup > 0: 283 | niter_ = niter - args.prune_start_epoch * N 284 | target_sparsity *= min(1.0, niter_ / args.prune_warmup) 285 | 286 | # compute expected model size and sparsity 287 | expected_size = sum(m.num_parameters(train=True) for m in hc_linear_modules) 288 | expected_sparsity = 1.0 - expected_size / num_prunable_params 289 | 290 | # compute lagrangian loss 291 | lagrangian_loss = lambda_1 * (expected_sparsity - target_sparsity) + \ 292 | lambda_2 * (expected_sparsity - target_sparsity)**2 * args.prune_beta 293 | (lagrangian_loss / args.update_param_freq).backward() 294 | expected_sparsity = expected_sparsity.item() 295 | lagrangian_loss = lagrangian_loss.item() 296 | 297 | # log training stats 298 | if local_rank == 0 and (niter - 1) % 100 == 0 and nbatch % args.update_param_freq == 0: 299 | if args.prune: 300 | train_writer.add_scalar('sparsity/expected_sparsity', expected_sparsity, niter) 301 | train_writer.add_scalar('sparsity/target_sparsity', target_sparsity, niter) 302 | train_writer.add_scalar('loss/lagrangian_loss', lagrangian_loss, niter) 303 | train_writer.add_scalar('lambda/1', lambda_1.item(), niter) 304 | train_writer.add_scalar('lambda/2', lambda_2.item(), niter) 305 | if (nbatch - 1) % 3000 == 0: 306 | for index, layer in enumerate(hc_modules): 307 | train_writer.add_histogram( 308 | 'log_alpha/{}'.format(index), 309 | layer.log_alpha, 310 | niter, 311 | bins='sqrt', 312 | ) 313 | sys.stderr.write("\r{:.4f} {:.2f} {:.2f} eta={:.1f}m".format( 314 | math.exp(loss), 315 | lagrangian_loss, 316 | expected_sparsity, 317 | (time.time()-start_time)/60.0/(i+1)*(N-i-1), 318 | )) 319 | train_writer.add_scalar('loss/ppl', math.exp(loss), niter) 320 | train_writer.add_scalar('loss/lm_loss', loss, niter) 321 | train_writer.add_scalar('loss/total_loss', loss + lagrangian_loss, niter) 322 | train_writer.add_scalar('parameter_norm', 323 | calc_norm([ x.data for x in m_parameters ]), 324 | niter 325 | ) 326 | train_writer.add_scalar('gradient_norm', 327 | calc_norm([ x.grad for x in m_parameters if x.grad is not None]), 328 | niter 329 | ) 330 | 331 | # perform gradient decent every few number of backward() 332 | if nbatch % args.update_param_freq == 0: 333 | if args.clip_grad > 0: 334 | torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad) 335 | optimizer.step() 336 | if start_prune: 337 | optimizer_max.step() 338 | optimizer_hc.step() 339 | # clear gradient 340 | model.zero_grad() 341 | if args.prune: 342 | optimizer_max.zero_grad() 343 | optimizer_hc.zero_grad() 344 | niter += 1 345 | 346 | if local_rank == 0 and (nbatch % args.log_period == 0 or i == N): 347 | elapsed_time = (time.time()-start_time)/60.0 348 | dev_ppl, dev_loss = eval_model(model_, dev) 349 | dev_writer.add_scalar('loss/lm_loss', dev_loss, niter) 350 | dev_writer.add_scalar('loss/ppl', dev_ppl, niter) 351 | dev_writer.add_scalar('ppl', dev_ppl, niter) 352 | sparsity = 0 353 | if args.prune: 354 | pruned_size = sum(m.num_parameters(train=False) for m in hc_linear_modules) 355 | sparsity = 1.0 - pruned_size / num_prunable_params 356 | dev_writer.add_scalar('sparsity/hard_sparsity', sparsity, niter) 357 | dev_writer.add_scalar('model_size/total_prunable', num_prunable_params, niter) 358 | dev_writer.add_scalar('model_size/current_prunable', pruned_size, niter) 359 | dev_writer.add_scalar('model_size/total', num_params, niter) 360 | dev_writer.add_scalar('model_size/current', 361 | num_params - num_prunable_params + pruned_size, 362 | niter 363 | ) 364 | dev_writer.add_scalar('model_size/current_embedding', 365 | model_.embedding_layer.num_parameters(train=False), 366 | niter 367 | ) 368 | dev_writer.add_scalar('model_size/current_output_layer', 369 | model_.output_layer.num_parameters(train=False), 370 | niter 371 | ) 372 | sys.stdout.write("\rnum_batches={} lr={:.5f} train_loss={:.4f} dev_loss={:.4f}" 373 | " dev_bpc={:.2f} sparsity={:.2f}\t[{:.1f}m]\n".format( 374 | nbatch, 375 | optimizer.param_groups[0]['lr'], 376 | loss, 377 | dev_loss, 378 | dev_ppl, 379 | sparsity, 380 | elapsed_time 381 | )) 382 | if dev_ppl < best_dev: 383 | if (not args.prune) or sparsity > args.prune_sparsity - 0.005: 384 | best_dev = dev_ppl 385 | checkpoint = copy_model(model_) 386 | sys.stdout.write("\n") 387 | sys.stdout.flush() 388 | 389 | nbatch += 1 390 | if args.noam: 391 | lr = min(1.0 / (niter**0.5), niter / (args.warmup_steps**1.5)) 392 | optimizer.param_groups[0]['lr'] = lr * args.lr / (args.n_d**0.5) 393 | if args.noam and start_prune: 394 | niter_ = niter - args.prune_start_epoch * N 395 | lr = min(1.0 / (niter_**0.5), niter_ / (args.warmup_steps**1.5)) 396 | optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr / (args.n_d**0.5) 397 | optimizer_hc.param_groups[0]['lr'] = lr * args.lr / (args.n_d**0.5) 398 | 399 | if local_rank == 0 and args.save and checkpoint is not None: 400 | torch.save(checkpoint, "{}.pt".format( 401 | args.save, 402 | )) 403 | 404 | if local_rank == 0: 405 | train_writer.close() 406 | dev_writer.close() 407 | 408 | if checkpoint is not None: 409 | model_.load_state_dict(checkpoint) 410 | model_.to(device) 411 | #dev = create_batches(dev_, 1) 412 | #test = create_batches(test_, 1) 413 | test = corpus.get_iterator('test', eval_batch_size, eval_unroll_size, device=device) 414 | dev_ppl, dev_loss = eval_model(model_, dev) 415 | test_ppl, test_loss = eval_model(model_, test) 416 | sys.stdout.write("dev_ppl={:.3f} test_ppl={:.3f}\n".format( 417 | dev_ppl, test_ppl 418 | )) 419 | 420 | if __name__ == "__main__": 421 | argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler='resolve') 422 | argparser.add_argument("--log", type=str, required=True) 423 | argparser.add_argument("--noam", action="store_true") 424 | argparser.add_argument("--warmup_steps", type=int, default=4000) 425 | argparser.add_argument("--layer_norm", action="store_true") 426 | argparser.add_argument("--rescale", action="store_true") 427 | argparser.add_argument("--not_tie", action="store_true") 428 | argparser.add_argument("--data", type=str, required=True, help="training file") 429 | argparser.add_argument("--update_param_freq", type=int, default=1) 430 | argparser.add_argument("--batch_size", "--batch", type=int, default=24) 431 | argparser.add_argument("--eval_batch_size", type=int, default=10) 432 | argparser.add_argument("--unroll_size", type=int, default=256) 433 | argparser.add_argument("--eval_unroll_size", type=int, default=0) 434 | argparser.add_argument("--max_epoch", type=int, default=100) 435 | argparser.add_argument("--n_e", type=int, default=1024) 436 | argparser.add_argument("--n_d", "--d", type=int, default=2048) 437 | argparser.add_argument("--n_proj", type=int, default=512) 438 | argparser.add_argument("--div_val", type=float, default=4) 439 | argparser.add_argument("--dropout", type=float, default=0.1, 440 | help="dropout probability" 441 | ) 442 | argparser.add_argument("--dropout_e", type=float, default=0.1) 443 | argparser.add_argument("--bias", type=float, default=-3, 444 | help="intial bias of highway gates", 445 | ) 446 | argparser.add_argument("--depth", type=int, default=12) 447 | argparser.add_argument("--lr", type=float, default=2) 448 | argparser.add_argument("--weight_decay", type=float, default=0) 449 | argparser.add_argument("--clip_grad", type=float, default=0.3) 450 | argparser.add_argument("--log_period", type=int, default=1000000) 451 | argparser.add_argument("--save", type=str, default="") 452 | argparser.add_argument("--load", type=str, default="") 453 | 454 | argparser.add_argument("--prune", action="store_true") 455 | argparser.add_argument("--prune_lr", type=float, default=5) 456 | argparser.add_argument("--prune_beta", type=float, default=1) 457 | argparser.add_argument("--prune_warmup", type=int, default=64000) 458 | argparser.add_argument("--prune_sparsity", type=float, default=0.8) 459 | argparser.add_argument("--prune_init_mean", type=float, default=0.5) 460 | argparser.add_argument("--prune_start_epoch", type=int, default=0) 461 | 462 | argparser.add_argument("--local_rank", type=int, default=0) 463 | args = argparser.parse_args() 464 | print (args) 465 | main(args) 466 | -------------------------------------------------------------------------------- /examples/wt103/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/flop/bdfc1845dbdddde70e65ce5a98ef7d0070833541/examples/wt103/utils/__init__.py -------------------------------------------------------------------------------- /examples/wt103/utils/data_parallel.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code taken from Transformer-XL 3 | https://github.com/kimiyoung/transformer-xl 4 | ''' 5 | 6 | from torch.nn.parallel import DataParallel 7 | import torch 8 | from torch.nn.parallel._functions import Scatter 9 | from torch.nn.parallel.parallel_apply import parallel_apply 10 | 11 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 12 | r""" 13 | Slices tensors into approximately equal chunks and 14 | distributes them across given GPUs. Duplicates 15 | references to objects that are not tensors. 16 | """ 17 | def scatter_map(obj): 18 | if isinstance(obj, torch.Tensor): 19 | try: 20 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 21 | except: 22 | print('obj', obj.size()) 23 | print('dim', dim) 24 | print('chunk_sizes', chunk_sizes) 25 | quit() 26 | if isinstance(obj, tuple) and len(obj) > 0: 27 | return list(zip(*map(scatter_map, obj))) 28 | if isinstance(obj, list) and len(obj) > 0: 29 | return list(map(list, zip(*map(scatter_map, obj)))) 30 | if isinstance(obj, dict) and len(obj) > 0: 31 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 32 | return [obj for targets in target_gpus] 33 | 34 | # After scatter_map is called, a scatter_map cell will exist. This cell 35 | # has a reference to the actual function scatter_map, which has references 36 | # to a closure that has a reference to the scatter_map cell (because the 37 | # fn is recursive). To avoid this reference cycle, we set the function to 38 | # None, clearing the cell 39 | try: 40 | return scatter_map(inputs) 41 | finally: 42 | scatter_map = None 43 | 44 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 45 | r"""Scatter with support for kwargs dictionary""" 46 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 47 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 48 | if len(inputs) < len(kwargs): 49 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 50 | elif len(kwargs) < len(inputs): 51 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 52 | inputs = tuple(inputs) 53 | kwargs = tuple(kwargs) 54 | return inputs, kwargs 55 | 56 | class BalancedDataParallel(DataParallel): 57 | def __init__(self, gpu0_bsz, *args, **kwargs): 58 | self.gpu0_bsz = gpu0_bsz 59 | super().__init__(*args, **kwargs) 60 | 61 | def forward(self, *inputs, **kwargs): 62 | if not self.device_ids: 63 | return self.module(*inputs, **kwargs) 64 | if self.gpu0_bsz == 0: 65 | device_ids = self.device_ids[1:] 66 | else: 67 | device_ids = self.device_ids 68 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 69 | if len(self.device_ids) == 1: 70 | return self.module(*inputs[0], **kwargs[0]) 71 | replicas = self.replicate(self.module, self.device_ids) 72 | if self.gpu0_bsz == 0: 73 | replicas = replicas[1:] 74 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 75 | return self.gather(outputs, self.output_device) 76 | 77 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 78 | return parallel_apply(replicas, inputs, kwargs, device_ids) 79 | 80 | def scatter(self, inputs, kwargs, device_ids): 81 | bsz = inputs[0].size(self.dim) 82 | num_dev = len(self.device_ids) 83 | gpu0_bsz = self.gpu0_bsz 84 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 85 | if gpu0_bsz < bsz_unit: 86 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 87 | delta = bsz - sum(chunk_sizes) 88 | for i in range(delta): 89 | chunk_sizes[i + 1] += 1 90 | if gpu0_bsz == 0: 91 | chunk_sizes = chunk_sizes[1:] 92 | else: 93 | return super().scatter(inputs, kwargs, device_ids) 94 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 95 | 96 | -------------------------------------------------------------------------------- /examples/wt103/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code taken from Transformer-XL 3 | https://github.com/kimiyoung/transformer-xl 4 | ''' 5 | 6 | import os, sys 7 | import glob 8 | 9 | from collections import Counter, OrderedDict 10 | import numpy as np 11 | import torch 12 | 13 | class Vocab(object): 14 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, 15 | delimiter=None, vocab_file=None): 16 | self.counter = Counter() 17 | self.special = special 18 | self.min_freq = min_freq 19 | self.max_size = max_size 20 | self.lower_case = lower_case 21 | self.delimiter = delimiter 22 | self.vocab_file = vocab_file 23 | 24 | def tokenize(self, line, add_eos=False, add_double_eos=False): 25 | line = line.strip() 26 | # convert to lower case 27 | if self.lower_case: 28 | line = line.lower() 29 | 30 | # empty delimiter '' will evaluate False 31 | if self.delimiter == '': 32 | symbols = line 33 | else: 34 | symbols = line.split(self.delimiter) 35 | 36 | if add_double_eos: # lm1b 37 | return [''] + symbols + [''] 38 | elif add_eos: 39 | return symbols + [''] 40 | else: 41 | return symbols 42 | 43 | def count_file(self, path, verbose=False, add_eos=False): 44 | if verbose: print('counting file {} ...'.format(path)) 45 | assert os.path.exists(path) 46 | 47 | sents = [] 48 | with open(path, 'r', encoding='utf-8') as f: 49 | for idx, line in enumerate(f): 50 | if verbose and idx > 0 and idx % 500000 == 0: 51 | print(' line {}'.format(idx)) 52 | symbols = self.tokenize(line, add_eos=add_eos) 53 | self.counter.update(symbols) 54 | sents.append(symbols) 55 | 56 | return sents 57 | 58 | def count_sents(self, sents, verbose=False): 59 | """ 60 | sents : a list of sentences, each a list of tokenized symbols 61 | """ 62 | if verbose: print('counting {} sents ...'.format(len(sents))) 63 | for idx, symbols in enumerate(sents): 64 | if verbose and idx > 0 and idx % 500000 == 0: 65 | print(' line {}'.format(idx)) 66 | self.counter.update(symbols) 67 | 68 | def _build_from_file(self, vocab_file): 69 | self.idx2sym = [] 70 | self.sym2idx = OrderedDict() 71 | 72 | with open(vocab_file, 'r', encoding='utf-8') as f: 73 | for line in f: 74 | symb = line.strip().split()[0] 75 | self.add_symbol(symb) 76 | self.unk_idx = self.sym2idx[''] 77 | 78 | def build_vocab(self): 79 | if self.vocab_file: 80 | print('building vocab from {}'.format(self.vocab_file)) 81 | self._build_from_file(self.vocab_file) 82 | print('final vocab size {}'.format(len(self))) 83 | else: 84 | print('building vocab with min_freq={}, max_size={}'.format( 85 | self.min_freq, self.max_size)) 86 | self.idx2sym = [] 87 | self.sym2idx = OrderedDict() 88 | 89 | for sym in self.special: 90 | self.add_special(sym) 91 | 92 | for sym, cnt in self.counter.most_common(self.max_size): 93 | if cnt < self.min_freq: break 94 | self.add_symbol(sym) 95 | 96 | print('final vocab size {} from {} unique tokens'.format( 97 | len(self), len(self.counter))) 98 | 99 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 100 | add_double_eos=False): 101 | if verbose: print('encoding file {} ...'.format(path)) 102 | assert os.path.exists(path) 103 | encoded = [] 104 | with open(path, 'r', encoding='utf-8') as f: 105 | for idx, line in enumerate(f): 106 | if verbose and idx > 0 and idx % 500000 == 0: 107 | print(' line {}'.format(idx)) 108 | symbols = self.tokenize(line, add_eos=add_eos, 109 | add_double_eos=add_double_eos) 110 | encoded.append(self.convert_to_tensor(symbols)) 111 | 112 | if ordered: 113 | encoded = torch.cat(encoded) 114 | 115 | return encoded 116 | 117 | def encode_sents(self, sents, ordered=False, verbose=False): 118 | if verbose: print('encoding {} sents ...'.format(len(sents))) 119 | encoded = [] 120 | for idx, symbols in enumerate(sents): 121 | if verbose and idx > 0 and idx % 500000 == 0: 122 | print(' line {}'.format(idx)) 123 | encoded.append(self.convert_to_tensor(symbols)) 124 | 125 | if ordered: 126 | encoded = torch.cat(encoded) 127 | 128 | return encoded 129 | 130 | def add_special(self, sym): 131 | if sym not in self.sym2idx: 132 | self.idx2sym.append(sym) 133 | self.sym2idx[sym] = len(self.idx2sym) - 1 134 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 135 | 136 | def add_symbol(self, sym): 137 | if sym not in self.sym2idx: 138 | self.idx2sym.append(sym) 139 | self.sym2idx[sym] = len(self.idx2sym) - 1 140 | 141 | def get_sym(self, idx): 142 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) 143 | return self.idx2sym[idx] 144 | 145 | def get_idx(self, sym): 146 | if sym in self.sym2idx: 147 | return self.sym2idx[sym] 148 | else: 149 | # print('encounter unk {}'.format(sym)) 150 | assert '' not in sym 151 | assert hasattr(self, 'unk_idx') 152 | return self.sym2idx.get(sym, self.unk_idx) 153 | 154 | def get_symbols(self, indices): 155 | return [self.get_sym(idx) for idx in indices] 156 | 157 | def get_indices(self, symbols): 158 | return [self.get_idx(sym) for sym in symbols] 159 | 160 | def convert_to_tensor(self, symbols): 161 | return torch.LongTensor(self.get_indices(symbols)) 162 | 163 | def convert_to_sent(self, indices, exclude=None): 164 | if exclude is None: 165 | return ' '.join([self.get_sym(idx) for idx in indices]) 166 | else: 167 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 168 | 169 | def __len__(self): 170 | return len(self.idx2sym) 171 | 172 | class LMOrderedIterator(object): 173 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): 174 | """ 175 | data -- LongTensor -- the LongTensor is strictly ordered 176 | """ 177 | self.bsz = bsz 178 | self.bptt = bptt 179 | self.ext_len = ext_len if ext_len is not None else 0 180 | 181 | self.device = device 182 | 183 | # Work out how cleanly we can divide the dataset into bsz parts. 184 | self.n_step = data.size(0) // bsz 185 | 186 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 187 | data = data.narrow(0, 0, self.n_step * bsz) 188 | 189 | # Evenly divide the data across the bsz batches. 190 | self.data = data.view(bsz, -1).t().contiguous().to(device) 191 | 192 | # Number of mini-batches 193 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 194 | 195 | def get_batch(self, i, bptt=None): 196 | if bptt is None: bptt = self.bptt 197 | seq_len = min(bptt, self.data.size(0) - 1 - i) 198 | 199 | end_idx = i + seq_len 200 | beg_idx = max(0, i - self.ext_len) 201 | 202 | data = self.data[beg_idx:end_idx] 203 | target = self.data[i+1:i+1+seq_len] 204 | 205 | return data, target, seq_len 206 | 207 | def get_fixlen_iter(self, start=0): 208 | for i in range(start, self.data.size(0) - 1, self.bptt): 209 | yield self.get_batch(i) 210 | 211 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 212 | max_len = self.bptt + max_deviation * std 213 | i = start 214 | while True: 215 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 216 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 217 | data, target, seq_len = self.get_batch(i, bptt) 218 | i += seq_len 219 | yield data, target, seq_len 220 | if i >= self.data.size(0) - 2: 221 | break 222 | 223 | def __iter__(self): 224 | return self.get_fixlen_iter() 225 | 226 | 227 | class DistributedLMOrderedIterator(object): 228 | def __init__(self, data, bsz, bptt, n_nodes=1, rank=0, device='cpu', ext_len=None): 229 | """ 230 | data -- LongTensor -- the LongTensor is strictly ordered 231 | """ 232 | ebsz = bsz * n_nodes 233 | self.ebsz = ebsz 234 | self.bsz = bsz 235 | self.bptt = bptt 236 | self.rank = rank 237 | self.ext_len = ext_len if ext_len is not None else 0 238 | 239 | self.device = device 240 | 241 | # Work out how cleanly we can divide the dataset into bsz parts. 242 | self.n_step = data.size(0) // ebsz 243 | 244 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 245 | data = data.narrow(0, 0, self.n_step * ebsz) 246 | data = data.view(n_nodes, bsz, -1)[rank] 247 | 248 | # Evenly divide the data across the bsz batches. 249 | self.data = data.view(bsz, -1).t().contiguous().to(device) 250 | 251 | # Number of mini-batches 252 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 253 | 254 | def get_batch(self, i, bptt=None): 255 | if bptt is None: bptt = self.bptt 256 | seq_len = min(bptt, self.data.size(0) - 1 - i) 257 | 258 | end_idx = i + seq_len 259 | beg_idx = max(0, i - self.ext_len) 260 | 261 | data = self.data[beg_idx:end_idx] 262 | target = self.data[i+1:i+1+seq_len] 263 | 264 | return data, target, seq_len 265 | 266 | def get_fixlen_iter(self, start=0): 267 | for i in range(start, self.data.size(0) - 1, self.bptt): 268 | yield self.get_batch(i) 269 | 270 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 271 | max_len = self.bptt + max_deviation * std 272 | i = start 273 | while True: 274 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 275 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 276 | data, target, seq_len = self.get_batch(i, bptt) 277 | i += seq_len 278 | yield data, target, seq_len 279 | if i >= self.data.size(0) - 2: 280 | break 281 | 282 | def __iter__(self): 283 | return self.get_fixlen_iter() 284 | 285 | 286 | class LMShuffledIterator(object): 287 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): 288 | """ 289 | data -- list[LongTensor] -- there is no order among the LongTensors 290 | """ 291 | self.data = data 292 | 293 | self.bsz = bsz 294 | self.bptt = bptt 295 | self.ext_len = ext_len if ext_len is not None else 0 296 | 297 | self.device = device 298 | self.shuffle = shuffle 299 | 300 | def get_sent_stream(self): 301 | # index iterator 302 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ 303 | else np.array(range(len(self.data))) 304 | 305 | # sentence iterator 306 | for idx in epoch_indices: 307 | yield self.data[idx] 308 | 309 | def stream_iterator(self, sent_stream): 310 | # streams for each data in the batch 311 | streams = [None] * self.bsz 312 | 313 | data = torch.LongTensor(self.bptt, self.bsz) 314 | target = torch.LongTensor(self.bptt, self.bsz) 315 | 316 | n_retain = 0 317 | 318 | while True: 319 | # data : [n_retain+bptt x bsz] 320 | # target : [bptt x bsz] 321 | data[n_retain:].fill_(-1) 322 | target.fill_(-1) 323 | 324 | valid_batch = True 325 | 326 | for i in range(self.bsz): 327 | n_filled = 0 328 | try: 329 | while n_filled < self.bptt: 330 | if streams[i] is None or len(streams[i]) <= 1: 331 | streams[i] = next(sent_stream) 332 | # number of new tokens to fill in 333 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled) 334 | # first n_retain tokens are retained from last batch 335 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ 336 | streams[i][:n_new] 337 | target[n_filled:n_filled+n_new, i] = \ 338 | streams[i][1:n_new+1] 339 | streams[i] = streams[i][n_new:] 340 | n_filled += n_new 341 | except StopIteration: 342 | valid_batch = False 343 | break 344 | 345 | if not valid_batch: 346 | return 347 | 348 | data = data.to(self.device) 349 | target = target.to(self.device) 350 | 351 | yield data, target, self.bptt 352 | 353 | n_retain = min(data.size(0), self.ext_len) 354 | if n_retain > 0: 355 | data[:n_retain] = data[-n_retain:] 356 | data.resize_(n_retain + self.bptt, data.size(1)) 357 | 358 | def __iter__(self): 359 | # sent_stream is an iterator 360 | sent_stream = self.get_sent_stream() 361 | 362 | for batch in self.stream_iterator(sent_stream): 363 | yield batch 364 | 365 | 366 | class LMMultiFileIterator(LMShuffledIterator): 367 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, 368 | shuffle=False): 369 | 370 | self.paths = paths 371 | self.vocab = vocab 372 | 373 | self.bsz = bsz 374 | self.bptt = bptt 375 | self.ext_len = ext_len if ext_len is not None else 0 376 | 377 | self.device = device 378 | self.shuffle = shuffle 379 | 380 | def get_sent_stream(self, path): 381 | sents = self.vocab.encode_file(path, add_double_eos=True) 382 | if self.shuffle: 383 | np.random.shuffle(sents) 384 | sent_stream = iter(sents) 385 | 386 | return sent_stream 387 | 388 | def __iter__(self): 389 | if self.shuffle: 390 | np.random.shuffle(self.paths) 391 | 392 | for path in self.paths: 393 | # sent_stream is an iterator 394 | sent_stream = self.get_sent_stream(path) 395 | for batch in self.stream_iterator(sent_stream): 396 | yield batch 397 | 398 | 399 | class Corpus(object): 400 | def __init__(self, path, dataset, *args, **kwargs): 401 | self.dataset = dataset 402 | self.vocab = Vocab(*args, **kwargs) 403 | 404 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: 405 | self.vocab.count_file(os.path.join(path, 'train.txt')) 406 | self.vocab.count_file(os.path.join(path, 'valid.txt')) 407 | self.vocab.count_file(os.path.join(path, 'test.txt')) 408 | elif self.dataset == 'wt103': 409 | self.vocab.count_file(os.path.join(path, 'wiki.train.tokens')) 410 | elif self.dataset == 'lm1b': 411 | train_path_pattern = os.path.join( 412 | path, '1-billion-word-language-modeling-benchmark-r13output', 413 | 'training-monolingual.tokenized.shuffled', 'news.en-*') 414 | train_paths = glob.glob(train_path_pattern) 415 | # the vocab will load from file when build_vocab() is called 416 | 417 | self.vocab.build_vocab() 418 | 419 | if self.dataset in ['ptb', 'wt2']: 420 | self.train = self.vocab.encode_file( 421 | os.path.join(path, 'train.txt'), ordered=True) 422 | self.valid = self.vocab.encode_file( 423 | os.path.join(path, 'valid.txt'), ordered=True) 424 | self.test = self.vocab.encode_file( 425 | os.path.join(path, 'test.txt'), ordered=True) 426 | elif self.dataset == 'wt103': 427 | self.train = self.vocab.encode_file( 428 | os.path.join(path, 'wiki.train.tokens'), ordered=True) 429 | self.valid = self.vocab.encode_file( 430 | os.path.join(path, 'wiki.valid.tokens'), ordered=True) 431 | self.test = self.vocab.encode_file( 432 | os.path.join(path, 'wiki.test.tokens'), ordered=True) 433 | elif self.dataset in ['enwik8', 'text8']: 434 | self.train = self.vocab.encode_file( 435 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False) 436 | self.valid = self.vocab.encode_file( 437 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) 438 | self.test = self.vocab.encode_file( 439 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False) 440 | elif self.dataset == 'lm1b': 441 | self.train = train_paths 442 | self.valid = self.vocab.encode_file( 443 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) 444 | self.test = self.vocab.encode_file( 445 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) 446 | 447 | def get_iterator(self, split, *args, **kwargs): 448 | if split == 'train': 449 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 450 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 451 | elif self.dataset == 'lm1b': 452 | kwargs['shuffle'] = True 453 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 454 | elif split in ['valid', 'test']: 455 | data = self.valid if split == 'valid' else self.test 456 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 457 | data_iter = LMOrderedIterator(data, *args, **kwargs) 458 | elif self.dataset == 'lm1b': 459 | data_iter = LMShuffledIterator(data, *args, **kwargs) 460 | 461 | return data_iter 462 | 463 | def get_distributed_iterator(self, split, *args, **kwargs): 464 | data_iter = None 465 | if split == 'train': 466 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 467 | data_iter = DistributedLMOrderedIterator(self.train, *args, **kwargs) 468 | #elif self.dataset == 'lm1b': 469 | # kwargs['shuffle'] = True 470 | # data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 471 | elif split in ['valid', 'test']: 472 | data = self.valid if split == 'valid' else self.test 473 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 474 | data_iter = DistributedLMOrderedIterator(data, *args, **kwargs) 475 | #elif self.dataset == 'lm1b': 476 | # data_iter = LMShuffledIterator(data, *args, **kwargs) 477 | 478 | return data_iter 479 | 480 | 481 | def get_lm_corpus(datadir, dataset): 482 | fn = os.path.join(datadir, 'cache.pt') 483 | if os.path.exists(fn): 484 | print('Loading cached dataset...') 485 | corpus = torch.load(fn) 486 | else: 487 | print('Producing dataset {}...'.format(dataset)) 488 | kwargs = {} 489 | if dataset in ['wt103', 'wt2']: 490 | kwargs['special'] = [''] 491 | kwargs['lower_case'] = False 492 | elif dataset == 'ptb': 493 | kwargs['special'] = [''] 494 | kwargs['lower_case'] = True 495 | elif dataset == 'lm1b': 496 | kwargs['special'] = [] 497 | kwargs['lower_case'] = False 498 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') 499 | elif dataset in ['enwik8', 'text8']: 500 | pass 501 | 502 | corpus = Corpus(datadir, dataset, **kwargs) 503 | torch.save(corpus, fn) 504 | 505 | return corpus 506 | 507 | if __name__ == '__main__': 508 | import argparse 509 | parser = argparse.ArgumentParser(description='unit test') 510 | parser.add_argument('--datadir', type=str, default='../data/text8', 511 | help='location of the data corpus') 512 | parser.add_argument('--dataset', type=str, default='text8', 513 | choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'], 514 | help='dataset name') 515 | args = parser.parse_args() 516 | 517 | corpus = get_lm_corpus(args.datadir, args.dataset) 518 | print('Vocab size : {}'.format(len(corpus.vocab.idx2sym))) 519 | -------------------------------------------------------------------------------- /examples/wt103/utils/log_uniform_sampler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code taken from Transformer-XL 3 | https://github.com/kimiyoung/transformer-xl 4 | ''' 5 | 6 | import torch 7 | from torch import nn 8 | import numpy as np 9 | 10 | class LogUniformSampler(object): 11 | def __init__(self, range_max, n_sample): 12 | """ 13 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 14 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 15 | 16 | expected count can be approximated by 1 - (1 - p)^n 17 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 18 | 19 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 20 | """ 21 | with torch.no_grad(): 22 | self.range_max = range_max 23 | log_indices = torch.arange(1., range_max+2., 1.).log_() 24 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 25 | # print('P', self.dist.numpy().tolist()[-30:]) 26 | 27 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 28 | 29 | self.n_sample = n_sample 30 | 31 | def sample(self, labels): 32 | """ 33 | labels: [b1, b2] 34 | Return 35 | true_log_probs: [b1, b2] 36 | samp_log_probs: [n_sample] 37 | neg_samples: [n_sample] 38 | """ 39 | 40 | # neg_samples = torch.empty(0).long() 41 | n_sample = self.n_sample 42 | n_tries = 2 * n_sample 43 | 44 | with torch.no_grad(): 45 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 46 | device = labels.device 47 | neg_samples = neg_samples.to(device) 48 | true_log_probs = self.log_q[labels].to(device) 49 | samp_log_probs = self.log_q[neg_samples].to(device) 50 | return true_log_probs, samp_log_probs, neg_samples 51 | 52 | def sample_logits(embedding, bias, labels, inputs, sampler): 53 | """ 54 | embedding: an nn.Embedding layer 55 | bias: [n_vocab] 56 | labels: [b1, b2] 57 | inputs: [b1, b2, n_emb] 58 | sampler: you may use a LogUniformSampler 59 | Return 60 | logits: [b1, b2, 1 + n_sample] 61 | """ 62 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 63 | n_sample = neg_samples.size(0) 64 | b1, b2 = labels.size(0), labels.size(1) 65 | all_ids = torch.cat([labels.view(-1), neg_samples]) 66 | all_w = embedding(all_ids) 67 | true_w = all_w[: -n_sample].view(b1, b2, -1) 68 | sample_w = all_w[- n_sample:].view(n_sample, -1) 69 | 70 | all_b = bias[all_ids] 71 | true_b = all_b[: -n_sample].view(b1, b2) 72 | sample_b = all_b[- n_sample:] 73 | 74 | hit = (labels[:, :, None] == neg_samples).detach() 75 | 76 | true_logits = torch.einsum('ijk,ijk->ij', 77 | [true_w, inputs]) + true_b - true_log_probs 78 | sample_logits = torch.einsum('lk,ijk->ijl', 79 | [sample_w, inputs]) + sample_b - samp_log_probs 80 | sample_logits.masked_fill_(hit, -1e30) 81 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 82 | 83 | return logits 84 | 85 | 86 | # class LogUniformSampler(object): 87 | # def __init__(self, range_max, unique=False): 88 | # """ 89 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 90 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 91 | # """ 92 | # self.range_max = range_max 93 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 94 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 95 | 96 | # self.unique = unique 97 | 98 | # if self.unique: 99 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 100 | 101 | # def sample(self, n_sample, labels): 102 | # pos_sample, new_labels = labels.unique(return_inverse=True) 103 | # n_pos_sample = pos_sample.size(0) 104 | # n_neg_sample = n_sample - n_pos_sample 105 | 106 | # if self.unique: 107 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 108 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 109 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 110 | # else: 111 | # sample_dist = self.dist 112 | 113 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 114 | 115 | # sample = torch.cat([pos_sample, neg_sample]) 116 | # sample_prob = self.dist[sample] 117 | 118 | # return new_labels, sample, sample_prob 119 | 120 | 121 | if __name__ == '__main__': 122 | S, B = 3, 4 123 | n_vocab = 10000 124 | n_sample = 5 125 | H = 32 126 | 127 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 128 | 129 | # sampler = LogUniformSampler(n_vocab, unique=False) 130 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 131 | 132 | sampler = LogUniformSampler(n_vocab, unique=True) 133 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 134 | 135 | # print('true_probs', true_probs.numpy().tolist()) 136 | # print('samp_probs', samp_probs.numpy().tolist()) 137 | # print('neg_samples', neg_samples.numpy().tolist()) 138 | 139 | # print('sum', torch.sum(sampler.dist).item()) 140 | 141 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 142 | 143 | embedding = nn.Embedding(n_vocab, H) 144 | bias = torch.zeros(n_vocab) 145 | inputs = torch.Tensor(S, B, H).normal_() 146 | 147 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 148 | print('logits', logits.detach().numpy().tolist()) 149 | print('logits shape', logits.size()) 150 | print('out_labels', out_labels.detach().numpy().tolist()) 151 | print('out_labels shape', out_labels.size()) 152 | 153 | -------------------------------------------------------------------------------- /flop/__init__.py: -------------------------------------------------------------------------------- 1 | from flop.hardconcrete import HardConcrete 2 | from flop.linear import ProjectedLinear, HardConcreteProjectedLinear 3 | from flop.linear import HardConcreteLinear, ProjectedLinearWithMask 4 | from flop.train import HardConcreteTrainer 5 | from flop.utils import make_hard_concrete, make_projected_linear, make_projected_linear_with_mask 6 | from flop.utils import get_hardconcrete_modules, get_hardconcrete_proj_linear_modules 7 | from flop.utils import get_hardconcrete_linear_modules 8 | from flop.utils import get_projected_linear_with_mask_modules 9 | from flop.utils import get_projected_linear_masks 10 | from flop.agp import NervanaPruner 11 | 12 | 13 | __all__ = ['HardConcrete', 'ProjectedLinear', 'HardConcreteLinear', 14 | 'HardConcreteProjectedLinear', 'HardConcreteTrainer', 'ProjectedLinearWithMask', 15 | 'make_hard_concrete', 'make_projected_linear', 16 | 'get_hardconcrete_modules', 'get_hardconcrete_proj_linear_modules', 17 | 'get_hardconcrete_linear_modules', 18 | 'get_projected_linear_with_mask_modules', 'get_projected_linear_masks', 19 | 'make_projected_linear_with_mask', 20 | 'NervanaPruner'] 21 | -------------------------------------------------------------------------------- /flop/agp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch.nn as nn 4 | 5 | try: 6 | import distiller 7 | from distiller.config import file_config, dict_config 8 | except: 9 | print("distiller not installed.") 10 | 11 | 12 | class NervanaPruner(object): 13 | def __init__(self, model: nn.Module, subpruners: Dict[str, Dict[str, Any]]): 14 | 15 | # Reorganize dictionary in Nervana format 16 | pruners = {} 17 | policies = [] 18 | for name, kwargs in subpruners.items(): 19 | 20 | # Split kwargs into pruner kwargs and policy kwargs 21 | pruner_kwargs = {} 22 | policy_kwargs = {"pruner": {"instance_name": name}} 23 | for key, value in kwargs.items(): 24 | 25 | if key in {"starting_epoch", "ending_epoch", "epochs"}: 26 | raise ValueError( 27 | "Please provide arguments by step (e.g. `starting_step`, " 28 | " `ending_step`, `steps`) instead of by epoch (e.g. " 29 | "`starting_epoch`, `ending_epoch`, `epochs`)." 30 | ) 31 | 32 | # Search for policy kwargs 33 | if key == "starting_step": 34 | policy_kwargs["starting_epoch"] = value 35 | elif key == "ending_step": 36 | policy_kwargs["ending_epoch"] = value 37 | elif key == "steps": 38 | policy_kwargs["steps"] = value 39 | elif key == "frequency": 40 | policy_kwargs["frequency"] = value 41 | else: 42 | pruner_kwargs[key] = value 43 | 44 | pruners[name] = pruner_kwargs 45 | policies.append(policy_kwargs) 46 | 47 | self.compression_scheduler = dict_config( 48 | model, None, {"pruners": pruners, "policies": policies} 49 | ) 50 | 51 | # Verify that all weights marked for pruning exist in model 52 | model_param_names = set(n for n, _ in model.named_parameters()) 53 | policies_by_step = self.compression_scheduler.policies.items() 54 | for step, policy_lst in policies_by_step: 55 | for policy in policy_lst: 56 | for name in policy.pruner.params_names: 57 | if name not in model_param_names: 58 | raise ValueError( 59 | f"Weight `{name}` was marked for pruning at step {step}, but does not exist in model!" 60 | ) 61 | 62 | def begin_step(self, step: int): 63 | self.compression_scheduler.on_epoch_begin(step) 64 | 65 | def end_step(self, step: int): 66 | self.compression_scheduler.on_epoch_end(step) 67 | 68 | def begin_iter(self, step: int, n_iter: int, iter_per_step: int): 69 | self.compression_scheduler.on_minibatch_begin( 70 | step, minibatch_id=n_iter, minibatches_per_epoch=iter_per_step 71 | ) 72 | 73 | def end_iter(self, step: int, n_iter: int, iter_per_step: int): 74 | self.compression_scheduler.on_minibatch_end( 75 | step, minibatch_id=n_iter, minibatches_per_epoch=iter_per_step 76 | ) 77 | 78 | def get_step_logs(self): 79 | model = self.compression_scheduler.model 80 | t, total = distiller.weights_sparsity_tbl_summary( 81 | model, return_total_sparsity=True 82 | ) 83 | return {"sparsity": total / 100} 84 | -------------------------------------------------------------------------------- /flop/hardconcrete.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class HardConcrete(nn.Module): 9 | """A HarcConcrete module. 10 | 11 | Use this module to create a mask of size N, which you can 12 | then use to perform L0 regularization. Note that in general, 13 | we also provide utilities which introduce HardConrete modules 14 | in the desired places in your model. See ``utils`` for details. 15 | 16 | To obtain a mask, simply run a forward pass through the module 17 | with no input data. The mask is sampled in training mode, and 18 | fixed during evaluation mode: 19 | 20 | >>> module = HardConcrete(n_in=100) 21 | >>> mask = module() 22 | >>> norm = module.l0_norm() 23 | 24 | """ 25 | 26 | def __init__(self, 27 | n_in: int, 28 | init_mean: float = 0.5, 29 | init_std: float = 0.01, 30 | temperature: float = 1.0, 31 | stretch: float = 0.1, 32 | eps: float = 1e-6) -> None: 33 | """Initialize the HardConcrete module. 34 | 35 | Parameters 36 | ---------- 37 | n_in : int 38 | The number of hard concrete variables in this mask. 39 | init_mean : float, optional 40 | Initialization value for hard concrete parameter, 41 | by default 0.5., 42 | init_std: float, optional 43 | Used to initialize the hard concrete parameters, 44 | by default 0.01. 45 | temperature : float, optional 46 | Temperature used to control the sharpness of the 47 | distribution, by default 1.0 48 | stretch : float, optional 49 | Stretch the sampled value from [0, 1] to the interval 50 | [-stretch, 1 + stretch], by default 0.1. 51 | 52 | """ 53 | super().__init__() 54 | 55 | self.n_in = n_in 56 | self.limit_l = -stretch 57 | self.limit_r = 1.0 + stretch 58 | self.log_alpha = nn.Parameter(torch.zeros(n_in)) # type: ignore 59 | self.beta = temperature 60 | self.init_mean = init_mean 61 | self.init_std = init_std 62 | self.bias = -self.beta * math.log(-self.limit_l / self.limit_r) 63 | 64 | self.eps = eps 65 | self.compiled_mask = None 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | """Reset the parameters of this module.""" 70 | self.compiled_mask = None 71 | mean = math.log(1 - self.init_mean) - math.log(self.init_mean) 72 | self.log_alpha.data.normal_(mean, self.init_std) 73 | 74 | def l0_norm(self) -> torch.Tensor: 75 | """Compute the expected L0 norm of this mask. 76 | 77 | Returns 78 | ------- 79 | torch.Tensor 80 | The expected L0 norm. 81 | 82 | """ 83 | return (self.log_alpha + self.bias).sigmoid().sum() 84 | 85 | def forward(self) -> torch.Tensor: # type: ignore 86 | """Sample a harconcrete mask. 87 | 88 | Returns 89 | ------- 90 | torch.Tensor 91 | The sampled binary mask 92 | 93 | """ 94 | if self.training: 95 | # Reset the compiled mask 96 | self.compiled_mask = None 97 | # Sample mask dynamically 98 | u = self.log_alpha.new(self.n_in).uniform_(self.eps, 1 - self.eps) # type: ignore 99 | s = F.sigmoid((torch.log(u / (1 - u)) + self.log_alpha) / self.beta) 100 | s = s * (self.limit_r - self.limit_l) + self.limit_l 101 | mask = s.clamp(min=0., max=1.) 102 | 103 | else: 104 | # Compile new mask if not cached 105 | if self.compiled_mask is None: 106 | # Get expected sparsity 107 | expected_num_zeros = self.n_in - self.l0_norm().item() 108 | num_zeros = round(expected_num_zeros) 109 | # Approximate expected value of each mask variable z; 110 | # We use an empirically validated magic number 0.8 111 | soft_mask = F.sigmoid(self.log_alpha / self.beta * 0.8) 112 | # Prune small values to set to 0 113 | _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) 114 | soft_mask[indices] = 0. 115 | self.compiled_mask = soft_mask 116 | mask = self.compiled_mask 117 | 118 | return mask 119 | 120 | def extre_repr(self) -> str: 121 | return str(self.n_in) 122 | 123 | def __repr__(self) -> str: 124 | return "{}({})".format(self.__class__.__name__, self.extre_repr()) 125 | -------------------------------------------------------------------------------- /flop/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Tuple 2 | from copy import deepcopy 3 | 4 | import torch.nn as nn 5 | 6 | from flop.hardconcrete import HardConcrete 7 | from flop.linear import ( 8 | ProjectedLinear, 9 | HardConcreteProjectedLinear, 10 | HardConcreteLinear, 11 | ProjectedLinearWithMask, 12 | ) 13 | 14 | 15 | def make_projected_linear(module: nn.Module, in_place: bool = True) -> nn.Module: 16 | """Replace all nn.Linear with ProjectedLinear. 17 | 18 | Parameters 19 | ---------- 20 | module : nn.Module 21 | The input module to modify 22 | in_place : bool, optional 23 | Whether to modify in place, by default True 24 | 25 | Returns 26 | ------- 27 | nn.Module 28 | The updated module. 29 | 30 | """ 31 | # First find all nn.Linear modules 32 | modules = [] 33 | for name, child in module.named_children(): 34 | if isinstance(child, nn.Linear): 35 | modules.append((name, child)) 36 | else: 37 | make_projected_linear(child, in_place) 38 | 39 | # Replace all modules found 40 | new_module = module if in_place else deepcopy(module) 41 | for name, child in modules: 42 | new_child = ProjectedLinear.from_module(child) 43 | setattr(new_module, name, new_child) 44 | 45 | return new_module 46 | 47 | 48 | def make_hard_concrete( 49 | module: nn.Module, 50 | in_place: bool = True, 51 | init_mean: float = 0.5, 52 | init_std: float = 0.01, 53 | ) -> nn.Module: 54 | """Replace all ProjectedLinear with HardConcreteProjectedLinear. 55 | 56 | Parameters 57 | ---------- 58 | module : nn.Module 59 | The input module to modify 60 | in_place : bool, optional 61 | Whether to modify in place, by default True 62 | 63 | Returns 64 | ------- 65 | nn.Module 66 | The updated module. 67 | 68 | """ 69 | # First find all ProjectedLinear modules 70 | modules: List[Tuple[str, Union[ProjectedLinear, nn.Linear]]] = [] 71 | for name, child in module.named_children(): 72 | if isinstance(child, ProjectedLinear): 73 | modules.append((name, child)) 74 | elif isinstance(child, nn.Linear): 75 | modules.append((name, child)) 76 | else: 77 | make_hard_concrete(child, in_place, init_mean, init_std) 78 | 79 | # Replace all modules found 80 | new_module = module if in_place else deepcopy(module) 81 | for name, child in modules: 82 | if isinstance(child, ProjectedLinear): 83 | new_child = HardConcreteProjectedLinear.from_module( 84 | child, init_mean, init_std 85 | ) 86 | else: # must be nn.Linear 87 | new_child = HardConcreteLinear.from_module(child, init_mean, init_std) 88 | setattr(new_module, name, new_child) 89 | 90 | return new_module 91 | 92 | 93 | def make_projected_linear_with_mask( 94 | module: nn.Module, in_place: bool = True, init_zero: bool = False 95 | ) -> nn.Module: 96 | """Replace all ProjectedLinear with ProjectedLinearWithMask. 97 | 98 | Parameters 99 | ---------- 100 | module : nn.Module 101 | The input module to modify 102 | in_place : bool, optional 103 | Whether to modify in place, by default True 104 | 105 | Returns 106 | ------- 107 | nn.Module 108 | The updated module. 109 | 110 | """ 111 | # First find all ProjectedLinear modules 112 | modules = [] 113 | for name, child in module.named_children(): 114 | if isinstance(child, ProjectedLinear): 115 | modules.append((name, child)) 116 | else: 117 | make_projected_linear_with_mask(child, in_place, init_zero=init_zero) 118 | 119 | # Replace all modules found 120 | new_module = module if in_place else deepcopy(module) 121 | for name, child in modules: 122 | new_child = ProjectedLinearWithMask.from_module(child, init_zero=init_zero) 123 | setattr(new_module, name, new_child) 124 | 125 | return new_module 126 | 127 | 128 | def get_hardconcrete_linear_modules( 129 | module: nn.Module, 130 | ) -> List[Union[HardConcreteProjectedLinear, HardConcreteLinear]]: 131 | """Get all HardConcrete*Linear modules. 132 | 133 | Parameters 134 | ---------- 135 | module : nn.Module 136 | The input module 137 | 138 | Returns 139 | ------- 140 | List[nn.Module] 141 | A list of the HardConcrete*Linear module. 142 | 143 | """ 144 | modules: List[Union[HardConcreteProjectedLinear, HardConcreteLinear]] = [] 145 | for m in module.children(): 146 | if isinstance(m, HardConcreteProjectedLinear): 147 | modules.append(m) 148 | elif isinstance(m, HardConcreteLinear): 149 | modules.append(m) 150 | else: 151 | modules.extend(get_hardconcrete_linear_modules(m)) 152 | return modules 153 | 154 | 155 | def get_hardconcrete_proj_linear_modules( 156 | module: nn.Module, 157 | ) -> List[HardConcreteProjectedLinear]: 158 | """Get all HardConcreteProjectedLinear modules. 159 | 160 | Parameters 161 | ---------- 162 | module : nn.Module 163 | The input module 164 | 165 | Returns 166 | ------- 167 | List[HardConcreteProjectedLinear] 168 | A list of the HardConcreteProjectedLinear module. 169 | 170 | """ 171 | modules = [] 172 | for m in module.children(): 173 | if isinstance(m, HardConcreteProjectedLinear): 174 | modules.append(m) 175 | else: 176 | modules.extend(get_hardconcrete_proj_linear_modules(m)) 177 | return modules 178 | 179 | 180 | def get_hardconcrete_modules(module: nn.Module) -> List[HardConcrete]: 181 | """Get all HardConcrete modules. 182 | 183 | Parameters 184 | ---------- 185 | module : nn.Module 186 | The input module 187 | 188 | Returns 189 | ------- 190 | List[HardConcrete] 191 | A list of the HardConcrete module. 192 | 193 | """ 194 | modules = [] 195 | for m in module.children(): 196 | if isinstance(m, HardConcrete): 197 | modules.append(m) 198 | else: 199 | modules.extend(get_hardconcrete_modules(m)) 200 | return modules 201 | 202 | 203 | def get_projected_linear_with_mask_modules( 204 | module: nn.Module, 205 | ) -> List[ProjectedLinearWithMask]: 206 | """Get all ProjectedLinearWithMask modules. 207 | 208 | Parameters 209 | ---------- 210 | module : nn.Module 211 | The input module 212 | 213 | Returns 214 | ------- 215 | List[HardConcreteProjectedLinear] 216 | A list of the ProjectedLinearWithMask module. 217 | 218 | """ 219 | modules = [] 220 | for m in module.children(): 221 | if isinstance(m, ProjectedLinearWithMask): 222 | modules.append(m) 223 | else: 224 | modules.extend(get_projected_linear_with_mask_modules(m)) 225 | return modules 226 | 227 | 228 | def get_projected_linear_masks(module: nn.Module) -> List[nn.Parameter]: 229 | """Get all masks from ProjectedLinearWithMask modules. 230 | 231 | Parameters 232 | ---------- 233 | module : nn.Module 234 | The input module 235 | 236 | Returns 237 | ------- 238 | List[HardConcrete] 239 | A list of the masks. 240 | 241 | """ 242 | modules = [] 243 | for m in module.children(): 244 | if isinstance(m, ProjectedLinearWithMask): 245 | modules.append(m.mask) 246 | else: 247 | modules.extend(get_projected_linear_masks(m)) 248 | return modules 249 | 250 | 251 | def get_num_prunable_params(modules) -> int: 252 | return sum([module.num_prunable_parameters() for module in modules]) 253 | 254 | 255 | def get_num_params(modules, train=True) -> int: 256 | return sum([module.num_parameters(train) for module in modules]) 257 | -------------------------------------------------------------------------------- /flop/version.py: -------------------------------------------------------------------------------- 1 | MAJOR = "1" 2 | MINOR = "0" 3 | PATCH = "2" 4 | 5 | VERSION = f'{MAJOR}.{MINOR}.{PATCH}' 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sru>=2.3.3 2 | GitPython==3.1.0 3 | scikit-learn==0.21.2 4 | setuptools>=41.0.0 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | setup.py 3 | """ 4 | 5 | from setuptools import setup, find_packages 6 | from typing import Dict 7 | import os 8 | 9 | 10 | NAME = "flop" 11 | AUTHOR = "ASAPP Inc." 12 | EMAIL = "jeremy@asapp.com" 13 | DESCRIPTION = "Pytorch based library for L0 based pruning." 14 | 15 | 16 | def readme(): 17 | with open("README.md", encoding="utf-8") as f: 18 | return f.read() 19 | 20 | 21 | def required(): 22 | with open("requirements.txt") as f: 23 | return f.read().splitlines() 24 | 25 | 26 | VERSION: Dict[str, str] = {} 27 | with open("flop/version.py", "r") as version_file: 28 | exec(version_file.read(), VERSION) 29 | 30 | 31 | setup( 32 | name=NAME, 33 | version=os.environ.get("TAG_VERSION", VERSION["VERSION"]), 34 | description=DESCRIPTION, 35 | # Author information 36 | author=AUTHOR, 37 | author_email=EMAIL, 38 | # What is packaged here. 39 | packages=find_packages(), 40 | install_requires=required(), 41 | dependency_links=[ 42 | "git+git://github.com/asappresearch/sru@custom-submodules#egg=sru", 43 | ], 44 | include_package_data=True, 45 | python_requires=">=3.6.1", 46 | zip_safe=True, 47 | ) 48 | --------------------------------------------------------------------------------