├── requirements.txt ├── llorma_g ├── configs.py ├── local.py ├── batch.py ├── pre_trainer.py ├── anchor.py ├── model.py └── trainer.py ├── llorma_p ├── configs.py ├── batch.py ├── trainer.py ├── pre_trainer.py ├── anchor.py ├── local.py └── model.py ├── train.py ├── base ├── train.py ├── rprop.py ├── dataset.py └── memory_saving_gradients.py ├── README.md ├── LICENSE └── .gitignore /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.4.0 2 | sklearn==0.0 3 | scipy==1.1.0 4 | -------------------------------------------------------------------------------- /llorma_g/configs.py: -------------------------------------------------------------------------------- 1 | GPU_MEMORY_FRAC = 0.95 2 | N_SHOT = 0 3 | 4 | N_ANCHOR = 100 5 | 6 | PRE_RANK = 5 7 | PRE_LEARNING_RATE = 2e-4 8 | PRE_LAMBDA = 10 9 | 10 | RANK = 20 11 | LEARNING_RATE = 1e-2 12 | LAMBDA = 1e-3 13 | BATCH_SIZE = 1000 14 | 15 | USE_CACHE = True 16 | -------------------------------------------------------------------------------- /llorma_p/configs.py: -------------------------------------------------------------------------------- 1 | GPU_MEMORY_FRAC = 0.95 2 | N_SHOT = 0 3 | 4 | N_ANCHOR = 50 5 | 6 | PRE_RANK = 5 7 | PRE_LEARNING_RATE = 1e-4 8 | PRE_LAMBDA = 10 9 | 10 | LOCAL_RANK = 20 11 | LOCAL_LEARNING_RATE = 1e-2 12 | LOCAL_LAMBDA = 1e-3 13 | 14 | BATCH_SIZE = 1000 15 | 16 | USE_CACHE = True 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from base.dataset import DatasetManager 2 | 3 | from llorma_p.trainer import main as llorma_parallel_train 4 | from llorma_g.trainer import main as llorma_global_train 5 | 6 | if __name__ == '__main__': 7 | # kind = DatasetManager.KIND_MOVIELENS_100K 8 | kind = DatasetManager.KIND_MOVIELENS_1M 9 | # kind = DatasetManager.KIND_MOVIELENS_10M 10 | # kind = DatasetManager.KIND_MOVIELENS_20M 11 | 12 | # llorma_parallel_train(kind) 13 | llorma_global_train(kind) 14 | -------------------------------------------------------------------------------- /base/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def cosine_decay_learning_rate(learning_rate, 7 | global_step, 8 | decay_steps=200, 9 | alpha=0.01): 10 | # tensorflow==1.4.0에서 못쓰니까 구현. 11 | global_step = tf.cast(global_step, tf.int64) 12 | step = tf.cast(tf.mod(global_step, decay_steps), tf.float32) 13 | cosine_decay = 0.5 * (1.0 + tf.cos(math.pi * step / decay_steps)) 14 | decayed = (1 - alpha) * cosine_decay + alpha 15 | return learning_rate * decayed 16 | -------------------------------------------------------------------------------- /llorma_g/local.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | 4 | import numpy as np 5 | 6 | from .configs import * 7 | 8 | 9 | class LocalModel: 10 | def __init__(self, session, models, anchor_idx, anchor_manager, 11 | batch_manager): 12 | self.session = session 13 | self.models = models 14 | self.batch_manager = batch_manager 15 | self.anchor_idx = anchor_idx 16 | self.anchor_manager = anchor_manager 17 | 18 | print('>> update k in anchor_idx [{}].'.format(anchor_idx)) 19 | self.train_k = anchor_manager.get_train_k(anchor_idx) 20 | self.valid_k = anchor_manager.get_valid_k(anchor_idx) 21 | self.test_k = anchor_manager.get_test_k(anchor_idx) 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLORMA-tensorflow 2 | 3 | * This repository is tensorflow implementation of "Local Low-rank Matrix Approximation". 4 | * I implemented two version of LLORMA: [Parallel LLORMA (ICML'13)](https://static.googleusercontent.com/media/research.google.com/ko//pubs/archive/45235.pdf) and [Global LLORMA (JMLR'16)](http://jmlr.org/papers/volume17/14-301/14-301.pdf). 5 | * I have increased the batch size for performance. If you want to get results as same as the original paper, please set batch size to 1. 6 | * I refer the codes from https://github.com/jnhwkim/PREA/tree/master/src/main/java/prea/recommender/llorma. 7 | 8 | * Folder description 9 | * llorma_p: Parallel LLORMA (ICML'13) 10 | * llorma_g: Global LLORMA (JMLR'16) 11 | 12 | * Dependecy: `python3.5` and see `requirements.txt` 13 | 14 | * How to run 15 | ``` 16 | python train.py 17 | ``` 18 | -------------------------------------------------------------------------------- /llorma_p/batch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from base.dataset import DatasetManager 4 | from .configs import * 5 | 6 | 7 | class BatchManager: 8 | def __init__(self, kind): 9 | self.kind = kind 10 | dataset_manager = DatasetManager(kind, N_SHOT) 11 | self.train_data = np.concatenate( 12 | [ 13 | dataset_manager.get_train_data(), 14 | dataset_manager.get_valid_data() 15 | ], 16 | axis=0) 17 | self.test_data = dataset_manager.get_test_data() 18 | 19 | self.n_user = int( 20 | max(np.max(self.train_data[:, 0]), np.max(self.test_data[:, 21 | 0]))) + 1 22 | self.n_item = int( 23 | max(np.max(self.train_data[:, 1]), np.max(self.test_data[:, 24 | 1]))) + 1 25 | self.mu = np.mean(self.train_data[:, 2]) 26 | self.std = np.std(self.train_data[:, 2]) 27 | -------------------------------------------------------------------------------- /llorma_g/batch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from base.dataset import DatasetManager 4 | from .configs import * 5 | 6 | 7 | class BatchManager: 8 | def __init__(self, kind): 9 | self.kind = kind 10 | dataset_manager = DatasetManager(kind, N_SHOT) 11 | self.train_data = dataset_manager.get_train_data() 12 | self.valid_data = dataset_manager.get_valid_data() 13 | self.test_data = dataset_manager.get_test_data() 14 | 15 | self.n_user = int( 16 | max( 17 | np.max(self.train_data[:, 0]), 18 | np.max(self.valid_data[:, 0]), np.max(self.test_data[:, 19 | 0]))) + 1 20 | self.n_item = int( 21 | max( 22 | np.max(self.train_data[:, 1]), 23 | np.max(self.valid_data[:, 1]), np.max(self.test_data[:, 24 | 1]))) + 1 25 | self.mu = np.mean(self.train_data[:, 2]) 26 | self.std = np.std(self.train_data[:, 2]) 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Joonyoung Yi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # 107 | logs/ 108 | data/ 109 | tmp/ 110 | .venv-cpu/ 111 | z_s_col.npy 112 | z_s_row.npy 113 | movielens-* 114 | -------------------------------------------------------------------------------- /llorma_p/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import random 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from . import pre_trainer 10 | from .anchor import AnchorManager 11 | from .batch import BatchManager 12 | from .configs import * 13 | from .local import LocalModel 14 | from .model import init_models 15 | 16 | 17 | def __init_session(): 18 | # gpu_options = tf.GPUOptions( 19 | # per_process_gpu_memory_fraction=GPU_MEMORY_FRAC) 20 | # gpu_config = tf.ConfigProto(gpu_options=gpu_options) 21 | # session = tf.Session(config=gpu_config) 22 | 23 | config = tf.ConfigProto() 24 | config.gpu_options.allow_growth = True 25 | 26 | session = tf.Session(config=config) 27 | session.run(tf.global_variables_initializer()) 28 | return session 29 | 30 | 31 | def __get_rmse(local_models, batch_manager, key='train'): 32 | r_hats = np.stack( 33 | [ 34 | getattr(local_model, '{}_r_hat'.format(key)) 35 | for local_model in local_models 36 | ], 37 | axis=1) 38 | ks = np.stack( 39 | [ 40 | getattr(local_model, '{}_k'.format(key)) 41 | for local_model in local_models 42 | ], 43 | axis=1) 44 | sum_ks = np.sum(ks, axis=1) 45 | sum_r_hats = np.sum(np.multiply(r_hats, ks), axis=1) 46 | r_hat = tf.divide(sum_r_hats, sum_ks) 47 | r_hat[np.isnan(r_hat)] = 3 48 | 49 | r = getattr(batch_manager, '{}_data'.format(key))[:, 2] 50 | rmse = np.sqrt(np.mean(np.square(r_hat - r))) 51 | return rmse 52 | 53 | 54 | def _get_rmses(local_models, batch_manager): 55 | train_rmse = __get_rmse(local_models, batch_manager, key='train') 56 | test_rmse = __get_rmse(local_models, batch_manager, key='test') 57 | return train_rmse, test_rmse 58 | 59 | 60 | def _train(kind): 61 | row_latent_init, col_latent_init = pre_trainer.get_p_and_q( 62 | kind, use_cache=USE_CACHE) 63 | 64 | batch_manager = BatchManager(kind) 65 | models = init_models(batch_manager) 66 | 67 | session = __init_session() 68 | anchor_manager = AnchorManager( 69 | session, 70 | models, 71 | batch_manager, 72 | row_latent_init, 73 | col_latent_init, ) 74 | local_models = [ 75 | LocalModel(session, models, anchor_idx, anchor_manager, batch_manager) 76 | for anchor_idx in range(N_ANCHOR) 77 | ] 78 | 79 | for local_idx, local_model in enumerate(local_models): 80 | local_model.train() 81 | 82 | train_rmse, test_rmse = _get_rmses(local_models[:i+1], batch_manager) 83 | 84 | print(">> LOCAL [{:3d}] {:.5f}, {:.5f}\n".format(local_idx, train_rmse, 85 | test_rmse)) 86 | 87 | 88 | def main(kind): 89 | _train(kind) 90 | -------------------------------------------------------------------------------- /llorma_g/pre_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import random 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from .batch import BatchManager 10 | from .configs import * 11 | from .model import init_models_for_pre_train 12 | 13 | 14 | def _validate(session, batch_manager, models): 15 | valid_rmse = session.run( 16 | models['rmse'], 17 | feed_dict={ 18 | models['u']: batch_manager.valid_data[:, 0], 19 | models['i']: batch_manager.valid_data[:, 1], 20 | models['r']: batch_manager.valid_data[:, 2] 21 | }) 22 | 23 | test_rmse = session.run( 24 | models['rmse'], 25 | feed_dict={ 26 | models['u']: batch_manager.test_data[:, 0], 27 | models['i']: batch_manager.test_data[:, 1], 28 | models['r']: batch_manager.test_data[:, 2] 29 | }) 30 | 31 | return valid_rmse, test_rmse 32 | 33 | 34 | def get_p_and_q(kind, use_cache=True): 35 | if use_cache: 36 | try: 37 | p = np.load('llorma_g/{}-p.npy'.format(kind)) 38 | q = np.load('llorma_g/{}-q.npy'.format(kind)) 39 | return p, q 40 | except: 41 | print('>> There is no cached p and q.') 42 | 43 | batch_manager = BatchManager(kind) 44 | models = init_models_for_pre_train(batch_manager) 45 | 46 | session = tf.Session() 47 | session.run(tf.global_variables_initializer()) 48 | 49 | min_valid_rmse = float("Inf") 50 | min_valid_iter = 0 51 | final_test_rmse = float("Inf") 52 | 53 | random_model_idx = random.randint(0, 1000000) 54 | 55 | file_path = "tmp/model-{}.ckpt".format(random_model_idx) 56 | 57 | u = batch_manager.train_data[:, 0] 58 | i = batch_manager.train_data[:, 1] 59 | r = batch_manager.train_data[:, 2] 60 | 61 | saver = tf.train.Saver() 62 | for iter in range(1000000): 63 | for train_op in models['train_ops']: 64 | _, loss, train_rmse = session.run( 65 | (train_op, models['loss'], models['rmse']), 66 | feed_dict={models['u']: u, 67 | models['i']: i, 68 | models['r']: r}) 69 | 70 | valid_rmse, test_rmse = _validate(session, batch_manager, models) 71 | 72 | if valid_rmse < min_valid_rmse: 73 | min_valid_rmse = valid_rmse 74 | min_valid_iter = iter 75 | final_test_rmse = test_rmse 76 | saver.save(session, file_path) 77 | 78 | if iter >= min_valid_iter + 100: 79 | break 80 | 81 | print('>> ITER:', 82 | "{:3d}".format(iter), "{:3f}, {:3f} {:3f} / {:3f}".format( 83 | train_rmse, valid_rmse, test_rmse, final_test_rmse)) 84 | 85 | saver.restore(session, file_path) 86 | p, q = session.run( 87 | (models['p'], models['q']), 88 | feed_dict={ 89 | models['u']: batch_manager.train_data[:, 0], 90 | models['i']: batch_manager.train_data[:, 1], 91 | models['r']: batch_manager.train_data[:, 2] 92 | }) 93 | np.save('llorma_g/{}-p.npy'.format(kind), p) 94 | np.save('llorma_g/{}-q.npy'.format(kind), q) 95 | 96 | session.close() 97 | return p, q 98 | -------------------------------------------------------------------------------- /llorma_p/pre_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import random 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from .batch import BatchManager 10 | from .configs import * 11 | from .model import init_models_for_pre_train 12 | 13 | 14 | def _validate(session, batch_manager, models): 15 | valid_rmse = session.run( 16 | models['rmse'], 17 | feed_dict={ 18 | models['u']: batch_manager.valid_data[:, 0], 19 | models['i']: batch_manager.valid_data[:, 1], 20 | models['r']: batch_manager.valid_data[:, 2] 21 | }) 22 | 23 | test_rmse = session.run( 24 | models['rmse'], 25 | feed_dict={ 26 | models['u']: batch_manager.test_data[:, 0], 27 | models['i']: batch_manager.test_data[:, 1], 28 | models['r']: batch_manager.test_data[:, 2] 29 | }) 30 | 31 | return valid_rmse, test_rmse 32 | 33 | 34 | def get_p_and_q(kind, use_cache=True): 35 | if use_cache: 36 | try: 37 | p = np.load('llorma_p/{}-p.npy'.format(kind)) 38 | q = np.load('llorma_p/{}-q.npy'.format(kind)) 39 | return p, q 40 | except: 41 | print('>> There is no cached p and q.') 42 | 43 | batch_manager = BatchManager(kind) 44 | models = init_models_for_pre_train(batch_manager) 45 | 46 | gpu_options = tf.GPUOptions( 47 | per_process_gpu_memory_fraction=GPU_MEMORY_FRAC) 48 | gpu_config = tf.ConfigProto(gpu_options=gpu_options) 49 | 50 | session = tf.Session(config=gpu_config) 51 | session.run(tf.global_variables_initializer()) 52 | 53 | min_valid_rmse = float("Inf") 54 | min_valid_iter = 0 55 | final_test_rmse = float("Inf") 56 | 57 | random_model_idx = random.randint(0, 1000000) 58 | 59 | file_path = "tmp/model-{}.ckpt".format(random_model_idx) 60 | 61 | u = batch_manager.train_data[:, 0] 62 | i = batch_manager.train_data[:, 1] 63 | r = batch_manager.train_data[:, 2] 64 | 65 | saver = tf.train.Saver() 66 | for iter in range(1000000): 67 | for train_op in models['train_ops']: 68 | _, loss, train_rmse = session.run( 69 | (train_op, models['loss'], models['rmse']), 70 | feed_dict={models['u']: u, 71 | models['i']: i, 72 | models['r']: r}) 73 | 74 | valid_rmse, test_rmse = _validate(session, batch_manager, models) 75 | 76 | if valid_rmse < min_valid_rmse: 77 | min_valid_rmse = valid_rmse 78 | min_valid_iter = iter 79 | final_test_rmse = test_rmse 80 | saver.save(session, file_path) 81 | 82 | if iter >= min_valid_iter + 100: 83 | break 84 | 85 | print('>> ITER:', 86 | "{:3d}".format(iter), "{:3f}, {:3f} {:3f} / {:3f}".format( 87 | train_rmse, valid_rmse, test_rmse, final_test_rmse)) 88 | 89 | saver.restore(session, file_path) 90 | p, q = session.run( 91 | (models['p'], models['q']), 92 | feed_dict={ 93 | models['u']: batch_manager.train_data[:, 0], 94 | models['i']: batch_manager.train_data[:, 1], 95 | models['r']: batch_manager.train_data[:, 2] 96 | }) 97 | np.save('llorma_p/{}-p.npy'.format(kind), p) 98 | np.save('llorma_p/{}-q.npy'.format(kind), q) 99 | 100 | session.close() 101 | return p, q 102 | -------------------------------------------------------------------------------- /llorma_p/anchor.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from sklearn.preprocessing import normalize 5 | 6 | from .configs import * 7 | 8 | 9 | def _init_anchor_points(train_data, row_k, col_k): 10 | train_user_ids = train_data[:, 0].astype(np.int64) 11 | train_item_ids = train_data[:, 1].astype(np.int64) 12 | 13 | anchor_idxs = [] 14 | while len(anchor_idxs) < N_ANCHOR: 15 | anchor_idx = random.randint(0, train_data.shape[0] - 1) 16 | if anchor_idx in anchor_idxs: 17 | continue 18 | 19 | anchor_row = train_data[anchor_idx] 20 | user_id = int(anchor_row[0]) 21 | item_id = int(anchor_row[1]) 22 | 23 | k = np.multiply(row_k[user_id][train_user_ids], 24 | col_k[item_id][train_item_ids]) 25 | sum_a_of_anchor = np.sum(k) 26 | if sum_a_of_anchor < 1: 27 | continue 28 | 29 | print('>> %10d\t%d' % (anchor_idx, sum_a_of_anchor)) 30 | anchor_idxs.append(anchor_idx) 31 | 32 | return anchor_idxs 33 | 34 | 35 | def _get_distance_matrix(latent): 36 | _normalized_latent = normalize(latent, axis=1) 37 | # print(_normalized_latent.shape) 38 | 39 | cos = np.matmul(_normalized_latent, _normalized_latent.T) 40 | cos = np.clip(cos, -1, 1) 41 | d = np.arccos(cos) 42 | assert np.count_nonzero(np.isnan(d)) == 0 43 | return d 44 | 45 | 46 | def _get_k_from_distance(d): 47 | m = np.zeros(d.shape) 48 | m[d < 0.8] = 1 49 | return np.multiply(np.subtract(np.ones(d.shape), np.square(d)), m) 50 | 51 | 52 | def _get_ks_from_latents(row_latent, col_latent): 53 | 54 | # for i in range(row_latent.shape[0]): 55 | # print(row_latent[i][:4]) 56 | # 57 | # assert False 58 | row_d = _get_distance_matrix(row_latent) 59 | col_d = _get_distance_matrix(col_latent) 60 | 61 | row_k = _get_k_from_distance(row_d) 62 | col_k = _get_k_from_distance(col_d) 63 | 64 | return row_k, col_k 65 | 66 | 67 | class AnchorManager: 68 | def __init__( 69 | self, 70 | session, 71 | models, 72 | batch_manager, 73 | row_latent_init, 74 | col_latent_init, ): 75 | 76 | train_data = batch_manager.train_data 77 | 78 | row_latent = row_latent_init 79 | col_latent = col_latent_init 80 | 81 | row_k, col_k = _get_ks_from_latents(row_latent, col_latent) 82 | 83 | anchor_idxs = _init_anchor_points(train_data, row_k, col_k) 84 | assert len(anchor_idxs) == N_ANCHOR 85 | # print(anchor_idxs) 86 | anchor_points = train_data[anchor_idxs] 87 | 88 | self.train_data = train_data 89 | self.test_data = batch_manager.test_data 90 | 91 | self.anchor_idxs = anchor_idxs 92 | self.anchor_points = anchor_points 93 | 94 | self.row_k = row_k 95 | self.col_k = col_k 96 | 97 | def _get_k(self, anchor_idx, data): 98 | row_k = self.row_k 99 | col_k = self.col_k 100 | anchor_point = self.anchor_points[anchor_idx] 101 | 102 | user_id = int(anchor_point[0]) 103 | item_id = int(anchor_point[1]) 104 | 105 | user_ids = data[:, 0].astype(np.int64) 106 | item_ids = data[:, 1].astype(np.int64) 107 | 108 | return np.multiply(row_k[user_id][user_ids], col_k[item_id][item_ids]) 109 | 110 | def get_train_k(self, anchor_idx): 111 | return self._get_k(anchor_idx, self.train_data) 112 | 113 | def get_test_k(self, anchor_idx): 114 | return self._get_k(anchor_idx, self.test_data) 115 | -------------------------------------------------------------------------------- /llorma_g/anchor.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from sklearn.preprocessing import normalize 5 | 6 | from .configs import * 7 | 8 | 9 | def _init_anchor_points(train_data, row_k, col_k): 10 | train_user_ids = train_data[:, 0].astype(np.int64) 11 | train_item_ids = train_data[:, 1].astype(np.int64) 12 | 13 | anchor_idxs = [] 14 | while len(anchor_idxs) < N_ANCHOR: 15 | anchor_idx = random.randint(0, train_data.shape[0] - 1) 16 | if anchor_idx in anchor_idxs: 17 | continue 18 | 19 | anchor_row = train_data[anchor_idx] 20 | user_id = int(anchor_row[0]) 21 | item_id = int(anchor_row[1]) 22 | 23 | k = np.multiply(row_k[user_id][train_user_ids], 24 | col_k[item_id][train_item_ids]) 25 | sum_a_of_anchor = np.sum(k) 26 | if sum_a_of_anchor < 1: 27 | continue 28 | 29 | print('>> %10d\t%d' % (anchor_idx, sum_a_of_anchor)) 30 | anchor_idxs.append(anchor_idx) 31 | 32 | return anchor_idxs 33 | 34 | 35 | def _get_distance_matrix(latent): 36 | _normalized_latent = normalize(latent, axis=1) 37 | # print(_normalized_latent.shape) 38 | 39 | cos = np.matmul(_normalized_latent, _normalized_latent.T) 40 | cos = np.clip(cos, -1, 1) 41 | d = np.arccos(cos) 42 | assert np.count_nonzero(np.isnan(d)) == 0 43 | return d 44 | 45 | 46 | def _get_k_from_distance(d): 47 | m = np.zeros(d.shape) 48 | m[d < 0.8] = 1 49 | return np.multiply(np.subtract(np.ones(d.shape), np.square(d)), m) 50 | 51 | 52 | def _get_ks_from_latents(row_latent, col_latent): 53 | 54 | # for i in range(row_latent.shape[0]): 55 | # print(row_latent[i][:4]) 56 | # 57 | # assert False 58 | row_d = _get_distance_matrix(row_latent) 59 | col_d = _get_distance_matrix(col_latent) 60 | 61 | row_k = _get_k_from_distance(row_d) 62 | col_k = _get_k_from_distance(col_d) 63 | 64 | return row_k, col_k 65 | 66 | 67 | class AnchorManager: 68 | def __init__( 69 | self, 70 | session, 71 | models, 72 | batch_manager, 73 | row_latent_init, 74 | col_latent_init, ): 75 | 76 | train_data = batch_manager.train_data 77 | 78 | row_latent = row_latent_init 79 | col_latent = col_latent_init 80 | 81 | row_k, col_k = _get_ks_from_latents(row_latent, col_latent) 82 | 83 | anchor_idxs = _init_anchor_points(train_data, row_k, col_k) 84 | assert len(anchor_idxs) == N_ANCHOR 85 | # print(anchor_idxs) 86 | anchor_points = train_data[anchor_idxs] 87 | 88 | self.train_data = train_data 89 | self.valid_data = batch_manager.valid_data 90 | self.test_data = batch_manager.test_data 91 | 92 | self.anchor_idxs = anchor_idxs 93 | self.anchor_points = anchor_points 94 | 95 | self.row_k = row_k 96 | self.col_k = col_k 97 | 98 | def _get_k(self, anchor_idx, data): 99 | row_k = self.row_k 100 | col_k = self.col_k 101 | anchor_point = self.anchor_points[anchor_idx] 102 | 103 | user_id = int(anchor_point[0]) 104 | item_id = int(anchor_point[1]) 105 | 106 | user_ids = data[:, 0].astype(np.int64) 107 | item_ids = data[:, 1].astype(np.int64) 108 | 109 | return np.multiply(row_k[user_id][user_ids], col_k[item_id][item_ids]) 110 | 111 | def get_train_k(self, anchor_idx): 112 | return self._get_k(anchor_idx, self.train_data) 113 | 114 | def get_valid_k(self, anchor_idx): 115 | return self._get_k(anchor_idx, self.valid_data) 116 | 117 | def get_test_k(self, anchor_idx): 118 | return self._get_k(anchor_idx, self.test_data) 119 | -------------------------------------------------------------------------------- /llorma_g/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | 5 | import numpy as np 6 | 7 | from .configs import * 8 | from base.dataset import DatasetManager 9 | from base.rprop import RPropOptimizer 10 | 11 | 12 | def _create_p_or_q_variable(n, rank, batch_manager): 13 | mu = batch_manager.mu 14 | std = batch_manager.std 15 | 16 | _mu = math.sqrt(mu / rank) 17 | _std = math.sqrt((math.sqrt(mu * mu + std * std) - mu) / rank) 18 | return tf.Variable( 19 | tf.truncated_normal([n, rank], _mu, _std, dtype=tf.float64)) 20 | 21 | 22 | def init_models_for_pre_train(batch_manager): 23 | n_row, n_col = batch_manager.n_user, batch_manager.n_item 24 | 25 | u = tf.placeholder(tf.int64, [None], name='u') 26 | i = tf.placeholder(tf.int64, [None], name='i') 27 | r = tf.placeholder(tf.float64, [None], name='r') 28 | 29 | # init weights 30 | mu = batch_manager.mu 31 | std = batch_manager.std 32 | p = _create_p_or_q_variable(n_row, PRE_RANK, batch_manager) 33 | q = _create_p_or_q_variable(n_col, PRE_RANK, batch_manager) 34 | 35 | p_lookup = tf.nn.embedding_lookup(p, u) 36 | q_lookup = tf.nn.embedding_lookup(q, i) 37 | r_hat = tf.reduce_sum(tf.multiply(p_lookup, q_lookup), 1) 38 | 39 | reg_loss = tf.add_n( 40 | [tf.reduce_sum(tf.square(p)), 41 | tf.reduce_sum(tf.square(q))]) 42 | loss = tf.reduce_sum(tf.square(r - r_hat)) + PRE_LAMBDA * reg_loss 43 | rmse = tf.sqrt(tf.reduce_mean(tf.square(r - r_hat))) 44 | 45 | optimizer = tf.train.MomentumOptimizer(PRE_LEARNING_RATE, 0.9) 46 | # optimizer = tf.train.GradientDescentOptimizer(PRE_LEARNING_RATE) 47 | train_ops = [ 48 | optimizer.minimize(loss, var_list=[p]), 49 | optimizer.minimize(loss, var_list=[q]) 50 | ] 51 | 52 | return { 53 | 'u': u, 54 | 'i': i, 55 | 'r': r, 56 | 'train_ops': train_ops, 57 | 'loss': loss, 58 | 'rmse': rmse, 59 | 'p': p, 60 | 'q': q, 61 | } 62 | 63 | 64 | def _get_train_op(optimizer, loss, var_list): 65 | gvs = optimizer.compute_gradients(loss, var_list=var_list) 66 | # capped_gvs = [(tf.clip_by_value(grad, -100.0, 100.0), var) 67 | # for grad, var in gvs] 68 | capped_gvs = gvs 69 | train_op = optimizer.apply_gradients(capped_gvs) 70 | return train_op 71 | 72 | 73 | def init_models(batch_manager): 74 | n_row, n_col = batch_manager.n_user, batch_manager.n_item 75 | 76 | u = tf.placeholder(tf.int64, [None], name='u') 77 | i = tf.placeholder(tf.int64, [None], name='i') 78 | r = tf.placeholder(tf.float64, [None], name='r') 79 | k = tf.placeholder(tf.float64, [None, N_ANCHOR], name='k') 80 | k_sum = tf.reduce_sum(k, axis=1) 81 | 82 | # init weights 83 | ps, qs, losses, r_hats = [], [], [], [] 84 | for anchor_idx in range(N_ANCHOR): 85 | p = _create_p_or_q_variable(n_row, RANK, batch_manager) 86 | q = _create_p_or_q_variable(n_col, RANK, batch_manager) 87 | ps.append(p) 88 | qs.append(q) 89 | 90 | p_lookup = tf.nn.embedding_lookup(p, u) 91 | q_lookup = tf.nn.embedding_lookup(q, i) 92 | r_hat = tf.reduce_sum(tf.multiply(p_lookup, q_lookup), axis=1) 93 | r_hats.append(r_hat) 94 | 95 | r_hat = tf.reduce_sum(tf.multiply(k, tf.stack(r_hats, axis=1)), axis=1) 96 | r_hat = tf.where(tf.greater(k_sum, 1e-2), r_hat, tf.ones_like(r_hat) * 3) 97 | rmse = tf.sqrt(tf.reduce_mean(tf.square(r - r_hat))) 98 | 99 | optimizer = tf.train.GradientDescentOptimizer(LEARNING_RATE) 100 | loss = tf.reduce_sum(tf.square(r_hat - r)) + LAMBDA * tf.reduce_sum( 101 | [tf.reduce_sum(tf.square(p_or_q)) for p_or_q in ps + qs]) 102 | train_ops = [ 103 | _get_train_op(optimizer, loss, [p, q]) for p, q in zip(ps, qs) 104 | ] 105 | 106 | return { 107 | 'u': u, 108 | 'i': i, 109 | 'r': r, 110 | 'k': k, 111 | 'train_ops': train_ops, 112 | 'rmse': rmse, 113 | } 114 | -------------------------------------------------------------------------------- /llorma_p/local.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | 4 | import numpy as np 5 | 6 | from .configs import * 7 | 8 | 9 | def _create_p_or_q(n, rank, batch_manager): 10 | mu = batch_manager.mu 11 | std = batch_manager.std 12 | 13 | _mu = math.sqrt(mu / rank) 14 | _std = math.sqrt((math.sqrt(mu * mu + std * std) - mu) / rank) 15 | return np.random.normal(_mu, _std, [n, rank]) 16 | 17 | 18 | def _assign_p_and_q(session, models, p, q): 19 | p_assign_op = models['local_p'].assign(p) 20 | q_assign_op = models['local_q'].assign(q) 21 | 22 | session.run((p_assign_op, q_assign_op)) 23 | 24 | 25 | class LocalModel: 26 | def __init__(self, session, models, anchor_idx, anchor_manager, 27 | batch_manager): 28 | self.session = session 29 | self.models = models 30 | self.batch_manager = batch_manager 31 | self.anchor_idx = anchor_idx 32 | self.anchor_manager = anchor_manager 33 | self.p = _create_p_or_q(batch_manager.n_user, LOCAL_RANK, 34 | batch_manager) 35 | self.q = _create_p_or_q(batch_manager.n_item, LOCAL_RANK, 36 | batch_manager) 37 | 38 | print('>> update k in anchor_idx [{}].'.format(anchor_idx)) 39 | self.train_k = anchor_manager.get_train_k(anchor_idx) 40 | self.test_k = anchor_manager.get_test_k(anchor_idx) 41 | 42 | _assign_p_and_q(session, models, self.p, self.q) 43 | self._update_r_hats() 44 | 45 | def _update_r_hats(self): 46 | session = self.session 47 | models = self.models 48 | batch_manager = self.batch_manager 49 | 50 | train_r_hat = session.run( 51 | models['local_r_hat'], 52 | feed_dict={ 53 | models['u']: batch_manager.train_data[:, 0], 54 | models['i']: batch_manager.train_data[:, 1], 55 | models['r']: batch_manager.train_data[:, 2], 56 | }) 57 | 58 | test_r_hat = session.run( 59 | models['local_r_hat'], 60 | feed_dict={ 61 | models['u']: batch_manager.test_data[:, 0], 62 | models['i']: batch_manager.test_data[:, 1], 63 | models['r']: batch_manager.test_data[:, 2], 64 | }) 65 | 66 | self.train_r_hat = train_r_hat 67 | self.test_r_hat = test_r_hat 68 | 69 | def train(self): 70 | session = self.session 71 | models = self.models 72 | batch_manager = self.batch_manager 73 | anchor_idx = self.anchor_idx 74 | anchor_manager = self.anchor_manager 75 | p = self.p 76 | q = self.q 77 | train_k = self.train_k 78 | 79 | _assign_p_and_q(session, models, p, q) 80 | 81 | train_data = batch_manager.train_data 82 | prev_train_rmse = 5.0 83 | sum_batch_sse = 0.0 84 | n_batch = 0 85 | for iter in range(1, 100 + 1): 86 | for m in range(0, train_data.shape[0], BATCH_SIZE): 87 | end_m = min(m + BATCH_SIZE, train_data.shape[0]) 88 | u = train_data[m:end_m, 0] 89 | i = train_data[m:end_m, 1] 90 | r = train_data[m:end_m, 2] 91 | k = train_k[m:end_m] 92 | sse, _ = session.run( 93 | (models['local_sse'], models['local_train_op']), 94 | feed_dict={ 95 | models['u']: u, 96 | models['i']: i, 97 | models['r']: r, 98 | models['k']: k, 99 | }) 100 | sum_batch_sse += sse 101 | n_batch += u.shape[0] 102 | 103 | train_rmse = math.sqrt(sum_batch_sse / n_batch) 104 | if iter % 10 == 0: 105 | print(' - ITER [{:3d}]'.format(iter), train_rmse) 106 | 107 | if abs(prev_train_rmse - train_rmse) < 1e-4: 108 | break 109 | prev_train_rmse = train_rmse 110 | p, q = session.run((models['local_p'], models['local_q'])) 111 | sum_batch_sse, n_batch = 0, 0 112 | 113 | self.p, self.q = p, q 114 | _assign_p_and_q(session, models, p, q) 115 | 116 | self._update_r_hats() 117 | 118 | self.train_k = train_k 119 | -------------------------------------------------------------------------------- /llorma_p/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | 5 | import numpy as np 6 | 7 | from .configs import * 8 | from base.dataset import DatasetManager 9 | from base.rprop import RPropOptimizer 10 | 11 | 12 | def init_models_for_pre_train(batch_manager): 13 | def _create_p_or_q_variable(n, rank, batch_manager): 14 | # TODO: 밖으로 꺼내야 함. 15 | mu = batch_manager.mu 16 | std = batch_manager.std 17 | 18 | _mu = math.sqrt(mu / rank) 19 | _std = math.sqrt((math.sqrt(mu * mu + std * std) - mu) / rank) 20 | return tf.Variable(tf.truncated_normal([n, rank], _mu, _std)) 21 | 22 | n_row, n_col = batch_manager.n_user, batch_manager.n_item 23 | 24 | u = tf.placeholder(tf.int64, [None], name='u') 25 | i = tf.placeholder(tf.int64, [None], name='i') 26 | r = tf.placeholder(tf.float32, [None], name='r') 27 | 28 | # init weights 29 | mu = batch_manager.mu 30 | std = batch_manager.std 31 | p = _create_p_or_q_variable(n_row, PRE_RANK, batch_manager) 32 | q = _create_p_or_q_variable(n_col, PRE_RANK, batch_manager) 33 | 34 | p_lookup = tf.nn.embedding_lookup(p, u) 35 | q_lookup = tf.nn.embedding_lookup(q, i) 36 | r_hat = tf.reduce_sum(tf.multiply(p_lookup, q_lookup), 1) 37 | 38 | reg_loss = tf.add_n( 39 | [tf.reduce_sum(tf.square(p)), 40 | tf.reduce_sum(tf.square(q))]) 41 | loss = tf.reduce_sum(tf.square(r - r_hat)) + PRE_LAMBDA * reg_loss 42 | rmse = tf.sqrt(tf.reduce_mean(tf.square(r - r_hat))) 43 | 44 | optimizer = tf.train.MomentumOptimizer(PRE_LEARNING_RATE, 0.9) 45 | # optimizer = tf.train.GradientDescentOptimizer(PRE_LEARNING_RATE) 46 | train_ops = [ 47 | optimizer.minimize(loss, var_list=[p]), 48 | optimizer.minimize(loss, var_list=[q]) 49 | ] 50 | 51 | return { 52 | 'u': u, 53 | 'i': i, 54 | 'r': r, 55 | 'train_ops': train_ops, 56 | 'loss': loss, 57 | 'rmse': rmse, 58 | 'p': p, 59 | 'q': q, 60 | } 61 | 62 | 63 | def init_models(batch_manager): 64 | n_row, n_col = batch_manager.n_user, batch_manager.n_item 65 | 66 | u = tf.placeholder(tf.int64, [None], name='u') 67 | i = tf.placeholder(tf.int64, [None], name='i') 68 | r = tf.placeholder(tf.float32, [None], name='r') 69 | k = tf.placeholder(tf.float32, [None], name='k') 70 | 71 | # train_u = tf.constant( 72 | # batch_manager.train_data[:, 0], dtype=tf.int64, name='train_u') 73 | # train_i = tf.constant( 74 | # batch_manager.train_data[:, 1], dtype=tf.int64, name='train_i') 75 | # train_r = tf.constant( 76 | # batch_manager.train_data[:, 2], dtype=tf.float32, name='train_r') 77 | 78 | # row_vecs = tf.SparseTensor( 79 | # indices=tf.stack([train_u, train_i], axis=1), 80 | # values=train_r, 81 | # dense_shape=[n_row, n_col]) 82 | # col_vecs = tf.SparseTensor( 83 | # indices=tf.stack([train_i, train_u], axis=1), 84 | # values=train_r, 85 | # dense_shape=[n_col, n_row]) 86 | 87 | # init weights 88 | local_p = tf.Variable(tf.zeros([n_row, LOCAL_RANK])) 89 | local_q = tf.Variable(tf.zeros([n_col, LOCAL_RANK])) 90 | 91 | local_p_lookup = tf.nn.embedding_lookup(local_p, u) 92 | local_q_lookup = tf.nn.embedding_lookup(local_q, i) 93 | local_r_hat = tf.reduce_sum( 94 | tf.multiply(local_p_lookup, local_q_lookup), axis=1) 95 | local_loss = tf.reduce_sum( 96 | tf.square(r - local_r_hat) * k) + LOCAL_LAMBDA * tf.add_n([ 97 | tf.reduce_sum(tf.square(local_p)), 98 | tf.reduce_sum(tf.square(local_q)) 99 | ]) 100 | local_sse = tf.reduce_sum(tf.square(r - local_r_hat)) 101 | 102 | # local_optimizer = tf.train.MomentumOptimizer(LOCAL_LEARNING_RATE, 0.9) 103 | # _optimizer = tf.train.MomentumOptimizer(LEARNING_RATE, 0.9) 104 | # _optimizer = tf.train.AdamOptimizer(LEARNING_RATE) 105 | local_optimizer = tf.train.GradientDescentOptimizer(LOCAL_LEARNING_RATE) 106 | local_train_op = local_optimizer.minimize( 107 | local_loss, var_list=[local_p, local_q]) 108 | 109 | return { 110 | 'u': u, 111 | 'i': i, 112 | 'r': r, 113 | 'k': k, 114 | 'local_p': local_p, 115 | 'local_q': local_q, 116 | 'local_r_hat': local_r_hat, 117 | 'local_train_op': local_train_op, 118 | 'local_sse': local_sse, 119 | } 120 | -------------------------------------------------------------------------------- /llorma_g/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import random 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from . import pre_trainer 10 | from .anchor import AnchorManager 11 | from .batch import BatchManager 12 | from .configs import * 13 | from .local import LocalModel 14 | from .model import init_models 15 | 16 | 17 | def __init_session(): 18 | # gpu_options = tf.GPUOptions( 19 | # per_process_gpu_memory_fraction=GPU_MEMORY_FRAC) 20 | # gpu_config = tf.ConfigProto(gpu_options=gpu_options) 21 | # session = tf.Session(config=gpu_config) 22 | 23 | config = tf.ConfigProto() 24 | config.gpu_options.allow_growth = True 25 | 26 | session = tf.Session(config=config) 27 | session.run(tf.global_variables_initializer()) 28 | return session 29 | 30 | 31 | def _get_k(local_models, kind='train'): 32 | k = np.stack( 33 | [ 34 | getattr(local_model, '{}_k'.format(kind)) 35 | for local_model in local_models 36 | ], 37 | axis=1) 38 | k = np.clip(k, 0.0, 1.0) 39 | k = np.divide(k, np.sum(k, axis=1, keepdims=1)) 40 | k[np.isnan(k)] = 0 41 | return k 42 | 43 | 44 | def _validate( 45 | session, 46 | models, 47 | batch_manager, 48 | valid_k, 49 | test_k, ): 50 | valid_rmse = session.run( 51 | models['rmse'], 52 | feed_dict={ 53 | models['u']: batch_manager.valid_data[:, 0], 54 | models['i']: batch_manager.valid_data[:, 1], 55 | models['r']: batch_manager.valid_data[:, 2], 56 | models['k']: valid_k, 57 | }) 58 | 59 | test_rmse = session.run( 60 | models['rmse'], 61 | feed_dict={ 62 | models['u']: batch_manager.test_data[:, 0], 63 | models['i']: batch_manager.test_data[:, 1], 64 | models['r']: batch_manager.test_data[:, 2], 65 | models['k']: test_k, 66 | }) 67 | 68 | return valid_rmse, test_rmse 69 | 70 | 71 | def _train(kind): 72 | row_latent_init, col_latent_init = pre_trainer.get_p_and_q( 73 | kind, use_cache=USE_CACHE) 74 | 75 | batch_manager = BatchManager(kind) 76 | models = init_models(batch_manager) 77 | 78 | session = __init_session() 79 | anchor_manager = AnchorManager( 80 | session, 81 | models, 82 | batch_manager, 83 | row_latent_init, 84 | col_latent_init, ) 85 | local_models = [ 86 | LocalModel(session, models, anchor_idx, anchor_manager, batch_manager) 87 | for anchor_idx in range(N_ANCHOR) 88 | ] 89 | 90 | train_k = _get_k(local_models, kind='train') 91 | valid_k = _get_k(local_models, kind='valid') 92 | test_k = _get_k(local_models, kind='test') 93 | 94 | min_valid_rmse = float("Inf") 95 | min_valid_iter = 0 96 | final_test_rmse = float("Inf") 97 | start_time = time.time() 98 | 99 | batch_rmses = [] 100 | train_data = batch_manager.train_data 101 | 102 | for iter in range(10000000): 103 | for m in range(0, train_data.shape[0], BATCH_SIZE): 104 | end_m = min(m + BATCH_SIZE, train_data.shape[0]) 105 | u = train_data[m:end_m, 0] 106 | i = train_data[m:end_m, 1] 107 | r = train_data[m:end_m, 2] 108 | k = train_k[m:end_m, :] 109 | results = session.run( 110 | [models['rmse']] + models['train_ops'], 111 | feed_dict={ 112 | models['u']: u, 113 | models['i']: i, 114 | models['r']: r, 115 | models['k']: k, 116 | }) 117 | batch_rmses.append(results[0]) 118 | 119 | if m % (BATCH_SIZE * 100) == 0: 120 | print(' - ', results[:1]) 121 | 122 | if iter % 1 == 0: 123 | valid_rmse, test_rmse = _validate(session, models, batch_manager, 124 | valid_k, test_k) 125 | if valid_rmse < min_valid_rmse: 126 | min_valid_rmse = valid_rmse 127 | min_valid_iter = iter 128 | final_test_rmse = test_rmse 129 | 130 | batch_rmse = sum(batch_rmses) / len(batch_rmses) 131 | batch_rmses = [] 132 | print(' - ITER{:4d}:'.format(iter), 133 | "{:.5f}, {:.5f} {:.5f} / {:.5f}".format( 134 | batch_rmse, valid_rmse, test_rmse, final_test_rmse)) 135 | 136 | 137 | def main(kind): 138 | _train(kind) 139 | -------------------------------------------------------------------------------- /base/rprop.py: -------------------------------------------------------------------------------- 1 | """ 2 | RProp (Resilient Backpropagation) for TensorFlow. 3 | This code is forked form "https://raw.githubusercontent.com/dirkweissenborn/genie-kb/master/rprop.py". 4 | """ 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import tensorflow as tf 11 | from tensorflow.python.framework import ops 12 | from tensorflow.python.training import optimizer 13 | 14 | 15 | class RPropOptimizer(optimizer.Optimizer): 16 | """ 17 | Optimizer that implements the RProp algorithm. 18 | """ 19 | 20 | def __init__(self, 21 | stepsize=0.1, 22 | etaplus=1.2, 23 | etaminus=0.5, 24 | stepsizemax=50.0, 25 | stepsizemin=1e-6, 26 | use_locking=False, 27 | name="RProp"): 28 | super(RPropOptimizer, self).__init__(use_locking, name) 29 | self._stepsize = stepsize 30 | self._etaplus = etaplus 31 | self._etaminus = etaminus 32 | self._stepsizemax = stepsizemax 33 | self._stepsizemin = stepsizemin 34 | 35 | def _create_slots(self, var_list): 36 | ''' 37 | :param var_list: 38 | :return: 39 | ''' 40 | # Create the beta1 and beta2 accumulators on the same device as the first 41 | # variable. 42 | 43 | # Create slots for the first and second moments. 44 | for v in var_list: 45 | self._get_or_make_slot( 46 | v, 47 | tf.ones([v.get_shape().num_elements()], dtype=tf.float32) * 48 | self._stepsize, 49 | "step", 50 | self._name, ) 51 | self._get_or_make_slot( 52 | v, 53 | tf.zeros([v.get_shape().num_elements()], dtype=tf.float32), 54 | "delta", 55 | self._name, ) 56 | self._get_or_make_slot( 57 | v, 58 | tf.zeros([v.get_shape().num_elements()], dtype=tf.float32), 59 | "grad", 60 | self._name, ) 61 | 62 | def _apply_dense(self, grad, var): 63 | grad_slot = self.get_slot(var, "grad") 64 | step_slot = self.get_slot(var, "step") 65 | delta_slot = self.get_slot(var, "delta") 66 | 67 | grad = tf.reshape(grad, [-1]) 68 | sign = tf.cast(tf.sign(grad_slot * grad), tf.int64) 69 | with tf.control_dependencies([sign]): 70 | grad = grad_slot.assign(grad) 71 | 72 | p_indices = tf.where(tf.equal(sign, 1)) # positive indices 73 | m_indices = tf.where(tf.equal(sign, -1)) # minus indices 74 | z_indices = tf.where(tf.equal(sign, 0)) # zero indices 75 | 76 | step_p_update = tf.expand_dims( 77 | tf.minimum( 78 | tf.gather_nd(step_slot, p_indices) * self._etaplus, 79 | self._stepsizemax), 1) 80 | step_m_update = tf.expand_dims( 81 | tf.maximum( 82 | tf.gather_nd(step_slot, m_indices) * self._etaminus, 83 | self._stepsizemin), 1) 84 | step_z_update = tf.expand_dims(tf.gather_nd(step_slot, z_indices), 1) 85 | with tf.control_dependencies( 86 | [step_p_update, step_m_update, step_z_update]): 87 | step = tf.scatter_update(step_slot, p_indices, step_p_update) 88 | step = tf.scatter_update(step, m_indices, step_m_update) 89 | step = tf.scatter_update(step, z_indices, step_z_update) 90 | step = step_slot.assign(step) 91 | 92 | delta_p_update = tf.expand_dims( 93 | tf.gather_nd(tf.sign(grad) * step, p_indices), 1) 94 | delta_z_update = tf.expand_dims( 95 | tf.gather_nd(tf.sign(grad) * step, z_indices), 1) 96 | with tf.control_dependencies([delta_p_update, delta_z_update]): 97 | delta = tf.scatter_update(delta_slot, p_indices, delta_p_update) 98 | delta = tf.scatter_update(delta, z_indices, delta_z_update) 99 | delta = delta_slot.assign(delta) 100 | 101 | with tf.control_dependencies([sign]): 102 | grad = tf.scatter_update(grad, m_indices, 103 | tf.zeros_like(m_indices, tf.float32)) 104 | grad = grad_slot.assign(grad) 105 | 106 | up = tf.reshape(delta, var.get_shape()) 107 | var_update = var.assign_sub(up, use_locking=self._use_locking) 108 | 109 | return tf.group(*[var_update, step, delta, grad]) 110 | 111 | def _apply_sparse(self, grad, var): 112 | raise NotImplementedError("RProp should be used only in batch_mode.") 113 | -------------------------------------------------------------------------------- /base/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | 6 | # from ..configs import * 7 | 8 | 9 | def _make_dir_if_not_exists(path): 10 | if not os.path.exists(path): 11 | os.mkdir(path) 12 | 13 | 14 | class DatasetManager: 15 | KIND_MOVIELENS_100K = 'movielens-100k' 16 | KIND_MOVIELENS_1M = 'movielens-1m' 17 | KIND_MOVIELENS_10M = 'movielens-10m' 18 | KIND_MOVIELENS_20M = 'movielens-20m' 19 | KIND_NETFLIX = 'netflix' 20 | 21 | KIND_OBJECTS = ( \ 22 | (KIND_MOVIELENS_100K, 'http://files.grouplens.org/datasets/movielens/ml-100k.zip'), \ 23 | (KIND_MOVIELENS_1M, 'http://files.grouplens.org/datasets/movielens/ml-1m.zip'), \ 24 | (KIND_MOVIELENS_10M, 'http://files.grouplens.org/datasets/movielens/ml-10m.zip'), \ 25 | (KIND_MOVIELENS_20M, 'http://files.grouplens.org/datasets/movielens/ml-20m.zip'), \ 26 | (KIND_NETFLIX, None) 27 | ) 28 | 29 | def _set_kind_and_url(self, kind): 30 | self.kind = kind 31 | for k, url in self.KIND_OBJECTS: 32 | if k == kind: 33 | self.url = url 34 | return True 35 | raise NotImplementedError() 36 | 37 | def _download_data_if_not_exists(self): 38 | if not os.path.exists('data/{}'.format(self.kind)): 39 | os.system('wget {url} -O data/{kind}.zip'.format( 40 | url=self.url, kind=self.kind)) 41 | os.system( 42 | 'unzip data/{kind}.zip -d data/{kind}/'.format(kind=self.kind)) 43 | 44 | def __init_data(self, detail_path, delimiter, header=False): 45 | current_u = 0 46 | u_dict = {} 47 | current_i = 0 48 | i_dict = {} 49 | 50 | data = [] 51 | with open('data/{}{}'.format(self.kind, detail_path), 'r') as f: 52 | if header: 53 | f.readline() 54 | 55 | for line in f: 56 | cols = line.strip().split(delimiter) 57 | assert len(cols) == 4 58 | # cols = [float(c) for c in cols] 59 | user_id = cols[0] 60 | item_id = cols[1] 61 | r = float(cols[2]) 62 | t = int(cols[3]) 63 | 64 | u = u_dict.get(user_id, None) 65 | if u is None: 66 | u_dict[user_id] = current_u 67 | u = current_u 68 | current_u += 1 69 | 70 | i = i_dict.get(item_id, None) 71 | if i is None: 72 | # print(current_i) 73 | i_dict[item_id] = current_i 74 | i = current_i 75 | current_i += 1 76 | 77 | data.append((u, i, r, t)) 78 | f.close() 79 | 80 | data = np.array(data) 81 | np.save('data/{}/data.npy'.format(self.kind), data) 82 | 83 | def _init_data(self): 84 | if self.kind == self.KIND_MOVIELENS_100K: 85 | self.__init_data('/ml-100k/u.data', '\t') 86 | elif self.kind == self.KIND_MOVIELENS_1M: 87 | self.__init_data('/ml-1m/ratings.dat', '::') 88 | elif self.kind == self.KIND_MOVIELENS_10M: 89 | self.__init_data('/ml-10M100K/ratings.dat', '::') 90 | elif self.kind == self.KIND_MOVIELENS_20M: 91 | self.__init_data('/ml-20m/ratings.csv', ',', header=True) 92 | else: 93 | raise NotImplementedError() 94 | 95 | def _load_base_data(self): 96 | return np.load('data/{}/data.npy'.format(self.kind)) 97 | 98 | def _split_data(self): 99 | data = self.data 100 | n_shot = self.n_shot 101 | np.random.shuffle(data) 102 | 103 | if self.n_shot == -1: 104 | # n_shot이 -1일때는 더 sparse하게 전체 레이팅을 9:1로 test train set을 나눈다. 105 | n_train = int(data.shape[0] * 0.1) 106 | n_valid = int(n_train * 0.9) 107 | 108 | train_data = data[:n_valid] 109 | valid_data = data[n_valid:n_train] 110 | test_data = data[n_train:] 111 | 112 | np.save(self._get_npy_path('train'), train_data) 113 | np.save(self._get_npy_path('valid'), valid_data) 114 | np.save(self._get_npy_path('test'), test_data) 115 | 116 | elif self.n_shot == 0: 117 | # n_shot이 0일때는 다른 알고리즘들처럼 전체 레이팅을 1:9로 test train set을 나눈다. 118 | n_train = int(data.shape[0] * 0.9) 119 | n_valid = int(n_train * 0.98) 120 | 121 | train_data = data[:n_valid] 122 | valid_data = data[n_valid:n_train] 123 | test_data = data[n_train:] 124 | 125 | np.save(self._get_npy_path('train'), train_data) 126 | np.save(self._get_npy_path('valid'), valid_data) 127 | np.save(self._get_npy_path('test'), test_data) 128 | 129 | else: 130 | # 전체 유저 중에 20%를 일단 test user로 뗍니다. 131 | test_user_ids = random.sample( 132 | list(range(self.n_user)), self.n_user // 5) 133 | 134 | train_data = [] 135 | test_data = [] 136 | count_dict = {} 137 | for i in range(data.shape[0]): 138 | row = data[i] 139 | user_id = int(row[0]) 140 | if user_id in test_user_ids: 141 | count = count_dict.get(user_id, 0) 142 | if count < n_shot: 143 | train_data.append(row) 144 | else: 145 | test_data.append(row) 146 | count_dict[user_id] = count + 1 147 | else: 148 | train_data.append(row) 149 | 150 | train_data = np.array(train_data) 151 | n_valid = int(train_data.shape[0] * 0.98) 152 | train_data, valid_data = train_data[:n_valid], train_data[n_valid:] 153 | 154 | np.save(self._get_npy_path('train'), train_data) 155 | np.save(self._get_npy_path('valid'), valid_data) 156 | 157 | test_data = np.array(test_data) 158 | np.save(self._get_npy_path('test'), test_data) 159 | 160 | def _get_npy_path(self, split_kind): 161 | return 'data/{}/shot-{}/{}.npy'.format(self.kind, self.n_shot, 162 | split_kind) 163 | 164 | def __init__(self, kind, n_shot=0): 165 | assert type(n_shot) == int and n_shot >= -1 166 | 167 | _make_dir_if_not_exists('data') 168 | self._set_kind_and_url(kind) 169 | self._download_data_if_not_exists() 170 | self.n_shot = n_shot 171 | 172 | # 예쁜 형태로 정제된 npy 파일이 없으면, 정제를 수행합니다. 173 | if not os.path.exists('data/{}/data.npy'.format(kind)): 174 | self._init_data() 175 | self.data = self._load_base_data() 176 | 177 | _make_dir_if_not_exists( 178 | 'data/{}/shot-{}'.format(self.kind, self.n_shot)) 179 | 180 | self.n_user = int(np.max(self.data[:, 0])) + 1 181 | self.n_item = int(np.max(self.data[:, 1])) + 1 182 | self.n_row = self.n_user 183 | self.n_col = self.n_item 184 | 185 | # split된 데이터가 없으면 split합니다. 186 | if not os.path.exists( 187 | self._get_npy_path('train')) or not os.path.exists( 188 | self._get_npy_path('valid')) or not os.path.exists( 189 | self._get_npy_path('test')): 190 | self._split_data() 191 | 192 | self.train_data = np.load(self._get_npy_path('train')) 193 | self.valid_data = np.load(self._get_npy_path('valid')) 194 | self.test_data = np.load(self._get_npy_path('test')) 195 | 196 | def get_train_data(self): 197 | return self.train_data 198 | 199 | def get_valid_data(self): 200 | return self.valid_data 201 | 202 | def get_test_data(self): 203 | return self.test_data 204 | 205 | 206 | # if __name__ == '__main__': 207 | # kind = DatasetManager.KIND_MOVIELENS_100K 208 | # kind = DatasetManager.KIND_MOVIELENS_1M 209 | # kind = DatasetManager.KIND_MOVIELENS_10M 210 | # kind = DatasetManager.KIND_MOVIELENS_20M 211 | # dataset_manager = DatasetManager(kind) 212 | -------------------------------------------------------------------------------- /base/memory_saving_gradients.py: -------------------------------------------------------------------------------- 1 | from toposort import toposort 2 | import contextlib 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow.contrib.graph_editor as ge 6 | import time 7 | import sys 8 | sys.setrecursionlimit(10000) 9 | # refers back to current module if we decide to split helpers out 10 | util = sys.modules[__name__] 11 | 12 | # getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated" 13 | setattr(tf.GraphKeys, "VARIABLES", "variables") 14 | 15 | # save original gradients since tf.gradient could be monkey-patched to point 16 | # to our version 17 | from tensorflow.python.ops import gradients as tf_gradients_lib 18 | tf_gradients = tf_gradients_lib.gradients 19 | 20 | MIN_CHECKPOINT_NODE_SIZE=1024 # use lower value during testing 21 | 22 | # specific versions we can use to do process-wide replacement of tf.gradients 23 | def gradients_speed(ys, xs, grad_ys=None, **kwargs): 24 | return gradients(ys, xs, grad_ys, checkpoints='speed', **kwargs) 25 | 26 | def gradients_memory(ys, xs, grad_ys=None, **kwargs): 27 | return gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs) 28 | 29 | def gradients_collection(ys, xs, grad_ys=None, **kwargs): 30 | return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs) 31 | 32 | def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): 33 | ''' 34 | Authors: Tim Salimans & Yaroslav Bulatov 35 | 36 | memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" 37 | by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) 38 | 39 | ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients 40 | (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 41 | 42 | 'checkpoints' can either be 43 | - a list consisting of tensors from the forward pass of the neural net 44 | that we should re-use when calculating the gradients in the backward pass 45 | all other tensors that do not appear in this list will be re-computed 46 | - a string specifying how this list should be determined. currently we support 47 | - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, 48 | so checkpointing them maximizes the running speed 49 | (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) 50 | - 'memory': try to minimize the memory usage 51 | (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) 52 | - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint 53 | ''' 54 | 55 | # print("Calling memsaving gradients with", checkpoints) 56 | if not isinstance(ys,list): 57 | ys = [ys] 58 | if not isinstance(xs,list): 59 | xs = [xs] 60 | 61 | bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], 62 | inclusive=True) 63 | 64 | debug_print("bwd_ops: %s", bwd_ops) 65 | 66 | # forward ops are all ops that are candidates for recomputation 67 | fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], 68 | inclusive=True, 69 | within_ops=bwd_ops) 70 | debug_print("fwd_ops: %s", fwd_ops) 71 | 72 | # exclude ops with no inputs 73 | fwd_ops = [op for op in fwd_ops if op.inputs] 74 | 75 | # don't recompute xs, remove variables 76 | xs_ops = _to_ops(xs) 77 | fwd_ops = [op for op in fwd_ops if not op in xs_ops] 78 | fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] 79 | fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] 80 | fwd_ops = [op for op in fwd_ops if not '/read' in op.name] 81 | ts_all = ge.filter_ts(fwd_ops, True) # get the tensors 82 | ts_all = [t for t in ts_all if '/read' not in t.name] 83 | ts_all = set(ts_all) - set(xs) - set(ys) 84 | 85 | # construct list of tensors to checkpoint during forward pass, if not 86 | # given as input 87 | if type(checkpoints) is not list: 88 | if checkpoints == 'collection': 89 | checkpoints = tf.get_collection('checkpoints') 90 | 91 | elif checkpoints == 'speed': 92 | # checkpoint all expensive ops to maximize running speed 93 | checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') 94 | 95 | elif checkpoints == 'memory': 96 | 97 | # remove very small tensors and some weird ops 98 | def fixdims(t): # tf.Dimension values are not compatible with int, convert manually 99 | try: 100 | return [int(e if e.value is not None else 64) for e in t] 101 | except: 102 | return [0] # unknown shape 103 | ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE] 104 | ts_all = [t for t in ts_all if 'L2Loss' not in t.name] 105 | ts_all = [t for t in ts_all if 'entropy' not in t.name] 106 | ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] 107 | ts_all = [t for t in ts_all if 'Switch' not in t.name] 108 | ts_all = [t for t in ts_all if 'dropout' not in t.name] 109 | # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 110 | ts_all = [t for t in ts_all if 'Cast' not in t.name] 111 | 112 | # filter out all tensors that are inputs of the backward graph 113 | with util.capture_ops() as bwd_ops: 114 | tf_gradients(ys, xs, grad_ys, **kwargs) 115 | 116 | bwd_inputs = [t for op in bwd_ops for t in op.inputs] 117 | # list of tensors in forward graph that is in input to bwd graph 118 | ts_filtered = list(set(bwd_inputs).intersection(ts_all)) 119 | debug_print("Using tensors %s", ts_filtered) 120 | 121 | # try two slightly different ways of getting bottlenecks tensors 122 | # to checkpoint 123 | for ts in [ts_filtered, ts_all]: 124 | 125 | # get all bottlenecks in the graph 126 | bottleneck_ts = [] 127 | for t in ts: 128 | b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) 129 | f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) 130 | # check that there are not shortcuts 131 | b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) 132 | f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) 133 | if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): 134 | bottleneck_ts.append(t) # we have a bottleneck! 135 | else: 136 | debug_print("Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp))) 137 | 138 | # success? or try again without filtering? 139 | if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found! 140 | break 141 | 142 | if not bottleneck_ts: 143 | raise Exception('unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".') 144 | 145 | # sort the bottlenecks 146 | bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) 147 | sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] 148 | 149 | # save an approximately optimal number ~ sqrt(N) 150 | N = len(ts_filtered) 151 | if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): 152 | checkpoints = sorted_bottlenecks 153 | else: 154 | step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) 155 | checkpoints = sorted_bottlenecks[step::step] 156 | 157 | else: 158 | raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,)) 159 | 160 | checkpoints = list(set(checkpoints).intersection(ts_all)) 161 | 162 | # at this point automatic selection happened and checkpoints is list of nodes 163 | assert isinstance(checkpoints, list) 164 | 165 | debug_print("Checkpoint nodes used: %s", checkpoints) 166 | # better error handling of special cases 167 | # xs are already handled as checkpoint nodes, so no need to include them 168 | xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) 169 | if xs_intersect_checkpoints: 170 | debug_print("Warning, some input nodes are also checkpoint nodes: %s", 171 | xs_intersect_checkpoints) 172 | ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) 173 | debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, 174 | ys_intersect_checkpoints) 175 | # saving an output node (ys) gives no benefit in memory while creating 176 | # new edge cases, exclude them 177 | if ys_intersect_checkpoints: 178 | debug_print("Warning, some output nodes are also checkpoints nodes: %s", 179 | format_ops(ys_intersect_checkpoints)) 180 | 181 | # remove initial and terminal nodes from checkpoints list if present 182 | checkpoints = list(set(checkpoints) - set(ys) - set(xs)) 183 | 184 | # check that we have some nodes to checkpoint 185 | if not checkpoints: 186 | raise Exception('no checkpoints nodes found or given as input! ') 187 | 188 | # disconnect dependencies between checkpointed tensors 189 | checkpoints_disconnected = {} 190 | for x in checkpoints: 191 | if x.op and x.op.name is not None: 192 | grad_node = tf.stop_gradient(x, name=x.op.name+"_sg") 193 | else: 194 | grad_node = tf.stop_gradient(x) 195 | checkpoints_disconnected[x] = grad_node 196 | 197 | # partial derivatives to the checkpointed tensors and xs 198 | ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], 199 | stop_at_ts=checkpoints, within_ops=fwd_ops) 200 | debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", 201 | len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) 202 | debug_print("ops_to_copy = %s", ops_to_copy) 203 | debug_print("Processing list %s", ys) 204 | copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) 205 | for origin_op, op in info._transformed_ops.items(): 206 | op._set_device(origin_op.node_def.device) 207 | copied_ops = info._transformed_ops.values() 208 | debug_print("Copied %s to %s", ops_to_copy, copied_ops) 209 | ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) 210 | debug_print("Rewired %s in place of %s restricted to %s", 211 | checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) 212 | 213 | # get gradients with respect to current boundary + original x's 214 | copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] 215 | boundary = list(checkpoints_disconnected.values()) 216 | dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs) 217 | debug_print("Got gradients %s", dv) 218 | debug_print("for %s", copied_ys) 219 | debug_print("with respect to %s", boundary+xs) 220 | 221 | inputs_to_do_before = [y.op for y in ys] 222 | if grad_ys is not None: 223 | inputs_to_do_before += grad_ys 224 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] 225 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) 226 | 227 | # partial derivatives to the checkpointed nodes 228 | # dictionary of "node: backprop" for nodes in the boundary 229 | d_checkpoints = {r: dr for r,dr in zip(checkpoints_disconnected.keys(), 230 | dv[:len(checkpoints_disconnected)])} 231 | # partial derivatives to xs (usually the params of the neural net) 232 | d_xs = dv[len(checkpoints_disconnected):] 233 | 234 | # incorporate derivatives flowing through the checkpointed nodes 235 | checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) 236 | for ts in checkpoints_sorted_lists[::-1]: 237 | debug_print("Processing list %s", ts) 238 | checkpoints_other = [r for r in checkpoints if r not in ts] 239 | checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other] 240 | 241 | # copy part of the graph below current checkpoint node, stopping at 242 | # other checkpoints nodes 243 | ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) 244 | debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", 245 | len(ops_to_copy), fwd_ops, [r.op for r in ts], 246 | checkpoints_other) 247 | debug_print("ops_to_copy = %s", ops_to_copy) 248 | if not ops_to_copy: # we're done! 249 | break 250 | copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) 251 | for origin_op, op in info._transformed_ops.items(): 252 | op._set_device(origin_op.node_def.device) 253 | copied_ops = info._transformed_ops.values() 254 | debug_print("Copied %s to %s", ops_to_copy, copied_ops) 255 | ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) 256 | debug_print("Rewired %s in place of %s restricted to %s", 257 | checkpoints_disconnected_other, checkpoints_other, copied_ops) 258 | 259 | # gradient flowing through the checkpointed node 260 | boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] 261 | substitute_backprops = [d_checkpoints[r] for r in ts] 262 | dv = tf_gradients(boundary, 263 | checkpoints_disconnected_other+xs, 264 | grad_ys=substitute_backprops, **kwargs) 265 | debug_print("Got gradients %s", dv) 266 | debug_print("for %s", boundary) 267 | debug_print("with respect to %s", checkpoints_disconnected_other+xs) 268 | debug_print("with boundary backprop substitutions %s", substitute_backprops) 269 | 270 | inputs_to_do_before = [d_checkpoints[r].op for r in ts] 271 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] 272 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) 273 | 274 | # partial derivatives to the checkpointed nodes 275 | for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): 276 | if dr is not None: 277 | if d_checkpoints[r] is None: 278 | d_checkpoints[r] = dr 279 | else: 280 | d_checkpoints[r] += dr 281 | def _unsparsify(x): 282 | if not isinstance(x, tf.IndexedSlices): 283 | return x 284 | assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" 285 | indices = x.indices 286 | while indices.shape.ndims < x.values.shape.ndims: 287 | indices = tf.expand_dims(indices, -1) 288 | return tf.scatter_nd(indices, x.values, x.dense_shape) 289 | 290 | # partial derivatives to xs (usually the params of the neural net) 291 | d_xs_new = dv[len(checkpoints_other):] 292 | for j in range(len(xs)): 293 | if d_xs_new[j] is not None: 294 | if d_xs[j] is None: 295 | d_xs[j] = _unsparsify(d_xs_new[j]) 296 | else: 297 | d_xs[j] += _unsparsify(d_xs_new[j]) 298 | 299 | 300 | return d_xs 301 | 302 | def tf_toposort(ts, within_ops=None): 303 | all_ops = ge.get_forward_walk_ops([x.op for x in ts], within_ops=within_ops) 304 | 305 | deps = {} 306 | for op in all_ops: 307 | for o in op.outputs: 308 | deps[o] = set(op.inputs) 309 | sorted_ts = toposort(deps) 310 | 311 | # only keep the tensors from our original list 312 | ts_sorted_lists = [] 313 | for l in sorted_ts: 314 | keep = list(set(l).intersection(ts)) 315 | if keep: 316 | ts_sorted_lists.append(keep) 317 | 318 | return ts_sorted_lists 319 | 320 | def fast_backward_ops(within_ops, seed_ops, stop_at_ts): 321 | bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts)) 322 | ops = bwd_ops.intersection(within_ops).difference([t.op for t in stop_at_ts]) 323 | return list(ops) 324 | 325 | @contextlib.contextmanager 326 | def capture_ops(): 327 | """Decorator to capture ops created in the block. 328 | with capture_ops() as ops: 329 | # create some ops 330 | print(ops) # => prints ops created. 331 | """ 332 | 333 | micros = int(time.time()*10**6) 334 | scope_name = str(micros) 335 | op_list = [] 336 | with tf.name_scope(scope_name): 337 | yield op_list 338 | 339 | g = tf.get_default_graph() 340 | op_list.extend(ge.select_ops(scope_name+"/.*", graph=g)) 341 | 342 | def _to_op(tensor_or_op): 343 | if hasattr(tensor_or_op, "op"): 344 | return tensor_or_op.op 345 | return tensor_or_op 346 | 347 | def _to_ops(iterable): 348 | if not _is_iterable(iterable): 349 | return iterable 350 | return [_to_op(i) for i in iterable] 351 | 352 | def _is_iterable(o): 353 | try: 354 | _ = iter(o) 355 | except Exception: 356 | return False 357 | return True 358 | 359 | DEBUG_LOGGING=False 360 | def debug_print(s, *args): 361 | """Like logger.log, but also replaces all TensorFlow ops/tensors with their 362 | names. Sensitive to value of DEBUG_LOGGING, see enable_debug/disable_debug 363 | 364 | Usage: 365 | debug_print("see tensors %s for %s", tensorlist, [1,2,3]) 366 | """ 367 | 368 | if DEBUG_LOGGING: 369 | formatted_args = [format_ops(arg) for arg in args] 370 | print("DEBUG "+s % tuple(formatted_args)) 371 | 372 | def format_ops(ops, sort_outputs=True): 373 | """Helper method for printing ops. Converts Tensor/Operation op to op.name, 374 | rest to str(op).""" 375 | 376 | if hasattr(ops, '__iter__') and not isinstance(ops, str): 377 | l = [(op.name if hasattr(op, "name") else str(op)) for op in ops] 378 | if sort_outputs: 379 | return sorted(l) 380 | return l 381 | else: 382 | return ops.name if hasattr(ops, "name") else str(ops) 383 | 384 | def my_add_control_inputs(wait_to_do_ops, inputs_to_do_before): 385 | for op in wait_to_do_ops: 386 | ci = [i for i in inputs_to_do_before if op.control_inputs is None or i not in op.control_inputs] 387 | ge.add_control_inputs(op, ci) 388 | --------------------------------------------------------------------------------