├── .gitignore ├── 001_sharing.py ├── 002_xor_backprop.py ├── 003_misc_tpu.py ├── 004_swarm.py ├── 005_visualization.py ├── 117M.json ├── 1558M.json ├── 774M.json ├── BigGAN.py ├── braces.py ├── chess.json ├── chess345m.json ├── configs └── biggan_run01.gin ├── el.py ├── input_fns.py ├── losses.py ├── main_biggan.py ├── main_gpt2.py ├── memory_saving_gradients.py ├── metric_fns.py ├── mnist_classifier.py ├── model_fns.py ├── models └── gpt2 │ ├── __init__.py │ ├── encoder.json │ ├── encoder.py │ ├── gpt2.py │ ├── gpt2_rev.py │ ├── sample.py │ └── vocab.bpe ├── optimizers.py ├── run.sh ├── runs ├── train-astra.sh ├── train-biggan.sh ├── train-checkpointing-117m.sh ├── train-chess-345m.sh ├── train-chess.sh ├── train-novels-1558m-eu.sh ├── train-novels-1558m.sh ├── train-rev-117m.sh ├── train-run0-117m-tensorflow.sh └── train-vanilla-117m.sh ├── tf_timeline.py ├── tf_tools.py ├── tfjpg_parser.py ├── tflex.py ├── tflex_tpu_device_assignment.py ├── tflex_tpu_topology.py ├── tftorch.py ├── tpu_normalization.py ├── tputil.py ├── train_biggan.py ├── train_flags.py ├── train_runner.py ├── utils.py └── wrapper.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | /bug 3 | -------------------------------------------------------------------------------- /001_sharing.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | 5 | import tensorflow as tf 6 | 7 | from tensorflow.core.protobuf import config_pb2 8 | from tensorflow.core.protobuf import tensorflow_server_pb2 9 | from tensorflow.python.client import session 10 | from tensorflow.python.framework import constant_op 11 | from tensorflow.python.framework import dtypes 12 | from tensorflow.python.framework import errors_impl 13 | from tensorflow.python.framework import ops 14 | from tensorflow.python.framework import test_util 15 | from tensorflow.python.ops import array_ops 16 | from tensorflow.python.ops import data_flow_ops 17 | from tensorflow.python.ops import math_ops 18 | from tensorflow.python.ops import variables 19 | from tensorflow.python.platform import test 20 | from tensorflow.python.training import input as input_ops 21 | from tensorflow.python.training import queue_runner_impl 22 | from tensorflow.python.training import server_lib 23 | 24 | import train_runner 25 | from train_flags import FLAGS 26 | 27 | from pprint import pprint as pp 28 | 29 | from model_fns import gpt2_model 30 | from input_fns import gpt2_input 31 | 32 | import json 33 | 34 | class GrpcServerTest(test.TestCase): 35 | 36 | def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 37 | super(GrpcServerTest, self).__init__(methodName) 38 | self._cached_server = server_lib.Server.create_local_server() 39 | 40 | def testRunStep(self): 41 | server = self._cached_server 42 | 43 | with session.Session(server.target) as sess: 44 | c = constant_op.constant([[2, 1]]) 45 | d = constant_op.constant([[1], [2]]) 46 | e = math_ops.matmul(c, d) 47 | self.assertAllEqual([[4]], sess.run(e)) 48 | # TODO(mrry): Add `server.stop()` and `server.join()` when these work. 49 | 50 | @test_util.run_v1_only("b/120545219") 51 | def testMultipleSessions(self): 52 | server = self._cached_server 53 | 54 | c = constant_op.constant([[2, 1]]) 55 | d = constant_op.constant([[1], [2]]) 56 | e = math_ops.matmul(c, d) 57 | 58 | sess_1 = session.Session(server.target) 59 | sess_2 = session.Session(server.target) 60 | 61 | self.assertAllEqual([[4]], sess_1.run(e)) 62 | self.assertAllEqual([[4]], sess_2.run(e)) 63 | 64 | sess_1.close() 65 | sess_2.close() 66 | # TODO(mrry): Add `server.stop()` and `server.join()` when these work. 67 | 68 | @test_util.run_v1_only("b/120545219") 69 | def testIsolateSessionState(self): 70 | server = self._cached_server 71 | 72 | init_value = array_ops.placeholder(dtypes.int32) 73 | v = variables.VariableV1(init_value, validate_shape=False, name="v") 74 | 75 | sharing_config = config_pb2.ConfigProto(isolate_session_state=False) 76 | sharing_sess_0 = session.Session(server.target, config=sharing_config) 77 | sharing_sess_1 = session.Session(server.target, config=sharing_config) 78 | 79 | isolate_config = config_pb2.ConfigProto(isolate_session_state=True) 80 | isolate_sess_0 = session.Session(server.target, config=isolate_config) 81 | isolate_sess_1 = session.Session(server.target, config=isolate_config) 82 | 83 | # Initially all variables are initialized. 84 | for sess in [sharing_sess_0, sharing_sess_1, 85 | isolate_sess_0, isolate_sess_1]: 86 | with self.assertRaises(errors_impl.FailedPreconditionError): 87 | sess.run(v) 88 | 89 | # Shared sessions will see each other's updates, but isolated sessions 90 | # will not. 91 | sharing_sess_0.run(v.initializer, feed_dict={init_value: 86}) 92 | self.assertAllEqual(86, sharing_sess_0.run(v)) 93 | self.assertAllEqual(86, sharing_sess_1.run(v)) 94 | with self.assertRaises(errors_impl.FailedPreconditionError): 95 | isolate_sess_0.run(v) 96 | with self.assertRaises(errors_impl.FailedPreconditionError): 97 | isolate_sess_1.run(v) 98 | 99 | # Changing the shape works because `validate_shape` is False. 100 | sharing_sess_1.run(v.initializer, feed_dict={init_value: [86, 99]}) 101 | self.assertAllEqual([86, 99], sharing_sess_0.run(v)) 102 | self.assertAllEqual([86, 99], sharing_sess_1.run(v)) 103 | with self.assertRaises(errors_impl.FailedPreconditionError): 104 | isolate_sess_0.run(v) 105 | with self.assertRaises(errors_impl.FailedPreconditionError): 106 | isolate_sess_1.run(v) 107 | 108 | # Initializing in an isolated session will only affect the state in that 109 | # session. 110 | isolate_sess_0.run(v.initializer, feed_dict={init_value: 37}) 111 | self.assertAllEqual([86, 99], sharing_sess_0.run(v)) 112 | self.assertAllEqual([86, 99], sharing_sess_1.run(v)) 113 | self.assertAllEqual(37, isolate_sess_0.run(v)) 114 | with self.assertRaises(errors_impl.FailedPreconditionError): 115 | isolate_sess_1.run(v) 116 | 117 | # Isolated sessions can have different shapes for the same variable. 118 | isolate_sess_1.run(v.initializer, feed_dict={init_value: [19, 86]}) 119 | self.assertAllEqual([86, 99], sharing_sess_0.run(v)) 120 | self.assertAllEqual([86, 99], sharing_sess_1.run(v)) 121 | self.assertAllEqual(37, isolate_sess_0.run(v)) 122 | self.assertAllEqual([19, 86], isolate_sess_1.run(v)) 123 | 124 | @test_util.run_v1_only("b/120545219") 125 | def testTrainRunner(self): 126 | #FLAGS.iterations_per_loop = 100 127 | #params = {'batch_size': FLAGS.train_batch_size} 128 | #params = {'batch_size': 128, 'use_tpu': True, 'precision': 'float32'} 129 | with open(FLAGS.params) as f: 130 | params = json.load(f) 131 | params['use_tpu'] = True 132 | batch_size_per_core = params['batch_size_per_core'] if 'batch_size_per_core' in params else 1 133 | FLAGS.train_batch_size = FLAGS.num_cores * batch_size_per_core 134 | FLAGS.iterations_per_loop = 20 if 'iterations' not in params else params['iterations'] 135 | FLAGS.train_steps = 2000 136 | params['batch_size'] = FLAGS.train_batch_size 137 | if 'precision' not in params: 138 | params['precision'] = 'float32' 139 | pp(params) 140 | trunner = train_runner.TrainRunner( 141 | iterations=FLAGS.iterations_per_loop, train_steps=FLAGS.train_steps) 142 | def input_fn(params): 143 | tokens = [[_ for _ in range(0, 1024)]] * params['batch_size'] 144 | labels = [[_ for _ in range(1, 1025)]] * params['batch_size'] 145 | t = tf.broadcast_to(tokens, [len(tokens), len(tokens[0])]) 146 | l = tf.broadcast_to(labels, [len(labels), len(labels[0])]) 147 | #dset1 = tf.data.Dataset.from_tensor_slices(t); 148 | #dset2 = tf.data.Dataset.from_tensor_slices(l); 149 | dset1 = tf.data.Dataset.from_tensors(t); 150 | dset2 = tf.data.Dataset.from_tensors(l); 151 | dset = tf.data.Dataset.zip((dset1, dset2)) 152 | dset = dset.repeat() 153 | return dset 154 | def create_train_op(loss, params): 155 | return tf.identity(loss) 156 | def model_fn(features, labels, mode, params): 157 | pp(['features', features]) 158 | pp(['labels', labels]) 159 | pp(['mode', mode]) 160 | pp(['params', params]) 161 | loss = tf.constant(0.0) 162 | if mode == tf.estimator.ModeKeys.TRAIN: 163 | train_op = create_train_op(loss, params) 164 | if params["use_tpu"]: 165 | return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op) 166 | else: 167 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 168 | trunner.initialize(gpt2_input, gpt2_model, params) 169 | pp(params) 170 | tf.logging.info('trunner.initialize(): Done. Training...') 171 | trunner.train() 172 | tf.logging.info('trunner.train(): Done. Shutting down...') 173 | trunner.shutdown() 174 | tf.logging.info('trunner.shutdown(): Done.') 175 | 176 | if __name__ == "__main__": 177 | test.main() 178 | -------------------------------------------------------------------------------- /002_xor_backprop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #np.random.seed(0) 3 | 4 | def sigmoid (x): 5 | return 1/(1 + np.exp(-x)) 6 | 7 | def sigmoid_derivative(x): 8 | return x * (1 - x) 9 | 10 | #Input datasets 11 | inputs = np.array([[0,0],[0,1],[1,0],[1,1]]) 12 | expected_output = np.array([[0],[1],[1],[0]]) 13 | 14 | epochs = 10000 15 | lr = 0.1 16 | inputLayerNeurons, hiddenLayerNeurons, outputLayerNeurons = 2,2,1 17 | 18 | #Random weights and bias initialization 19 | hidden_weights = np.random.uniform(size=(inputLayerNeurons,hiddenLayerNeurons)) 20 | hidden_bias =np.random.uniform(size=(1,hiddenLayerNeurons)) 21 | output_weights = np.random.uniform(size=(hiddenLayerNeurons,outputLayerNeurons)) 22 | output_bias = np.random.uniform(size=(1,outputLayerNeurons)) 23 | 24 | print("Initial hidden weights: ",end='') 25 | print(*hidden_weights) 26 | print("Initial hidden biases: ",end='') 27 | print(*hidden_bias) 28 | print("Initial output weights: ",end='') 29 | print(*output_weights) 30 | print("Initial output biases: ",end='') 31 | print(*output_bias) 32 | 33 | 34 | #Training algorithm 35 | for _ in range(epochs): 36 | #Forward Propagation 37 | hidden_layer_activation = np.dot(inputs,hidden_weights) 38 | hidden_layer_activation += hidden_bias 39 | hidden_layer_output = sigmoid(hidden_layer_activation) 40 | 41 | output_layer_activation = np.dot(hidden_layer_output,output_weights) 42 | output_layer_activation += output_bias 43 | predicted_output = sigmoid(output_layer_activation) 44 | 45 | #Backpropagation 46 | error = expected_output - predicted_output 47 | d_predicted_output = error * sigmoid_derivative(predicted_output) 48 | 49 | error_hidden_layer = d_predicted_output.dot(output_weights.T) 50 | d_hidden_layer = error_hidden_layer * sigmoid_derivative(hidden_layer_output) 51 | 52 | #Updating Weights and Biases 53 | output_weights += hidden_layer_output.T.dot(d_predicted_output) * lr 54 | output_bias += np.sum(d_predicted_output,axis=0,keepdims=True) * lr 55 | hidden_weights += inputs.T.dot(d_hidden_layer) * lr 56 | hidden_bias += np.sum(d_hidden_layer,axis=0,keepdims=True) * lr 57 | 58 | print("Final hidden weights: ",end='') 59 | print(*hidden_weights) 60 | print("Final hidden bias: ",end='') 61 | print(*hidden_bias) 62 | print("Final output weights: ",end='') 63 | print(*output_weights) 64 | print("Final output bias: ",end='') 65 | print(*output_bias) 66 | 67 | print("\nOutput from neural network after 10,000 epochs: ",end='') 68 | print(*predicted_output) 69 | -------------------------------------------------------------------------------- /117M.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_vocab": 50257, 3 | "n_ctx": 1024, 4 | "n_head": 12, 5 | "n_layer": 12, 6 | "n_embd": 768, 7 | "iterations": 20, 8 | "batch_per_core": 1, 9 | "precision": "float32", 10 | "encoder_path": "encoder", 11 | "embed_dropout": 0.0, 12 | "lr": 0.00025, 13 | "warmup_steps": 0, 14 | "beta1": 0.9, 15 | "beta2": 0.999, 16 | "epsilon": 1e-9, 17 | "opt_name": "adam", 18 | "attn_dropout": 0.0, 19 | "train_steps": -1, 20 | "eval_steps": 10, 21 | "max_steps": 5000000, 22 | "res_dropout": 0.0, 23 | "predict_batch_size": 1, 24 | "eval_batch_size": 32, 25 | "input": "my_input", 26 | "model": "GPT2", 27 | "predict_path": "logs/predictions.txt", 28 | "scale_by_depth": false, 29 | "scale_by_in": false 30 | } 31 | -------------------------------------------------------------------------------- /1558M.json: -------------------------------------------------------------------------------- 1 | { 2 | "iterations": 20, 3 | "batch_per_core": 1, 4 | "n_head": 25, 5 | "n_vocab": 50257, 6 | "embed_dropout": 0.0, 7 | "lr": 0.00025, 8 | "warmup_steps": 0, 9 | "beta1": 0.0, 10 | "beta2": 0.999, 11 | "epsilon": 1e-9, 12 | "decay_type": "none", 13 | "decay_exponent": 0.8, 14 | "opt_name": "adafactor", 15 | "attn_dropout": 0.0, 16 | "train_steps": -1, 17 | "eval_steps": 10, 18 | "max_steps": 300000, 19 | "res_dropout": 0.0, 20 | "predict_batch_size": 1, 21 | "eval_batch_size": 8, 22 | "n_embd": 1600, 23 | "n_ctx": 1024, 24 | "n_layer": 48, 25 | "precision": "bfloat16", 26 | "scale_by_depth": false, 27 | "scale_by_in": false 28 | } 29 | -------------------------------------------------------------------------------- /774M.json: -------------------------------------------------------------------------------- 1 | { 2 | "iterations": 20, 3 | "batch_per_core": 1, 4 | "n_head": 20, 5 | "n_vocab": 50257, 6 | "embed_dropout": 0.0, 7 | "lr": 0.00025, 8 | "warmup_steps": 0, 9 | "beta1": 0.0, 10 | "beta2": 0.999, 11 | "epsilon": 1e-9, 12 | "decay_type": "none", 13 | "decay_exponent": 0.8, 14 | "opt_name": "adafactor", 15 | "train_batch_size": 512, 16 | "attn_dropout": 0.0, 17 | "train_steps": -1, 18 | "eval_steps": 10, 19 | "max_steps": 300000, 20 | "res_dropout": 0.0, 21 | "predict_batch_size": 1, 22 | "eval_batch_size": 8, 23 | "n_embd": 1280, 24 | "n_ctx": 1024, 25 | "n_layer": 36, 26 | "precision": "bfloat16", 27 | "scale_by_depth": true, 28 | "scale_by_in": true 29 | } -------------------------------------------------------------------------------- /braces.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 2015 Stanis Trendelenburg 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | """Bash-style brace expansion""" 23 | import re 24 | import string 25 | import sys 26 | from itertools import chain, product 27 | 28 | __version__ = '0.1.5' 29 | 30 | __all__ = ['braceexpand', 'alphabet', 'UnbalancedBracesError'] 31 | 32 | class UnbalancedBracesError(ValueError): pass 33 | 34 | PY3 = sys.version_info[0] >= 3 35 | 36 | if PY3: 37 | xrange = range 38 | 39 | alphabet = string.ascii_uppercase + string.ascii_lowercase 40 | 41 | int_range_re = re.compile(r'^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$') 42 | char_range_re = re.compile(r'^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$') 43 | 44 | 45 | def braceexpand(pattern, escape=True): 46 | """braceexpand(pattern) -> iterator over generated strings 47 | 48 | Returns an iterator over the strings resulting from brace expansion 49 | of pattern. This function implements Brace Expansion as described in 50 | bash(1), with the following limitations: 51 | 52 | * A pattern containing unbalanced braces will raise an 53 | UnbalancedBracesError exception. In bash, unbalanced braces will either 54 | be partly expanded or ignored. 55 | 56 | * A mixed-case character range like '{Z..a}' or '{a..Z}' will not 57 | include the characters '[]^_`' between 'Z' and 'a'. 58 | 59 | When escape is True (the default), characters in pattern can be 60 | prefixed with a backslash to cause them not to be interpreted as 61 | special characters for brace expansion (such as '{', '}', ','). 62 | To pass through a a literal backslash, double it ('\\\\'). 63 | 64 | When escape is False, backslashes in pattern have no special 65 | meaning and will be preserved in the output. 66 | 67 | Examples: 68 | 69 | >>> from braceexpand import braceexpand 70 | 71 | # Integer range 72 | >>> list(braceexpand('item{1..3}')) 73 | ['item1', 'item2', 'item3'] 74 | 75 | # Character range 76 | >>> list(braceexpand('{a..c}')) 77 | ['a', 'b', 'c'] 78 | 79 | # Sequence 80 | >>> list(braceexpand('index.html{,.backup}')) 81 | ['index.html', 'index.html.backup'] 82 | 83 | # Nested patterns 84 | >>> list(braceexpand('python{2.{5..7},3.{2,3}}')) 85 | ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3'] 86 | 87 | # Prefixing an integer with zero causes all numbers to be padded to 88 | # the same width. 89 | >>> list(braceexpand('{07..10}')) 90 | ['07', '08', '09', '10'] 91 | 92 | # An optional increment can be specified for ranges. 93 | >>> list(braceexpand('{a..g..2}')) 94 | ['a', 'c', 'e', 'g'] 95 | 96 | # Ranges can go in both directions. 97 | >>> list(braceexpand('{4..1}')) 98 | ['4', '3', '2', '1'] 99 | 100 | # Numbers can be negative 101 | >>> list(braceexpand('{2..-1}')) 102 | ['2', '1', '0', '-1'] 103 | 104 | # Unbalanced braces raise an exception. 105 | >>> list(braceexpand('{1{2,3}')) 106 | Traceback (most recent call last): 107 | ... 108 | UnbalancedBracesError: Unbalanced braces: '{1{2,3}' 109 | 110 | # By default, the backslash is the escape character. 111 | >>> list(braceexpand(r'{1\\{2,3}')) 112 | ['1{2', '3'] 113 | 114 | # Setting 'escape' to False disables backslash escaping. 115 | >>> list(braceexpand(r'\\{1,2}', escape=False)) 116 | ['\\\\1', '\\\\2'] 117 | 118 | """ 119 | return list(_flatten(t, escape) for t in parse_pattern(pattern, escape)) 120 | 121 | 122 | def parse_pattern(pattern, escape): 123 | # pattern -> product(*parts) 124 | start = 0 125 | pos = 0 126 | bracketdepth = 0 127 | items = [] 128 | 129 | #print 'pattern:', pattern 130 | while pos < len(pattern): 131 | if escape and pattern[pos] == '\\': 132 | pos += 2 133 | continue 134 | elif pattern[pos] == '{': 135 | if bracketdepth == 0 and pos > start: 136 | #print 'literal:', pattern[start:pos] 137 | items.append([pattern[start:pos]]) 138 | start = pos 139 | bracketdepth += 1 140 | elif pattern[pos] == '}': 141 | bracketdepth -= 1 142 | if bracketdepth == 0: 143 | #print 'expression:', pattern[start+1:pos] 144 | expr = pattern[start+1:pos] 145 | item = parse_expression(expr, escape) 146 | if item is None: # not a range or sequence 147 | items.extend([['{'], parse_pattern(expr, escape), ['}']]) 148 | else: 149 | items.append(item) 150 | start = pos + 1 # skip the closing brace 151 | pos += 1 152 | 153 | if bracketdepth != 0: # unbalanced braces 154 | raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern) 155 | 156 | if start < pos: 157 | #print 'literal:', pattern[start:] 158 | items.append([pattern[start:]]) 159 | 160 | return product(*items) 161 | 162 | 163 | def parse_expression(expr, escape): 164 | int_range_match = int_range_re.match(expr) 165 | if int_range_match: 166 | return make_int_range(*int_range_match.groups()) 167 | 168 | char_range_match = char_range_re.match(expr) 169 | if char_range_match: 170 | return make_char_range(*char_range_match.groups()) 171 | 172 | return parse_sequence(expr, escape) 173 | 174 | 175 | def parse_sequence(seq, escape): 176 | # sequence -> chain(*sequence_items) 177 | start = 0 178 | pos = 0 179 | bracketdepth = 0 180 | items = [] 181 | 182 | #print 'sequence:', seq 183 | while pos < len(seq): 184 | if escape and seq[pos] == '\\': 185 | pos += 2 186 | continue 187 | elif seq[pos] == '{': 188 | bracketdepth += 1 189 | elif seq[pos] == '}': 190 | bracketdepth -= 1 191 | elif seq[pos] == ',' and bracketdepth == 0: 192 | items.append(parse_pattern(seq[start:pos], escape)) 193 | start = pos + 1 # skip the comma 194 | pos += 1 195 | 196 | if bracketdepth != 0 or not items: # unbalanced braces or not a sequence 197 | return None 198 | 199 | # part after the last comma (may be the empty string) 200 | items.append(parse_pattern(seq[start:], escape)) 201 | return chain(*items) 202 | 203 | 204 | def make_int_range(start, end, step=None): 205 | if any([s.startswith(('0', '-0')) 206 | for s in (start, end) if s not in ('0', '-0')]): 207 | padding = max(len(start), len(end)) 208 | else: 209 | padding = 0 210 | step = (int(step) or 1) if step else 1 211 | start = int(start) 212 | end = int(end) 213 | r = xrange(start, end+1, step) if start < end else \ 214 | xrange(start, end-1, -step) 215 | fmt = '%0{}d'.format(padding) 216 | return (fmt % i for i in r) 217 | 218 | 219 | def make_char_range(start, end, step=None): 220 | step = int(step) if step else 1 221 | start = alphabet.index(start) 222 | end = alphabet.index(end) 223 | return alphabet[start:end+1:step] if start < end else \ 224 | alphabet[start:end-1:-step] 225 | 226 | 227 | escape_re = re.compile(r'\\(.)') 228 | 229 | def _flatten(t, escape): 230 | l = [] 231 | for item in t: 232 | if isinstance(item, tuple): l.extend(_flatten(item, escape)) 233 | else: l.append(item) 234 | s = ''.join(l) 235 | # Strip escape characters from generated strings after expansion. 236 | return escape_re.sub(r'\1', s) if escape else s 237 | 238 | 239 | if __name__ == '__main__': 240 | import doctest 241 | import sys 242 | failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL) 243 | if failed: 244 | sys.exit(1) 245 | -------------------------------------------------------------------------------- /chess.json: -------------------------------------------------------------------------------- 1 | { 2 | "iterations": 50, 3 | "batch_per_core": 1, 4 | "precision": "float32", 5 | "n_head": 12, 6 | "encoder_path": "encoder", 7 | "n_vocab": 100, 8 | "embed_dropout": 0.0, 9 | "lr": 0.00025, 10 | "warmup_steps": 0, 11 | "beta1": 0.9, 12 | "beta2": 0.999, 13 | "epsilon": 1e-9, 14 | "opt_name": "adam", 15 | "train_batch_size": 1024, 16 | "attn_dropout": 0.0, 17 | "train_steps": -1, 18 | "eval_steps": 10, 19 | "max_steps": 500000, 20 | "data_path": "gs://gpt-2-poetry/data/kingbase2019/epd/tfrecords/", 21 | "res_dropout": 0.0, 22 | "predict_batch_size": 1, 23 | "eval_batch_size": 32, 24 | "n_embd": 768, 25 | "input": "my_input", 26 | "model": "GPT2", 27 | "model_path": "gs://gpt-2-poetry/models/GPT2-117M-test6-newinit", 28 | "n_ctx": 1024, 29 | "predict_path": "logs/predictions.txt", 30 | "n_layer": 12, 31 | "scale_by_depth": true, 32 | "scale_by_in": true 33 | } 34 | -------------------------------------------------------------------------------- /chess345m.json: -------------------------------------------------------------------------------- 1 | { 2 | "iterations": 40, 3 | "batch_per_core": 1, 4 | "precision": "float32", 5 | "encoder_path": "encoder", 6 | "n_vocab": 100, 7 | "embed_dropout": 0.0, 8 | "lr": 0.00025, 9 | "warmup_steps": 0, 10 | "beta1": 0.9, 11 | "beta2": 0.999, 12 | "epsilon": 1e-9, 13 | "opt_name": "adam", 14 | "train_batch_size": 1024, 15 | "attn_dropout": 0.0, 16 | "train_steps": -1, 17 | "eval_steps": 10, 18 | "max_steps": 500000, 19 | "data_path": "gs://gpt-2-poetry/data/kingbase2019/epd/tfrecords/", 20 | "res_dropout": 0.0, 21 | "predict_batch_size": 1, 22 | "eval_batch_size": 32, 23 | "input": "my_input", 24 | "model": "GPT2", 25 | "model_path": "gs://gpt-2-poetry/models/GPT2-117M-test6-newinit", 26 | "predict_path": "logs/predictions.txt", 27 | "n_ctx": 1024, 28 | "n_embd": 1024, 29 | "n_head": 16, 30 | "n_layer": 24, 31 | "scale_by_depth": true, 32 | "scale_by_in": true 33 | } 34 | -------------------------------------------------------------------------------- /configs/biggan_run01.gin: -------------------------------------------------------------------------------- 1 | # BigGAN architecture and settings on ImageNet 128. 2 | # http://arxiv.org/abs/1809.11096 3 | 4 | # This should be similar to row 7 in Table 1. 5 | # It does not include orthogonal regularization (which would be row 8) and uses 6 | # a different learning rate. 7 | 8 | # Recommended training platform: TPU v3-128. 9 | 10 | # dataset.name = "imagenet_128" 11 | # options.batch_per_core = 1 12 | # options.training_steps = -1 13 | # options.iterations = 1 14 | BigGAN256.use_ema = False 15 | GAN.gan = @BigGAN256 16 | options.resolution = 256 17 | 18 | # options.architecture = "resnet_biggan_arch" 19 | # ModularGAN.conditional = True 20 | # options.batch_size = 2048 21 | # options.gan_class = @ModularGAN 22 | # options.lamba = 1 23 | # options.training_steps = 250000 24 | # weights.initializer = "orthogonal" 25 | # spectral_norm.singular_value = "auto" 26 | options.batch_per_core = 1 27 | #options.dataset = "gs://tpu-usc1/datasets/danbooru2019-s/danbooru2019-s-0*" 28 | #options.dataset = "gs://tpu-usc1/datasets/imagenet/train-0*" 29 | options.dataset = "gs://tpu-usc1/datasets/imagenet/validation-0*" 30 | options.model_dir = "gs://tpu-usc1/runs/biggan_run01/b" 31 | options.no_save = True 32 | 33 | py/macro.value = '%py' 34 | 35 | options.body = [%py, """ 36 | print('Hello, world!') 37 | """] 38 | 39 | # # Generator 40 | # G.batch_norm_fn = @conditional_batch_norm 41 | # G.spectral_norm = True 42 | # ModularGAN.g_use_ema = True 43 | # resnet_biggan.Generator.hierarchical_z = True 44 | # resnet_biggan.Generator.embed_y = True 45 | # standardize_batch.decay = 0.9 46 | # standardize_batch.epsilon = 1e-5 47 | # standardize_batch.use_moving_averages = False 48 | 49 | # # Discriminator 50 | # options.disc_iters = 2 51 | # D.spectral_norm = True 52 | # resnet_biggan.Discriminator.project_y = True 53 | 54 | # # Loss and optimizer 55 | # loss.fn = @hinge 56 | # penalty.fn = @no_penalty 57 | # ModularGAN.g_lr = 0.0001 58 | # ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer 59 | # ModularGAN.d_lr = 0.0005 60 | # ModularGAN.d_optimizer_fn = @tf.train.AdamOptimizer 61 | tf.train.AdamOptimizer.beta1 = 0.0 62 | tf.train.AdamOptimizer.beta2 = 0.999 63 | 64 | # z.distribution_fn = @tf.random.normal 65 | # eval_z.distribution_fn = @tf.random.normal 66 | 67 | #run_config.iterations_per_loop = 500 68 | #run_config.save_checkpoints_steps = 2500 69 | 70 | run_config.iterations_per_loop = 1 71 | run_config.save_checkpoints_steps = 1000 72 | -------------------------------------------------------------------------------- /el.py: -------------------------------------------------------------------------------- 1 | #G = globals().get('G', globals()) 2 | G = globals() 3 | 4 | def stringp(x): 5 | return isinstance(x, str) 6 | 7 | def keywordp(x): 8 | return stringp(x) and len(x) > 1 and x[0] == ':' 9 | 10 | def symbolp(x): 11 | if keywordp(x): 12 | return True 13 | return stringp(x) and len(x) > 2 and x[0] == "|" and x[-1] == "|" 14 | 15 | def inner(x): 16 | return x[1:-1] 17 | 18 | def symbol_name(x): 19 | assert symbolp(x), "Expected a symbol" 20 | return inner(x) 21 | 22 | def symbol_id(x): 23 | assert symbolp(x), "Expected a symbol" 24 | return inner(x).replace('-', '_') 25 | 26 | def symbol_value(x): 27 | assert symbolp(x), "Expected a symbol" 28 | try: 29 | return G[x] 30 | except KeyError: 31 | if keywordp(x): 32 | return x 33 | raise 34 | 35 | def symbol_function(x): 36 | assert symbolp(x), "Expected a symbol" 37 | return G.get(symbol_id(x)) 38 | 39 | def symbol_plist_id(x): 40 | assert symbolp(x), "Expected a symbol" 41 | return "{" + symbol_name(x) + "}" 42 | 43 | def symbol_plist(x): 44 | assert symbolp(x), "Expected a symbol" 45 | return G.get(symbol_plist_id(x)) 46 | 47 | def plist_get(plist, property): 48 | if plist is None: 49 | return None 50 | n = len(plist) 51 | for i in range(0, n, 2): 52 | if plist[i] == property: 53 | try: 54 | return plist[i+1] 55 | except IndexError: 56 | return None 57 | 58 | def plist_put(plist, property, value): 59 | if plist is None: 60 | plist = [] 61 | n = len(plist) 62 | for i in range(0, n, 2): 63 | if plist[i] == property: 64 | if i+1 >= n: 65 | plist.append(None) 66 | plist[i+1] = value 67 | return plist 68 | plist.append(property) 69 | plist.append(value) 70 | return plist 71 | 72 | def lax_plist_get(plist, property): 73 | n = len(plist) 74 | for i in range(0, n, 2): 75 | if equal(plist[i], property): 76 | try: 77 | return plist[i+1] 78 | except IndexError: 79 | return None 80 | 81 | def lax_plist_put(plist, property, value): 82 | if plist is None: 83 | plist = [] 84 | n = len(plist) 85 | for i in range(0, n, 2): 86 | if equal(plist[i], property): 87 | if i+1 >= n: 88 | plist.append(None) 89 | plist[i+1] = value 90 | return plist 91 | plist.append(property) 92 | plist.append(value) 93 | return plist 94 | 95 | def setplist(symbol, plist): 96 | assert symbolp(symbol), "Expected a symbol" 97 | G[symbol_plist_id(symbol)] = plist 98 | 99 | def get(symbol, property): 100 | assert symbolp(symbol), "Expected a symbol" 101 | pl = symbol_plist(symbol) 102 | return plist_get(pl, property) 103 | 104 | def put(symbol, property, value): 105 | assert symbolp(symbol), "Expected a symbol" 106 | pl = symbol_plist(symbol) 107 | pl = plist_put(pl, property, value) 108 | setplist(symbol, pl) 109 | return value 110 | 111 | def y_len(l): 112 | n = -1 113 | for k, v in y_for(l): 114 | if isinstance(k, int): 115 | if n < k: 116 | n = k 117 | n += 1 118 | return n 119 | 120 | def y_get(l, key, unset=None): 121 | if isinstance(key, int) and key < 0: 122 | n = y_len(l) 123 | key = clamp(key + n, 0, n - 1) 124 | for k, v in y_for(l): 125 | if k == key: 126 | return v 127 | return unset 128 | 129 | def y_put(l, key, val): 130 | if isinstance(key, int) and key < 0: 131 | n = y_len(l) 132 | key = clamp(key + n, 0, n - 1) 133 | r = [] 134 | seen = False 135 | n = -1 136 | for k, v in y_for(l): 137 | if k == key: 138 | v = val 139 | seen = True 140 | if isinstance(k, str): 141 | r.append(keyword(k)) 142 | elif isinstance(k, int): 143 | if n < k: 144 | n = k 145 | r.append(v) 146 | n += 1 147 | if not seen: 148 | k = key 149 | v = val 150 | if isinstance(k, str): 151 | r.append(keyword(k)) 152 | else: 153 | while n < k: 154 | r.append(None) 155 | n += 1 156 | r.append(v) 157 | return r 158 | 159 | 160 | def make_symbol(x): 161 | assert stringp(x), "Expected a string" 162 | return "|" + x + "|" 163 | 164 | def car(x): 165 | return x[0] 166 | 167 | def cdr(x): 168 | try: 169 | return x[1:] 170 | except TypeError: 171 | return x 172 | 173 | def cddr(x): 174 | try: 175 | return x[2:] 176 | except TypeError: 177 | return x 178 | 179 | def y_key(x): 180 | if keywordp(x): 181 | return x[1:] 182 | else: 183 | return x 184 | 185 | def y_next(h): 186 | if keywordp(car(h)): 187 | return cddr(h) 188 | return cdr(h) 189 | 190 | from collections import abc 191 | 192 | def either(x, *ys): 193 | if null(x) and len(ys) > 0: 194 | return either(*ys) 195 | return x 196 | 197 | def Or(*args): 198 | if len(args) <= 0: 199 | return [] 200 | if len(args) <= 1: 201 | return args[0] 202 | x = args[0] 203 | if not nil(x): 204 | return x 205 | else: 206 | return Or(*args[1:]) 207 | 208 | def awaitable(x): 209 | return isinstance(x, abc.Awaitable) 210 | 211 | async def AND(x, *args): 212 | if len(args) <= 0: 213 | return x 214 | 215 | def orf (*fns): 216 | def fn(*args): 217 | def self(fs): 218 | if t(fs): 219 | return eitherf(apply(car(fs), args, kws), lambda: self(cdr(fs))) 220 | 221 | 222 | 223 | def eitherf(x, body): 224 | if null(x): 225 | return body() 226 | if not null(x) and len(ys) > 0: 227 | return either(*ys) 228 | return x 229 | 230 | def chunks(l, n): 231 | for i in range(0, len(l), n): 232 | yield l[i:i + n] 233 | 234 | def tuples(l, n): 235 | return list(chunks(l, n)) 236 | 237 | def pair(l): 238 | return tuples(l, 2) 239 | 240 | def dbind(n, l): 241 | for v in l: 242 | v = list(v) 243 | while len(v) < n: 244 | v.append(None) 245 | while len(v) > n: 246 | v.pop() 247 | yield v 248 | 249 | def replace(x, *subs): 250 | for a, b in dbind(2, pair(subs)): 251 | if equal(x, a): 252 | x = b 253 | return x 254 | 255 | def inner(x): 256 | return x[1:-1] 257 | 258 | def set(k, v): 259 | assert symbolp(k), "Expected a symbol" 260 | G[k] = v 261 | return v 262 | 263 | def keyword(x): 264 | if keywordp(x): 265 | return x 266 | if stringp(x) and len(x) > 0: 267 | return ":" + x 268 | 269 | def setq(k, v): 270 | return set(make_symbol(k), v) 271 | 272 | def rep(x): 273 | x = repr(x) 274 | if x in ['()', '[]', 'None']: 275 | return 'nil' 276 | return x 277 | 278 | def equal(a, b): 279 | if inspect.isfunction(a) and inspect.isfunction(b) and a.__qualname__ == b.__qualname__: 280 | return True 281 | return rep(a) == rep(b) 282 | 283 | def cons(a, b=None): 284 | assert iterable(b) or null(b) 285 | return [a, *([] if not iterable(b) else b)] 286 | 287 | def push(element, listname): 288 | symbol = make_symbol(listname) 289 | G[symbol] = cons(element, G.get(symbol, [])) 290 | return G[symbol] 291 | 292 | 293 | def add_to_list(symbol, element, *, append=None, compare_fn=None): 294 | assert symbolp(symbol), "Expected a symbol" 295 | if compare_fn is None: 296 | compare_fn = equal 297 | l = G.get(symbol, []) 298 | for x in l: 299 | if compare_fn(x, element): 300 | return l 301 | if yes(append): 302 | l.append(element) 303 | else: 304 | l = cons(element, l) 305 | set(symbol, l) 306 | return l 307 | 308 | def named(name, value, *, qualname=None): 309 | value.__name__ = name 310 | value.__qualname__ = name if qualname is None else qualname 311 | return value 312 | 313 | def defalias(name, definition, *, doc=None): 314 | assert symbolp(name), "Expected a symbol" 315 | if doc is not None: 316 | definition.__doc__ = doc 317 | G[symbol_id(name)] = named(name, definition) 318 | return name 319 | 320 | import inspect 321 | 322 | def eval(x): 323 | if symbolp(x): 324 | return symbol_value(x) 325 | return x 326 | 327 | def call(f, *args, **kws): 328 | if symbolp(f): 329 | f = symbol_function(f) 330 | return f(*args, **kws) 331 | 332 | def run_hooks(*hookvars): 333 | for hookvar in hookvars: 334 | run_hook_with_args(hookvar) 335 | 336 | def run_hook_with_args(hook, *args, **kws): 337 | if symbolp(hook): 338 | hook = symbol_value(hook) 339 | if inspect.isfunction(hook): 340 | hook = [hook] 341 | if hook is not None: 342 | for fn in hook: 343 | call(fn, *args, **kws) 344 | 345 | def run_hook_with_args_until_success(hook, *args, **kws): 346 | if symbolp(hook): 347 | hook = symbol_value(hook) 348 | if inspect.isfunction(hook): 349 | hook = [hook] 350 | if hook is not None: 351 | for fn in hook: 352 | result = call(fn, *args, **kws) 353 | if result is not None: 354 | return result 355 | 356 | def run_hook_with_args_until_failure(hook, *args, **kws): 357 | if symbolp(hook): 358 | hook = symbol_value(hook) 359 | if inspect.isfunction(hook): 360 | hook = [hook] 361 | if hook is not None: 362 | for fn in hook: 363 | result = call(fn, *args, **kws) 364 | if no(result): 365 | return result 366 | return True 367 | 368 | 369 | def y_for(h, upto=None): 370 | if inspect.ismodule(h): 371 | h = vars(h) 372 | if isinstance(h, abc.Mapping): 373 | for k, v in h.items(): 374 | yield k, v 375 | return 376 | try: 377 | it = iter(h) 378 | except TypeError: 379 | return 380 | try: 381 | i = -1 382 | while True: 383 | v = next(it) 384 | if keywordp(v): 385 | k = y_key(v) 386 | v = next(it) 387 | yield k, v 388 | else: 389 | if upto is not None: 390 | if i >= upto: 391 | return 392 | i += 1 393 | yield i, v 394 | except StopIteration: 395 | pass 396 | 397 | def maybe_int(x): 398 | if string63(x): 399 | try: 400 | return int(x) 401 | except ValueError: 402 | pass 403 | return x 404 | 405 | def isa(x, *types): 406 | return isinstance(x, types) 407 | 408 | def clamp(n, lo=None, hi=None): 409 | if lo is not None and n < lo: 410 | return lo 411 | if hi is not None and n > hi: 412 | return hi 413 | return n 414 | 415 | def iterable(x): 416 | return isinstance(x, abc.Iterable) 417 | 418 | # def no(x): 419 | # if x is None: 420 | # return True 421 | # if x is False: 422 | # return True 423 | # if iterable(x) and len(x) == 0: 424 | # return True 425 | # return False 426 | 427 | # def yes(x): 428 | # return not no(x) 429 | 430 | def null(x): 431 | return x is None 432 | 433 | def nil(x): 434 | return null(x) or none63(x) 435 | 436 | def t(x): 437 | return not nil(x) 438 | 439 | def no(x): 440 | return null(x) or x is False 441 | 442 | def yes(x): 443 | return not no(x) 444 | 445 | def at(x, i): 446 | if nil(x): 447 | return x 448 | return x[i] 449 | 450 | def cut(x, lo=None, hi=None): 451 | if nil(x): 452 | return x 453 | return x[lo:hi] 454 | 455 | def hd(x): 456 | return at(x, 0) 457 | 458 | def tl(x): 459 | return cut(x, 1) 460 | 461 | 462 | def only(f): 463 | def fn(*args, **kws): 464 | if t(hd(args)): 465 | return f(*args, **kws) 466 | 467 | def iterate(x, upto=None): 468 | if hasattr(x, 'items'): 469 | for k, v in x.items(): 470 | i = number(k) 471 | if nil(i): 472 | yield [k, v] 473 | 474 | def cut(x, i): 475 | if i < 0: 476 | return x 477 | return 478 | if nil(x): 479 | return x 480 | 481 | def length(x, upto=None): 482 | if nil(x): 483 | return 0 484 | # if upto is None: 485 | # return len(x) 486 | # else: 487 | if True: 488 | it = iter(x) 489 | i = 0 490 | try: 491 | while True: 492 | if is63(upto): 493 | if i > upto: 494 | return i 495 | next(it) 496 | i += 1 497 | except StopIteration: 498 | return i 499 | 500 | 501 | def many63(x): return length(x, 1) == 2 502 | 503 | def some63(x): return length(x, 0) == 1 504 | 505 | def none63(x): return length(x, 0) == 0 506 | 507 | def one63(x): return length(x, 1) == 1 508 | 509 | def two63(x): return length(x, 2) == 2 510 | 511 | def either(x, *ys): 512 | if x is None: 513 | if len(ys) > 0: 514 | return either(*ys) 515 | return x 516 | 517 | def number(x): 518 | try: 519 | return int(x) 520 | except ValueError: 521 | try: 522 | return float(x) 523 | except ValueError: 524 | pass 525 | 526 | def maybe_number(x): 527 | r = number(x) 528 | return x if r is None else r 529 | -------------------------------------------------------------------------------- /input_fns.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tflex 4 | import os 5 | import tqdm 6 | import tputil 7 | from pprint import pprint as pp 8 | 9 | from train_flags import FLAGS 10 | 11 | def _int64_feature(value): 12 | """Returns an int64_list from a bool / enum / int / uint.""" 13 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 14 | 15 | def _bytes_feature(value): 16 | """Returns a bytes_list from a string / byte.""" 17 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 18 | 19 | class TFRecordExporter: 20 | def __init__(self, tfrecord_dir, expected_examples, print_progress=True, progress_interval=10): 21 | self.tfrecord_dir = tfrecord_dir 22 | self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir)) 23 | self.expected_examples = expected_examples 24 | self.cur_examples = 0 25 | self.shape = None 26 | self.resolution_log2 = None 27 | self.tfr_writers = [] 28 | self.print_progress = print_progress 29 | self.progress_interval = progress_interval 30 | 31 | if self.print_progress: 32 | print('Creating dataset "%s"' % tfrecord_dir) 33 | if not os.path.isdir(self.tfrecord_dir): 34 | os.makedirs(self.tfrecord_dir) 35 | assert os.path.isdir(self.tfrecord_dir) 36 | 37 | def close(self): 38 | if self.print_progress: 39 | print('%-40s\r' % 'Flushing data...', end='', flush=True) 40 | for tfr_writer in self.tfr_writers: 41 | tfr_writer.close() 42 | self.tfr_writers = [] 43 | if self.print_progress: 44 | print('%-40s\r' % '', end='', flush=True) 45 | print('Added %d images.' % self.cur_examples) 46 | 47 | def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order. 48 | order = np.arange(self.expected_examples) 49 | np.random.RandomState(123).shuffle(order) 50 | return order 51 | 52 | def add_tokens(self, tokens): 53 | if self.print_progress and self.cur_examples % self.progress_interval == 0: 54 | print('%d / %d\r' % (self.cur_examples, self.expected_examples), end='', flush=True) 55 | if len(self.tfr_writers) <= 0: 56 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) 57 | tfr_file = self.tfr_prefix + '.tfrecords' 58 | self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt)) 59 | for lod, tfr_writer in enumerate(self.tfr_writers): 60 | #import pdb; pdb.set_trace() 61 | data = np.array(tokens, dtype=np.int32) 62 | feature = { 63 | #"hash": _bytes_feature(hash.encode()), 64 | "text": _int64_feature(data) 65 | } 66 | ex = tf.train.Example(features=tf.train.Features(feature=feature)) 67 | s = ex.SerializeToString() 68 | tfr_writer.write(s) 69 | self.cur_examples += 1 70 | 71 | def __enter__(self): 72 | return self 73 | 74 | def __exit__(self, *args): 75 | self.close() 76 | 77 | 78 | 79 | # Sample 1024(+1) tokens from the stitched together text 80 | def sample_text(x, amount, batch_size=None): 81 | if batch_size is not None: 82 | features, labels = [], [] 83 | for i in range(batch_size): 84 | features1, labels1 = sample_text(x, amount) 85 | features.append(features1) 86 | labels.append(labels1) 87 | features = tf.stack(features) 88 | labels = tf.stack(labels) 89 | return features, labels 90 | s = tf.size(x, out_type=tf.dtypes.int64) 91 | r = tf.random.uniform([], maxval=s-(amount+1), dtype=tf.dtypes.int64) 92 | r1 = tf.range(r, r+amount) 93 | r2 = tf.range(r+1, (r+1)+amount) 94 | r1 = tf.reshape(r1, [amount]) # Somehow, this makes the compiler happy 95 | r2 = tf.reshape(r2, [amount]) # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input 96 | vals1 = tf.gather(x, r1) 97 | vals2 = tf.gather(x, r2) 98 | vals1 = tf.cast(vals1, tf.dtypes.int32) 99 | vals2 = tf.cast(vals2, tf.dtypes.int32) 100 | features, labels = vals1, vals2 101 | return features, labels 102 | 103 | def make_source_dataset(index, num_hosts, batch_size, n_ctx): 104 | pp({'op': 'make_source_dataset', 'index': index, 'num_hosts': num_hosts, 'batch_size': batch_size}) 105 | tokens = [[(_ + 0) for _ in range(0, n_ctx)]] * batch_size 106 | labels = [[(_ + 1) for _ in range(0, n_ctx)]] * batch_size 107 | #with tflex.device('/tpu:%d' % index): 108 | #with tf.device('/job:worker/replica:0/task:0/device:CPU:0'): 109 | with tflex.nullcontext(): 110 | t = tf.broadcast_to(tokens, [len(tokens), len(tokens[0])]) 111 | l = tf.broadcast_to(labels, [len(labels), len(labels[0])]) 112 | #dset1 = tf.data.Dataset.from_tensor_slices(t); 113 | #dset2 = tf.data.Dataset.from_tensor_slices(l); 114 | dset1 = tf.data.Dataset.from_tensors(t); 115 | dset2 = tf.data.Dataset.from_tensors(l); 116 | dset = tf.data.Dataset.zip((dset1, dset2)) 117 | #dset = dset.shuffle() 118 | dset = dset.repeat() 119 | return dset 120 | 121 | def export_source_tokens(tfrecord_dir, tokens): 122 | tf.logging.info("Exporting tokens to %s...", FLAGS.export_dataset) 123 | with TFRecordExporter(tfrecord_dir, 1) as tfr: 124 | tfr.add_tokens(tokens) 125 | tf.logging.info("Exported tokens to %s", FLAGS.export_dataset) 126 | 127 | if 'api' not in globals(): 128 | api = tflex.Dictator() 129 | api.tokens = None 130 | 131 | def unload_source_tokens(): 132 | api.tokens = None 133 | 134 | def load_source_tokens(dataset, export_dataset=None, quit_after_exporting=True): 135 | if dataset is None: 136 | tf.logging.info("Generating random fake tokens") 137 | tokens = [(_ + 0) % n_vocab for _ in range(0, 100000)] 138 | elif dataset.endswith('.tok16'): 139 | tf.logging.info("Reading tokens from %s...", dataset) 140 | with tf.io.gfile.GFile(dataset, 'rb') as f: 141 | data = f.read() 142 | tf.logging.info("len(data)=%s; np.frombuffer(%s, dtype=np.uint16)...", len(data), repr(dataset)) 143 | tokens = np.frombuffer(data, dtype=np.uint16) 144 | else: 145 | tf.logging.info("Loading tokens from %s...", dataset) 146 | tokens = [] 147 | npz = np.load(dataset) 148 | for item in npz.files: 149 | tokens.extend(npz[item]) 150 | tf.logging.info("Finished reading tokens.") 151 | if export_dataset: 152 | export_source_tokens(export_dataset, tokens) 153 | if quit_after_exporting: 154 | tf.logging.info("Tokens exported; quitting.") 155 | import posix 156 | posix._exit(0) 157 | return tokens 158 | 159 | def get_source_tokens(dataset=None, reload=False, export_dataset=None): 160 | if dataset is None: 161 | dataset = FLAGS.dataset 162 | if export_dataset is None: 163 | export_dataset = FLAGS.export_dataset 164 | if api.tokens is None or reload: 165 | unload_source_tokens() 166 | api.tokens = load_source_tokens(dataset) 167 | return api.tokens 168 | 169 | def make_source_tokens(index, num_hosts, n_vocab): 170 | tokens = get_source_tokens() 171 | n = len(tokens) 172 | k = n // num_hosts 173 | i = index * k 174 | j = (index + 1) * k 175 | tokens = tokens[i:j] 176 | tf.logging.info("Shard %d/%d has %d tokens", index, num_hosts, len(tokens)) 177 | dset = None 178 | step = int(10e6) 179 | for offset in tqdm.trange(0, len(tokens), step): 180 | t = tokens[offset:offset+step] 181 | #t = tf.broadcast_to(tf.cast(t, tf.int32), [len(t)]) 182 | t = tf.data.Dataset.from_tensors(t); 183 | dset = t if dset is None else dset.concatenate(t) 184 | if _loaded_dataset is not None: 185 | if index >= num_hosts - 1: 186 | tf.logging.info('Resetting tokens') 187 | if not isinstance(_loaded_dataset, np.ndarray): 188 | if isinstance(_loaded_dataset, list): 189 | while len(_loaded_dataset) > 0: 190 | _loaded_dataset.pop() 191 | _loaded_dataset = None 192 | return dset 193 | 194 | def bpe_text(batch_size, files, iterations, stitch, amount=1024, batch=True): 195 | dataset = tf.data.Dataset.from_tensor_slices(files) 196 | dataset = dataset.apply(tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=True)) 197 | 198 | def _parse_function(example_proto): 199 | features = { 200 | #"hash": tf.VarLenFeature(tf.string), 201 | "text": tf.VarLenFeature(tf.int64) 202 | } 203 | parsed_features = tf.parse_single_example(example_proto, features) 204 | return parsed_features["text"], parsed_features["text"].dense_shape[0] 205 | 206 | dataset = dataset.map(_parse_function, num_parallel_calls=1).shuffle(1000 * stitch) 207 | 208 | # Since samples can be less than the correct length, and TPUs don't like variable lengths, this function stitches together enough samples 209 | # to have a text at least 1024 tokens long. For this to work the stitch parameter must be correctly tuned so that 210 | # stitch * min(characters_in_text) >= amount 211 | def _stitch_text(x, y): 212 | x = tf.sparse.to_dense(x) 213 | 214 | def _get_x(i): 215 | return tf.gather(x[i], tf.range(y[i])) 216 | 217 | out = _get_x(0) 218 | for i in range(1, stitch): 219 | #out = tf.concat([out, [50256], _get_x(i)], axis=0) # text1<|endoftext|>text2 220 | out = tf.concat([out, _get_x(i)], axis=0) # text1+text2 221 | 222 | return out 223 | 224 | # Hack-y way to stitch together multiple texts 225 | dataset = dataset.batch(stitch, drop_remainder=True).map(_stitch_text, num_parallel_calls=tf.data.experimental.AUTOTUNE) 226 | 227 | # Sample 1024(+1) tokens from the stitched together text 228 | def _sample_text(x): 229 | s = tf.size(x) 230 | r = tf.random.uniform([], maxval=s-(amount+1), dtype=tf.dtypes.int32) 231 | r1 = tf.range(r, r+amount) 232 | r2 = tf.range(r+1, (r+1)+amount) 233 | r1 = tf.reshape(r1, [amount]) # Somehow, this makes the compiler happy 234 | r2 = tf.reshape(r2, [amount]) # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input 235 | vals1 = tf.gather(x, r1) 236 | vals2 = tf.gather(x, r2) 237 | vals1 = tf.cast(vals1, tf.dtypes.int32) 238 | vals2 = tf.cast(vals2, tf.dtypes.int32) 239 | return vals1, vals2 240 | 241 | if batch: 242 | dataset = dataset.apply(tf.data.experimental.map_and_batch( 243 | map_func=_sample_text, batch_size=batch_size, 244 | num_parallel_calls=tf.data.experimental.AUTOTUNE, 245 | drop_remainder=True)) 246 | dataset = dataset.repeat().prefetch(iterations) 247 | 248 | else: 249 | dataset = dataset.map(_sample_text, num_parallel_calls=tf.data.experimental.AUTOTUNE).repeat() 250 | 251 | return dataset 252 | 253 | 254 | def gpt2_input(params): 255 | pp({'op': 'gpt2_input', 'params': params}) 256 | batch_size = params['batch_size'] 257 | iterations = FLAGS.iterations_per_loop 258 | # TODO(dehao): Replace the following with params['context'].current_host 259 | if 'context' in params: 260 | current_host = params['context'].current_input_fn_deployment()[1] 261 | num_hosts = params['context'].num_hosts 262 | else: 263 | if 'dataset_index' in params: 264 | current_host = params['dataset_index'] 265 | num_hosts = params['dataset_num_shards'] 266 | else: 267 | current_host = 0 268 | num_hosts = 1 269 | if False: 270 | dset = make_source_dataset(current_host, num_hosts, batch_size, n_ctx=params['n_ctx']) 271 | elif FLAGS.dataset is not None and FLAGS.dataset.startswith('gs://') and '*' in FLAGS.dataset: 272 | files = [] 273 | for fname in FLAGS.dataset.split(','): 274 | files.extend(sorted(tf.io.gfile.glob(fname))) 275 | assert len(files) > 0 276 | dset = bpe_text(batch_size, files, iterations=iterations, stitch=min(2, len(files)), amount=params['n_ctx'], batch=True) 277 | elif False: 278 | dset = make_source_tokens(current_host, num_hosts, n_vocab=params['n_vocab']) 279 | batch=True 280 | def _sample_text(*args, **kws): 281 | return sample_text(*args, **kws, amount=params['n_ctx']) 282 | if batch: 283 | dset = dset.apply(tf.data.experimental.map_and_batch( 284 | map_func=_sample_text, batch_size=batch_size, 285 | num_parallel_calls=tf.data.experimental.AUTOTUNE, 286 | drop_remainder=True)) 287 | dset = dset.repeat().prefetch(iterations) 288 | else: 289 | dset = dset.map(_sample_text, num_parallel_calls=tf.data.experimental.AUTOTUNE).repeat() 290 | elif FLAGS.dataset.endswith('.tok16') and FLAGS.dataset.startswith('gs://'): 291 | tokens_var = tputil.tf_shard_variable(FLAGS.dataset, tf.uint16, current_host, num_hosts, use_resource=False) 292 | def sample_fn(): 293 | return tputil.sample_text(tokens_var, amount=params['n_ctx'], batch_size=batch_size) 294 | def init_fn(): 295 | return tokens_var.initializer 296 | def upload_fn(session=None): 297 | if session is None: 298 | session = tf.get_default_session() 299 | #n = len(tokens) 300 | n = tokens_var.shape[0].value 301 | tf.logging.info('Loading %s tokens to TPU host %d...', tflex.num(n), current_host) 302 | assert session is not None 303 | pass 304 | dset = tflex.make_dataset_function(sample_fn=sample_fn, init_fn=init_fn, upload_fn=upload_fn) 305 | else: 306 | #dset = make_source_tokens(current_host, num_hosts, n_vocab=params['n_vocab']) 307 | all_tokens = get_source_tokens() 308 | assert all_tokens.ndim == 1 309 | n = len(all_tokens) 310 | k = n // num_hosts 311 | i = current_host * k 312 | j = (current_host + 1) * k 313 | tokens = all_tokens[i:j] 314 | tf.logging.info("Shard %d/%d has %s tokens out of %s total", current_host, num_hosts, tflex.num(len(tokens)), tflex.num(len(all_tokens))) 315 | with tf.variable_scope('cpu%d' % current_host): 316 | tokens_var = tf.get_local_variable('input_tokens', dtype=tf.uint16, shape=[len(tokens)], use_resource=True) 317 | def sample_fn(): 318 | return tputil.sample_text(tokens_var, amount=params['n_ctx'], batch_size=batch_size) 319 | def init_fn(): 320 | return tokens_var.initializer 321 | def upload_fn(session=None): 322 | if session is None: 323 | session = tf.get_default_session() 324 | tf.logging.info('Loading %s tokens to TPU host %d...', tflex.num(len(tokens)), current_host) 325 | assert session is not None 326 | with tflex.with_elapsed(tflex.assign_values, [tokens_var], [tokens], session=session) as (elapsed, result): 327 | tf.logging.info('Loaded %s tokens to TPU host %d in %.2fs', tflex.num(len(tokens)), current_host, elapsed) 328 | tf.logging.info('Unloading source tokens.') 329 | unload_source_tokens() 330 | dset = tflex.make_dataset_function(sample_fn=sample_fn, init_fn=init_fn, upload_fn=upload_fn) 331 | return dset 332 | return dset 333 | 334 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | #import torch 2 | #import torch.nn.functional as F 3 | import tftorch as torch; F = torch; nn = F 4 | 5 | # DCGAN loss 6 | def loss_dcgan_dis(dis_fake, dis_real): 7 | L1 = torch.mean(F.softplus(-dis_real)) 8 | L2 = torch.mean(F.softplus(dis_fake)) 9 | return L1, L2 10 | 11 | 12 | def loss_dcgan_gen(dis_fake): 13 | loss = torch.mean(F.softplus(-dis_fake)) 14 | return loss 15 | 16 | 17 | # Hinge Loss 18 | def loss_hinge_dis(dis_fake, dis_real): 19 | loss_real = torch.mean(F.relu(1. - dis_real)) 20 | loss_fake = torch.mean(F.relu(1. + dis_fake)) 21 | return loss_real, loss_fake 22 | # def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss 23 | # loss = torch.mean(F.relu(1. - dis_real)) 24 | # loss += torch.mean(F.relu(1. + dis_fake)) 25 | # return loss 26 | 27 | 28 | def loss_hinge_gen(dis_fake): 29 | loss = -torch.mean(dis_fake) 30 | return loss 31 | 32 | # Default to hinge loss 33 | generator_loss = loss_hinge_gen 34 | discriminator_loss = loss_hinge_dis 35 | 36 | -------------------------------------------------------------------------------- /main_biggan.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | 4 | import numpy as np 5 | 6 | import tensorflow.compat.v1 as tf 7 | 8 | from absl import app 9 | from absl import logging 10 | 11 | from tensorflow.core.protobuf import config_pb2 12 | from tensorflow.core.protobuf import tensorflow_server_pb2 13 | from tensorflow.python.client import session 14 | from tensorflow.python.framework import constant_op 15 | from tensorflow.python.framework import dtypes 16 | from tensorflow.python.framework import errors_impl 17 | from tensorflow.python.framework import ops 18 | from tensorflow.python.framework import test_util 19 | from tensorflow.python.ops import array_ops 20 | from tensorflow.python.ops import data_flow_ops 21 | from tensorflow.python.ops import math_ops 22 | from tensorflow.python.ops import variables 23 | from tensorflow.python.platform import test 24 | from tensorflow.python.training import input as input_ops 25 | from tensorflow.python.training import queue_runner_impl 26 | from tensorflow.python.training import server_lib 27 | from tensorflow.python.tpu import tpu 28 | 29 | import train_runner 30 | from train_flags import flags, FLAGS 31 | 32 | from pprint import pprint as pp 33 | from pprint import pformat as pf 34 | 35 | # from model_fns import gpt2_model, gpt2_rev_model 36 | # from input_fns import gpt2_input 37 | 38 | import json 39 | 40 | from tfjpg_parser import ImageNet, iterate_dataset 41 | 42 | import tflex 43 | 44 | import BigGAN 45 | 46 | import gin 47 | 48 | flags.DEFINE_multi_string( 49 | "gin_config", [], 50 | "List of paths to the config files.") 51 | flags.DEFINE_multi_string( 52 | "gin_bindings", [], 53 | "Newline separated list of Gin parameter bindings.") 54 | 55 | 56 | def parseval(value, dtype, default=None): 57 | if dtype == 'str' or isinstance(default, str): 58 | pass 59 | elif dtype == 'int' or isinstance(default, int): 60 | value = int(value) 61 | elif dtype == 'float' or isinstance(default, float): 62 | value = float(value) 63 | elif dtype == 'bool' or isinstance(default, bool): 64 | if value == '1' or value.lower() == 'true': 65 | value = True 66 | else: 67 | value = False 68 | else: 69 | assert dtype is not None 70 | value = dtype(value) 71 | return value 72 | 73 | 74 | def getval(name, default, dtype=None): 75 | if name.upper() in os.environ: 76 | value = os.environ[name.upper()] 77 | value = parseval(value, dtype=dtype, default=default) 78 | tf.logging.info('getval(%s, %s) = os.environ[%s] = %s', repr(name), repr(default), repr(name.upper()), repr(value)) 79 | else: 80 | value = params.get(name, default) 81 | tf.logging.info('getval(%s, %s) = params[%s] = %s', repr(name), repr(default), repr(name), repr(value)) 82 | return value 83 | 84 | @gin.configurable 85 | def options(**kwargs): 86 | return dict(**kwargs) 87 | 88 | def main(unused_argv): 89 | logging.info("Gin config: %s\nGin bindings: %s", 90 | FLAGS.gin_config, FLAGS.gin_bindings) 91 | gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) 92 | global params 93 | #FLAGS.iterations_per_loop = 100 94 | #params = {'batch_size': FLAGS.train_batch_size} 95 | #params = {'batch_size': 128, 'use_tpu': True, 'precision': 'float32'} 96 | # with open(FLAGS.params) as f: 97 | # params = json.load(f) 98 | params = options() 99 | params['use_tpu'] = getval('use_tpu', True) 100 | params['batch_per_core'] = getval('batch_per_core', 1) 101 | params['iterations'] = getval('iterations', 20) 102 | params['batch_size'] = FLAGS.num_cores * params['batch_per_core'] 103 | params['opt_name'] = getval('opt_name', 'adam') 104 | params['beta1'] = getval('beta1', 0.9) 105 | params['beta2'] = getval('beta2', 0.999) 106 | params['epsilon'] = getval('epsilon', 1e-9) 107 | params['lr'] = getval('lr', 0.00025) 108 | FLAGS.train_batch_size = params['batch_size'] 109 | FLAGS.iterations_per_loop = params['iterations'] 110 | FLAGS.train_steps = getval('train_steps', int(2e6)) 111 | params['precision'] = getval('precision', 'float32') 112 | params['model'] = getval('model', 'GPT2') 113 | assert params['model'] in ['GPT2', 'GPT2Rev'] 114 | 115 | graph = tf.Graph() 116 | with graph.as_default(): 117 | master = FLAGS.tpu or FLAGS.master or getval('TPU_NAME', 'unknown') 118 | cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( 119 | master, 120 | zone=FLAGS.tpu_zone, 121 | project=FLAGS.gcp_project) 122 | config = tf.ConfigProto(operation_timeout_in_ms=600 * 60 * 1000, 123 | # graph_options=tf.GraphOptions( 124 | # rewrite_options=rewriter_config_pb2.RewriterConfig( 125 | # disable_meta_optimizer=True, 126 | # ), 127 | # ), 128 | isolate_session_state=True) 129 | cluster_spec = cluster_resolver.cluster_spec() 130 | if cluster_spec: 131 | config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 132 | sess = tf.InteractiveSession(cluster_resolver.get_master(), graph=graph, config=config) 133 | devices = sess.list_devices() 134 | cores = sorted([x.name for x in devices if ':TPU:' in x.name]) 135 | num_cores = len(cores) 136 | assert num_cores % 8 == 0 137 | num_hosts = num_cores // 8 138 | print(config.cluster_def) 139 | print('cores: %d hosts: %d ip: %s' % (num_cores, num_hosts, master)) 140 | tf.logging.info("TrainRunner: initializing TPU session...") 141 | if not bool(int(os.environ.get('TPU_NO_INIT', '0'))): 142 | tflex.run(sess, tf.tpu.initialize_system()) 143 | tf.logging.info("TrainRunner: initializing TPU session (done)") 144 | gan = BigGAN.GAN() 145 | pp(tf.trainable_variables()) 146 | import pdb; pdb.set_trace() 147 | 148 | 149 | # seed = 0 150 | # dataset = ImageNet.make_dataset(FLAGS.dataset or "gs://dota-euw4a/datasets/danbooru2019-s/danbooru2019-s-0*", 0, 1, seed=seed) 151 | # it = iterate_dataset(dataset) 152 | 153 | # def go(): 154 | # zz = next(it) 155 | # images = [zz['image']] 156 | # labels = [zz['label']] 157 | 158 | # #import IPython 159 | # print('label', labels[0]) 160 | # #print(labels[0] - 1, imagenet_label_names[labels[0] - 1]) 161 | # print(images[0].shape) 162 | # print('embedding', zz['parsed']['image/class/embedding'].values.shape) 163 | # print('filename', zz['parsed']['image/filename']) 164 | # print('hash', zz['parsed']['image/hash']) 165 | # op = tf.io.encode_jpeg(images[0]) 166 | # with open('test.png', 'wb') as f: 167 | # f.write(sess.run(op)) 168 | # go() 169 | 170 | import pdb; pdb.set_trace() 171 | 172 | dataset = dataset 173 | 174 | # model = gpt2_rev_model if params['model'] == 'GPT2Rev' else gpt2_model 175 | # pp(params) 176 | # trunner = train_runner.TrainRunner( 177 | # iterations=FLAGS.iterations_per_loop, train_steps=FLAGS.train_steps) 178 | # def input_fn(params): 179 | # tokens = [[_ for _ in range(0, 1024)]] * params['batch_size'] 180 | # labels = [[_ for _ in range(1, 1025)]] * params['batch_size'] 181 | # t = tf.broadcast_to(tokens, [len(tokens), len(tokens[0])]) 182 | # l = tf.broadcast_to(labels, [len(labels), len(labels[0])]) 183 | # #dset1 = tf.data.Dataset.from_tensor_slices(t); 184 | # #dset2 = tf.data.Dataset.from_tensor_slices(l); 185 | # dset1 = tf.data.Dataset.from_tensors(t); 186 | # dset2 = tf.data.Dataset.from_tensors(l); 187 | # dset = tf.data.Dataset.zip((dset1, dset2)) 188 | # dset = dset.repeat() 189 | # return dset 190 | # def create_train_op(loss, params): 191 | # return tf.identity(loss) 192 | # def model_fn(features, labels, mode, params): 193 | # pp(['features', features]) 194 | # pp(['labels', labels]) 195 | # pp(['mode', mode]) 196 | # pp(['params', params]) 197 | # loss = tf.constant(0.0) 198 | # if mode == tf.estimator.ModeKeys.TRAIN: 199 | # train_op = create_train_op(loss, params) 200 | # if params['use_tpu']: 201 | # return tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op) 202 | # else: 203 | # return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 204 | # trunner.initialize(gpt2_input, model, params) 205 | # tf.logging.info('trunner.initialize(): Done. Training...') 206 | # trunner.train() 207 | # tf.logging.info('trunner.train(): Done. Shutting down...') 208 | # trunner.shutdown() 209 | # tf.logging.info('trunner.shutdown(): Done.') 210 | 211 | if __name__ == "__main__": 212 | app.run(main) 213 | 214 | -------------------------------------------------------------------------------- /main_gpt2.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | 4 | import numpy as np 5 | 6 | import tensorflow as tf 7 | 8 | from absl import app 9 | 10 | from tensorflow.core.protobuf import config_pb2 11 | from tensorflow.core.protobuf import tensorflow_server_pb2 12 | from tensorflow.python.client import session 13 | from tensorflow.python.framework import constant_op 14 | from tensorflow.python.framework import dtypes 15 | from tensorflow.python.framework import errors_impl 16 | from tensorflow.python.framework import ops 17 | from tensorflow.python.framework import test_util 18 | from tensorflow.python.ops import array_ops 19 | from tensorflow.python.ops import data_flow_ops 20 | from tensorflow.python.ops import math_ops 21 | from tensorflow.python.ops import variables 22 | from tensorflow.python.platform import test 23 | from tensorflow.python.training import input as input_ops 24 | from tensorflow.python.training import queue_runner_impl 25 | from tensorflow.python.training import server_lib 26 | 27 | import train_runner 28 | from train_flags import FLAGS 29 | 30 | from pprint import pprint as pp 31 | 32 | from model_fns import gpt2_model, gpt2_rev_model 33 | from input_fns import gpt2_input 34 | 35 | import json 36 | 37 | def parseval(value, dtype, default=None): 38 | if dtype == 'str' or isinstance(default, str): 39 | pass 40 | elif dtype == 'int' or isinstance(default, int): 41 | value = int(value) 42 | elif dtype == 'float' or isinstance(default, float): 43 | value = float(value) 44 | elif dtype == 'bool' or isinstance(default, bool): 45 | if value == '1' or value.lower() == 'true': 46 | value = True 47 | else: 48 | value = False 49 | else: 50 | assert dtype is not None 51 | value = dtype(value) 52 | return value 53 | 54 | 55 | def getval(name, default, dtype=None): 56 | if name.upper() in os.environ: 57 | value = os.environ[name.upper()] 58 | value = parseval(value, dtype=dtype, default=default) 59 | tf.logging.info('getval(%s, %s) = os.environ[%s] = %s', repr(name), repr(default), repr(name.upper()), repr(value)) 60 | else: 61 | value = params.get(name, default) 62 | tf.logging.info('getval(%s, %s) = params[%s] = %s', repr(name), repr(default), repr(name), repr(value)) 63 | return value 64 | 65 | 66 | def main(unused_argv): 67 | global params 68 | #FLAGS.iterations_per_loop = 100 69 | #params = {'batch_size': FLAGS.train_batch_size} 70 | #params = {'batch_size': 128, 'use_tpu': True, 'precision': 'float32'} 71 | with open(FLAGS.params) as f: 72 | params = json.load(f) 73 | params['use_tpu'] = getval('use_tpu', True) 74 | params['batch_per_core'] = getval('batch_per_core', 1) 75 | params['iterations'] = getval('iterations', 20) 76 | params['batch_size'] = FLAGS.num_cores * params['batch_per_core'] 77 | params['n_ctx'] = getval('n_ctx', 1024) 78 | params['n_embd'] = getval('n_embd', 768) 79 | params['n_head'] = getval('n_head', 12) 80 | params['n_layer'] = getval('n_layer', 12) 81 | params['n_vocab'] = getval('n_vocab', 50257) 82 | params['opt_name'] = getval('opt_name', 'adam') 83 | params['beta1'] = getval('beta1', 0.9) 84 | params['beta2'] = getval('beta2', 0.999) 85 | params['epsilon'] = getval('epsilon', 1e-9) 86 | params['lr'] = getval('lr', 0.00025) 87 | FLAGS.train_batch_size = params['batch_size'] 88 | FLAGS.iterations_per_loop = params['iterations'] 89 | FLAGS.train_steps = getval('train_steps', int(2e6)) 90 | params['precision'] = getval('precision', 'float32') 91 | params['model'] = getval('model', 'GPT2') 92 | assert params['model'] in ['GPT2', 'GPT2Rev'] 93 | model = gpt2_rev_model if params['model'] == 'GPT2Rev' else gpt2_model 94 | pp(params) 95 | trunner = train_runner.TrainRunner( 96 | iterations=FLAGS.iterations_per_loop, train_steps=FLAGS.train_steps) 97 | def input_fn(params): 98 | tokens = [[_ for _ in range(0, 1024)]] * params['batch_size'] 99 | labels = [[_ for _ in range(1, 1025)]] * params['batch_size'] 100 | t = tf.broadcast_to(tokens, [len(tokens), len(tokens[0])]) 101 | l = tf.broadcast_to(labels, [len(labels), len(labels[0])]) 102 | #dset1 = tf.data.Dataset.from_tensor_slices(t); 103 | #dset2 = tf.data.Dataset.from_tensor_slices(l); 104 | dset1 = tf.data.Dataset.from_tensors(t); 105 | dset2 = tf.data.Dataset.from_tensors(l); 106 | dset = tf.data.Dataset.zip((dset1, dset2)) 107 | dset = dset.repeat() 108 | return dset 109 | def create_train_op(loss, params): 110 | return tf.identity(loss) 111 | def model_fn(features, labels, mode, params): 112 | pp(['features', features]) 113 | pp(['labels', labels]) 114 | pp(['mode', mode]) 115 | pp(['params', params]) 116 | loss = tf.constant(0.0) 117 | if mode == tf.estimator.ModeKeys.TRAIN: 118 | train_op = create_train_op(loss, params) 119 | if params['use_tpu']: 120 | return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op) 121 | else: 122 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 123 | trunner.initialize(gpt2_input, model, params) 124 | tf.logging.info('trunner.initialize(): Done. Training...') 125 | trunner.train() 126 | tf.logging.info('trunner.train(): Done. Shutting down...') 127 | trunner.shutdown() 128 | tf.logging.info('trunner.shutdown(): Done.') 129 | 130 | if __name__ == "__main__": 131 | app.run(main) 132 | -------------------------------------------------------------------------------- /memory_saving_gradients.py: -------------------------------------------------------------------------------- 1 | from toposort import toposort 2 | import contextlib 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow.contrib.graph_editor as ge 6 | import time 7 | import sys 8 | sys.setrecursionlimit(10000) 9 | # refers back to current module if we decide to split helpers out 10 | util = sys.modules[__name__] 11 | 12 | # getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated" 13 | setattr(tf.GraphKeys, "VARIABLES", "variables") 14 | 15 | # save original gradients since tf.gradient could be monkey-patched to point 16 | # to our version 17 | from tensorflow.python.ops import gradients as tf_gradients_lib 18 | tf_gradients = tf_gradients_lib.gradients 19 | 20 | MIN_CHECKPOINT_NODE_SIZE=1024 # use lower value during testing 21 | 22 | # specific versions we can use to do process-wide replacement of tf.gradients 23 | def gradients_speed(ys, xs, grad_ys=None, **kwargs): 24 | return gradients(ys, xs, grad_ys, checkpoints='speed', **kwargs) 25 | 26 | def gradients_memory(ys, xs, grad_ys=None, **kwargs): 27 | return gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs) 28 | 29 | def gradients_collection(ys, xs, grad_ys=None, **kwargs): 30 | return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs) 31 | 32 | def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): 33 | ''' 34 | Authors: Tim Salimans & Yaroslav Bulatov 35 | 36 | memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" 37 | by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) 38 | 39 | ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients 40 | (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 41 | 42 | 'checkpoints' can either be 43 | - a list consisting of tensors from the forward pass of the neural net 44 | that we should re-use when calculating the gradients in the backward pass 45 | all other tensors that do not appear in this list will be re-computed 46 | - a string specifying how this list should be determined. currently we support 47 | - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, 48 | so checkpointing them maximizes the running speed 49 | (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) 50 | - 'memory': try to minimize the memory usage 51 | (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) 52 | - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint 53 | ''' 54 | 55 | # print("Calling memsaving gradients with", checkpoints) 56 | if not isinstance(ys,list): 57 | ys = [ys] 58 | if not isinstance(xs,list): 59 | xs = [xs] 60 | 61 | bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], 62 | inclusive=True) 63 | 64 | debug_print("bwd_ops: %s", bwd_ops) 65 | 66 | # forward ops are all ops that are candidates for recomputation 67 | fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], 68 | inclusive=True, 69 | within_ops=bwd_ops) 70 | debug_print("fwd_ops: %s", fwd_ops) 71 | 72 | # exclude ops with no inputs 73 | fwd_ops = [op for op in fwd_ops if op.inputs] 74 | 75 | # don't recompute xs, remove variables 76 | xs_ops = _to_ops(xs) 77 | fwd_ops = [op for op in fwd_ops if not op in xs_ops] 78 | fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] 79 | fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] 80 | fwd_ops = [op for op in fwd_ops if not '/read' in op.name] 81 | ts_all = ge.filter_ts(fwd_ops, True) # get the tensors 82 | ts_all = [t for t in ts_all if '/read' not in t.name] 83 | ts_all = set(ts_all) - set(xs) - set(ys) 84 | 85 | # construct list of tensors to checkpoint during forward pass, if not 86 | # given as input 87 | if type(checkpoints) is not list: 88 | if checkpoints == 'collection': 89 | checkpoints = tf.get_collection('checkpoints') 90 | 91 | elif checkpoints == 'speed': 92 | # checkpoint all expensive ops to maximize running speed 93 | checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') 94 | 95 | elif checkpoints == 'memory': 96 | 97 | # remove very small tensors and some weird ops 98 | def fixdims(t): # tf.Dimension values are not compatible with int, convert manually 99 | try: 100 | return [int(e if e.value is not None else 64) for e in t] 101 | except: 102 | return [0] # unknown shape 103 | ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE] 104 | ts_all = [t for t in ts_all if 'L2Loss' not in t.name] 105 | ts_all = [t for t in ts_all if 'entropy' not in t.name] 106 | ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] 107 | ts_all = [t for t in ts_all if 'Switch' not in t.name] 108 | ts_all = [t for t in ts_all if 'dropout' not in t.name] 109 | # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 110 | ts_all = [t for t in ts_all if 'Cast' not in t.name] 111 | 112 | # filter out all tensors that are inputs of the backward graph 113 | with util.capture_ops() as bwd_ops: 114 | tf_gradients(ys, xs, grad_ys, **kwargs) 115 | 116 | bwd_inputs = [t for op in bwd_ops for t in op.inputs] 117 | # list of tensors in forward graph that is in input to bwd graph 118 | ts_filtered = list(set(bwd_inputs).intersection(ts_all)) 119 | debug_print("Using tensors %s", ts_filtered) 120 | 121 | # try two slightly different ways of getting bottlenecks tensors 122 | # to checkpoint 123 | for ts in [ts_filtered, ts_all]: 124 | 125 | # get all bottlenecks in the graph 126 | bottleneck_ts = [] 127 | for t in ts: 128 | b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) 129 | f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) 130 | # check that there are not shortcuts 131 | b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) 132 | f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) 133 | if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): 134 | bottleneck_ts.append(t) # we have a bottleneck! 135 | else: 136 | debug_print("Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp))) 137 | 138 | # success? or try again without filtering? 139 | if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found! 140 | break 141 | 142 | if not bottleneck_ts: 143 | raise Exception('unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".') 144 | 145 | # sort the bottlenecks 146 | bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) 147 | sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] 148 | 149 | # save an approximately optimal number ~ sqrt(N) 150 | N = len(ts_filtered) 151 | if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): 152 | checkpoints = sorted_bottlenecks 153 | else: 154 | step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) 155 | checkpoints = sorted_bottlenecks[step::step] 156 | 157 | else: 158 | raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,)) 159 | 160 | checkpoints = list(set(checkpoints).intersection(ts_all)) 161 | 162 | # at this point automatic selection happened and checkpoints is list of nodes 163 | assert isinstance(checkpoints, list) 164 | 165 | debug_print("Checkpoint nodes used: %s", checkpoints) 166 | # better error handling of special cases 167 | # xs are already handled as checkpoint nodes, so no need to include them 168 | xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) 169 | if xs_intersect_checkpoints: 170 | debug_print("Warning, some input nodes are also checkpoint nodes: %s", 171 | xs_intersect_checkpoints) 172 | ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) 173 | debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, 174 | ys_intersect_checkpoints) 175 | # saving an output node (ys) gives no benefit in memory while creating 176 | # new edge cases, exclude them 177 | if ys_intersect_checkpoints: 178 | debug_print("Warning, some output nodes are also checkpoints nodes: %s", 179 | format_ops(ys_intersect_checkpoints)) 180 | 181 | # remove initial and terminal nodes from checkpoints list if present 182 | checkpoints = list(set(checkpoints) - set(ys) - set(xs)) 183 | 184 | # check that we have some nodes to checkpoint 185 | # if not checkpoints: 186 | # raise Exception('no checkpoints nodes found or given as input! ') 187 | 188 | # disconnect dependencies between checkpointed tensors 189 | checkpoints_disconnected = {} 190 | for x in checkpoints: 191 | if x.op and x.op.name is not None: 192 | grad_node = tf.stop_gradient(x, name=x.op.name+"_sg") 193 | else: 194 | grad_node = tf.stop_gradient(x) 195 | checkpoints_disconnected[x] = grad_node 196 | 197 | # partial derivatives to the checkpointed tensors and xs 198 | ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], 199 | stop_at_ts=checkpoints, within_ops=fwd_ops) 200 | debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", 201 | len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) 202 | debug_print("ops_to_copy = %s", ops_to_copy) 203 | debug_print("Processing list %s", ys) 204 | sgv_ops_to_copy = ge.sgv(ops_to_copy) 205 | copied_sgv, info = ge.copy_with_input_replacements(sgv_ops_to_copy, {}) 206 | for origin_op, op in info._transformed_ops.items(): 207 | op._set_device(origin_op.node_def.device) 208 | copied_ops = info._transformed_ops.values() 209 | debug_print("Copied %s to %s", ops_to_copy, copied_ops) 210 | ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) 211 | debug_print("Rewired %s in place of %s restricted to %s", 212 | checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) 213 | 214 | # get gradients with respect to current boundary + original x's 215 | copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] 216 | boundary = list(checkpoints_disconnected.values()) 217 | dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs) 218 | debug_print("Got gradients %s", dv) 219 | debug_print("for %s", copied_ys) 220 | debug_print("with respect to %s", boundary+xs) 221 | 222 | inputs_to_do_before = [y.op for y in ys] 223 | if grad_ys is not None: 224 | inputs_to_do_before += grad_ys 225 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] 226 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) 227 | 228 | # partial derivatives to the checkpointed nodes 229 | # dictionary of "node: backprop" for nodes in the boundary 230 | d_checkpoints = {r: dr for r,dr in zip(checkpoints_disconnected.keys(), 231 | dv[:len(checkpoints_disconnected)])} 232 | # partial derivatives to xs (usually the params of the neural net) 233 | d_xs = dv[len(checkpoints_disconnected):] 234 | 235 | # incorporate derivatives flowing through the checkpointed nodes 236 | checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) 237 | for ts in checkpoints_sorted_lists[::-1]: 238 | debug_print("Processing list %s", ts) 239 | checkpoints_other = [r for r in checkpoints if r not in ts] 240 | checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other] 241 | 242 | # copy part of the graph below current checkpoint node, stopping at 243 | # other checkpoints nodes 244 | ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) 245 | debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", 246 | len(ops_to_copy), fwd_ops, [r.op for r in ts], 247 | checkpoints_other) 248 | debug_print("ops_to_copy = %s", ops_to_copy) 249 | if not ops_to_copy: # we're done! 250 | break 251 | copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) 252 | for origin_op, op in info._transformed_ops.items(): 253 | op._set_device(origin_op.node_def.device) 254 | copied_ops = info._transformed_ops.values() 255 | debug_print("Copied %s to %s", ops_to_copy, copied_ops) 256 | ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) 257 | debug_print("Rewired %s in place of %s restricted to %s", 258 | checkpoints_disconnected_other, checkpoints_other, copied_ops) 259 | 260 | # gradient flowing through the checkpointed node 261 | boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] 262 | substitute_backprops = [d_checkpoints[r] for r in ts] 263 | dv = tf_gradients(boundary, 264 | checkpoints_disconnected_other+xs, 265 | grad_ys=substitute_backprops, **kwargs) 266 | debug_print("Got gradients %s", dv) 267 | debug_print("for %s", boundary) 268 | debug_print("with respect to %s", checkpoints_disconnected_other+xs) 269 | debug_print("with boundary backprop substitutions %s", substitute_backprops) 270 | 271 | inputs_to_do_before = [d_checkpoints[r].op for r in ts] 272 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] 273 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) 274 | 275 | # partial derivatives to the checkpointed nodes 276 | for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): 277 | if dr is not None: 278 | if d_checkpoints[r] is None: 279 | d_checkpoints[r] = dr 280 | else: 281 | d_checkpoints[r] += dr 282 | def _unsparsify(x): 283 | if not isinstance(x, tf.IndexedSlices): 284 | return x 285 | assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" 286 | indices = x.indices 287 | while indices.shape.ndims < x.values.shape.ndims: 288 | indices = tf.expand_dims(indices, -1) 289 | return tf.scatter_nd(indices, x.values, x.dense_shape) 290 | 291 | # partial derivatives to xs (usually the params of the neural net) 292 | d_xs_new = dv[len(checkpoints_other):] 293 | for j in range(len(xs)): 294 | if d_xs_new[j] is not None: 295 | if d_xs[j] is None: 296 | d_xs[j] = _unsparsify(d_xs_new[j]) 297 | else: 298 | d_xs[j] += _unsparsify(d_xs_new[j]) 299 | 300 | 301 | return d_xs 302 | 303 | def tf_toposort(ts, within_ops=None): 304 | all_ops = ge.get_forward_walk_ops([x.op for x in ts], within_ops=within_ops) 305 | 306 | deps = {} 307 | for op in all_ops: 308 | for o in op.outputs: 309 | deps[o] = set(op.inputs) 310 | sorted_ts = toposort(deps) 311 | 312 | # only keep the tensors from our original list 313 | ts_sorted_lists = [] 314 | for l in sorted_ts: 315 | keep = list(set(l).intersection(ts)) 316 | if keep: 317 | ts_sorted_lists.append(keep) 318 | 319 | return ts_sorted_lists 320 | 321 | def fast_backward_ops(within_ops, seed_ops, stop_at_ts): 322 | bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts)) 323 | ops = bwd_ops.intersection(within_ops).difference([t.op for t in stop_at_ts]) 324 | return list(ops) 325 | 326 | @contextlib.contextmanager 327 | def capture_ops(): 328 | """Decorator to capture ops created in the block. 329 | with capture_ops() as ops: 330 | # create some ops 331 | print(ops) # => prints ops created. 332 | """ 333 | 334 | micros = int(time.time()*10**6) 335 | scope_name = str(micros) 336 | op_list = [] 337 | with tf.name_scope(scope_name): 338 | yield op_list 339 | 340 | g = tf.get_default_graph() 341 | op_list.extend(ge.select_ops(scope_name+"/.*", graph=g)) 342 | 343 | def _to_op(tensor_or_op): 344 | if hasattr(tensor_or_op, "op"): 345 | return tensor_or_op.op 346 | return tensor_or_op 347 | 348 | def _to_ops(iterable): 349 | if not _is_iterable(iterable): 350 | return iterable 351 | return [_to_op(i) for i in iterable] 352 | 353 | def _is_iterable(o): 354 | try: 355 | _ = iter(o) 356 | except Exception: 357 | return False 358 | return True 359 | 360 | DEBUG_LOGGING=True 361 | def debug_print(s, *args): 362 | """Like logger.log, but also replaces all TensorFlow ops/tensors with their 363 | names. Sensitive to value of DEBUG_LOGGING, see enable_debug/disable_debug 364 | 365 | Usage: 366 | debug_print("see tensors %s for %s", tensorlist, [1,2,3]) 367 | """ 368 | 369 | if DEBUG_LOGGING: 370 | formatted_args = [format_ops(arg) for arg in args] 371 | print("DEBUG "+s % tuple(formatted_args)) 372 | 373 | def format_ops(ops, sort_outputs=True): 374 | """Helper method for printing ops. Converts Tensor/Operation op to op.name, 375 | rest to str(op).""" 376 | 377 | if hasattr(ops, '__iter__') and not isinstance(ops, str): 378 | l = [(op.name if hasattr(op, "name") else str(op)) for op in ops] 379 | if sort_outputs: 380 | return sorted(l) 381 | return l 382 | else: 383 | return ops.name if hasattr(ops, "name") else str(ops) 384 | 385 | def my_add_control_inputs(wait_to_do_ops, inputs_to_do_before): 386 | for op in wait_to_do_ops: 387 | ci = [i for i in inputs_to_do_before if op.control_inputs is None or i not in op.control_inputs] 388 | ge.add_control_inputs(op, ci) 389 | -------------------------------------------------------------------------------- /metric_fns.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def perplexity_metric(loss): 4 | loss = tf.reduce_mean(loss) 5 | perplexity = tf.exp(loss) 6 | return {"perplexity": tf.metrics.mean(perplexity)} 7 | 8 | -------------------------------------------------------------------------------- /mnist_classifier.py: -------------------------------------------------------------------------------- 1 | """MNIST classifier. See MNISTClassifier class for usage example.""" 2 | import tftorch as nn 3 | 4 | class MNISTClassifier(nn.Sequential): 5 | """MNIST classifier. 6 | 7 | Example usage: 8 | 9 | >>> import mnist_classifier 10 | >>> mnist = mnist_classifier.MNISTClassifier() 11 | >>> mnist 12 | MNISTClassifier( 13 | (0): Conv2d(1, 10, kernel_size=(1, 1), stride=(1, 1), padding=VALID) 14 | (1): ResidualBlock( 15 | (0): ReLU() 16 | (1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=SAME) 17 | (2): ReLU() 18 | (3): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=SAME) 19 | ) 20 | (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 21 | (3): ResidualBlock( 22 | (0): ReLU() 23 | (1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=SAME) 24 | (2): ReLU() 25 | (3): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=SAME) 26 | ) 27 | (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 28 | (5): Flatten(start_dim=1, end_dim=-1) 29 | (6): Linear(in_features=490, out_features=10, bias=True) 30 | (7): LogSoftmax(dim=-1) 31 | ) 32 | >>> pp(list(mnist.parameters())) 33 | [, 34 | , 35 | , 36 | , 37 | , 38 | , 39 | , 40 | , 41 | , 42 | , 43 | , 44 | ] 45 | >>> batch_size = 16 46 | >>> input = tf.placeholder(tf.float32, [batch_size, 28, 28, 1], name="mnist_in") 47 | >>> output = mnist(input) 48 | >>> output 49 | 50 | >>> mnist 51 | MNISTClassifier( 52 | IN: f32[16,28,28,1, name='mnist_in_12:0'], 53 | OUT: f32[16,10, name='mnist_23/log_softmax/LogSoftmax:0'] 54 | (0): Conv2d( 55 | 1, 10, kernel_size=(1, 1), stride=(1, 1), padding=VALID 56 | IN: f32[16,28,28,1, name='mnist_in_12:0'], 57 | OUT: f32[16,28,28,10, name='mnist_23/pre_conv/BiasAdd:0'] 58 | ) 59 | (1): ResidualBlock( 60 | IN: f32[16,28,28,10, name='mnist_23/pre_conv/BiasAdd:0'], 61 | OUT: f32[16,28,28,10, name='mnist_23/residual/add:0'] 62 | (0): ReLU( 63 | IN: f32[16,28,28,10, name='mnist_23/residual/mnist_23/pre_conv/BiasAdd_clone:0'], 64 | OUT: f32[16,28,28,10, name='mnist_23/residual/ReLU/Relu:0'] 65 | ) 66 | (1): Conv2d( 67 | 10, 10, kernel_size=(3, 3), stride=(1, 1), padding=SAME 68 | IN: f32[16,28,28,10, name='mnist_23/residual/ReLU/Relu:0'], 69 | OUT: f32[16,28,28,10, name='mnist_23/residual/conv_2d/BiasAdd:0'] 70 | ) 71 | (2): ReLU( 72 | IN: f32[16,28,28,10, name='mnist_23/residual/conv_2d/BiasAdd:0'], 73 | OUT: f32[16,28,28,10, name='mnist_23/residual/ReLU_1/Relu:0'] 74 | ) 75 | (3): Conv2d( 76 | 10, 10, kernel_size=(3, 3), stride=(1, 1), padding=SAME 77 | IN: f32[16,28,28,10, name='mnist_23/residual/ReLU_1/Relu:0'], 78 | OUT: f32[16,28,28,10, name='mnist_23/residual/conv_2d_1/BiasAdd:0'] 79 | ) 80 | ) 81 | (2): MaxPool2d( 82 | kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False 83 | IN: f32[16,28,28,10, name='mnist_23/residual/add:0'], 84 | OUT: f32[16,14,14,10, name='mnist_23/max_pool2d/MaxPool2d:0'] 85 | ) 86 | (3): ResidualBlock( 87 | IN: f32[16,14,14,10, name='mnist_23/max_pool2d/MaxPool2d:0'], 88 | OUT: f32[16,14,14,10, name='mnist_23/residual_1/add:0'] 89 | (0): ReLU( 90 | IN: f32[16,14,14,10, name='mnist_23/residual_1/mnist_23/max_pool2d/MaxPool2d_clone:0'], 91 | OUT: f32[16,14,14,10, name='mnist_23/residual_1/ReLU/Relu:0'] 92 | ) 93 | (1): Conv2d( 94 | 10, 10, kernel_size=(3, 3), stride=(1, 1), padding=SAME 95 | IN: f32[16,14,14,10, name='mnist_23/residual_1/ReLU/Relu:0'], 96 | OUT: f32[16,14,14,10, name='mnist_23/residual_1/conv_2d/BiasAdd:0'] 97 | ) 98 | (2): ReLU( 99 | IN: f32[16,14,14,10, name='mnist_23/residual_1/conv_2d/BiasAdd:0'], 100 | OUT: f32[16,14,14,10, name='mnist_23/residual_1/ReLU_1/Relu:0'] 101 | ) 102 | (3): Conv2d( 103 | 10, 10, kernel_size=(3, 3), stride=(1, 1), padding=SAME 104 | IN: f32[16,14,14,10, name='mnist_23/residual_1/ReLU_1/Relu:0'], 105 | OUT: f32[16,14,14,10, name='mnist_23/residual_1/conv_2d_1/BiasAdd:0'] 106 | ) 107 | ) 108 | (4): MaxPool2d( 109 | kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False 110 | IN: f32[16,14,14,10, name='mnist_23/residual_1/add:0'], 111 | OUT: f32[16,7,7,10, name='mnist_23/max_pool2d_1/MaxPool2d:0'] 112 | ) 113 | (5): Flatten( 114 | start_dim=1, end_dim=-1 115 | IN: f32[16,7,7,10, name='mnist_23/max_pool2d_1/MaxPool2d:0'], 116 | OUT: f32[16,490, name='mnist_23/flatten/Reshape:0'] 117 | ) 118 | (6): Linear( 119 | in_features=490, out_features=10, bias=True 120 | IN: f32[16,490, name='mnist_23/flatten/Reshape:0'], 121 | OUT: f32[16,10, name='mnist_23/linear/BiasAdd:0'] 122 | ) 123 | (7): LogSoftmax( 124 | dim=-1 125 | IN: f32[16,10, name='mnist_23/linear/BiasAdd:0'], 126 | OUT: f32[16,10, name='mnist_23/log_softmax/LogSoftmax:0'] 127 | ) 128 | ) 129 | >>> 130 | """ 131 | def __init__(self, scope='mnist', **kwargs): 132 | super().__init__(scope=scope, **kwargs, body=lambda: [ 133 | nn.Conv2d(1, 10, 1, scope='pre_conv'), 134 | nn.ResidualBlock(body=lambda: [ 135 | nn.ReLU(), 136 | nn.Conv2d(10, 10, 3, padding=1), 137 | nn.ReLU(), 138 | nn.Conv2d(10, 10, 3, padding=1, index=1), 139 | ]), 140 | nn.MaxPool2d(2), 141 | nn.ResidualBlock(index=1, body=lambda: [ 142 | nn.ReLU(), 143 | nn.Conv2d(10, 10, 3, padding=1), 144 | nn.ReLU(), 145 | nn.Conv2d(10, 10, 3, padding=1, index=1), 146 | ]), 147 | nn.MaxPool2d(2), 148 | nn.Flatten(), 149 | nn.Linear(7*7*10, 10), 150 | nn.LogSoftmax(dim=-1), 151 | ]) 152 | 153 | -------------------------------------------------------------------------------- /model_fns.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from models.gpt2 import gpt2, gpt2_rev 3 | from optimizers import create_train_op 4 | from metric_fns import perplexity_metric 5 | from models.gpt2 import sample 6 | 7 | 8 | def gpt2_rev_model(features, labels, mode, params): 9 | tf.logging.info('model_fns.py: gpt2_rev_model(features=%s, labels=%s, mode=%s, params=%s)', features, labels, mode, params) 10 | 11 | if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL: 12 | if params["precision"] == 'bfloat16': 13 | with tf.contrib.tpu.bfloat16_scope(): 14 | output = gpt2_rev.model_grad(X=features, 15 | params=params, 16 | labels=labels, 17 | past=None, reuse=tf.AUTO_REUSE, 18 | train=mode==tf.estimator.ModeKeys.TRAIN) 19 | 20 | output["logits"] = tf.cast(output["logits"], tf.float32) 21 | 22 | else: 23 | output = gpt2_rev.model_grad(X=features, params=params, 24 | labels=labels, 25 | past=None, reuse=tf.AUTO_REUSE, 26 | train=mode==tf.estimator.ModeKeys.TRAIN) 27 | if mode == tf.estimator.ModeKeys.TRAIN: 28 | #from optimizers import create_train_op 29 | grads = output["grads_and_vars"] 30 | train_op = create_train_op(params, grads=grads) 31 | loss = output["loss"] 32 | 33 | if params["use_tpu"]: 34 | return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op) 35 | else: 36 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 37 | else: 38 | raise NotImplementedError() 39 | 40 | 41 | def gpt2_model(features, labels, mode, params): 42 | tf.logging.info('model_fns.py: gpt2_model(features=%s, labels=%s, mode=%s, params=%s)', features, labels, mode, params) 43 | 44 | if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL: 45 | if params["precision"] == 'bfloat16': 46 | with tf.contrib.tpu.bfloat16_scope(): 47 | output = gpt2.model(X=features, params=params, 48 | labels=labels, 49 | past=None, reuse=tf.AUTO_REUSE, 50 | train=mode==tf.estimator.ModeKeys.TRAIN) 51 | 52 | output["logits"] = tf.cast(output["logits"], tf.float32) 53 | 54 | else: 55 | output = gpt2.model(X=features, params=params, 56 | labels=labels, 57 | past=None, reuse=tf.AUTO_REUSE, 58 | train=mode==tf.estimator.ModeKeys.TRAIN) 59 | 60 | loss_batch = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output["logits"], labels=labels) 61 | loss = tf.reduce_mean(loss_batch) 62 | 63 | if mode == tf.estimator.ModeKeys.TRAIN: 64 | #from optimizers import create_train_op 65 | train_op = create_train_op(params, loss=loss) 66 | 67 | if params["use_tpu"]: 68 | return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op) 69 | else: 70 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 71 | 72 | 73 | if mode == tf.estimator.ModeKeys.EVAL: 74 | #from metric_fns import perplexity_metric 75 | 76 | if params["use_tpu"]: 77 | # Metric inputs are transferred to CPU and must preserve batch dimension 78 | return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, 79 | loss=loss, eval_metrics=(perplexity_metric, {"loss": loss_batch})) 80 | else: 81 | return tf.estimator.EstimatorSpec(mode=mode, 82 | loss=loss, eval_metric_ops=perplexity_metric(loss_batch)) 83 | 84 | 85 | if mode == tf.estimator.ModeKeys.PREDICT: 86 | #from models.gpt2 import sample 87 | 88 | if not "top_k" in params.keys(): 89 | params["top_k"] = 0 90 | 91 | output = sample.sample_sequence( 92 | params=params, length=min(params['length'] - params['text_len'], params["n_ctx"]), 93 | context=features, 94 | batch_size=params["batch_size"], 95 | temperature=1.0, top_k=params["top_k"] 96 | ) 97 | 98 | predictions = { 99 | "tokens": output 100 | } 101 | 102 | if params["use_tpu"]: 103 | return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions) 104 | else: 105 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 106 | 107 | -------------------------------------------------------------------------------- /models/gpt2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawwn/ml-notes/dd6dab42c498c5a9fa3df15422d83a989b4b282b/models/gpt2/__init__.py -------------------------------------------------------------------------------- /models/gpt2/encoder.py: -------------------------------------------------------------------------------- 1 | """Byte pair encoding utilities""" 2 | 3 | import os 4 | import json 5 | import regex as re 6 | from functools import lru_cache 7 | 8 | @lru_cache() 9 | def bytes_to_unicode(): 10 | """ 11 | Returns list of utf-8 byte and a corresponding list of unicode strings. 12 | The reversible bpe codes work on unicode strings. 13 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 14 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 15 | This is a signficant percentage of your normal, say, 32K bpe vocab. 16 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 17 | And avoids mapping to whitespace/control characters the bpe code barfs on. 18 | """ 19 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 20 | cs = bs[:] 21 | n = 0 22 | for b in range(2**8): 23 | if b not in bs: 24 | bs.append(b) 25 | cs.append(2**8+n) 26 | n += 1 27 | cs = [chr(n) for n in cs] 28 | return dict(zip(bs, cs)) 29 | 30 | def get_pairs(word): 31 | """Return set of symbol pairs in a word. 32 | 33 | Word is represented as tuple of symbols (symbols being variable-length strings). 34 | """ 35 | pairs = set() 36 | prev_char = word[0] 37 | for char in word[1:]: 38 | pairs.add((prev_char, char)) 39 | prev_char = char 40 | return pairs 41 | 42 | class Encoder: 43 | def __init__(self, encoder, bpe_merges, errors='replace'): 44 | self.encoder = encoder 45 | self.decoder = {v:k for k,v in self.encoder.items()} 46 | self.errors = errors # how to handle errors in decoding 47 | self.byte_encoder = bytes_to_unicode() 48 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 49 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 50 | self.cache = {} 51 | 52 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 53 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 54 | 55 | def bpe(self, token): 56 | if token in self.cache: 57 | return self.cache[token] 58 | word = tuple(token) 59 | pairs = get_pairs(word) 60 | 61 | if not pairs: 62 | return token 63 | 64 | while True: 65 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 66 | if bigram not in self.bpe_ranks: 67 | break 68 | first, second = bigram 69 | new_word = [] 70 | i = 0 71 | while i < len(word): 72 | try: 73 | j = word.index(first, i) 74 | new_word.extend(word[i:j]) 75 | i = j 76 | except: 77 | new_word.extend(word[i:]) 78 | break 79 | 80 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 81 | new_word.append(first+second) 82 | i += 2 83 | else: 84 | new_word.append(word[i]) 85 | i += 1 86 | new_word = tuple(new_word) 87 | word = new_word 88 | if len(word) == 1: 89 | break 90 | else: 91 | pairs = get_pairs(word) 92 | word = ' '.join(word) 93 | self.cache[token] = word 94 | while len(self.cache) > 1000: 95 | self.cache.popitem() 96 | return word 97 | 98 | def encode(self, text): 99 | bpe_tokens = [] 100 | for token in re.findall(self.pat, text): 101 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 102 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 103 | return bpe_tokens 104 | 105 | def decode(self, tokens): 106 | text = ''.join([self.decoder[token] for token in tokens]) 107 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 108 | return text 109 | 110 | try: 111 | from tokenizers import Tokenizer, models, pre_tokenizers, decoders 112 | use_high_speed_tokenizer = True 113 | print('Using high-speed tokenizer') 114 | except: 115 | use_high_speed_tokenizer = False 116 | 117 | class HighSpeedTokenizer(object): 118 | def __init__(self, vocab_path, bpe_merges_path): 119 | tokenizer = Tokenizer(models.BPE.from_files(vocab_path, bpe_merges_path)) 120 | # Use the byte level 121 | add_prefix_spaces = False # Whether to automatically prefix the sequences with a space if none found 122 | tokenizer.with_pre_tokenizer(pre_tokenizers.ByteLevel.new(add_prefix_spaces)) 123 | tokenizer.with_decoder(decoders.ByteLevel.new()) 124 | # Setup truncation if needed 125 | truncate = False 126 | max_length = 1024 127 | if truncate: 128 | stride = 0 129 | strategy = 'longest_first' # Can also be `only_first` or `only_second` 130 | tokenizer.with_truncation(max_length, stride, strategy) 131 | # Setup padding if needed 132 | padding = False 133 | # Whether to always pad to max_length. If this is false, we will pad to the 134 | # longest sequence in the batch. 135 | pad_to_max_length = False 136 | padding_side = "right" # Can also be "left" 137 | pad_token_id = 0 138 | pad_token_type_id = 0 139 | pad_token = "[PAD]" 140 | if padding: 141 | tokenizer.with_padding( 142 | max_length if pad_to_max_length else None, 143 | padding_side, 144 | pad_token_id, 145 | pad_token_type_id, 146 | pad_token 147 | ) 148 | self.tokenizer = tokenizer 149 | 150 | def encode(self, text): 151 | tokens = [] 152 | lines = text.splitlines() 153 | c = '\n' 154 | n = len(lines) - 1 155 | for i, line in enumerate(lines): 156 | if i >= n: 157 | c = '' 158 | encoding = self.tokenizer.encode(line + c) 159 | tokens.extend(encoding.ids) 160 | if text.endswith('\n'): 161 | tokens.extend(self.tokenizer.encode('\n').ids) 162 | return tokens 163 | 164 | def decode(self, tokens): 165 | text = self.tokenizer.decode(tokens, False) 166 | return text 167 | 168 | def read_bucket(path, mode='rb'): 169 | if os.path.isfile(path): 170 | with open(path, mode) as f: 171 | return f.read() 172 | else: 173 | import tensorflow as tf 174 | with tf.io.gfile.GFile(path, mode=mode) as f: 175 | return f.read() 176 | 177 | import tempfile 178 | from contextlib import contextmanager 179 | 180 | @contextmanager 181 | def bucket_file(path): 182 | if os.path.isfile(path): 183 | with open(path, "rb") as f: 184 | data = f.read() 185 | yield path, data 186 | else: 187 | data = read_bucket(path) 188 | with tempfile.NamedTemporaryFile() as tmp: 189 | tmp.write(data) 190 | tmp.seek(0) 191 | yield tmp.name, data 192 | 193 | def bucket_path(path, *parts): 194 | if len(parts) <= 0: 195 | return path 196 | if path.startswith('gs://'): 197 | sep = '/' 198 | else: 199 | sep = os.sep 200 | if not path.endswith(sep): 201 | path = path + sep 202 | path = path + parts[0] 203 | return bucket_path(path, *parts[1:]) 204 | 205 | def get_encoder(model_path=None): 206 | if model_path is None: 207 | #model_path = 'gs://gpt-2/models/117M/' 208 | #model_path = os.path.dirname(__file__) 209 | from transformers import GPT2TokenizerFast 210 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 211 | return tokenizer 212 | with bucket_file(bucket_path(model_path, 'encoder.json')) as (vocab_path, vocab_data): 213 | with bucket_file(bucket_path(model_path, 'vocab.bpe')) as (bpe_merges_path, bpe_data): 214 | encoder = json.loads(vocab_data.decode('utf8')) 215 | if use_high_speed_tokenizer: 216 | tokenizer = HighSpeedTokenizer(vocab_path=vocab_path, bpe_merges_path=bpe_merges_path) 217 | tokenizer.encoder = encoder 218 | return tokenizer 219 | bpe_data = bpe_data.decode('utf8') 220 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 221 | return Encoder( 222 | encoder=encoder, 223 | bpe_merges=bpe_merges, 224 | ) 225 | 226 | -------------------------------------------------------------------------------- /models/gpt2/gpt2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from pprint import pformat as pps 7 | 8 | tf1 = tf.compat.v1 9 | 10 | # from tensorflow.python.training import HParams 11 | 12 | def default_hparams(trainable=True, dtype=tf.float32, scope='model'): 13 | return { 14 | 'n_vocab': 50257, 15 | 'n_ctx': 1024, 16 | 'n_embd': 768, 17 | 'n_head': 12, 18 | 'n_layer': 12, 19 | 'res_dropout': 0.0, 20 | 'attn_dropout': 0.0, 21 | 'embed_dropout': 0.0, 22 | 'dtype': dtype, 23 | 'trainable': trainable, 24 | 'scope': scope, 25 | 'precision': 'bfloat16' if dtype == tf.bfloat16 else 'float32', 26 | 'scale_by_depth': False, 27 | 'scale_by_in': False, 28 | } 29 | 30 | def shape_list(x): 31 | """Deal with dynamic shape in tensorflow cleanly.""" 32 | static = x.shape.as_list() 33 | dynamic = tf.shape(x) 34 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 35 | 36 | def softmax(x, axis=-1): 37 | x = x - tf.reduce_max(x, axis=axis, keepdims=True) 38 | ex = tf.exp(x) 39 | return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) 40 | 41 | def gelu(x): 42 | return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) 43 | 44 | def norm(x, scope, *, axis=-1, epsilon=1e-5, params=None): 45 | """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" 46 | with tf1.variable_scope(scope): 47 | n_state = shape_list(x)[-1] 48 | if params["precision"] == "bfloat16": 49 | g = tf1.get_variable('g', [n_state], initializer=tf.constant_initializer(1, dtype=tf.bfloat16), dtype=tf.bfloat16) 50 | b = tf1.get_variable('b', [n_state], initializer=tf.constant_initializer(0, dtype=tf.bfloat16), dtype=tf.bfloat16) 51 | else: 52 | g = tf1.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) 53 | b = tf1.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) 54 | u = tf.reduce_mean(x, axis=axis, keepdims=True) 55 | s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) 56 | x = (x - u) * tf.math.rsqrt(s + epsilon) 57 | x = x*g + b 58 | return x 59 | 60 | def split_states(x, n): 61 | """Reshape the last dimension of x into [n, x.shape[-1]/n].""" 62 | *start, u, v = shape_list(x) 63 | m = u * v 64 | return tf.reshape(x, start + [n, m//n]) 65 | 66 | def merge_states(x): 67 | """Smash the last two dimensions of x into a single dimension.""" 68 | *start, a, b = shape_list(x) 69 | return tf.reshape(x, start + [a*b]) 70 | 71 | def conv1d(x, scope, nf, *, w_init_stdev=0.02, params=None, scale=False): 72 | if params["scale_by_depth"] and scale: # Scale by sqrt(num_layers), only happens at the final projection before a res block output 73 | w_init_stdev = w_init_stdev * (1. / math.sqrt(params["n_layer"])) 74 | if params["scale_by_in"]: # Scale by sqrt(num_input_features) 75 | w_init_stdev = w_init_stdev * (1. / math.sqrt(shape_list(x)[-1])) 76 | 77 | with tf1.variable_scope(scope): 78 | *start, nx = shape_list(x) 79 | if params["precision"] == "bfloat16": 80 | w = tf1.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev, dtype=tf.bfloat16), dtype=tf.bfloat16) 81 | b = tf1.get_variable('b', [nf], initializer=tf.constant_initializer(0, dtype=tf.bfloat16), dtype=tf.bfloat16) 82 | else: 83 | w = tf1.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) 84 | b = tf1.get_variable('b', [nf], initializer=tf.constant_initializer(0)) 85 | c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf]) 86 | return c 87 | 88 | def attention_mask(nd, ns, *, dtype): 89 | """1's in the lower triangle, counting from the lower right corner. 90 | 91 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. 92 | """ 93 | i = tf.range(nd)[:,None] 94 | j = tf.range(ns) 95 | m = i >= j - ns + nd 96 | return tf.cast(m, dtype) 97 | 98 | 99 | def attn(x, scope, n_state, *, past, params, batch_size, seq_length, train=False): 100 | assert x.shape.ndims == 2 # Should be [batch*sequence, features] 101 | assert n_state % params["n_head"] == 0 102 | *start, hidden_size = shape_list(x) 103 | num_attention_heads = params["n_head"] 104 | assert(hidden_size % num_attention_heads == 0) 105 | size_per_head = hidden_size // num_attention_heads 106 | 107 | if past is not None: 108 | assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] 109 | 110 | def split_heads(x): 111 | # From [batch, sequence, features] to [batch, heads, sequence, features] 112 | x = tf.reshape(x, [batch_size, seq_length, num_attention_heads, size_per_head]) 113 | x = split_states(x, params["n_head"]) 114 | return tf.transpose(x, [0, 2, 1, 3]) 115 | 116 | def merge_heads(x): 117 | # Reverse of split_heads 118 | x = tf.transpose(x, [0, 2, 1, 3]) 119 | x = merge_states(x) 120 | x = tf.reshape(x, [batch_size * seq_length, num_attention_heads * size_per_head]) 121 | return x 122 | 123 | def mask_attn_weights(w): 124 | # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. 125 | _, _, nd, ns = shape_list(w) 126 | b = attention_mask(nd, ns, dtype=w.dtype) 127 | b = tf.reshape(b, [1, 1, nd, ns]) 128 | w = w*b - tf.cast(1e10, w.dtype)*(1-b) 129 | return w 130 | 131 | def multihead_attn(q, k, v): 132 | # q, k, v have shape [batch, heads, sequence, features] 133 | w = tf.matmul(q, k, transpose_b=True) 134 | w = w * tf.math.rsqrt(tf.cast(shape_list(v)[-1], w.dtype)) 135 | 136 | w = mask_attn_weights(w) 137 | w = softmax(w) 138 | 139 | w = dropout(w, params["attn_dropout"], train) 140 | 141 | a = tf.matmul(w, v) 142 | return a 143 | 144 | with tf1.variable_scope(scope): 145 | c = conv1d(x, 'c_attn', n_state*3, params=params) 146 | q, k, v = map(split_heads, tf.split(c, 3, axis=-1)) 147 | present = tf.stack([k, v], axis=1) 148 | if past is not None: 149 | pk, pv = tf.unstack(past, axis=1) 150 | k = tf.concat([pk, k], axis=-2) 151 | v = tf.concat([pv, v], axis=-2) 152 | a = multihead_attn(q, k, v) 153 | a = merge_heads(a) 154 | a = conv1d(a, 'c_proj', n_state, params=params) 155 | a = dropout(a, params["res_dropout"], train) 156 | return a, present 157 | 158 | 159 | def mlp(x, scope, n_state, *, params, train=False): 160 | with tf1.variable_scope(scope): 161 | nx = shape_list(x)[-1] 162 | h = gelu(conv1d(x, 'c_fc', n_state, params=params)) 163 | h2 = conv1d(h, 'c_proj', nx, params=params, scale=True) 164 | h2 = dropout(h2, params["res_dropout"], train) 165 | return h2 166 | 167 | 168 | def block(x, scope, *, past, params, attn, train=False, **attn_kws): 169 | with tf1.variable_scope(scope): 170 | nx = shape_list(x)[-1] 171 | ln_1 = norm(x, 'ln_1', params=params) 172 | a, present = attn(ln_1, 'attn', nx, past=past, params=params, train=train, **attn_kws) 173 | x = x + a 174 | ln_2 = norm(x, 'ln_2', params=params) 175 | m = mlp(ln_2, 'mlp', nx*4, params=params, train=train) 176 | x = x + m 177 | return x, present 178 | 179 | def past_shape(*, params, batch_size=None, sequence=None): 180 | return [batch_size, params["n_layer"], 2, params["n_head"], sequence, params["n_embd"] // params["n_head"]] 181 | 182 | def expand_tile(value, size): 183 | """Add a new axis of given size.""" 184 | value = tf.convert_to_tensor(value, name='value') 185 | ndims = value.shape.ndims 186 | return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims) 187 | 188 | def positions_for(tokens, past_length): 189 | batch_size = tf.shape(tokens)[0] 190 | nsteps = tf.shape(tokens)[1] 191 | return expand_tile(past_length + tf.range(nsteps), batch_size) 192 | 193 | def dropout(x, pdrop, train): 194 | if train and pdrop > 0: 195 | x = tf.nn.dropout(x, rate=pdrop) 196 | return x 197 | 198 | def _assert_float_dtype(dtype): 199 | if not dtype.is_floating: 200 | raise ValueError("Expected floating point type, got %s." % dtype) 201 | return dtype 202 | 203 | 204 | def model(X, params, labels=None, past=None, scope='model', reuse=False, train=False): 205 | with tf1.variable_scope(scope, reuse=reuse): 206 | results = {} 207 | batch, sequence = shape_list(X) 208 | 209 | if params["precision"] == "bfloat16": 210 | wpe = tf1.get_variable('wpe', [params["n_ctx"], params["n_embd"]], # Position encoding 211 | initializer=tf.random_normal_initializer(stddev=0.01, dtype=tf.bfloat16), dtype=tf.bfloat16) 212 | wte = tf1.get_variable('wte', [params["n_vocab"], params["n_embd"]], # Text encoding 213 | initializer=tf.random_normal_initializer(stddev=0.02, dtype=tf.bfloat16), dtype=tf.bfloat16) 214 | 215 | else: 216 | wpe = tf1.get_variable('wpe', [params["n_ctx"], params["n_embd"]], # Position encoding 217 | initializer=tf.random_normal_initializer(stddev=0.01)) 218 | wte = tf1.get_variable('wte', [params["n_vocab"], params["n_embd"]], # Text encoding 219 | initializer=tf.random_normal_initializer(stddev=0.02)) 220 | past_length = 0 if past is None else tf.shape(past)[-2] 221 | 222 | wpe = dropout(wpe, params["embed_dropout"], train) 223 | wte = dropout(wte, params["embed_dropout"], train) 224 | 225 | h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) 226 | 227 | ## We keep the representation as a 2D tensor to avoid re-shaping it back and 228 | ## forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 229 | ## the GPU/CPU but may not be free on the TPU, so we want to minimize them to 230 | ## help the optimizer. 231 | batch_size, seq_length, hidden_size = shape_list(h) 232 | h = tf.reshape(h, [batch_size * seq_length, hidden_size]) 233 | 234 | # Transformer 235 | presents = [] 236 | activations = [] 237 | pasts = tf.unstack(past, axis=1) if past is not None else [None] * params["n_layer"] 238 | assert len(pasts) == params["n_layer"] 239 | checkpoint=False if 'memory_saving_gradients' not in params else params['memory_saving_gradients'] 240 | every = 1 if 'memory_saving_checkpoints' not in params else params['memory_saving_checkpoints'] 241 | for layer, past in enumerate(pasts): 242 | def block0(x): 243 | with tf1.variable_scope(scope, reuse=reuse): 244 | x1, present = block(x, 'h%d' % layer, past=past, params=params, attn=attn, train=train, batch_size=batch, seq_length=sequence) 245 | presents.append(present) 246 | return x1 247 | @tf.custom_gradient 248 | def block1(input): 249 | def grad(dy, variables=None): 250 | # dy is d(output)/d(loss). 251 | # variables contains the tensors used to calculate 252 | # d(param)/d(loss). 253 | # first, we use stop_gradient to ensure that the 254 | # forward pass is completely disconnected. 255 | input0 = tf.stop_gradient(input) 256 | # then, we use the disconnected input to recalculate 257 | # the output for this layer. 258 | output0 = block0(input0) 259 | # now that we have the output, we need to calculate 260 | # d(input)/d(output) * d(output)/d(loss), i.e. chain rule: 261 | result = tf.gradients(output0, input0, dy) 262 | if variables != None: 263 | paramcount = lambda vs: sum([np.prod(v.shape.as_list()) for v in vs]) 264 | def logvars(variables, label, print_variables=True): 265 | if print_variables: 266 | tf.logging.info("%s (%s parameters): %s", label, paramcount(variables), pps(variables)) 267 | else: 268 | tf.logging.info("%s (%s parameters)", label, paramcount(variables)) 269 | return variables 270 | tf.logging.info("---------") 271 | logvars(variables, "block1_grad variables for layer %s" % layer) 272 | # ditto for d(param)/d(output) * d(output)/d(loss) 273 | return result, tf.gradients(output0, variables, dy) 274 | return result 275 | output = block0(input) 276 | return output, grad 277 | if checkpoint and (isinstance(every, int) and layer % every == 0 or layer in every): 278 | tf.logging.info('checkpointing layer %d', layer) 279 | tf.add_to_collection('checkpoints', h) 280 | if bool(int(os.environ.get('GRADIENT_CHECKPOINTING', '0'))): 281 | h = block1(h) 282 | else: 283 | h = block0(h) 284 | activations.append(h) 285 | results['present'] = tf.stack(presents, axis=1) 286 | results['activations'] = activations 287 | h = norm(h, 'ln_f', params=params) 288 | 289 | h_flat = tf.reshape(h, [batch*sequence, params["n_embd"]]) 290 | logits = tf.matmul(h_flat, wte, transpose_b=True) 291 | logits = tf.reshape(logits, [batch, sequence, params["n_vocab"]]) 292 | results['logits'] = logits 293 | return results 294 | 295 | -------------------------------------------------------------------------------- /models/gpt2/sample.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from . import gpt2 4 | 5 | def top_k_logits(logits, k): 6 | if k == 0: 7 | # no truncation 8 | return logits 9 | 10 | def _top_k(): 11 | values, _ = tf.nn.top_k(logits, k=k) 12 | min_values = values[:, -1, tf.newaxis] 13 | return tf.where( 14 | logits < min_values, 15 | tf.ones_like(logits, dtype=logits.dtype) * -1e10, 16 | logits, 17 | ) 18 | return tf.cond( 19 | tf.equal(k, 0), 20 | lambda: logits, 21 | lambda: _top_k(), 22 | ) 23 | 24 | 25 | def sample_sequence(*, params, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0): 26 | if start_token is None: 27 | assert context is not None, 'Specify exactly one of start_token and context!' 28 | else: 29 | assert context is None, 'Specify exactly one of start_token and context!' 30 | context = tf.fill([batch_size, 1], start_token) 31 | 32 | length = length - params["text_len"] 33 | 34 | def step(params, tokens, past=None): 35 | if params["precision"] == 'bfloat16': 36 | with tf.contrib.tpu.bfloat16_scope(): 37 | lm_output = gpt2.model(params=params, X=tokens, past=past, reuse=tf.AUTO_REUSE) 38 | 39 | lm_output["logits"] = tf.cast(lm_output["logits"], tf.float32) 40 | 41 | else: 42 | lm_output = lm_output = gpt2.model(params=params, X=tokens, past=past, reuse=tf.AUTO_REUSE) 43 | 44 | 45 | logits = lm_output['logits'][:, :, :params["n_vocab"]] 46 | presents = lm_output['present'] 47 | presents.set_shape(gpt2.past_shape(params=params, batch_size=batch_size)) 48 | return { 49 | 'logits': logits, 50 | 'presents': presents, 51 | } 52 | 53 | with tf.name_scope('sample_sequence'): 54 | 55 | context_output = step(params, context[:, :-1]) 56 | 57 | def body(past, prev, output): 58 | next_outputs = step(params, prev[:, tf.newaxis], past=past) 59 | logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) 60 | logits = top_k_logits(logits, k=top_k) 61 | samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) 62 | return [ 63 | tf.concat([past, next_outputs['presents']], axis=-2), 64 | tf.squeeze(samples, axis=[1]), 65 | tf.concat([output, samples], axis=1), 66 | ] 67 | 68 | def cond(*args): 69 | return True 70 | 71 | _, _, tokens = tf.while_loop( 72 | cond=cond, body=body, 73 | maximum_iterations=length, 74 | loop_vars=[ 75 | context_output['presents'], 76 | context[:, -1], 77 | context, 78 | ], 79 | shape_invariants=[ 80 | tf.TensorShape(gpt2.past_shape(params=params, batch_size=batch_size)), 81 | tf.TensorShape([None]), 82 | tf.TensorShape([None, None]), 83 | ], 84 | back_prop=False, 85 | ) 86 | 87 | return tokens 88 | 89 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | source "${HOME}/bin/activate-tf1" 3 | set -x 4 | if [ -z $TPU_HOST ] 5 | then 6 | 1>&2 echo "Set \$TPU_HOST" 7 | exit 1 8 | fi 9 | 10 | #export TPU_CORES=8 11 | #params=117M.json 12 | #model_dir=gs://danbooru-euw4a/checkpoint/test117m-0 13 | #tpu=tpu-euw4a-69 14 | # 15 | #params=1.5B.json 16 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-0 17 | #tpu=tpu-euw4a-69 18 | # 19 | #params=1558M.json 20 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-1 21 | #tpu=tpu-euw4a-68 22 | # 23 | #params=1558M.json 24 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-2 25 | #tpu=tpu-euw4a-67 26 | # 27 | #params=117M_memory_saving_gradients.json 28 | #model_dir=gs://danbooru-euw4a/checkpoint/test117m-1 29 | #tpu=tpu-euw4a-66 30 | # 31 | #params=117M.json 32 | #model_dir=gs://danbooru-euw4a/checkpoint/test117m-2 33 | #tpu=tpu-euw4a-65 34 | # 35 | #params=1558M.json 36 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-3 37 | #tpu=tpu-euw4a-65 38 | # 39 | #params=1558M.json 40 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-65 41 | #tpu=tpu-euw4a-65 42 | ##export TPU_CORES=2 43 | # 44 | #params=1.5B.json 45 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-69 46 | #tpu=tpu-euw4a-69 47 | # 48 | #params=1.5B_adam.json 49 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-67 50 | #tpu=tpu-euw4a-67 51 | # 52 | #params=1.5B_adam.json 53 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-65 54 | #tpu=tpu-euw4a-65 55 | # 56 | #params=1.5B_adam.json 57 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-70 58 | #tpu=tpu-euw4a-70 59 | # 60 | #params=1.5B_adam.json 61 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-71 62 | #tpu=tpu-euw4a-71 63 | #restore_dir=gs://gpt-2/models/1558M 64 | # 65 | #params=1.5B_adam.json 66 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-78 67 | #tpu=tpu-euw4a-78 68 | #restore_dir=gs://danbooru-euw4a/models/1558M 69 | # 70 | #params=1.5B.json 71 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-76 72 | #tpu=tpu-euw4a-76 73 | ##restore_dir=gs://danbooru-euw4a/models/1558M 74 | #unset restore_dir 75 | # 76 | #params=1.5B.json 77 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-69 78 | #tpu=tpu-euw4a-69 79 | #restore_dir=gs://danbooru-euw4a/models/1558M 80 | ##unset restore_dir 81 | ##gsutil -m rm -rf "${model_dir}" 82 | # 83 | #params=1.5B.json 84 | #model_dir=gs://danbooru-euw4a/checkpoint/test1558m-77 85 | #tpu=tpu-euw4a-77 86 | #restore_dir=gs://danbooru-euw4a/models/1558M 87 | #unset restore_dir 88 | # 89 | #params=117M.json 90 | #model_dir=gs://danbooru-euw4a/checkpoint/test117m-71-2 91 | #tpu=tpu-euw4a-71 92 | #restore_dir=gs://danbooru-euw4a/models/117M 93 | #dataset="--dataset combined-pgpf-ftfy.txt.npz" 94 | ##unset restore_dir 95 | # 96 | #params=117M.json 97 | #model_dir=gs://danbooru-euw4a/checkpoint/test117m-76 98 | #tpu=tpu-euw4a-76 99 | ##restore_dir=gs://danbooru-euw4a/models/117M 100 | ##restore_trainable="--restore_trainable_variables true" 101 | #restore_dir="${model_dir}" 102 | #dataset="--dataset combined-pgpf-ftfy.txt.npz" 103 | ##unset restore_dir 104 | ##gsutil -m rm -rf "${model_dir}" 105 | # 106 | 107 | export TPU_NAME="${TPU_NAME:-tpu-v3-128-euw4a-50}" 108 | export TPU_CORES=128 109 | params=117M.json 110 | model_dir=gs://danbooru-euw4a/runs/gpt-2/run0-117m-tensorflow 111 | restore_dir=gs://danbooru-euw4a/models/gpt-2/117M 112 | restore_trainable="--restore_trainable_variables true" 113 | #dataset="--dataset train.txt.tok16 --export_dataset datasets/train" 114 | #dataset="--dataset train.txt.tok16" 115 | #dataset="--dataset combined-pgpf-ftfy.txt.npz --export_dataset datasets/combined-pgpf-ftfy" 116 | #dataset="--dataset gs://danbooru-euw4a/datasets/combined-pgpf-ftfy/*.tfrecords" 117 | dataset="--dataset gs://dota-euw4a/data/tensorflow.tok16" 118 | 119 | if [ ! -z "$restore_dir" ] 120 | then 121 | restore_dir="--restore_dir ${restore_dir} ${restore_trainable}" 122 | fi 123 | #exec python3 001_sharing.py --tpu "${tpu}" --model_dir "${model_dir}" --restore_dir "${restore_dir}" --params "${params}" "$@" 124 | exec python3 -m pdb -c continue main_gpt2.py --tpu "${tpu}" --model_dir "${model_dir}" ${restore_dir} --params "${params}" --num_cores "${TPU_CORES}" ${dataset} "$@" 125 | -------------------------------------------------------------------------------- /runs/train-astra.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-/tfk/lib}" 4 | export TPU_HOST=${TPU_HOST:-10.255.128.3} 5 | export TPU_NAME="${TPU_NAME:-tpu-v3-8-euw4a-201}" 6 | 7 | export RUN_ID="${RUN_ID:-a}" 8 | export RUN_NAME="${RUN_NAME:-astra01}" 9 | export RUN_DESC="${RUN_DESC:-1558M astra run}" 10 | tmux-set-title "${RUN_NAME}/${RUN_ID} ${TPU_NAME}" 11 | export MODEL_DIR="${MODEL_DIR:-gs://dota-euw4a/runs/gpt-2/${RUN_NAME}/${RUN_ID}/}" 12 | export MODEL_DIR="$(printf '%s' "${MODEL_DIR}" | sed 's/\/$//')" # normalize model dir; ensure it does *not* end with a slash 13 | export GIN_CONFIG="cfg/${RUN_NAME}.gin" 14 | 15 | 16 | export MODEL_NAME="${MODEL_NAME:-1558M}" 17 | export DATASET="${DATASET:-gs://dota-euw4a/datasets/uberset_v1.tok16}" 18 | export RESTORE_DIR="${RESTORE_DIR:-gs://dota-euw4a/models/gpt-2/${MODEL_NAME}}" 19 | 20 | 21 | 22 | date="$(python3 -c 'import datetime; print(datetime.datetime.now().strftime("%Y-%m-%d-%H"))')" 23 | logfile="logs/${RUN_NAME}-${RUN_ID}-${date}.txt" 24 | cloud_log_file="${MODEL_DIR}/logs-${date}-${RUN_NAME}-${RUN_ID}.txt" 25 | cloud_description_file="${MODEL_DIR}/description.txt" 26 | mkdir -p logs 27 | 28 | export DATASET="--dataset ${DATASET}" 29 | #export RESTORE_DIR="--restore_dir ${RESTORE_DIR} --restore_trainable_variables true" 30 | export RESTORE_DIR="--restore_dir ${MODEL_DIR} --restore_trainable_variables true" 31 | export RUN_DESC=" 32 | name: ${RUN_NAME}/${RUN_ID} 33 | date: ${date} 34 | tpu: ${TPU_NAME} 35 | model_dir: ${MODEL_DIR} 36 | dataset: ${DATASET} 37 | model_name: ${MODEL_NAME} 38 | 39 | ${RUN_DESC}" 40 | 41 | printf "%s" "${RUN_DESC}" 42 | 43 | #pu list -s -t $TPU_NAME | sed 's/\x1b\[[0-9;]*m//g' 44 | 45 | 46 | export TPU_SPLIT_COMPILE_AND_EXECUTE=1 47 | export TF_TPU_WATCHDOG_TIMEOUT=1800 48 | 49 | if [ -z "$TPU_CORES" ] 50 | then 51 | cores="$(echo $TPU_NAME | sed 's/^tpu-v[23][-]\([0-9]*\).*$/\1/g')" 52 | if [ -z "$cores" ] 53 | then 54 | 1>&2 echo "Failed to parse TPU core count from $TPU_NAME" 55 | exit 1 56 | fi 57 | export TPU_CORES=$cores 58 | fi 59 | 60 | 61 | if [ ! -z "${DEV}" ] 62 | then 63 | exec python3 -m pdb -c continue wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 64 | exit -1 65 | fi 66 | 67 | 68 | while true; do 69 | echo "Saving description to ${cloud_description_file} ..." 70 | printf "%s" "${RUN_DESC}" | gsutil cp - "${cloud_description_file}" 71 | 72 | echo "Starting production training run in 10s ..." 73 | sleep 10 74 | 75 | timeout --signal=SIGKILL 4h python3 wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 2>&1 | tee -a "${logfile}" | tee /dev/fd/2 | gsutil cp - "${cloud_log_file}" 76 | if [ ! -z "$TPU_NO_RECREATE" ] 77 | then 78 | echo "Not recreating TPU. Waiting 30s." 79 | sleep 30 80 | else 81 | echo "Recreating TPU in 30." 82 | sleep 30 83 | # sudo pip3 install -U tpudiepie 84 | pu recreate "$TPU_NAME" --yes --retry 300 85 | fi 86 | done 87 | -------------------------------------------------------------------------------- /runs/train-biggan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | #exec bash repl0.sh main_biggan.py --gin_config configs/biggan_run01.gin --params 117M.json 6 | #exec bash repl0.sh train_biggan.py --gin_config configs/biggan_run01.gin 7 | #exec bash replpen.sh train_biggan.py --gin_config configs/biggan_run01.gin 8 | exec bash repl0.sh train_biggan.py --gin_config configs/biggan_run01.gin 9 | 10 | -------------------------------------------------------------------------------- /runs/train-checkpointing-117m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-/tfk/lib}" 4 | export TPU_HOST=${TPU_HOST:-10.254.128.2} 5 | export TPU_NAME="${TPU_NAME:-tpu-v2-8-usc1f-7}" 6 | 7 | export GRADIENT_CHECKPOINTING=1 8 | 9 | export RUN_ID="${RUN_ID:-c_checkpointing_01}" 10 | export RUN_NAME="${RUN_NAME:-vanilla117m01}" 11 | export RUN_DESC="${RUN_DESC:-117M vanilla gradient checkpointing run}" 12 | tmux-set-title "${RUN_NAME}/${RUN_ID} ${TPU_NAME}" 13 | export MODEL_DIR="${MODEL_DIR:-gs://tpu-usc1/runs/gpt-2/${RUN_NAME}/${RUN_ID}/}" 14 | export MODEL_DIR="$(printf '%s' "${MODEL_DIR}" | sed 's/\/$//')" # normalize model dir; ensure it does *not* end with a slash 15 | export GIN_CONFIG="cfg/${RUN_NAME}.gin" 16 | 17 | 18 | export MODEL="${MODEL:-GPT2}" 19 | export MODEL_NAME="${MODEL_NAME:-117M}" 20 | export DATASET="${DATASET:-gs://tpu-usc1/datasets/novels.tok16}" 21 | export RESTORE_DIR="${RESTORE_DIR:-gs://tpu-usc1/models/gpt-2/${MODEL_NAME}}" 22 | 23 | export WRAPPER="${WRAPPER:-wrapper.py}" 24 | 25 | 26 | 27 | date="$(python3 -c 'import datetime; print(datetime.datetime.now().strftime("%Y-%m-%d-%H"))')" 28 | logfile="logs/${RUN_NAME}-${RUN_ID}-${date}.txt" 29 | cloud_log_file="${MODEL_DIR}/logs-${date}-${RUN_NAME}-${RUN_ID}.txt" 30 | cloud_description_file="${MODEL_DIR}/description.txt" 31 | mkdir -p logs 32 | 33 | export DATASET="--dataset ${DATASET}" 34 | #export RESTORE_DIR="--restore_dir ${RESTORE_DIR} --restore_trainable_variables true" 35 | export RESTORE_DIR="--restore_dir ${MODEL_DIR} --restore_trainable_variables true" 36 | export RUN_DESC=" 37 | name: ${RUN_NAME}/${RUN_ID} 38 | date: ${date} 39 | tpu: ${TPU_NAME} 40 | model_dir: ${MODEL_DIR} 41 | dataset: ${DATASET} 42 | model_name: ${MODEL_NAME} 43 | 44 | ${RUN_DESC}" 45 | 46 | printf "%s" "${RUN_DESC}" 47 | 48 | #pu list -s -t $TPU_NAME | sed 's/\x1b\[[0-9;]*m//g' 49 | 50 | 51 | export TPU_SPLIT_COMPILE_AND_EXECUTE=1 52 | export TF_TPU_WATCHDOG_TIMEOUT=1800 53 | 54 | if [ -z "$TPU_CORES" ] 55 | then 56 | cores="$(echo $TPU_NAME | sed 's/^tpu-v[23][-]\([0-9]*\).*$/\1/g')" 57 | if [ -z "$cores" ] 58 | then 59 | 1>&2 echo "Failed to parse TPU core count from $TPU_NAME" 60 | exit 1 61 | fi 62 | export TPU_CORES=$cores 63 | fi 64 | 65 | 66 | if [ ! -z "${DEV}" ] 67 | then 68 | exec python3 -m pdb -c continue $WRAPPER main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 69 | exit -1 70 | fi 71 | 72 | 73 | while true; do 74 | echo "Saving description to ${cloud_description_file} ..." 75 | printf "%s" "${RUN_DESC}" | gsutil cp - "${cloud_description_file}" 76 | 77 | echo "Starting production training run in 10s ..." 78 | sleep 10 79 | 80 | timeout --signal=SIGKILL 4h python3 $WRAPPER main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 2>&1 | tee -a "${logfile}" | tee /dev/fd/2 | gsutil cp - "${cloud_log_file}" 81 | if [ ! -z "$TPU_NO_RECREATE" ] 82 | then 83 | echo "Not recreating TPU. Waiting 30s." 84 | sleep 30 85 | else 86 | echo "Recreating TPU in 30." 87 | sleep 30 88 | # sudo pip3 install -U tpudiepie 89 | pu recreate "$TPU_NAME" --yes --retry 300 90 | fi 91 | done 92 | -------------------------------------------------------------------------------- /runs/train-chess-345m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-/tfk/lib}" 4 | export TPU_HOST=${TPU_HOST:-10.254.128.2} 5 | export TPU_NAME="${TPU_NAME:-tpu-v2-512-usc1a-2}" 6 | 7 | export RUN_ID="${RUN_ID:-b}" 8 | export RUN_NAME="${RUN_NAME:-chessrun03}" 9 | export RUN_DESC="${RUN_DESC:-345M test run}" 10 | tmux-set-title "${RUN_NAME}/${RUN_ID} ${TPU_NAME}" 11 | export MODEL_DIR="${MODEL_DIR:-gs://tpu-usc1/runs/gpt-2/${RUN_NAME}/${RUN_ID}/}" 12 | export MODEL_DIR="$(printf '%s' "${MODEL_DIR}" | sed 's/\/$//')" # normalize model dir; ensure it does *not* end with a slash 13 | export GIN_CONFIG="cfg/${RUN_NAME}.gin" 14 | 15 | 16 | export MODEL_NAME="${MODEL_NAME:-chess345m}" 17 | export DATASET="${DATASET:-gs://tpu-usc1/datasets/chess/kingbase2019gpt-16g.txt.tok16}" 18 | #export RESTORE_DIR="${RESTORE_DIR:-gs://tpu-usc1/models/gpt-2/${MODEL_NAME}}" 19 | 20 | 21 | 22 | date="$(python3 -c 'import datetime; print(datetime.datetime.now().strftime("%Y-%m-%d-%H"))')" 23 | logfile="logs/${RUN_NAME}-${RUN_ID}-${date}.txt" 24 | cloud_log_file="${MODEL_DIR}/logs-${date}-${RUN_NAME}-${RUN_ID}.txt" 25 | cloud_description_file="${MODEL_DIR}/description.txt" 26 | mkdir -p logs 27 | 28 | export DATASET="--dataset ${DATASET}" 29 | #export RESTORE_DIR="--restore_dir ${RESTORE_DIR} --restore_trainable_variables true" 30 | export RESTORE_DIR="--restore_dir ${MODEL_DIR} --restore_trainable_variables true" 31 | export RUN_DESC=" 32 | name: ${RUN_NAME}/${RUN_ID} 33 | date: ${date} 34 | tpu: ${TPU_NAME} 35 | model_dir: ${MODEL_DIR} 36 | dataset: ${DATASET} 37 | model_name: ${MODEL_NAME} 38 | 39 | ${RUN_DESC}" 40 | 41 | printf "%s" "${RUN_DESC}" 42 | 43 | #pu list -s -t $TPU_NAME | sed 's/\x1b\[[0-9;]*m//g' 44 | 45 | 46 | export TPU_SPLIT_COMPILE_AND_EXECUTE=1 47 | export TF_TPU_WATCHDOG_TIMEOUT=1800 48 | 49 | if [ -z "$TPU_CORES" ] 50 | then 51 | cores="$(echo $TPU_NAME | sed 's/^tpu-v[23][-]\([0-9]*\).*$/\1/g')" 52 | if [ -z "$cores" ] 53 | then 54 | 1>&2 echo "Failed to parse TPU core count from $TPU_NAME" 55 | exit 1 56 | fi 57 | export TPU_CORES=$cores 58 | fi 59 | 60 | 61 | if [ ! -z "${DEV}" ] 62 | then 63 | exec python3 -m pdb -c continue wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 64 | exit -1 65 | fi 66 | 67 | 68 | while true; do 69 | echo "Saving description to ${cloud_description_file} ..." 70 | printf "%s" "${RUN_DESC}" | gsutil cp - "${cloud_description_file}" 71 | 72 | echo "Starting production training run in 10s ..." 73 | sleep 10 74 | 75 | timeout --signal=SIGKILL 4h python3 wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 2>&1 | tee -a "${logfile}" | tee /dev/fd/2 | gsutil cp - "${cloud_log_file}" 76 | if [ ! -z "$TPU_NO_RECREATE" ] 77 | then 78 | echo "Not recreating TPU. Waiting 30s." 79 | sleep 30 80 | else 81 | echo "Recreating TPU in 30." 82 | sleep 30 83 | # sudo pip3 install -U tpudiepie 84 | pu recreate "$TPU_NAME" --yes 85 | fi 86 | done 87 | -------------------------------------------------------------------------------- /runs/train-chess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-/tfk/lib}" 4 | export TPU_HOST=${TPU_HOST:-10.254.128.2} 5 | export TPU_NAME="${TPU_NAME:-tpu-v2-256-usc1a-0}" 6 | 7 | export RUN_ID="${RUN_ID:-a}" 8 | export RUN_NAME="${RUN_NAME:-chessrun00}" 9 | export RUN_DESC="${RUN_DESC:-117M test run}" 10 | tmux-set-title "${RUN_NAME}/${RUN_ID} ${TPU_NAME}" 11 | export MODEL_DIR="${MODEL_DIR:-gs://tpu-usc1/runs/gpt-2/${RUN_NAME}/${RUN_ID}/}" 12 | export MODEL_DIR="$(printf '%s' "${MODEL_DIR}" | sed 's/\/$//')" # normalize model dir; ensure it does *not* end with a slash 13 | export GIN_CONFIG="cfg/${RUN_NAME}.gin" 14 | 15 | 16 | export MODEL_NAME="${MODEL_NAME:-chess}" 17 | export DATASET="${DATASET:-gs://tpu-usc1/datasets/chess/kingbase2019gpt-5g.txt.tok16}" 18 | #export RESTORE_DIR="${RESTORE_DIR:-gs://tpu-usc1/models/gpt-2/${MODEL_NAME}}" 19 | 20 | 21 | 22 | date="$(python3 -c 'import datetime; print(datetime.datetime.now().strftime("%Y-%m-%d-%H"))')" 23 | logfile="logs/${RUN_NAME}-${RUN_ID}-${date}.txt" 24 | cloud_log_file="${MODEL_DIR}/logs-${date}-${RUN_NAME}-${RUN_ID}.txt" 25 | cloud_description_file="${MODEL_DIR}/description.txt" 26 | mkdir -p logs 27 | 28 | export DATASET="--dataset ${DATASET}" 29 | #export RESTORE_DIR="--restore_dir ${RESTORE_DIR} --restore_trainable_variables true" 30 | export RESTORE_DIR="--restore_dir ${MODEL_DIR} --restore_trainable_variables true" 31 | export RUN_DESC=" 32 | name: ${RUN_NAME}/${RUN_ID} 33 | date: ${date} 34 | tpu: ${TPU_NAME} 35 | model_dir: ${MODEL_DIR} 36 | dataset: ${DATASET} 37 | model_name: ${MODEL_NAME} 38 | 39 | ${RUN_DESC}" 40 | 41 | printf "%s" "${RUN_DESC}" 42 | 43 | #pu list -s -t $TPU_NAME | sed 's/\x1b\[[0-9;]*m//g' 44 | 45 | 46 | export TPU_SPLIT_COMPILE_AND_EXECUTE=1 47 | export TF_TPU_WATCHDOG_TIMEOUT=1800 48 | 49 | if [ -z "$TPU_CORES" ] 50 | then 51 | cores="$(echo $TPU_NAME | sed 's/^tpu-v[23][-]\([0-9]*\).*$/\1/g')" 52 | if [ -z "$cores" ] 53 | then 54 | 1>&2 echo "Failed to parse TPU core count from $TPU_NAME" 55 | exit 1 56 | fi 57 | export TPU_CORES=$cores 58 | fi 59 | 60 | 61 | if [ ! -z "${DEV}" ] 62 | then 63 | exec python3 -m pdb -c continue wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 64 | exit -1 65 | fi 66 | 67 | 68 | while true; do 69 | echo "Saving description to ${cloud_description_file} ..." 70 | printf "%s" "${RUN_DESC}" | gsutil cp - "${cloud_description_file}" 71 | 72 | echo "Starting production training run in 10s ..." 73 | sleep 10 74 | 75 | timeout --signal=SIGKILL 4h python3 wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 2>&1 | tee -a "${logfile}" | tee /dev/fd/2 | gsutil cp - "${cloud_log_file}" 76 | if [ ! -z "$TPU_NO_RECREATE" ] 77 | then 78 | echo "Not recreating TPU. Waiting 30s." 79 | sleep 30 80 | else 81 | echo "Recreating TPU in 30." 82 | sleep 30 83 | # sudo pip3 install -U tpudiepie 84 | pu recreate "$TPU_NAME" --yes 85 | fi 86 | done 87 | -------------------------------------------------------------------------------- /runs/train-novels-1558m-eu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-/tfk/lib}" 4 | export TPU_HOST=${TPU_HOST:-10.255.128.3} 5 | export TPU_NAME="${TPU_NAME:-tpu-v3-8-euw4a-202}" 6 | 7 | export RUN_ID="${RUN_ID:-a}" 8 | export RUN_NAME="${RUN_NAME:-novels02}" 9 | export RUN_DESC="${RUN_DESC:-1558M novels run}" 10 | tmux-set-title "${RUN_NAME}/${RUN_ID} ${TPU_NAME}" 11 | export MODEL_DIR="${MODEL_DIR:-gs://dota-euw4a/runs/gpt-2/${RUN_NAME}/${RUN_ID}/}" 12 | export MODEL_DIR="$(printf '%s' "${MODEL_DIR}" | sed 's/\/$//')" # normalize model dir; ensure it does *not* end with a slash 13 | export GIN_CONFIG="cfg/${RUN_NAME}.gin" 14 | 15 | 16 | export MODEL_NAME="${MODEL_NAME:-1558M}" 17 | export DATASET="${DATASET:-gs://dota-euw4a/datasets/novels.tok16}" 18 | export RESTORE_DIR="${RESTORE_DIR:-gs://dota-euw4a/models/gpt-2/${MODEL_NAME}}" 19 | 20 | 21 | 22 | date="$(python3 -c 'import datetime; print(datetime.datetime.now().strftime("%Y-%m-%d-%H"))')" 23 | logfile="logs/${RUN_NAME}-${RUN_ID}-${date}.txt" 24 | cloud_log_file="${MODEL_DIR}/logs-${date}-${RUN_NAME}-${RUN_ID}.txt" 25 | cloud_description_file="${MODEL_DIR}/description.txt" 26 | mkdir -p logs 27 | 28 | export DATASET="--dataset ${DATASET}" 29 | #export RESTORE_DIR="--restore_dir ${RESTORE_DIR} --restore_trainable_variables true" 30 | export RESTORE_DIR="--restore_dir ${MODEL_DIR} --restore_trainable_variables true" 31 | export RUN_DESC=" 32 | name: ${RUN_NAME}/${RUN_ID} 33 | date: ${date} 34 | tpu: ${TPU_NAME} 35 | model_dir: ${MODEL_DIR} 36 | dataset: ${DATASET} 37 | model_name: ${MODEL_NAME} 38 | 39 | ${RUN_DESC}" 40 | 41 | printf "%s" "${RUN_DESC}" 42 | 43 | #pu list -s -t $TPU_NAME | sed 's/\x1b\[[0-9;]*m//g' 44 | 45 | 46 | export TPU_SPLIT_COMPILE_AND_EXECUTE=1 47 | export TF_TPU_WATCHDOG_TIMEOUT=1800 48 | 49 | if [ -z "$TPU_CORES" ] 50 | then 51 | cores="$(echo $TPU_NAME | sed 's/^tpu-v[23][-]\([0-9]*\).*$/\1/g')" 52 | if [ -z "$cores" ] 53 | then 54 | 1>&2 echo "Failed to parse TPU core count from $TPU_NAME" 55 | exit 1 56 | fi 57 | export TPU_CORES=$cores 58 | fi 59 | 60 | 61 | if [ ! -z "${DEV}" ] 62 | then 63 | exec python3 -m pdb -c continue wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 64 | exit -1 65 | fi 66 | 67 | 68 | while true; do 69 | echo "Saving description to ${cloud_description_file} ..." 70 | printf "%s" "${RUN_DESC}" | gsutil cp - "${cloud_description_file}" 71 | 72 | echo "Starting production training run in 10s ..." 73 | sleep 10 74 | 75 | timeout --signal=SIGKILL 4h python3 wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 2>&1 | tee -a "${logfile}" | tee /dev/fd/2 | gsutil cp - "${cloud_log_file}" 76 | if [ ! -z "$TPU_NO_RECREATE" ] 77 | then 78 | echo "Not recreating TPU. Waiting 30s." 79 | sleep 30 80 | else 81 | echo "Recreating TPU in 30." 82 | sleep 30 83 | # sudo pip3 install -U tpudiepie 84 | pu recreate "$TPU_NAME" --yes --retry 300 85 | fi 86 | done 87 | -------------------------------------------------------------------------------- /runs/train-novels-1558m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-/tfk/lib}" 4 | export TPU_HOST=${TPU_HOST:-10.254.128.2} 5 | export TPU_NAME="${TPU_NAME:-tpu-v3-8-usc1a-201}" 6 | 7 | export RUN_ID="${RUN_ID:-a}" 8 | export RUN_NAME="${RUN_NAME:-novels02}" 9 | export RUN_DESC="${RUN_DESC:-1558M novels run}" 10 | tmux-set-title "${RUN_NAME}/${RUN_ID} ${TPU_NAME}" 11 | export MODEL_DIR="${MODEL_DIR:-gs://tpu-usc1/runs/gpt-2/${RUN_NAME}/${RUN_ID}/}" 12 | export MODEL_DIR="$(printf '%s' "${MODEL_DIR}" | sed 's/\/$//')" # normalize model dir; ensure it does *not* end with a slash 13 | export GIN_CONFIG="cfg/${RUN_NAME}.gin" 14 | 15 | 16 | export MODEL_NAME="${MODEL_NAME:-1558M}" 17 | export DATASET="${DATASET:-gs://tpu-usc1/datasets/novels.tok16}" 18 | export RESTORE_DIR="${RESTORE_DIR:-gs://tpu-usc1/models/gpt-2/${MODEL_NAME}}" 19 | 20 | 21 | 22 | date="$(python3 -c 'import datetime; print(datetime.datetime.now().strftime("%Y-%m-%d-%H"))')" 23 | logfile="logs/${RUN_NAME}-${RUN_ID}-${date}.txt" 24 | cloud_log_file="${MODEL_DIR}/logs-${date}-${RUN_NAME}-${RUN_ID}.txt" 25 | cloud_description_file="${MODEL_DIR}/description.txt" 26 | mkdir -p logs 27 | 28 | export DATASET="--dataset ${DATASET}" 29 | #export RESTORE_DIR="--restore_dir ${RESTORE_DIR} --restore_trainable_variables true" 30 | export RESTORE_DIR="--restore_dir ${MODEL_DIR} --restore_trainable_variables true" 31 | export RUN_DESC=" 32 | name: ${RUN_NAME}/${RUN_ID} 33 | date: ${date} 34 | tpu: ${TPU_NAME} 35 | model_dir: ${MODEL_DIR} 36 | dataset: ${DATASET} 37 | model_name: ${MODEL_NAME} 38 | 39 | ${RUN_DESC}" 40 | 41 | printf "%s" "${RUN_DESC}" 42 | 43 | #pu list -s -t $TPU_NAME | sed 's/\x1b\[[0-9;]*m//g' 44 | 45 | 46 | export TPU_SPLIT_COMPILE_AND_EXECUTE=1 47 | export TF_TPU_WATCHDOG_TIMEOUT=1800 48 | 49 | if [ -z "$TPU_CORES" ] 50 | then 51 | cores="$(echo $TPU_NAME | sed 's/^tpu-v[23][-]\([0-9]*\).*$/\1/g')" 52 | if [ -z "$cores" ] 53 | then 54 | 1>&2 echo "Failed to parse TPU core count from $TPU_NAME" 55 | exit 1 56 | fi 57 | export TPU_CORES=$cores 58 | fi 59 | 60 | 61 | if [ ! -z "${DEV}" ] 62 | then 63 | exec python3 -m pdb -c continue wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 64 | exit -1 65 | fi 66 | 67 | 68 | while true; do 69 | echo "Saving description to ${cloud_description_file} ..." 70 | printf "%s" "${RUN_DESC}" | gsutil cp - "${cloud_description_file}" 71 | 72 | echo "Starting production training run in 10s ..." 73 | sleep 10 74 | 75 | timeout --signal=SIGKILL 4h python3 wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 2>&1 | tee -a "${logfile}" | tee /dev/fd/2 | gsutil cp - "${cloud_log_file}" 76 | if [ ! -z "$TPU_NO_RECREATE" ] 77 | then 78 | echo "Not recreating TPU. Waiting 30s." 79 | sleep 30 80 | else 81 | echo "Recreating TPU in 30." 82 | sleep 30 83 | # sudo pip3 install -U tpudiepie 84 | pu recreate "$TPU_NAME" --yes --retry 300 85 | fi 86 | done 87 | -------------------------------------------------------------------------------- /runs/train-rev-117m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-/tfk/lib}" 4 | export TPU_HOST=${TPU_HOST:-10.254.128.2} 5 | export TPU_NAME="${TPU_NAME:-tpu-v2-512-usc1a-2}" 6 | 7 | export RUN_ID="${RUN_ID:-c}" 8 | export RUN_NAME="${RUN_NAME:-revnet01}" 9 | export RUN_DESC="${RUN_DESC:-117M revnet run}" 10 | tmux-set-title "${RUN_NAME}/${RUN_ID} ${TPU_NAME}" 11 | export MODEL_DIR="${MODEL_DIR:-gs://tpu-usc1/runs/gpt-2/${RUN_NAME}/${RUN_ID}/}" 12 | export MODEL_DIR="$(printf '%s' "${MODEL_DIR}" | sed 's/\/$//')" # normalize model dir; ensure it does *not* end with a slash 13 | export GIN_CONFIG="cfg/${RUN_NAME}.gin" 14 | 15 | 16 | export MODEL="${MODEL:-GPT2Rev}" 17 | export MODEL_NAME="${MODEL_NAME:-117M}" 18 | export DATASET="${DATASET:-gs://tpu-usc1/datasets/novels.tok16}" 19 | export RESTORE_DIR="${RESTORE_DIR:-gs://tpu-usc1/models/gpt-2/${MODEL_NAME}}" 20 | 21 | export WRAPPER="${WRAPPER:-wrapper.py}" 22 | 23 | 24 | 25 | date="$(python3 -c 'import datetime; print(datetime.datetime.now().strftime("%Y-%m-%d-%H"))')" 26 | logfile="logs/${RUN_NAME}-${RUN_ID}-${date}.txt" 27 | cloud_log_file="${MODEL_DIR}/logs-${date}-${RUN_NAME}-${RUN_ID}.txt" 28 | cloud_description_file="${MODEL_DIR}/description.txt" 29 | mkdir -p logs 30 | 31 | export DATASET="--dataset ${DATASET}" 32 | #export RESTORE_DIR="--restore_dir ${RESTORE_DIR} --restore_trainable_variables true" 33 | export RESTORE_DIR="--restore_dir ${MODEL_DIR} --restore_trainable_variables true" 34 | export RUN_DESC=" 35 | name: ${RUN_NAME}/${RUN_ID} 36 | date: ${date} 37 | tpu: ${TPU_NAME} 38 | model_dir: ${MODEL_DIR} 39 | dataset: ${DATASET} 40 | model_name: ${MODEL_NAME} 41 | 42 | ${RUN_DESC}" 43 | 44 | printf "%s" "${RUN_DESC}" 45 | 46 | #pu list -s -t $TPU_NAME | sed 's/\x1b\[[0-9;]*m//g' 47 | 48 | 49 | export TPU_SPLIT_COMPILE_AND_EXECUTE=1 50 | export TF_TPU_WATCHDOG_TIMEOUT=1800 51 | 52 | if [ -z "$TPU_CORES" ] 53 | then 54 | cores="$(echo $TPU_NAME | sed 's/^tpu-v[23][-]\([0-9]*\).*$/\1/g')" 55 | if [ -z "$cores" ] 56 | then 57 | 1>&2 echo "Failed to parse TPU core count from $TPU_NAME" 58 | exit 1 59 | fi 60 | export TPU_CORES=$cores 61 | fi 62 | 63 | 64 | if [ ! -z "${DEV}" ] 65 | then 66 | exec python3 -m pdb -c continue $WRAPPER main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 67 | exit -1 68 | fi 69 | 70 | 71 | while true; do 72 | echo "Saving description to ${cloud_description_file} ..." 73 | printf "%s" "${RUN_DESC}" | gsutil cp - "${cloud_description_file}" 74 | 75 | echo "Starting production training run in 10s ..." 76 | sleep 10 77 | 78 | timeout --signal=SIGKILL 4h python3 $WRAPPER main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 2>&1 | tee -a "${logfile}" | tee /dev/fd/2 | gsutil cp - "${cloud_log_file}" 79 | if [ ! -z "$TPU_NO_RECREATE" ] 80 | then 81 | echo "Not recreating TPU. Waiting 30s." 82 | sleep 30 83 | else 84 | echo "Recreating TPU in 30." 85 | sleep 30 86 | # sudo pip3 install -U tpudiepie 87 | pu recreate "$TPU_NAME" --yes --retry 300 88 | fi 89 | done 90 | -------------------------------------------------------------------------------- /runs/train-run0-117m-tensorflow.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-/tfk/lib}" 4 | export TPU_HOST=${TPU_HOST:-10.255.128.2} 5 | export TPU_NAME="${TPU_NAME:-tpu-v3-128-euw4a-50}" 6 | 7 | export RUN_ID="${RUN_ID:-a}" 8 | export RUN_NAME="${RUN_NAME:-gpt2run00}" 9 | export RUN_DESC="${RUN_DESC:-117M test run}" 10 | tmux-set-title "${RUN_NAME}/${RUN_ID} ${TPU_NAME}" 11 | export MODEL_DIR="${MODEL_DIR:-gs://dota-euw4a/runs/gpt-2/${RUN_NAME}/${RUN_ID}/}" 12 | export MODEL_DIR="$(printf '%s' "${MODEL_DIR}" | sed 's/\/$//')" # normalize model dir; ensure it does *not* end with a slash 13 | export GIN_CONFIG="cfg/${RUN_NAME}.gin" 14 | 15 | 16 | export MODEL_NAME="${MODEL_NAME:-117M}" 17 | export DATASET="${DATASET:-gs://dota-euw4a/datasets/tensorflow.tok16}" 18 | export RESTORE_DIR="${RESTORE_DIR:-gs://dota-euw4a/models/gpt-2/${MODEL_NAME}}" 19 | 20 | 21 | 22 | date="$(python3 -c 'import datetime; print(datetime.datetime.now().strftime("%Y-%m-%d-%H"))')" 23 | logfile="logs/${RUN_NAME}-${date}.txt" 24 | cloud_log_file="${MODEL_DIR}/logs-${date}-${RUN_NAME}.txt" 25 | cloud_description_file="${MODEL_DIR}/description.txt" 26 | mkdir -p logs 27 | 28 | export DATASET="--dataset ${DATASET}" 29 | export RESTORE_DIR="--restore_dir ${RESTORE_DIR} --restore_trainable_variables true" 30 | export RUN_DESC=" 31 | name: ${RUN_NAME}/${RUN_ID} 32 | date: ${date} 33 | tpu: ${TPU_NAME} 34 | model_dir: ${MODEL_DIR} 35 | dataset: ${DATASET} 36 | model_name: ${MODEL_NAME} 37 | 38 | ${RUN_DESC}" 39 | 40 | printf "%s" "${RUN_DESC}" 41 | 42 | #pu list -s -t $TPU_NAME | sed 's/\x1b\[[0-9;]*m//g' 43 | 44 | 45 | export TPU_SPLIT_COMPILE_AND_EXECUTE=1 46 | export TF_TPU_WATCHDOG_TIMEOUT=1800 47 | 48 | cores="$(echo $TPU_NAME | sed 's/^tpu-v[23][-]\([0-9]*\).*$/\1/g')" 49 | if [ -z "$cores" ] 50 | then 51 | 1>&2 echo "Failed to parse TPU core count from $TPU_NAME" 52 | exit 1 53 | fi 54 | export TPU_CORES=$cores 55 | 56 | 57 | if [ ! -z "${DEV}" ] 58 | then 59 | exec python3 -m pdb -c continue wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 60 | exit -1 61 | fi 62 | 63 | 64 | while true; do 65 | echo "Saving description to ${cloud_description_file} ..." 66 | printf "%s" "${RUN_DESC}" | gsutil cp - "${cloud_description_file}" 67 | 68 | echo "Starting production training run in 10s ..." 69 | sleep 10 70 | 71 | timeout --signal=SIGKILL 4h python3 wrapper.py main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 2>&1 | tee -a "${logfile}" | tee /dev/fd/2 | gsutil cp - "${cloud_log_file}" 72 | if [ ! -z "$TPU_NO_RECREATE" ] 73 | then 74 | echo "Not recreating TPU. Waiting 30s." 75 | sleep 30 76 | else 77 | echo "Recreating TPU in 30." 78 | sleep 30 79 | # sudo pip3 install -U tpudiepie 80 | pu recreate "$TPU_NAME" --yes 81 | fi 82 | done 83 | -------------------------------------------------------------------------------- /runs/train-vanilla-117m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-/tfk/lib}" 4 | export TPU_HOST=${TPU_HOST:-10.254.128.2} 5 | export TPU_NAME="${TPU_NAME:-tpu-v2-8-usc1f-8}" 6 | 7 | export RUN_ID="${RUN_ID:-c}" 8 | export RUN_NAME="${RUN_NAME:-vanilla117m01}" 9 | export RUN_DESC="${RUN_DESC:-117M vanilla run}" 10 | tmux-set-title "${RUN_NAME}/${RUN_ID} ${TPU_NAME}" 11 | export MODEL_DIR="${MODEL_DIR:-gs://tpu-usc1/runs/gpt-2/${RUN_NAME}/${RUN_ID}/}" 12 | export MODEL_DIR="$(printf '%s' "${MODEL_DIR}" | sed 's/\/$//')" # normalize model dir; ensure it does *not* end with a slash 13 | export GIN_CONFIG="cfg/${RUN_NAME}.gin" 14 | 15 | 16 | export MODEL="${MODEL:-GPT2}" 17 | export MODEL_NAME="${MODEL_NAME:-117M}" 18 | export DATASET="${DATASET:-gs://tpu-usc1/datasets/novels.tok16}" 19 | export RESTORE_DIR="${RESTORE_DIR:-gs://tpu-usc1/models/gpt-2/${MODEL_NAME}}" 20 | 21 | export WRAPPER="${WRAPPER:-wrapper.py}" 22 | 23 | 24 | 25 | date="$(python3 -c 'import datetime; print(datetime.datetime.now().strftime("%Y-%m-%d-%H"))')" 26 | logfile="logs/${RUN_NAME}-${RUN_ID}-${date}.txt" 27 | cloud_log_file="${MODEL_DIR}/logs-${date}-${RUN_NAME}-${RUN_ID}.txt" 28 | cloud_description_file="${MODEL_DIR}/description.txt" 29 | mkdir -p logs 30 | 31 | export DATASET="--dataset ${DATASET}" 32 | #export RESTORE_DIR="--restore_dir ${RESTORE_DIR} --restore_trainable_variables true" 33 | export RESTORE_DIR="--restore_dir ${MODEL_DIR} --restore_trainable_variables true" 34 | export RUN_DESC=" 35 | name: ${RUN_NAME}/${RUN_ID} 36 | date: ${date} 37 | tpu: ${TPU_NAME} 38 | model_dir: ${MODEL_DIR} 39 | dataset: ${DATASET} 40 | model_name: ${MODEL_NAME} 41 | 42 | ${RUN_DESC}" 43 | 44 | printf "%s" "${RUN_DESC}" 45 | 46 | #pu list -s -t $TPU_NAME | sed 's/\x1b\[[0-9;]*m//g' 47 | 48 | 49 | export TPU_SPLIT_COMPILE_AND_EXECUTE=1 50 | export TF_TPU_WATCHDOG_TIMEOUT=1800 51 | 52 | if [ -z "$TPU_CORES" ] 53 | then 54 | cores="$(echo $TPU_NAME | sed 's/^tpu-v[23][-]\([0-9]*\).*$/\1/g')" 55 | if [ -z "$cores" ] 56 | then 57 | 1>&2 echo "Failed to parse TPU core count from $TPU_NAME" 58 | exit 1 59 | fi 60 | export TPU_CORES=$cores 61 | fi 62 | 63 | 64 | if [ ! -z "${DEV}" ] 65 | then 66 | exec python3 -m pdb -c continue $WRAPPER main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 67 | exit -1 68 | fi 69 | 70 | 71 | while true; do 72 | echo "Saving description to ${cloud_description_file} ..." 73 | printf "%s" "${RUN_DESC}" | gsutil cp - "${cloud_description_file}" 74 | 75 | echo "Starting production training run in 10s ..." 76 | sleep 10 77 | 78 | timeout --signal=SIGKILL 4h python3 $WRAPPER main_gpt2.py --tpu "${TPU_NAME}" --model_dir "${MODEL_DIR}" ${RESTORE_DIR} --params "${MODEL_NAME}.json" --num_cores "${TPU_CORES}" ${DATASET} "$@" 2>&1 | tee -a "${logfile}" | tee /dev/fd/2 | gsutil cp - "${cloud_log_file}" 79 | if [ ! -z "$TPU_NO_RECREATE" ] 80 | then 81 | echo "Not recreating TPU. Waiting 30s." 82 | sleep 30 83 | else 84 | echo "Recreating TPU in 30." 85 | sleep 30 86 | # sudo pip3 install -U tpudiepie 87 | pu recreate "$TPU_NAME" --yes --retry 300 88 | fi 89 | done 90 | -------------------------------------------------------------------------------- /tf_timeline.py: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/guide/data_performance#reproducing_the_figures 2 | """ 3 | This dataset provides samples of shape [[2, 1], [2, 2], [2, 3]] and of type [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32]. Each sample is: 4 | 5 | 6 | ( 7 | [("Open"), ("Read")], 8 | [(t0, d), (t0, d)], 9 | [(i, e, -1), (i, e, s)] 10 | ) 11 | Where: 12 | 13 | Open and Read are steps identifiers 14 | t0 is the timestamp when the corresponding step started 15 | d is the time spent in the corresponding step 16 | i is the instance index 17 | e is the epoch index (number of times the dataset has been iterated) 18 | s is the sample index 19 | """ 20 | 21 | import itertools 22 | from collections import OrderedDict, defaultdict 23 | 24 | import numpy as np 25 | import matplotlib as mpl 26 | import matplotlib.pyplot as plt 27 | import time 28 | 29 | 30 | def now(): 31 | return time.perf_counter() 32 | 33 | def wait(secs): 34 | t0 = now() 35 | time.sleep(secs) 36 | return t0, secs 37 | 38 | class OrderedDefaultDict(OrderedDict): 39 | def __init__(self, default_factory=None, *args, **kwargs): 40 | #in python3 you can omit the args to super 41 | super(OrderedDefaultDict, self).__init__(*args, **kwargs) 42 | self.default_factory = default_factory 43 | def __missing__(self, key): 44 | self[key] = value = self.default_factory() 45 | return value 46 | 47 | 48 | class TimelineStep: 49 | def __init__(self): 50 | self.times = [] 51 | self.values = [] 52 | 53 | def add(self, time_start, time_spent, instance_index=0, epoch_index=0, sample_index=-1): 54 | self.times += [(time_start, time_spent)] 55 | self.values += [(instance_index, epoch_index, sample_index)] 56 | 57 | 58 | class Timeline: 59 | def __init__(self): 60 | self.steps = OrderedDefaultDict(default_factory=lambda: TimelineStep()) 61 | def add(self, step_name, time_start, time_spent, instance_index=0, epoch_index=0, sample_index=-1): 62 | self.steps[step_name].add(time_start=time_start, time_spent=time_spent, instance_index=instance_index, epoch_index=epoch_index, sample_index=sample_index) 63 | def get_timeline(self): 64 | steps = [] 65 | times = [] 66 | values = [] 67 | for step, ts in self.steps.items(): 68 | step = step.encode('utf8') 69 | for t, v in zip(ts.times, ts.values): 70 | steps += [tuple([step])] 71 | times += [t] 72 | values += [v] 73 | return {'steps': steps, 'times': times, 'values': values} 74 | 75 | 76 | def test_timeline(): 77 | tim = Timeline() 78 | t0 = now() 79 | for i in range(10): 80 | t = t0 + 2*i 81 | for phase in 'Open Read Map Train'.split(): 82 | d = np.random.uniform() 83 | tim.add(phase, t, d, i); t += d 84 | return tim 85 | 86 | def make_test_timeline(): 87 | i = 0 88 | e = 0 89 | s = 0 90 | # tl = ( 91 | # [("Open"), ("Read")], 92 | # [(t0, d), (t0, d)], 93 | # [(i, e, -1), (i, e, s)] 94 | # ) 95 | t0 = now() 96 | time.sleep(0.3) 97 | d = now() - t0 98 | steps += [("Open")] 99 | times += [(t0, d)] 100 | values += [(i, e, -1)] 101 | 102 | time.sleep(0.1) 103 | t0 = now() 104 | time.sleep(0.3) 105 | d = now() - t0 106 | steps += [("Read")] 107 | times += [(t0, d)] 108 | values += [(i, e, s)] 109 | return {'steps': steps, 'times': times, 'values': values} 110 | 111 | 112 | def draw_timeline(timeline, title, width=0.5, annotate=False, save=False): 113 | # convert to numpy 114 | timeline['steps'] = np.array(timeline['steps'], dtype=np.bytes_) 115 | timeline['times'] = np.array(timeline['times'], dtype=np.float32) 116 | timeline['values'] = np.array(timeline['values'], dtype=np.int32) 117 | # Remove invalid entries (negative times, or empty steps) from the timelines 118 | invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0] 119 | steps = timeline['steps'][invalid_mask] 120 | times = timeline['times'][invalid_mask] 121 | values = timeline['values'][invalid_mask] 122 | 123 | # Get a set of different steps, ordered by the first time they are encountered 124 | step_ids, indices = np.stack(np.unique(steps, return_index=True)) 125 | step_ids = step_ids[np.argsort(indices)] 126 | 127 | # Shift the starting time to 0 and compute the maximal time value 128 | min_time = times[:,0].min() 129 | times[:,0] = (times[:,0] - min_time) 130 | end = max(width, (times[:,0]+times[:,1]).max() + 0.01) 131 | 132 | cmap = mpl.cm.get_cmap("plasma") 133 | plt.close() 134 | fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0}) 135 | fig.suptitle(title) 136 | fig.set_size_inches(17.0, len(step_ids)) 137 | plt.xlim(-0.01, end) 138 | 139 | for i, step in enumerate(step_ids): 140 | step_name = step.decode() 141 | ax = axs[i] 142 | ax.set_ylabel(step_name) 143 | ax.set_ylim(0, 1) 144 | ax.set_yticks([]) 145 | ax.set_xlabel("time (s)") 146 | ax.set_xticklabels([]) 147 | ax.grid(which="both", axis="x", color="k", linestyle=":") 148 | 149 | # Get timings and annotation for the given step 150 | entries_mask = np.squeeze(steps==step) 151 | serie = np.unique(times[entries_mask], axis=0) 152 | annotations = values[entries_mask] 153 | 154 | ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66) 155 | if annotate: 156 | for j, (start, width) in enumerate(serie): 157 | annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])]) 158 | ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation, 159 | horizontalalignment='left', verticalalignment='center') 160 | if save: 161 | plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg") 162 | 163 | if __name__ == '__main__': 164 | tim = test_timeline() 165 | draw_timeline(tim.get_timeline(), 'test'); 166 | plt.show() 167 | -------------------------------------------------------------------------------- /tfjpg_parser.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ImageNet preprocessing for ResNet.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl import flags 21 | import tensorflow as tf 22 | import tputil 23 | 24 | IMAGE_SIZE = 512 25 | CROP_PADDING = 32 26 | 27 | #FLAGS = flags.FLAGS 28 | class Namespace: 29 | pass 30 | 31 | FLAGS = Namespace() 32 | FLAGS.cache_decoded_image = False 33 | 34 | 35 | def _int64_feature(value): 36 | """Wrapper for inserting int64 features into Example proto.""" 37 | if not isinstance(value, list): 38 | value = [value] 39 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 40 | 41 | 42 | def _bytes_feature(value): 43 | """Wrapper for inserting bytes features into Example proto.""" 44 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 45 | 46 | 47 | def _convert_to_example(image_buffer, label): 48 | """Build an Example proto for an example. 49 | 50 | Args: 51 | image_buffer: string, JPEG encoding of RGB image 52 | label: integer, identifier for the ground truth for the network 53 | 54 | Returns: 55 | Example proto 56 | """ 57 | 58 | example = tf.train.Example( 59 | features=tf.train.Features( 60 | feature={ 61 | 'image/class/label': _int64_feature(label), 62 | 'image/encoded': _bytes_feature(image_buffer) 63 | })) 64 | return example 65 | 66 | 67 | class ImageNet(object): 68 | 69 | @staticmethod 70 | def set_shapes(image_size, channels, transpose_input, train_batch_size, batch_size, num_cores, features, labels): 71 | """Statically set the batch_size dimension.""" 72 | dick = isinstance(features, dict) 73 | images = features["images"] if dick else features 74 | if transpose_input: 75 | if train_batch_size // num_cores > 8: 76 | shape = [image_size, image_size, channels, batch_size] 77 | else: 78 | shape = [image_size, image_size, batch_size, channels] 79 | images.set_shape(images.get_shape().merge_with(tf.TensorShape(shape))) 80 | images = tf.reshape(images, [-1]) 81 | labels.set_shape(labels.get_shape().merge_with( 82 | tf.TensorShape([batch_size]))) 83 | else: 84 | images.set_shape(images.get_shape().merge_with( 85 | tf.TensorShape([batch_size, image_size, image_size, channels]))) 86 | labels.set_shape(labels.get_shape().merge_with( 87 | tf.TensorShape([batch_size]))) 88 | if dick: 89 | features["images"] = images 90 | else: 91 | features = images 92 | return features, labels 93 | 94 | @staticmethod 95 | def dataset_parser_static(value): 96 | """Parses an image and its label from a serialized ResNet-50 TFExample. 97 | 98 | This only decodes the image, which is prepared for caching. 99 | 100 | Args: 101 | value: serialized string containing an ImageNet TFExample. 102 | 103 | Returns: 104 | Returns a tuple of (image, label) from the TFExample. 105 | """ 106 | keys_to_features = { 107 | 'image/encoded': tf.FixedLenFeature((), tf.string, ''), 108 | 'image/format': tf.FixedLenFeature((), tf.string, 'jpeg'), 109 | 'image/class/label': tf.FixedLenFeature([], tf.int64, -1), 110 | 'image/class/embedding': tf.VarLenFeature(tf.float32), 111 | 'image/width': tf.FixedLenFeature([], tf.int64, -1), 112 | 'image/height': tf.FixedLenFeature([], tf.int64, -1), 113 | 'image/filename': tf.FixedLenFeature([], tf.string, ''), 114 | 'image/class/text': tf.FixedLenFeature([], tf.string, ''), 115 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 116 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 117 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 118 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 119 | 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64), 120 | } 121 | 122 | parsed = tf.parse_single_example(value, keys_to_features) 123 | parsed['image/hash'] = tf.raw_ops.Fingerprint(data=[parsed['image/encoded']], method='farmhash64')[0] 124 | identifier = tf.abs(tf.bitcast(parsed['image/hash'], tf.int64)) 125 | image_bytes = tf.reshape(parsed['image/encoded'], shape=[]) 126 | image = tf.io.decode_image(image_bytes, 3) 127 | image.set_shape(tf.TensorShape([None, None, 3])) 128 | 129 | # Subtract one so that labels are in [0, 1000). 130 | label = tf.cast( 131 | tf.reshape(parsed['image/class/label'], shape=[]), dtype=tf.int32) - 0 132 | 133 | embedding = parsed['image/class/embedding'].values 134 | 135 | embedding = tf.cond( 136 | tf.math.greater(tf.shape(embedding)[0], 0), 137 | lambda: embedding, 138 | lambda: tf.one_hot(label, 1000)) 139 | 140 | return { 141 | 'id': identifier, 142 | 'image': image, 143 | 'label': label, 144 | 'embedding': embedding, 145 | 'parsed': parsed, 146 | } 147 | 148 | @staticmethod 149 | def get_current_host(params): 150 | # TODO(dehao): Replace the following with params['context'].current_host 151 | if 'context' in params: 152 | return params['context'].current_input_fn_deployment()[1] 153 | elif 'dataset_index' in params: 154 | return params['dataset_index'] 155 | else: 156 | return 0 157 | 158 | @staticmethod 159 | def get_num_hosts(params): 160 | if 'context' in params: 161 | return params['context'].num_hosts 162 | elif 'dataset_index' in params: 163 | return params['dataset_num_shards'] 164 | else: 165 | return 1 166 | 167 | @staticmethod 168 | def get_num_cores(params): 169 | return 8 * ImageNet.get_num_hosts(params) 170 | 171 | @staticmethod 172 | def make_dataset(file_patterns, index, num_hosts, 173 | seed=None, shuffle_filenames=False, 174 | num_parallel_calls = 64, 175 | filter_fn=None, 176 | parse_fn=None, 177 | batch_size=None, 178 | cache_image_data=False, 179 | cache_decoded_image=False): 180 | 181 | if shuffle_filenames: 182 | assert seed is not None 183 | 184 | # For multi-host training, we want each hosts to always process the same 185 | # subset of files. Each host only sees a subset of the entire dataset, 186 | # allowing us to cache larger datasets in memory. 187 | if False: 188 | file_patterns = [x.strip() for x in file_patterns.split(',') if len(x.strip()) > 0] 189 | dataset = None 190 | for pattern in file_patterns: 191 | x = tf.data.Dataset.list_files(pattern, shuffle=shuffle_filenames, seed=seed) 192 | dataset = x if dataset is None else dataset.concatenate(x) 193 | dataset = dataset.shard(num_hosts, index) 194 | 195 | def fetch_dataset(filename): 196 | buffer_size = 8 * 1024 * 1024 # 8 MiB per file 197 | dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) 198 | return dataset 199 | 200 | # Read the data from disk in parallel 201 | dataset = dataset.apply( 202 | tf.contrib.data.parallel_interleave( 203 | fetch_dataset, cycle_length=num_parallel_calls, sloppy=True)) 204 | else: 205 | # filenames = [] 206 | # for pattern in file_patterns: 207 | # files = tf.io.gfile.glob(pattern) 208 | # if len(files) <= 0: 209 | # raise ValueError("Pattern matched no files: {}".format(pattern)) 210 | # filenames.append(files) 211 | dataset = tputil.tf_sharded_datasets(file_patterns, num_hosts=num_hosts, current_host=index) 212 | 213 | dataset = dataset.map( 214 | ImageNet.dataset_parser_static, 215 | num_parallel_calls=num_parallel_calls) 216 | 217 | if filter_fn is not None: 218 | raise NotImplementedError() 219 | #dataset = dataset.filter(filter_fn) 220 | 221 | if cache_image_data: 222 | assert cache_decoded_image == False 223 | dataset = dataset.cache() 224 | 225 | if parse_fn is not None: 226 | dataset = dataset.map( 227 | parse_fn, 228 | num_parallel_calls=num_parallel_calls) 229 | if cache_decoded_image: 230 | dataset = dataset.cache() 231 | 232 | if batch_size is not None: 233 | dataset = dataset.repeat() 234 | dataset = dataset.batch(batch_size) 235 | def set_batch_size(arg): 236 | if isinstance(arg, (tuple, list)): 237 | for x in arg: 238 | set_batch_size(x) 239 | elif isinstance(arg, dict): 240 | for k, v in arg.items(): 241 | set_batch_size(v) 242 | elif isinstance(arg, tf.Tensor): 243 | shape = arg.shape.as_list() 244 | if len(shape) > 0 and shape[0] is None: 245 | shape[0] = batch_size 246 | arg.set_shape(shape) 247 | return arg 248 | dataset = dataset.map( 249 | set_batch_size, 250 | num_parallel_calls=num_parallel_calls) 251 | return dataset 252 | 253 | 254 | 255 | from tensorflow.python.framework.errors_impl import OutOfRangeError 256 | 257 | def iterate_dataset(dataset, n = -1, session=None): 258 | if session is None: 259 | session = tf.get_default_session() 260 | iterator = dataset.make_one_shot_iterator() 261 | next_element = iterator.get_next() 262 | while n != 0: 263 | n -= 1 264 | try: 265 | yield session.run(next_element) 266 | except OutOfRangeError: 267 | return 268 | -------------------------------------------------------------------------------- /tflex_tpu_topology.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import base64 4 | 5 | from tensorflow.python.tpu import tpu as tpu_ops 6 | from tensorflow.compiler.tf2xla.python import xla 7 | from tensorflow.compiler.tf2xla.ops import gen_xla_ops 8 | from tensorflow.python.tpu import tpu_strategy_util 9 | from tensorflow.python.tpu import device_assignment as device_assignment_lib 10 | from tensorflow.python.tpu import topology as topology_lib 11 | from tensorflow.contrib.cluster_resolver import TPUClusterResolver as BaseTPUClusterResolver 12 | from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 13 | 14 | 15 | _TOPOLOGY_CACHE_FILENAME = '.tpu_topology_cache.json' 16 | 17 | 18 | class Context(): 19 | pass 20 | 21 | 22 | if 'api' not in globals(): 23 | api = Context() 24 | api.topology = None 25 | api.topology_cache = {} 26 | try: 27 | with open(_TOPOLOGY_CACHE_FILENAME, 'r') as f: 28 | api.topology_cache = json.load(f) 29 | except FileNotFoundError: 30 | pass 31 | 32 | 33 | def cached_topology(name=None): 34 | if name is None: 35 | name = os.environ.get('TPU_NAME', '') 36 | result = api.topology_cache.get(name, None) 37 | if result is not None: 38 | serialized = base64.b64decode(result) 39 | return topology_lib.Topology(serialized=serialized) 40 | 41 | 42 | def get_cluster_resolver(cluster_resolver=None): 43 | if cluster_resolver is None: 44 | cluster_resolver = BaseTPUClusterResolver(os.environ['TPU_NAME']) 45 | return cluster_resolver 46 | 47 | 48 | def get_topology(cluster_resolver=None): 49 | api.topology = cached_topology() 50 | if api.topology is None: 51 | cluster_resolver = get_cluster_resolver(cluster_resolver) 52 | api.topology = tpu_strategy_util.initialize_tpu_system(cluster_resolver) 53 | api.topology_cache.update({os.environ['TPU_NAME']: base64.b64encode(api.topology.serialized()).decode('utf8')}) 54 | with open(_TOPOLOGY_CACHE_FILENAME, 'w') as f: 55 | f.write(json.dumps(api.topology_cache)) 56 | return api.topology 57 | 58 | 59 | def get_task_and_cores_to_replicas(): 60 | return device_assignment_lib._compute_task_and_cores_to_replicas(api.topology.device_coordinates, api.topology) 61 | 62 | 63 | def get_core_assignment(*core_ids): 64 | return device_assignment_lib.DeviceAssignment(get_topology(), [[get_topology().device_coordinates[0][i]] for i in core_ids]) 65 | 66 | 67 | def get_metadata(cluster_resolver=None): 68 | cluster_resolver = get_cluster_resolver(cluster_resolver) 69 | meta = tpu_system_metadata_lib._query_tpu_system_metadata(cluster_resolver.get_master(), cluster_def=cluster_resolver.cluster_spec().as_cluster_def(), query_topology=True) 70 | return meta 71 | 72 | 73 | api.topology = cached_topology() 74 | 75 | -------------------------------------------------------------------------------- /tpu_normalization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Normamlization methods that implements cross replica nomalization for TPU.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | import tensorflow.compat.v1 as tf 23 | 24 | from tensorflow.python.ops import math_ops # pylint:disable=g-direct-tensorflow-import 25 | from tensorflow.python.tpu import tpu_function # pylint:disable=g-direct-tensorflow-import 26 | 27 | import tf_tools as tft 28 | 29 | 30 | def cross_replica_average(t, num_groups=1): 31 | """Calculates the average value of input tensor across TPU replicas.""" 32 | num_shards = tpu_function.get_tpu_context().number_of_shards 33 | num_shards_per_group = 1 34 | group_assignment = None 35 | if num_groups > 0: 36 | if num_shards % num_groups != 0: 37 | raise ValueError('num_shards: %d mod num_groups: %d, should be 0' % 38 | (num_shards, num_groups)) 39 | num_shards_per_group = num_shards // num_groups 40 | group_assignment = [[ 41 | x for x in range(num_shards) if x // num_shards_per_group == y 42 | ] for y in range(num_groups)] 43 | return tf.tpu.cross_replica_sum(t, group_assignment) / math_ops.cast( 44 | num_shards_per_group, t.dtype) 45 | 46 | 47 | class BatchNormalization(tf.layers.BatchNormalization): 48 | """Batch Normalization layer that supports cross replica computation on TPU. 49 | This class extends the keras.BatchNormalization implementation by supporting 50 | cross replica means and variances. The base class implementation only computes 51 | moments based on mini-batch per replica (TPU core). 52 | For detailed information of arguments and implementation, refer to: 53 | https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization 54 | Attributes: 55 | fused: if `None` or `True`, use a faster, fused implementation if possible. 56 | If `False`, use the system recommended implementation. 57 | cross_replica_average_fn: A function takes a tensor and outputs the mean 58 | value across all the replicas. Currently, only TPU version supports this 59 | feature. If specified, fused must be `False`. 60 | """ 61 | 62 | def __init__(self, fused=None, cross_replica_average_fn=None, **kwargs): 63 | kwargs['fused'] = fused 64 | super(BatchNormalization, self).__init__(**kwargs) 65 | self.cross_replica_average_fn = cross_replica_average_fn 66 | 67 | if fused and cross_replica_average_fn is not None: 68 | raise ValueError('fused must be `False` when sepcifying' 69 | ' cross_replica_average_fn') 70 | 71 | def _moments(self, inputs, reduction_axes, keep_dims): 72 | shard_mean, shard_variance = super(BatchNormalization, self)._moments( 73 | inputs, reduction_axes, keep_dims=keep_dims) 74 | if self.cross_replica_average_fn: 75 | # Uses the definition of Var[X] = E[X^2] - E[X]^2. 76 | shard_square_of_mean = tf.math.square(shard_mean) 77 | shard_mean_of_square = shard_variance + shard_square_of_mean 78 | group_mean = self.cross_replica_average_fn(shard_mean) 79 | group_mean_of_square = self.cross_replica_average_fn(shard_mean_of_square) 80 | group_variance = group_mean_of_square - tf.math.square(group_mean) 81 | return (group_mean, group_variance) 82 | else: 83 | return (shard_mean, shard_variance) 84 | 85 | def add_weight(self, name, **kws): 86 | print('add_weight', name, kws) 87 | experimental_autocast = kws.pop('experimental_autocast', None) 88 | return tft.globalvar(name, **kws) 89 | 90 | 91 | class CrossReplicaBatchNormalization(BatchNormalization): 92 | def __init__(self, num_distributed_groups=1, **kwargs): 93 | super(CrossReplicaBatchNormalization, self).__init__( 94 | cross_replica_average_fn=functools.partial( 95 | cross_replica_average, num_groups=num_distributed_groups), 96 | **kwargs) 97 | 98 | 99 | def cross_replica_batch_normalization(inputs, 100 | training=False, 101 | num_distributed_groups=1, 102 | **kwargs): 103 | """Functional interface for the cross replica batch normalization layer. 104 | For detailed information of arguments and implementation, refer to: 105 | https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization 106 | Arguments: 107 | inputs: Tensor input. 108 | training: Either a Python boolean, or a TensorFlow boolean scalar tensor 109 | (e.g. a placeholder). Whether to return the output in training mode 110 | (normalized with statistics of the current batch) or in inference mode 111 | (normalized with moving statistics). **NOTE**: make sure to set this 112 | parameter correctly, or else your training/inference will not work 113 | properly. 114 | num_distributed_groups: Number of groups to normalize in the distributed 115 | batch normalization. Replicas will evenly split into groups. For example, 116 | 1 for global batch norm and -1 or None for per-replica batch norm. 117 | **kwargs: For passing through arguments to BatchNormalization. 118 | Returns: 119 | Output tensor. 120 | Raises: 121 | ValueError: if eager execution is enabled. 122 | """ 123 | # layer = BatchNormalization( 124 | # cross_replica_average_fn=functools.partial( 125 | # cross_replica_average, num_groups=num_distributed_groups), 126 | # **kwargs) 127 | layer = CrossReplicaBatchNormalization(num_distributed_groups=num_distributed_groups, **kwargs) 128 | return layer.apply(inputs, training=training) 129 | -------------------------------------------------------------------------------- /train_biggan.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | 4 | import numpy as np 5 | 6 | # Required import to configure core TF classes and functions. 7 | import gin 8 | import gin.tf.external_configurables 9 | import gin.tf.utils 10 | import tensorflow as tf 11 | #import tensorflow.compat.v1 as tf 12 | 13 | from absl import app 14 | from absl import logging 15 | 16 | import train_runner 17 | import train_flags 18 | 19 | FLAGS = train_flags.FLAGS 20 | 21 | from pprint import pprint as pp 22 | from pprint import pformat as pps 23 | 24 | # from model_fns import gpt2_model, gpt2_rev_model 25 | # from input_fns import gpt2_input 26 | 27 | import BigGAN 28 | #from tfjpg_parser import ImageNet 29 | import tfjpg_parser 30 | import losses 31 | import utils 32 | 33 | import tflex 34 | 35 | 36 | def main(unused_argv): 37 | logging.info("Gin config: %s\nGin bindings: %s", 38 | FLAGS.gin_config, FLAGS.gin_bindings) 39 | gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) 40 | 41 | cfg = train_flags.run_config() 42 | pp(cfg) 43 | trunner = train_runner.TrainRunner( 44 | iterations=cfg.iterations_per_loop, train_steps=cfg.train_steps) 45 | def input_fn(params): 46 | tokens = [[_ for _ in range(0, 1024)]] * params['batch_size'] 47 | labels = [[_ for _ in range(1, 1025)]] * params['batch_size'] 48 | t = tf.broadcast_to(tokens, [len(tokens), len(tokens[0])]) 49 | l = tf.broadcast_to(labels, [len(labels), len(labels[0])]) 50 | #dset1 = tf.data.Dataset.from_tensor_slices(t); 51 | #dset2 = tf.data.Dataset.from_tensor_slices(l); 52 | dset1 = tf.data.Dataset.from_tensors(t); 53 | dset2 = tf.data.Dataset.from_tensors(l); 54 | dset = tf.data.Dataset.zip((dset1, dset2)) 55 | dset = dset.repeat() 56 | return dset 57 | def input_fn(params): 58 | info = train_runner.get_input_info(params) 59 | pp(['input_fn.params', params]) 60 | pp(['input_fn.info', info]) 61 | seed = params.get('seed', None) 62 | # seed = 0 63 | # dataset = tfjpg_parser.ImageNet.make_dataset(FLAGS.dataset or "gs://dota-euw4a/datasets/danbooru2019-s/danbooru2019-s-0*", 0, 1, seed=seed) 64 | #dset = tfjpg_parser.ImageNet.make_dataset(params['dataset'], info.current_host, info.num_hosts, seed=seed, shuffle_filenames=False) 65 | #import pdb; pdb.set_trace() 66 | # def filter_fn(input): 67 | # pp(['filter_fn.input', input]) 68 | # return tf.mod(input['id'], 100) == 0 69 | filter_fn = None 70 | def parse_fn(input): 71 | pp(['parse_fn.input', input]) 72 | target_image_resolution = train_flags.options().resolution 73 | target_image_shape = [target_image_resolution, target_image_resolution] 74 | image = ((input['image'] / 255) - 0.5) * 2.0 75 | image = tf.image.resize_image_with_pad( 76 | image, target_image_shape[1], target_image_shape[0], 77 | method=tf.image.ResizeMethod.AREA) 78 | features = image 79 | label = tf.mod(input['id'], 1000) 80 | return {'reals': (features, label)} 81 | dset = tfjpg_parser.ImageNet.make_dataset( 82 | params['dataset'], 83 | info.current_host, 84 | info.num_hosts, 85 | seed=seed, 86 | shuffle_filenames=False, 87 | #filter_fn=lambda dset: pp(dset) or True, 88 | #parse_fn=lambda dset: {'image': ((dset['image'] / 255) - 0.5) * 2.0}, 89 | filter_fn=filter_fn, 90 | parse_fn=parse_fn, 91 | #batch_size=params['batch_size'], 92 | batch_size=params['batch_per_core'], 93 | cache_image_data=True, 94 | ) 95 | pp(['training_dataset', dset]) 96 | return dset 97 | def create_train_op(input, labels, params): 98 | assert labels is None 99 | reals, reals_class_id = input['reals'] 100 | pp(['input', input]) 101 | pp(['reals', reals]) 102 | pp(['reals_class_id', reals_class_id]) 103 | pp(['params', params]) 104 | mdl = BigGAN.GAN() 105 | BigGAN.instance = mdl 106 | dim_z = mdl.gan.generator.dim_z 107 | nclasses = mdl.gan.discriminator.n_class 108 | N, H, W, C = reals.shape.as_list() 109 | fakes_z, fakes_class_id = utils.prepare_z_y(G_batch_size=N, dim_z=dim_z, nclasses=nclasses) 110 | reals_y = tf.one_hot(reals_class_id, nclasses) 111 | fakes_y = tf.one_hot(fakes_class_id, nclasses) 112 | fakes = mdl.gan.generator(fakes_z, fakes_y) 113 | reals_D = mdl.gan.discriminator(reals, reals_y) 114 | fakes_D = mdl.gan.discriminator(fakes, fakes_y) 115 | global_step = tflex.get_or_create_global_step() 116 | #inc_global_step = global_step.assign_add(1, read_value=False, name="inc_global_step") 117 | # G_vars = [] 118 | # D_vars = [] 119 | # for variable in tf.trainable_variables(): 120 | # if variable.name.startswith('Generator/'): 121 | # G_vars.append(variable) 122 | # elif variable.name.startswith('Discriminator/'): 123 | # D_vars.append(variable) 124 | # elif variable.name.startswith('linear/w'): 125 | # G_vars.append(variable) 126 | # D_vars.append(variable) 127 | # else: 128 | # import pdb; pdb.set_trace() 129 | # assert False, "Unexpected trainable variable" 130 | T_vars = tf.trainable_variables() 131 | G_vars = [x for x in T_vars if x.name.startswith('Generator/') or x.name.startswith('linear/w:')] 132 | D_vars = [x for x in T_vars if x.name.startswith('Discriminator/') or x.name.startswith('linear/w:')] 133 | leftover_vars = [x for x in T_vars if x not in G_vars and x not in D_vars] 134 | if len(leftover_vars) > 0: 135 | import pdb; pdb.set_trace() 136 | raise ValueError("Unexpected trainable variables") 137 | # pp({ 138 | # "G_vars": G_vars, 139 | # "D_vars": D_vars, 140 | # "leftover_vars": leftover_vars, 141 | # }) 142 | if True: 143 | def should_train_variable(v): return True 144 | train_vars = [v for v in tf.trainable_variables() if should_train_variable(v)] 145 | non_train_vars = [v for v in tf.trainable_variables() if not should_train_variable(v)] 146 | other_vars = [v for v in tf.global_variables() if v not in train_vars and v not in non_train_vars] 147 | local_vars = [v for v in tf.local_variables()] 148 | 149 | paramcount = lambda vs: sum([np.prod(v.shape.as_list()) for v in vs]) 150 | 151 | def logvars(variables, label, print_variables=False): 152 | if print_variables: 153 | tf.logging.info("%s (%s parameters): %s", label, paramcount(variables), pps(variables)) 154 | else: 155 | tf.logging.info("%s (%s parameters)", label, paramcount(variables)) 156 | return variables 157 | 158 | tf.logging.info("Training %d parameters (%.2fM) out of %d parameters (%.2fM)" % ( 159 | paramcount(train_vars), paramcount(train_vars)/(1024.0*1024.0), 160 | paramcount(tf.trainable_variables()), paramcount(tf.trainable_variables())/(1024.0*1024.0), 161 | )) 162 | 163 | tf.logging.info("---------") 164 | tf.logging.info("Variable details:") 165 | logvars(train_vars, "trainable variables", print_variables=True) 166 | logvars(non_train_vars, "non-trainable variables", print_variables=True) 167 | logvars(other_vars, "other global variables", print_variables=True) 168 | logvars(local_vars, "other local variables", print_variables=True) 169 | 170 | tf.logging.info("---------") 171 | tf.logging.info("Variable summary:") 172 | logvars(train_vars, "trainable variables") 173 | logvars(non_train_vars, "non-trainable variables") 174 | logvars(other_vars, "other global variables") 175 | logvars(local_vars, "other local variables") 176 | 177 | G_loss = losses.generator_loss(fakes_D) 178 | D_loss_real, D_loss_fake = losses.discriminator_loss(reals_D, fakes_D) 179 | D_loss = D_loss_real + D_loss_fake 180 | #loss = tf.constant(0.0) 181 | loss = G_loss + D_loss 182 | optimizer = tf.train.AdamOptimizer() 183 | if params['use_tpu']: 184 | optimizer = tf.tpu.CrossShardOptimizer(optimizer) 185 | #import pdb; pdb.set_trace() 186 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # To update batchnorm, if present 187 | pp(['tf.GraphKeys.UPDATE_OPS', update_ops]) 188 | with tf.control_dependencies(update_ops): 189 | train_op = optimizer.minimize(loss, var_list=T_vars, global_step=global_step) 190 | return train_op, loss #D_loss_real 191 | def model_fn(input, labels, mode, params): 192 | pp(['model_fn.mode', mode]) 193 | if mode == tf.estimator.ModeKeys.TRAIN: 194 | train_op, loss = create_train_op(input, labels, params) 195 | if params['use_tpu']: 196 | return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op) 197 | else: 198 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 199 | else: 200 | import pdb; pdb.set_trace() 201 | raise NotImplementedError() 202 | params = train_flags.options() 203 | trunner.initialize(input_fn, model_fn, params) 204 | tf.logging.info('trunner.initialize(): Done. Training...') 205 | trunner.train() 206 | tf.logging.info('trunner.train(): Done. Shutting down...') 207 | trunner.shutdown() 208 | tf.logging.info('trunner.shutdown(): Done.') 209 | 210 | 211 | 212 | if __name__ == "__main__": 213 | app.run(main) 214 | 215 | -------------------------------------------------------------------------------- /train_flags.py: -------------------------------------------------------------------------------- 1 | import gin 2 | from absl import flags 3 | 4 | FLAGS = flags.FLAGS 5 | 6 | FAKE_DATA_DIR = 'gs://cloud-tpu-test-datasets/fake_imagenet' 7 | 8 | flags.DEFINE_bool( 9 | 'use_tpu', default=True, 10 | help=('Use TPU to execute the model for training and evaluation. If' 11 | ' --use_tpu=false, will use whatever devices are available to' 12 | ' TensorFlow by default (e.g. CPU and GPU)')) 13 | 14 | # Cloud TPU Cluster Resolvers 15 | flags.DEFINE_string( 16 | 'tpu', default=None, 17 | help='The Cloud TPU to use for training. This should be either the name ' 18 | 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.') 19 | 20 | flags.DEFINE_string( 21 | 'master', 22 | default=None, 23 | help='The Cloud TPU to use for training. This should be either the name ' 24 | 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.') 25 | 26 | flags.DEFINE_string('tpu_job_name', default=None, help='The tpu worker name.') 27 | 28 | flags.DEFINE_string( 29 | 'gcp_project', default=None, 30 | help='Project name for the Cloud TPU-enabled project. If not specified, we ' 31 | 'will attempt to automatically detect the GCE project from metadata.') 32 | 33 | flags.DEFINE_string( 34 | 'tpu_zone', default=None, 35 | help='GCE zone where the Cloud TPU is located in. If not specified, we ' 36 | 'will attempt to automatically detect the GCE project from metadata.') 37 | 38 | # Model specific flags 39 | flags.DEFINE_string( 40 | 'data_dir', default=FAKE_DATA_DIR, 41 | help=('The directory where the ImageNet input data is stored. Please see' 42 | ' the README.md for the expected data format.')) 43 | 44 | flags.DEFINE_string( 45 | 'model_dir', default=None, 46 | help=('The directory where the model and training/evaluation summaries are' 47 | ' stored.')) 48 | 49 | flags.DEFINE_string( 50 | 'restore_dir', default=None, 51 | help=('The directory where the model should be restored from')) 52 | 53 | flags.DEFINE_bool( 54 | 'restore_trainable_variables', default=True, 55 | help=('Only restore trainable variables')) 56 | 57 | flags.DEFINE_string( 58 | 'params', default=None, 59 | help=('The file to read model parameters from')) 60 | 61 | flags.DEFINE_string( 62 | 'dataset', default=None, 63 | help=('The file to read examples from')) 64 | 65 | flags.DEFINE_string( 66 | 'export_dataset', default=None, 67 | help=('Export the dataset as a .tfrecord file')) 68 | 69 | flags.DEFINE_integer( 70 | 'resnet_depth', default=50, 71 | help=('Depth of ResNet model to use. Must be one of {18, 34, 50, 101, 152,' 72 | ' 200}. ResNet-18 and 34 use the pre-activation residual blocks' 73 | ' without bottleneck layers. The other models use pre-activation' 74 | ' bottleneck layers. Deeper models require more training time and' 75 | ' more memory and may require reducing --train_batch_size to prevent' 76 | ' running out of memory.')) 77 | 78 | flags.DEFINE_string( 79 | 'mode', default='in_memory_eval', 80 | help='One of {"train_and_eval", "train", "eval"}.') 81 | 82 | flags.DEFINE_integer( 83 | 'train_steps', default=112590, 84 | help=('The number of steps to use for training. Default is 112590 steps' 85 | ' which is approximately 90 epochs at batch size 1024. This flag' 86 | ' should be adjusted according to the --train_batch_size flag.')) 87 | 88 | flags.DEFINE_integer( 89 | 'train_batch_size', default=1024, help='Batch size for training.') 90 | 91 | flags.DEFINE_integer( 92 | 'eval_batch_size', default=1024, help='Batch size for evaluation.') 93 | 94 | flags.DEFINE_integer( 95 | 'num_train_images', default=1281167, help='Size of training data set.') 96 | 97 | flags.DEFINE_integer( 98 | 'num_eval_images', default=50000, help='Size of evaluation data set.') 99 | 100 | flags.DEFINE_integer( 101 | 'num_label_classes', default=1000, help='Number of classes, at least 2') 102 | 103 | flags.DEFINE_integer( 104 | 'steps_per_eval', default=1251, 105 | help=('Controls how often evaluation is performed. Since evaluation is' 106 | ' fairly expensive, it is advised to evaluate as infrequently as' 107 | ' possible (i.e. up to --train_steps, which evaluates the model only' 108 | ' after finishing the entire training regime).')) 109 | 110 | flags.DEFINE_integer( 111 | 'eval_timeout', 112 | default=None, 113 | help='Maximum seconds between checkpoints before evaluation terminates.') 114 | 115 | flags.DEFINE_bool( 116 | 'skip_host_call', 117 | default=True, 118 | help=('Skip the host_call which is executed every training step. This is' 119 | ' generally used for generating training summaries (train loss,' 120 | ' learning rate, etc...). When --skip_host_call=false, there could' 121 | ' be a performance drop if host_call function is slow and cannot' 122 | ' keep up with the TPU-side computation.')) 123 | 124 | flags.DEFINE_integer( 125 | 'iterations_per_loop', default=1251, 126 | help=('Number of steps to run on TPU before outfeeding metrics to the CPU.' 127 | ' If the number of iterations in the loop would exceed the number of' 128 | ' train steps, the loop will exit before reaching' 129 | ' --iterations_per_loop. The larger this value is, the higher the' 130 | ' utilization on the TPU.')) 131 | 132 | flags.DEFINE_integer( 133 | 'num_parallel_calls', default=64, 134 | help=('Cycle length of the parallel interleave in tf.data.dataset.')) 135 | 136 | flags.DEFINE_integer( 137 | 'num_prefetch_threads', 138 | default=16, 139 | help=('Number of prefetch threads in CPU for the input pipeline')) 140 | 141 | flags.DEFINE_bool( 142 | 'prefetch_depth_auto_tune', 143 | default=True, 144 | help=('Number of prefetch threads in CPU for the input pipeline')) 145 | 146 | flags.DEFINE_integer( 147 | 'num_cores', default=8, 148 | help=('Number of TPU cores. For a single TPU device, this is 8 because each' 149 | ' TPU has 4 chips each with 2 cores.')) 150 | 151 | flags.DEFINE_string( 152 | 'bigtable_project', None, 153 | 'The Cloud Bigtable project. If None, --gcp_project will be used.') 154 | flags.DEFINE_string( 155 | 'bigtable_instance', None, 156 | 'The Cloud Bigtable instance to load data from.') 157 | flags.DEFINE_string( 158 | 'bigtable_table', 'imagenet', 159 | 'The Cloud Bigtable table to load data from.') 160 | flags.DEFINE_string( 161 | 'bigtable_train_prefix', 'train_', 162 | 'The prefix identifying training rows.') 163 | flags.DEFINE_string( 164 | 'bigtable_eval_prefix', 'validation_', 165 | 'The prefix identifying evaluation rows.') 166 | flags.DEFINE_string( 167 | 'bigtable_column_family', 'tfexample', 168 | 'The column family storing TFExamples.') 169 | flags.DEFINE_string( 170 | 'bigtable_column_qualifier', 'example', 171 | 'The column name storing TFExamples.') 172 | 173 | flags.DEFINE_string( 174 | 'data_format', default='channels_last', 175 | help=('A flag to override the data format used in the model. The value' 176 | ' is either channels_first or channels_last. To run the network on' 177 | ' CPU or TPU, channels_last should be used. For GPU, channels_first' 178 | ' will improve performance.')) 179 | 180 | # TODO(chrisying): remove this flag once --transpose_tpu_infeed flag is enabled 181 | # by default for TPU 182 | flags.DEFINE_bool( 183 | 'transpose_input', default=True, 184 | help='Use TPU double transpose optimization') 185 | 186 | flags.DEFINE_string( 187 | 'export_dir', 188 | default=None, 189 | help=('The directory where the exported SavedModel will be stored.')) 190 | 191 | flags.DEFINE_string( 192 | 'precision', default='bfloat16', 193 | help=('Precision to use; one of: {bfloat16, float32}')) 194 | 195 | flags.DEFINE_float( 196 | 'base_learning_rate', default=0.1, 197 | help=('Base learning rate when train batch size is 256.')) 198 | 199 | flags.DEFINE_float( 200 | 'momentum', default=0.9, 201 | help=('Momentum parameter used in the MomentumOptimizer.')) 202 | 203 | flags.DEFINE_float( 204 | 'weight_decay', default=1e-4, 205 | help=('Weight decay coefficiant for l2 regularization.')) 206 | 207 | flags.DEFINE_float( 208 | 'label_smoothing', default=0.0, 209 | help=('Label smoothing parameter used in the softmax_cross_entropy')) 210 | 211 | flags.DEFINE_integer('log_step_count_steps', 64, 'The number of steps at ' 212 | 'which the global step information is logged.') 213 | 214 | flags.DEFINE_bool('enable_lars', 215 | default=False, 216 | help=('Enable LARS optimizer for large batch training.')) 217 | 218 | flags.DEFINE_float('poly_rate', default=0.0, 219 | help=('Set LARS/Poly learning rate.')) 220 | 221 | flags.DEFINE_bool( 222 | 'use_cache', default=True, help=('Enable cache for training input.')) 223 | flags.DEFINE_bool( 224 | 'cache_decoded_image', default=False, help=('Cache decoded images.')) 225 | 226 | flags.DEFINE_bool( 227 | 'use_async_checkpointing', default=False, help=('Enable async checkpoint')) 228 | flags.DEFINE_float( 229 | 'stop_threshold', default=0.759, help=('Stop threshold for MLPerf.')) 230 | 231 | flags.DEFINE_bool( 232 | 'use_eval_runner', default=True, help=('Bypass estimator on eval.')) 233 | 234 | flags.DEFINE_bool( 235 | 'use_train_runner', default=False, help=('Bypass estimator on train.')) 236 | 237 | flags.DEFINE_integer( 238 | 'tpu_cores_per_host', default=8, help=('Number of TPU cores per host.')) 239 | 240 | flags.DEFINE_integer('image_size', 224, 'The input image size.') 241 | 242 | flags.DEFINE_integer( 243 | 'distributed_group_size', 244 | default=1, 245 | help=('When set to > 1, it will enable distributed batch normalization')) 246 | 247 | # Learning rate schedule 248 | LR_SCHEDULE = [ # (multiplier, epoch to start) tuples 249 | (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) 250 | ] 251 | 252 | flags.DEFINE_boolean( 253 | 'output_summaries', 254 | default=False, 255 | help=('When set to true, outputs tensorboard logs')) 256 | 257 | flags.DEFINE_boolean( 258 | 'enable_auto_tracing', 259 | default=False, 260 | help=('When set to true traces collected from worker-0 on every run')) 261 | 262 | 263 | 264 | import tflex 265 | 266 | 267 | flags.DEFINE_multi_string( 268 | "gin_config", [], 269 | "List of paths to the config files.") 270 | 271 | flags.DEFINE_multi_string( 272 | "gin_bindings", [], 273 | "Newline separated list of Gin parameter bindings.") 274 | 275 | 276 | @gin.configurable 277 | def run_config(*, 278 | iterations_per_loop, 279 | save_checkpoints_steps, 280 | train_steps=-1, 281 | **kwargs): 282 | return tflex.Dictator( 283 | iterations_per_loop=iterations_per_loop, 284 | save_checkpoints_steps=save_checkpoints_steps, 285 | train_steps=train_steps, 286 | **kwargs) 287 | 288 | 289 | @gin.configurable 290 | def options(*, 291 | dataset, 292 | batch_per_core, 293 | use_tpu=True, 294 | **kwargs): 295 | return tflex.Dictator( 296 | dataset=dataset, 297 | batch_per_core=batch_per_core, 298 | use_tpu=use_tpu, 299 | **kwargs) 300 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | # # A highly simplified convenience class for sampling from distributions 5 | # # One could also use PyTorch's inbuilt distributions package. 6 | # # Note that this class requires initialization to proceed as 7 | # # x = Distribution(torch.randn(size)) 8 | # # x.init_distribution(dist_type, **dist_kwargs) 9 | # # x = x.to(device,dtype) 10 | # # This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2 11 | # class Distribution(torch.Tensor): 12 | # # Init the params of the distribution 13 | # def init_distribution(self, dist_type, **kwargs): 14 | # self.dist_type = dist_type 15 | # self.dist_kwargs = kwargs 16 | # if self.dist_type == 'normal': 17 | # self.mean, self.var = kwargs['mean'], kwargs['var'] 18 | # elif self.dist_type == 'categorical': 19 | # self.num_categories = kwargs['num_categories'] 20 | 21 | # def sample_(self): 22 | # if self.dist_type == 'normal': 23 | # self.normal_(self.mean, self.var) 24 | # elif self.dist_type == 'categorical': 25 | # self.random_(0, self.num_categories) 26 | # # return self.variable 27 | 28 | # # Silly hack: overwrite the to() method to wrap the new object 29 | # # in a distribution as well 30 | # def to(self, *args, **kwargs): 31 | # new_obj = Distribution(self) 32 | # new_obj.init_distribution(self.dist_type, **self.dist_kwargs) 33 | # new_obj.data = super().to(*args, **kwargs) 34 | # return new_obj 35 | 36 | 37 | # # Convenience function to prepare a z and y vector 38 | # def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda', 39 | # fp16=False,z_var=1.0): 40 | # z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False)) 41 | # z_.init_distribution('normal', mean=0, var=z_var) 42 | # z_ = z_.to(device,torch.float16 if fp16 else torch.float32) 43 | 44 | # if fp16: 45 | # z_ = z_.half() 46 | 47 | # y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False)) 48 | # y_.init_distribution('categorical',num_categories=nclasses) 49 | # y_ = y_.to(device, torch.int64) 50 | # return z_, y_ 51 | 52 | def distribution(dist_type, shape, *, seed=None, **kwargs): 53 | if dist_type == 'normal': 54 | if seed is None: 55 | return tf.random.normal(shape=shape, **kwargs) 56 | else: 57 | return tf.random.stateless_normal(shape=shape, seed=seed, **kwargs) 58 | elif dist_type == 'categorical': 59 | num_categories = kwargs.pop('num_categories') 60 | if seed is None: 61 | return tf.random.uniform(shape=shape, minval=0, maxval=num_categories, dtype=tf.int64) 62 | else: 63 | return tf.random.stateless_uniform(shape=shape, seed=seed, minval=0, maxval=num_categories, dtype=tf.int64) 64 | else: 65 | raise NotImplementedError() 66 | 67 | def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda', 68 | fp16=False,z_var=1.0, seed=None): 69 | #z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False)) 70 | #z_.init_distribution('normal', mean=0, var=z_var) 71 | #z_ = z_.to(device,torch.float16 if fp16 else torch.float32) 72 | z_ = distribution('normal', shape=[G_batch_size, dim_z], mean=0.0, stddev=z_var, seed=seed) 73 | 74 | if fp16: 75 | #z_ = z_.half() 76 | z_ = tf.cast(z_, tf.float16) 77 | 78 | # y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False)) 79 | # y_.init_distribution('categorical',num_categories=nclasses) 80 | # y_ = y_.to(device, torch.int64) 81 | y_ = distribution('categorical', shape=[G_batch_size], num_categories=nclasses, seed=seed) 82 | assert y_.dtype == tf.int64 83 | return z_, y_ 84 | 85 | # From tinygrad 86 | 87 | import pickle 88 | import numpy as np 89 | from math import prod 90 | 91 | def fetch(url): 92 | if url.startswith("/"): 93 | with open(url, "rb") as f: 94 | dat = f.read() 95 | return dat 96 | import requests, os, hashlib, tempfile 97 | fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest()) 98 | if os.path.isfile(fp) and os.stat(fp).st_size > 0 and os.getenv("NOCACHE", None) is None: 99 | with open(fp, "rb") as f: 100 | dat = f.read() 101 | else: 102 | print("fetching %s" % url) 103 | r = requests.get(url) 104 | assert r.status_code == 200 105 | dat = r.content 106 | with open(fp+".tmp", "wb") as f: 107 | f.write(dat) 108 | os.rename(fp+".tmp", fp) 109 | return dat 110 | 111 | def my_unpickle(fb0): 112 | key_prelookup = {} 113 | class HackTensor: 114 | def __new__(cls, *args): 115 | #print(args) 116 | ident, storage_type, obj_key, location, obj_size = args[0][0:5] 117 | assert ident == 'storage' 118 | 119 | assert prod(args[2]) == obj_size 120 | ret = np.zeros(args[2], dtype=storage_type) 121 | key_prelookup[obj_key] = (storage_type, obj_size, ret, args[2], args[3]) 122 | return ret 123 | 124 | class HackParameter: 125 | def __new__(cls, *args): 126 | #print(args) 127 | pass 128 | 129 | class Dummy: 130 | pass 131 | 132 | class MyPickle(pickle.Unpickler): 133 | def find_class(self, module, name): 134 | #print(module, name) 135 | if name == 'FloatStorage': 136 | return np.float32 137 | if name == 'LongStorage': 138 | return np.int64 139 | if name == 'HalfStorage': 140 | return np.float16 141 | if module == "torch._utils": 142 | if name == "_rebuild_tensor_v2": 143 | return HackTensor 144 | elif name == "_rebuild_parameter": 145 | return HackParameter 146 | else: 147 | try: 148 | return pickle.Unpickler.find_class(self, module, name) 149 | except Exception: 150 | return Dummy 151 | 152 | def persistent_load(self, pid): 153 | return pid 154 | 155 | return MyPickle(fb0).load(), key_prelookup 156 | 157 | def fake_torch_load_zipped(fb0, load_weights=True): 158 | import zipfile 159 | with zipfile.ZipFile(fb0, 'r') as myzip: 160 | with myzip.open('archive/data.pkl') as myfile: 161 | ret = my_unpickle(myfile) 162 | if load_weights: 163 | for k,v in ret[1].items(): 164 | with myzip.open(f'archive/data/{k}') as myfile: 165 | if v[2].dtype == "object": 166 | print(f"issue assigning object on {k}") 167 | continue 168 | np.copyto(v[2], np.frombuffer(myfile.read(), v[2].dtype).reshape(v[3])) 169 | return ret[0] 170 | 171 | def fake_torch_load(b0): 172 | import io 173 | import struct 174 | 175 | # convert it to a file 176 | fb0 = io.BytesIO(b0) 177 | 178 | if b0[0:2] == b"\x50\x4b": 179 | return fake_torch_load_zipped(fb0) 180 | 181 | # skip three junk pickles 182 | pickle.load(fb0) 183 | pickle.load(fb0) 184 | pickle.load(fb0) 185 | 186 | ret, key_prelookup = my_unpickle(fb0) 187 | 188 | # create key_lookup 189 | key_lookup = pickle.load(fb0) 190 | key_real = [None] * len(key_lookup) 191 | for k,v in key_prelookup.items(): 192 | key_real[key_lookup.index(k)] = v 193 | 194 | # read in the actual data 195 | for storage_type, obj_size, np_array, np_shape, np_strides in key_real: 196 | ll = struct.unpack("Q", fb0.read(8))[0] 197 | assert ll == obj_size 198 | bytes_size = {np.float32: 4, np.int64: 8}[storage_type] 199 | mydat = fb0.read(ll * bytes_size) 200 | np.copyto(np_array, np.frombuffer(mydat, storage_type).reshape(np_shape)) 201 | 202 | # numpy stores its strides in bytes 203 | real_strides = tuple([x*bytes_size for x in np_strides]) 204 | np_array.strides = real_strides 205 | 206 | return ret 207 | 208 | def get_child(parent, key): 209 | obj = parent 210 | for k in key.split('.'): 211 | if k.isnumeric(): 212 | obj = obj[int(k)] 213 | elif isinstance(obj, dict): 214 | obj = obj[k] 215 | else: 216 | obj = getattr(obj, k) 217 | return obj 218 | 219 | -------------------------------------------------------------------------------- /wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from pprint import pprint as pp 7 | from pprint import pformat as pf 8 | from contextlib import contextmanager 9 | 10 | import sys 11 | import os 12 | import re 13 | import six 14 | import json 15 | import base64 16 | from six.moves.urllib.error import URLError 17 | 18 | 19 | from tensorflow.python.eager import context 20 | from tensorflow.python import framework 21 | from tensorflow.python.client import session 22 | from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver 23 | from tensorflow.compat.v1.distribute.cluster_resolver import TPUClusterResolver as BaseTPUClusterResolver 24 | from tensorflow.python.eager.context import LogicalDevice 25 | from tensorflow.python.framework import errors 26 | from tensorflow.python.framework import test_util 27 | from tensorflow.python.platform import test 28 | from tensorflow.python.training import server_lib 29 | from tensorflow.python.util import compat 30 | 31 | from tensorflow.core.protobuf.tpu import topology_pb2 32 | from tensorflow.python.tpu import topology as topology_lib 33 | 34 | import gin 35 | 36 | try: 37 | from cloud_tpu_client import client # pylint: disable=g-import-not-at-top 38 | except ImportError: 39 | try: 40 | logging.debug( 41 | 'Falling back to TensorFlow client; we recommended you install the Cloud ' 42 | 'TPU client directly with pip install cloud-tpu-client.') 43 | from tensorflow.python.tpu.client import client # pylint: disable=g-import-not-at-top 44 | except ImportError: 45 | client = None 46 | 47 | 48 | mock = test.mock 49 | 50 | def reroute(addr, host=None): 51 | if host is None or host is False: 52 | return addr 53 | if addr.startswith('grpc://'): 54 | return 'grpc://' + reroute(addr[len('grpc://'):], host=host) 55 | if not re.match('[0-9]+[.][0-9]+[.][0-9]+[.][0-9]+[:]8470', addr): 56 | return addr 57 | if not addr.endswith(':8470'): 58 | return addr 59 | a, b, c, d = [int(x) for x in addr.split(':')[0].split('.')] 60 | if a == 10 and b in [48, 49]: 61 | assert (d == 2) 62 | port = b * 1000 + c 63 | elif a == 10 and b in range(2, 66) and c == 0: 64 | port = b * 1000 + d 65 | else: 66 | return addr 67 | return host + ':' + str(port) 68 | 69 | 70 | class TPUClusterResolver(BaseTPUClusterResolver): 71 | def __init__(self, *args, host=None, node_count=None, node_offset=None, **kws): 72 | kws['project'] = kws.pop('project', 'gpt-2-15b-poetry') 73 | super(TPUClusterResolver, self).__init__(*args, **kws) 74 | if host is None: 75 | host = _tpu_host() 76 | self._host = host 77 | if node_count is None: 78 | if 'TPU_NODE_COUNT' in os.environ: 79 | node_count = int(os.environ['TPU_NODE_COUNT']) 80 | self._node_count = node_count 81 | if node_offset is None: 82 | if 'TPU_NODE_OFFSET' in os.environ: 83 | node_offset = int(os.environ['TPU_NODE_OFFSET']) 84 | self._node_offset = node_offset 85 | 86 | def master(self, *args, **kws): 87 | ip = super(TPUClusterResolver, self).master(*args, **kws) 88 | return reroute(ip, host=self._host) 89 | 90 | def cluster_spec(self): 91 | spec = super(TPUClusterResolver, self).cluster_spec() 92 | r = dict() 93 | for k, v in spec.as_dict().items(): 94 | r[k] = [reroute(ip, host=self._host) for ip in v] 95 | #k = 'worker' 96 | i = self._node_count or len(r[k]) 97 | j = self._node_offset or 0 98 | for k, v in r.items(): 99 | r[k] = [r[k][0]] + r[k][(j+1):(j+1)+(i-1)] 100 | spec2 = server_lib.ClusterSpec(r) 101 | print(spec2.as_cluster_def()) 102 | return spec2 103 | 104 | 105 | from six.moves.urllib import request 106 | 107 | def _as_text(s): 108 | if isinstance(s, bytes): 109 | return s.decode('utf-8') 110 | return s 111 | 112 | def _request_compute_metadata(path): 113 | _GCE_METADATA_ENDPOINT = 'http://35.225.160.61' 114 | req = request.Request( 115 | '%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path), 116 | headers={'Metadata-Flavor': 'Google'}) 117 | resp = request.urlopen(req) 118 | return _as_text(resp.read()) 119 | 120 | # cli = client.Client(tpu=os.environ['TPU_NAME']) 121 | # service = cli._tpu_service() 122 | # info = service.projects().locations().nodes().get(name=cli._full_name().replace(os.environ['TPU_NAME'], 'tpu-v2-8-usc1f-0')).execute() 123 | # {'name': 'projects/gpt-2-15b-poetry/locations/us-central1-f/nodes/tpu-v2-8-usc1f-0', 'acceleratorType': 'v2-8', 'ipAddress': '10.48.0.2', 'state': 'READY', 'tensorflowVersion': '2.3', 'network': 'global/networks/tpu-usc1f', 'cidrBlock': '10.48.0.0/29', 'port': '8470', 'serviceAccount': 'service-41076153887@cloud-tpu.iam.gserviceaccount.com', 'createTime': '2020-09-18T07:21:45.237850246Z', 'schedulingConfig': {'preemptible': True}, 'networkEndpoints': [{'ipAddress': '10.48.0.2', 'port': 8470}], 'health': 'HEALTHY'} 124 | 125 | 126 | _master = resolver.TPUClusterResolver.master 127 | 128 | def _tpu_host(): 129 | return os.environ.get('TPU_HOST', '10.255.128.3') 130 | 131 | def mock_master(cls, *args, **kws): 132 | ip = _master(cls, *args, **kws) 133 | return reroute(ip, host=os.environ.get('TPU_HOST', None)) 134 | 135 | _cluster_spec = resolver.TPUClusterResolver.cluster_spec 136 | 137 | def cluster_spec(cls, *args, **kws): 138 | spec = _cluster_spec(cls, *args, **kws) 139 | r = dict() 140 | for k, v in spec.as_dict().items(): 141 | r[k] = [reroute(ip, host=os.environ.get('TPU_HOST', None)) for ip in v] 142 | return server_lib.ClusterSpec(r) 143 | 144 | 145 | __fetch_cloud_tpu_metadata = (client.Client if client is not None else resolver.TPUClusterResolver)._fetch_cloud_tpu_metadata 146 | 147 | def _fetch_cloud_tpu_metadata(cls, *args, **kws): 148 | while True: 149 | try: 150 | return __fetch_cloud_tpu_metadata(cls, *args, **kws) 151 | except Exception as e: 152 | if '[Errno 111] Connection refused' in str(e): 153 | # retry 154 | import time 155 | time.sleep(1.0) 156 | else: 157 | raise e 158 | 159 | 160 | __parse_topology = topology_lib.Topology._parse_topology 161 | 162 | def _parse_topology(self, serialized): 163 | """Parses a serialized `TopologyProto` into `self`.""" 164 | proto = topology_pb2.TopologyProto() 165 | proto.ParseFromString(serialized) 166 | 167 | self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32) 168 | if len(self._mesh_shape) != 4 or any(self._mesh_shape < 1): 169 | return __parse_topology(self, serialized) 170 | raise ValueError("`mesh_shape` must be a vector of size 4 with positive " 171 | "entries; got {}".format(self._mesh_shape)) 172 | 173 | if proto.num_tasks < 0: 174 | raise ValueError("`num_tasks` must be >= 0; got {}".format( 175 | proto.num_tasks)) 176 | if proto.num_tpu_devices_per_task < 0: 177 | raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format( 178 | proto.num_tpu_devices_per_task)) 179 | 180 | expected_coordinates_size = ( 181 | proto.num_tasks * proto.num_tpu_devices_per_task * len( 182 | proto.mesh_shape)) 183 | if len(proto.device_coordinates) != expected_coordinates_size: 184 | raise ValueError("`device_coordinates` must have shape num_tasks ({}) * " 185 | "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); " 186 | "got shape {}".format(proto.num_tasks, 187 | proto.num_tpu_devices_per_task, 188 | proto.mesh_shape, 189 | len(proto.device_coordinates))) 190 | 191 | coords = np.array(proto.device_coordinates, dtype=np.int32) 192 | if any(coords < 0): 193 | raise ValueError("`device_coordinates` must be >= 0") 194 | coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task, 195 | len(proto.mesh_shape))) 196 | self._device_coordinates = coords 197 | 198 | 199 | __invert_topology = topology_lib.Topology._invert_topology 200 | 201 | def _invert_topology(self): 202 | """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps.""" 203 | if len(self.mesh_shape) != 4 or any(self.mesh_shape < 1): 204 | return __invert_topology(self) 205 | tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32) 206 | devices = np.full(list(self.mesh_shape), -1, dtype=np.int32) 207 | for task in range(self.device_coordinates.shape[0]): 208 | for device in range(self.device_coordinates.shape[1]): 209 | x, y, z, core = self.device_coordinates[task, device, :] 210 | tasks[x, y, z, core] = task 211 | devices[x, y, z, core] = device 212 | return tasks, devices 213 | 214 | 215 | @contextmanager 216 | def patch_tensorflow(): 217 | with mock.patch.object(resolver.TPUClusterResolver, 'master', mock_master): 218 | with mock.patch.object(resolver.TPUClusterResolver, 'cluster_spec', cluster_spec): 219 | with mock.patch.object(client.Client if client is not None else resolver.TPUClusterResolver, '_fetch_cloud_tpu_metadata', _fetch_cloud_tpu_metadata): 220 | with mock.patch.object(topology_lib.Topology, '_parse_topology', _parse_topology): 221 | with mock.patch.object(topology_lib.Topology, '_invert_topology', _invert_topology): 222 | result = yield 223 | return result 224 | 225 | 226 | def patch_tensorflow_interactive(): 227 | patch = patch_tensorflow() 228 | patch.__enter__() 229 | gin.enter_interactive_mode() 230 | return patch 231 | 232 | def interact(): 233 | import code 234 | code.InteractiveConsole(locals=globals()).interact() 235 | 236 | 237 | def clone_session(session=None, graph=None, interactive=False, **kws): 238 | if session is None: 239 | session = tf.get_default_session() 240 | if graph is None: 241 | graph = session.graph 242 | config = session._config # is there a better way to do this? 243 | master = session.sess_str # is there a better way to do this? 244 | Session = (tf.compat.v1.InteractiveSession if interactive else tf.Session) 245 | return Session(master, graph=graph, config=config, **kws) 246 | 247 | 248 | def reset_session(session=None, graph=None, interactive=True, **kws): 249 | if session is None: 250 | session = tf.get_default_session() 251 | if graph is None: 252 | graph = tf.Graph() 253 | graph.as_default().__enter__() 254 | session2 = clone_session(session, graph=graph, interactive=interactive, **kws) 255 | session2.as_default().__enter__() 256 | if 'sess' in globals(): 257 | globals()['sess'] = session2 258 | return session2 259 | 260 | from tensorflow.python.distribute import values 261 | 262 | def enclosing_tpu_context(): 263 | return values._enclosing_tpu_context() 264 | 265 | 266 | from tensorflow.core.protobuf import config_pb2 267 | from tensorflow.core.protobuf import rewriter_config_pb2 268 | from tensorflow.python.client import session 269 | from tensorflow.python.debug.lib import debug_data 270 | from tensorflow.python.debug.lib import debug_gradients 271 | from tensorflow.python.debug.lib import debug_utils 272 | from tensorflow.python.framework import ops 273 | from tensorflow.python.framework import test_util 274 | from tensorflow.python.lib.io import file_io 275 | from tensorflow.python.ops import gradients_impl 276 | from tensorflow.python.ops import math_ops 277 | from tensorflow.python.ops import variables 278 | from tensorflow.python.platform import googletest 279 | from tensorflow.python.training import gradient_descent 280 | 281 | 282 | 283 | if __name__ == '__main__': 284 | _tf_patch = patch_tensorflow_interactive() 285 | if len(sys.argv) <= 1: 286 | from tensorflow.core.protobuf import config_pb2 287 | import tensorflow as tf 288 | tf1 = tf.compat.v1 289 | tf.compat.v1.logging.set_verbosity('DEBUG') 290 | import numpy as np 291 | #session_config = config_pb2.ConfigProto(allow_soft_placement=True, isolate_session_state=True) 292 | rpc_options = config_pb2.RPCOptions() 293 | # Setting cache_rpc_response to true will enable sender side caching of 294 | # response for RecvTensorAsync and RecvBufAsync to allow receiver to retry 295 | # requests . This is only necessary when the network fabric is experiencing a 296 | # significant error rate. Without it we'll fail a step on an network error, 297 | # while with it we'll be able to complete long steps (like complex 298 | # initializations) in the face of some network errors during RecvTensor. 299 | rpc_options.cache_rpc_response = True 300 | 301 | rewriter_config = rewriter_config_pb2.RewriterConfig( 302 | disable_model_pruning=True, 303 | disable_meta_optimizer=True, 304 | dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, 305 | fail_on_optimizer_errors=True, 306 | ) 307 | 308 | graph_options = config_pb2.GraphOptions( 309 | rewrite_options=rewriter_config, 310 | place_pruned_graph=True, 311 | infer_shapes=True, 312 | ) 313 | 314 | session_config = config_pb2.ConfigProto( 315 | graph_options=graph_options, 316 | allow_soft_placement=True, 317 | isolate_session_state=False, 318 | ) 319 | # share variables across sessions on TPUs 320 | session_config.experimental.share_session_state_in_clusterspec_propagation = True 321 | 322 | # TODO: research this. What does it do? 323 | # session_config.share_cluster_devices_in_session = True 324 | 325 | master = None 326 | res = None 327 | cluster_spec = None 328 | cluster_def = None 329 | job_names = None 330 | master_job = 'worker' 331 | try: 332 | if 'TPU_NAME' in os.environ: 333 | res = TPUClusterResolver(os.environ['TPU_NAME']) 334 | master = res.get_master() 335 | cluster_spec = res.cluster_spec() 336 | if cluster_spec: 337 | cluster_def = cluster_spec.as_cluster_def() 338 | session_config.cluster_def.CopyFrom(cluster_def) 339 | job_names = set([job.name for job in cluster_def.job]) 340 | assert len(job_names) == 1 341 | master_job = cluster_def.job[0].name 342 | elif 'TPU_IP' in os.environ: 343 | master = os.environ['TPU_IP'].replace('grpc://', '') 344 | if ':' not in master: 345 | master = master + ':8470' 346 | master = 'grpc://' + master 347 | except: 348 | import traceback 349 | traceback.print_exc() 350 | graph = tf.Graph() 351 | sess = tf.compat.v1.InteractiveSession(master, graph=graph, config=session_config) 352 | devices = sess.list_devices() 353 | cores = sorted([x.name for x in devices if ':TPU:' in x.name]) 354 | num_cores = len(cores) 355 | print(cluster_def) 356 | print('cores: %d ip: %s' % (num_cores, master)) 357 | r = sess.run 358 | from importlib import reload 359 | import tf_tools as tft 360 | from tensorflow.python.tpu import tpu as tpu_ops 361 | from tensorflow.compiler.tf2xla.python import xla 362 | from tensorflow.compiler.tf2xla.ops import gen_xla_ops 363 | from tensorflow.python.tpu import tpu_strategy_util 364 | from tensorflow.python.tpu import device_assignment as device_assignment_lib 365 | from tensorflow.python.tpu import topology as topology_lib 366 | tpu_topology = None 367 | topology_cache = {} 368 | try: 369 | with open('topology.cache', 'r') as f: 370 | topology_cache = json.load(f) 371 | except FileNotFoundError: 372 | pass 373 | def cached_topology(name=None): 374 | if name is None: 375 | name = os.environ.get('TPU_NAME', None) 376 | result = topology_cache.get(name, None) 377 | if result is not None: 378 | serialized = base64.b64decode(result) 379 | return topology_lib.Topology(serialized=serialized) 380 | def get_topology(): 381 | global tpu_topology 382 | tpu_topology = cached_topology() 383 | if tpu_topology is None: 384 | tpu_topology = tpu_strategy_util.initialize_tpu_system(res) 385 | topology_cache.update({os.environ['TPU_NAME']: base64.b64encode(tpu_topology.serialized()).decode('utf8')}) 386 | with open('topology.cache', 'w') as f: 387 | f.write(json.dumps(topology_cache)) 388 | return tpu_topology 389 | def get_task_and_cores_to_replicas(): 390 | return device_assignment_lib._compute_task_and_cores_to_replicas(tpu_topology.device_coordinates, tpu_topology) 391 | def get_core_assignment(*core_ids): 392 | return device_assignment_lib.DeviceAssignment(get_topology(), [[get_topology().device_coordinates[0][i]] for i in core_ids]) 393 | def get_device_assignment(num_replicas, computation_shape=None, topology=None): 394 | if topology is None: 395 | topology = get_topology() 396 | if computation_shape is None: 397 | computation_shape = [1, 1, 1, 2] 398 | device_assignment = tf.tpu.experimental.DeviceAssignment.build(topology, computation_shape=computation_shape, num_replicas=num_replicas) 399 | return device_assignment 400 | tpu_topology = cached_topology() 401 | else: 402 | filename = sys.argv[1] 403 | sys.argv = sys.argv[1:] 404 | with open(filename) as f: 405 | source = f.read() 406 | code = compile(source, filename, 'exec') 407 | exec(code, globals(), globals()) 408 | 409 | 410 | --------------------------------------------------------------------------------