├── lib ├── .gitignore ├── cosernn │ ├── __init__.py │ ├── helpers.py │ └── models.py └── setup.py ├── requirements.txt ├── CODE_OF_CONDUCT.md ├── README.md ├── CONTRIBUTING.md ├── scripts ├── generate_data.py └── train.py └── LICENSE /lib/.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1 2 | tensorflow-gpu>=1.14,<2 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | We feel that a welcoming community is important and we ask that you follow Spotify's 2 | [Open Source Code of Conduct](https://github.com/spotify/code-of-conduct/blob/master/code-of-conduct.md) 3 | in all interactions with the community. 4 | -------------------------------------------------------------------------------- /lib/cosernn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Spotify AB 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | __version__ = "0.1.0" 15 | __description__ = "CoSeRNN reference implementation" 16 | __uri__ = "https://github.com/spotify-research/cosernn" 17 | __author__ = "Casper Hansen" 18 | __author_email__ = "casper.hanzen@gmail.com" 19 | __maintainer__ = "Lucas Maystre" 20 | __maintainer_email__ = "lucasm@spotify.com" 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cosernn 2 | 3 | A reference implemententation of the CoSeRNN model for contextual music 4 | recommendation, presented in the following paper: 5 | 6 | > Casper Hansen, Christian Hansen, Lucas Maystre, Rishabh Mehrotra, Brian 7 | > Brost, Federico Tomasi, Mounia Lalmas. _[Contextual and Sequential User 8 | > Embeddings for Large-Scale Music 9 | > Recommendation](https://dl.acm.org/doi/10.1145/3383313.3412248)_, RecSys 10 | > 2020. 11 | 12 | 13 | ## Getting Started 14 | 15 | Our implementation requires Python 3.7 and TensorFlow 1.x. To run the code, you 16 | will need a CUDA-enabled GPU. 17 | 18 | To get started, simply follow these steps: 19 | 20 | - Clone the repo locally with: `git clone 21 | https://github.com/spotify-research/cosernn.git` 22 | - Move to the repository with: `cd cosernn` 23 | - install the dependencies: `pip install -r requirements.txt` 24 | - install the package: `pip install -e lib/` 25 | 26 | Generate data using 27 | 28 | python scripts/generate_data.py 29 | 30 | Train the CoSeRNN model using 31 | 32 | python scripts/train.py path/to/records 33 | 34 | 35 | ## Support 36 | 37 | Create a [new issue](https://github.com/spotify-research/cosernn/issues/new) 38 | 39 | 40 | ## Contributing 41 | 42 | We feel that a welcoming community is important and we ask that you follow Spotify's 43 | [Open Source Code of Conduct](https://github.com/spotify/code-of-conduct/blob/master/code-of-conduct.md) 44 | in all interactions with the community. 45 | 46 | 47 | ## Authors 48 | 49 | - [Casper Hansen](mailto:casper.hanzen@gmail.com) 50 | - [Lucas Maystre](mailto:lucasm@spotify.com) 51 | 52 | A full list of [contributors](https://github.com/spotify-research/cosernn/graphs/contributors?type=a) can 53 | be found on GitHub. 54 | 55 | Follow [@SpotifyResearch](https://twitter.com/SpotifyResearch) on Twitter for 56 | updates. 57 | 58 | 59 | ## License 60 | 61 | Copyright 2020 Spotify, Inc. 62 | 63 | Licensed under the Apache License, Version 2.0: 64 | https://www.apache.org/licenses/LICENSE-2.0 65 | 66 | 67 | ## Security Issues? 68 | 69 | Please report sensitive security issues via Spotify's bug-bounty program 70 | (https://hackerone.com/spotify) rather than GitHub. 71 | -------------------------------------------------------------------------------- /lib/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2020 Spotify AB 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import codecs 17 | import os 18 | import re 19 | 20 | from setuptools import setup 21 | 22 | 23 | HERE = os.path.abspath(os.path.dirname(__file__)) 24 | 25 | 26 | ##### 27 | # Helper functions 28 | ##### 29 | def read(*filenames, **kwargs): 30 | """ 31 | Build an absolute path from ``*filenames``, and return contents of 32 | resulting file. Defaults to UTF-8 encoding. 33 | """ 34 | encoding = kwargs.get("encoding", "utf-8") 35 | sep = kwargs.get("sep", "\n") 36 | buf = [] 37 | for fl in filenames: 38 | with codecs.open(os.path.join(HERE, fl), "rb", encoding) as f: 39 | buf.append(f.read()) 40 | return sep.join(buf) 41 | 42 | 43 | def find_meta(meta): 44 | """Extract __*meta*__ from META_FILE.""" 45 | re_str = r"^__{meta}__ = ['\"]([^'\"]*)['\"]".format(meta=meta) 46 | meta_match = re.search(re_str, META_FILE, re.M) 47 | if meta_match: 48 | return meta_match.group(1) 49 | raise RuntimeError("Unable to find __{meta}__ string.".format(meta=meta)) 50 | 51 | 52 | ##### 53 | # Project-specific constants 54 | ##### 55 | NAME = "cosernn" 56 | PACKAGE_NAME = "cosernn" 57 | META_PATH = os.path.join(PACKAGE_NAME, "__init__.py") 58 | CLASSIFIERS = [ 59 | "Development Status :: 3 - Alpha", 60 | "Programming Language :: Python :: 3.7", 61 | "Intended Audience :: Science/Research", 62 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 63 | ] 64 | META_FILE = read(META_PATH) 65 | 66 | 67 | setup( 68 | name=NAME, 69 | version=find_meta("version"), 70 | description=find_meta("description"), 71 | url=find_meta("uri"), 72 | author=find_meta("author"), 73 | author_email=find_meta("author_email"), 74 | maintainer=find_meta("maintainer"), 75 | maintainer_email=find_meta("maintainer_email"), 76 | packages=[ 77 | "cosernn", 78 | ], 79 | include_package_data=True, 80 | classifiers=CLASSIFIERS, 81 | zip_safe=False, 82 | install_requires=[ 83 | "numpy", 84 | "tensorflow-gpu", 85 | ], 86 | ) 87 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to get patches from you! 4 | 5 | ## Workflow 6 | 7 | We follow the [GitHub Flow 8 | Workflow](https://guides.github.com/introduction/flow/) 9 | 10 | Below is an Example. 11 | 12 | 1. Fork the project 13 | 1. Check out the `master` branch 14 | 1. Create a feature branch 15 | 1. Write code and tests for your change 16 | 1. From your branch, make a pull request against 17 | `spotify-research/cosernn/master` 18 | 1. Work with repo maintainers to get your change reviewed 19 | 1. Wait for your change to be pulled into `spotify-research/cosernn/master` 20 | 1. Delete your feature branch 21 | 22 | ## Style 23 | 24 | We follow `black` style. 25 | 26 | ## Issues 27 | 28 | When creating an issue please try to adhere to the following format: 29 | 30 | cosernn: One line summary of the issue (less than 72 characters) 31 | 32 | ### Expected behavior 33 | 34 | As concisely as possible, describe the expected behavior. 35 | 36 | ### Actual behavior 37 | 38 | As concisely as possible, describe the observed behavior. 39 | 40 | ### Steps to reproduce the behavior 41 | 42 | List all relevant steps to reproduce the observed behavior. 43 | 44 | ## Pull Requests 45 | 46 | 47 | Comments should be formatted to a width no greater than 80 columns. 48 | 49 | Files should be exempt of trailing spaces. 50 | 51 | We adhere to a specific format for commit messages. Please write your commit 52 | messages along these guidelines. Please keep the line width no greater than 80 53 | columns (You can use `fmt -n -p -w 80` to accomplish this). 54 | 55 | cosernn: One line description of your change (less than 72 characters) 56 | 57 | Problem 58 | 59 | Explain the context and why you're making that change. What is the problem 60 | you're trying to solve? In some cases there is not a problem and this can be 61 | thought of being the motivation for your change. 62 | 63 | Solution 64 | 65 | Describe the modifications you've done. 66 | 67 | Result 68 | 69 | What will change as a result of your pull request? Note that sometimes this 70 | section is unnecessary because it is self-explanatory based on the solution. 71 | 72 | Some important notes regarding the summary line: 73 | 74 | * Describe what was done; not the result 75 | * Use the active voice 76 | * Use the present tense 77 | * Capitalize properly 78 | * Do not end in a period — this is a title/subject 79 | * Prefix the subject with its scope 80 | 81 | ## Code Review 82 | 83 | The repository on GitHub is kept in sync with an internal repository at 84 | Spotify. For the most part this process should be transparent to the project 85 | users, but it does have some implications for how pull requests are merged into 86 | the codebase. 87 | 88 | When you submit a pull request on GitHub, it will be reviewed by the project 89 | community (both inside and outside of Spotify), and once the changes are 90 | approved, your commits will be brought into Spotify's internal system for 91 | additional testing. Once the changes are merged internally, they will be pushed 92 | back to GitHub with the next sync. 93 | 94 | This process means that the pull request will not be merged in the usual way. 95 | Instead a member of the project team will post a message in the pull request 96 | thread when your changes have made their way back to GitHub, and the pull 97 | request will be closed. 98 | The changes in the pull request will be collapsed into a single commit, but the 99 | authorship metadata will be preserved. 100 | 101 | ## Documentation 102 | 103 | We also welcome improvements to the project documentation or to the existing 104 | docs. Please file an 105 | [issue](https://github.com/spotify-research/cosernn/issues). 106 | 107 | # License 108 | 109 | By contributing your code, you agree to license your contribution under the 110 | terms of the APLv2: 111 | https://github.com/spotify-research/cosernn/blob/master/LICENSE 112 | 113 | # Code of Conduct 114 | 115 | Read our [Code of Conduct](CODE_OF_CONDUCT.md) for the project. 116 | -------------------------------------------------------------------------------- /scripts/generate_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Spotify AB 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | import argparse 15 | import numpy as np 16 | import os 17 | import tensorflow as tf 18 | 19 | 20 | def create_int_feature(values): 21 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 22 | return f 23 | 24 | 25 | def create_float_feature(values): 26 | f = tf.train.Feature(float_list=tf.train.FloatList(value=values)) 27 | return f 28 | 29 | 30 | def create_numpy_matrix(size, above_val=0.5, type=np.int64, add=1): 31 | v = ((np.random.random(size)>above_val).astype(int) + add).astype(type) 32 | return v 33 | 34 | 35 | def main(num_samples, repeats, out): 36 | try: 37 | os.mkdir(out) 38 | except Exception as e: 39 | pass 40 | for repeat in range(repeats): 41 | writer = tf.io.TFRecordWriter(f"{out}/random_record{repeat}") 42 | for i in range(num_samples): 43 | features = { 44 | 'day_hist': create_int_feature(create_numpy_matrix(400)), #tf.FixedLenFeature([400], tf.int64), 45 | 'hour_hist': create_int_feature(create_numpy_matrix(400)), #tf.FixedLenFeature([400], tf.int64), 46 | 'minute_hist': create_int_feature(create_numpy_matrix(400)), #tf.FixedLenFeature([400], tf.int64), 47 | 'device_hist': create_int_feature(create_numpy_matrix(400)), #tf.FixedLenFeature([400], tf.int64), 48 | 49 | 'hist_avgs': create_int_feature(create_numpy_matrix((400,10)).flatten()), #tf.VarLenFeature(tf.int64), 50 | 'histavgs_feedback': create_int_feature(create_numpy_matrix((400,10)).flatten()), #tf.VarLenFeature(tf.int64), 51 | 'histavgs_skip': create_int_feature(create_numpy_matrix((400,10)).flatten()), #tf.VarLenFeature(tf.int64), 52 | 'histavgs_listen':create_int_feature(create_numpy_matrix((400,10)).flatten()), #tf.VarLenFeature(tf.int64), 53 | 54 | 'hist_avgs_shape': create_int_feature(create_numpy_matrix((400,10)).shape), #tf.FixedLenFeature([2], tf.int64), 55 | 56 | 'user': create_int_feature(create_numpy_matrix(1)),#tf.FixedLenFeature([1], tf.int64), 57 | 'mask': create_int_feature(create_numpy_matrix(400, above_val=-1, add=0)), #tf.FixedLenFeature([400], tf.int64), 58 | 'mask_split': create_int_feature(create_numpy_matrix(400, above_val=-1, add=0)), #tf.FixedLenFeature([400], tf.int64), 59 | 'mask_split_above': create_int_feature(create_numpy_matrix(400, above_val=-1, add=0)), #tf.FixedLenFeature([400], tf.int64), 60 | 61 | 'time_since_last_session': create_float_feature(create_numpy_matrix(400, type=np.float32)), #tf.FixedLenFeature([400], tf.float32), 62 | 'top_context': create_int_feature(create_numpy_matrix(400)), #tf.FixedLenFeature([400], tf.int64), 63 | 'number_of_tracks_in_sessions': create_int_feature(create_numpy_matrix(400)+10), #tf.FixedLenFeature([400], tf.int64), 64 | 'session_start_time': create_int_feature(create_numpy_matrix(400)) #tf.FixedLenFeature([400], tf.int64) 65 | } 66 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 67 | writer.write(tf_example.SerializeToString()) 68 | writer.close() 69 | 70 | 71 | def _parse_args(): 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--n-samples", type=int, default=100) 74 | parser.add_argument("--n-repeats", type=int, default=10) 75 | parser.add_argument("--out", default="./records") 76 | return parser.parse_args() 77 | 78 | 79 | if __name__ == "__main__": 80 | args = _parse_args() 81 | main(args.n_samples, args.n_repeats, args.out) 82 | -------------------------------------------------------------------------------- /lib/cosernn/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Spotify AB 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | PLATFORMS = {'mobile':0, 'desktop':1, 'speaker':2, 'web':3, 'tablet':4, 'tv':5, 'remaining':6} 18 | SPOTIFY_TOP_CONTEXT = { 19 | 'nan': 0, 20 | 'play_queue': 1, 21 | 'personalized_playlist': 2, 22 | 'mix': 3, 23 | 'catalog': 4, 24 | 'editorial_playlist': 5, 25 | 'radio': 6, 26 | 'algotorial_playlist': 7, 27 | 'user_playlist': 8, 28 | 'sounds_of': 9, 29 | 'charts': 10, 30 | 'user_collection': 11 31 | } 32 | day_count = 7 33 | hour_count = 24 34 | minute_count = 4 35 | device_count = len(PLATFORMS) 36 | top_context_count = len(SPOTIFY_TOP_CONTEXT) 37 | 38 | def extract_fn_onerow(data_record): 39 | features = { 40 | 'day_hist': tf.FixedLenFeature([400], tf.int64), 41 | 'hour_hist': tf.FixedLenFeature([400], tf.int64), 42 | 'minute_hist': tf.FixedLenFeature([400], tf.int64), 43 | 'device_hist': tf.FixedLenFeature([400], tf.int64), 44 | 45 | 'hist_avgs': tf.VarLenFeature(tf.int64), 46 | 'histavgs_feedback': tf.VarLenFeature(tf.int64), 47 | 'histavgs_skip': tf.VarLenFeature(tf.int64), 48 | 'histavgs_listen': tf.VarLenFeature(tf.int64), 49 | 50 | 'hist_avgs_shape': tf.FixedLenFeature([2], tf.int64), 51 | 52 | 'user': tf.FixedLenFeature([1], tf.int64), 53 | 'mask': tf.FixedLenFeature([400], tf.int64), 54 | 'mask_split': tf.FixedLenFeature([400], tf.int64), 55 | 'mask_split_above': tf.FixedLenFeature([400], tf.int64), 56 | 57 | 58 | 'time_since_last_session': tf.FixedLenFeature([400], tf.float32), 59 | 'top_context': tf.FixedLenFeature([400], tf.int64), 60 | 'number_of_tracks_in_sessions': tf.FixedLenFeature([400], tf.int64), 61 | 'session_start_time': tf.FixedLenFeature([400], tf.int64) 62 | } 63 | 64 | sample = tf.parse_single_example(data_record, features) 65 | 66 | for dyntype in ['hist_avgs','histavgs_skip','histavgs_listen','histavgs_feedback']: 67 | sample[dyntype] = tf.sparse.to_dense(sample[dyntype]) 68 | sample[dyntype] = tf.reshape(sample[dyntype], sample['hist_avgs_shape']) 69 | 70 | for v in features.keys():# ['day_hist','hour_hist','minute_hist','device_hist','day_now','hour_now','minute_now','device_now','hist_avgs','target_avg','target_skip']: 71 | if 'mask' not in v and 'time_since_last_session' not in v and 'session_start_time' not in v: 72 | sample[v] = tf.cast(sample[v], tf.int32) 73 | 74 | sample['mask'] = tf.cast(sample['mask'], tf.float32) 75 | sample['mask_split'] = tf.cast(sample['mask_split'], tf.float32) 76 | sample['mask_split_above'] = tf.cast(sample['mask_split_above'], tf.float32) 77 | 78 | #sample['time_since_last_session'] = tf.cast(sample['time_since_last_session'], tf.float32) 79 | 80 | sample['number_of_tracks_in_sessions'] = tf.cast(sample['number_of_tracks_in_sessions'], tf.float32) 81 | #sample['session_start_time'] = tf.cast(sample['session_start_time'], tf.float32) 82 | 83 | 84 | sample['user'] = tf.squeeze(sample['user'], -1) 85 | 86 | return tuple([sample[f] for f in ['day_hist','hour_hist','minute_hist', 87 | 'device_hist','hist_avgs','mask','mask_split','user', 88 | 'histavgs_skip','histavgs_listen','histavgs_feedback', 89 | 'mask_split_above', 'time_since_last_session', 'top_context', 'number_of_tracks_in_sessions', 'session_start_time']]) # 90 | 91 | 92 | def make_dataset_generator_onerow(sess, handle, pargs, dataset_paths, is_test): 93 | maxlen = pargs['max_length'] 94 | 95 | output_t = tuple([tf.int32, tf.int32, tf.int32, tf.int32, # history list: day, hour, minute, device 96 | tf.int32, # history list: average session vector 97 | tf.float32, # mask 98 | tf.float32, #mask split 99 | tf.int32, # user 100 | tf.int32, 101 | tf.int32, 102 | tf.int32, 103 | tf.float32, # mask split above 104 | tf.float32, # time since last sess 105 | tf.int32, # top context 106 | tf.float32, 107 | tf.int64 108 | ]) 109 | 110 | filenames = dataset_paths 111 | print(len(filenames), 'tfrecord files') 112 | 113 | output_s = tuple( 114 | [tf.TensorShape([None, maxlen]), tf.TensorShape([None, maxlen]), tf.TensorShape([None, maxlen]), tf.TensorShape([None, maxlen]), 115 | tf.TensorShape([None, maxlen, 10]), 116 | tf.TensorShape([None, maxlen]), 117 | tf.TensorShape([None, maxlen]), 118 | tf.TensorShape([None, ]), 119 | tf.TensorShape([None, maxlen, 10]), 120 | tf.TensorShape([None, maxlen, 10]), 121 | tf.TensorShape([None, maxlen, 10]), 122 | tf.TensorShape([None, maxlen]), 123 | tf.TensorShape([None, maxlen]), 124 | tf.TensorShape([None, maxlen]), 125 | tf.TensorShape([None, maxlen]), 126 | tf.TensorShape([None, maxlen]) 127 | ]) 128 | 129 | dataset = tf.data.Dataset.from_tensor_slices(filenames) 130 | if not is_test: 131 | dataset = dataset.repeat() 132 | dataset = dataset.shuffle(1000) 133 | dataset = dataset.flat_map(tf.data.TFRecordDataset) 134 | dataset = dataset.map(extract_fn_onerow, num_parallel_calls=3) 135 | if not is_test: 136 | dataset = dataset.shuffle(1000) 137 | 138 | if not is_test: 139 | dataset = dataset.batch(pargs["batch_size"]) 140 | else: 141 | dataset = dataset.batch(500) 142 | 143 | if not is_test: 144 | dataset = dataset.prefetch(1) 145 | else: 146 | dataset = dataset.prefetch(10) 147 | iterator = dataset.make_initializable_iterator() #tf.compat.v1.data.make_initializable_iterator(dataset) # 148 | 149 | if handle is not None: 150 | generic_iter = tf.data.Iterator.from_string_handle(handle, output_t, output_s) 151 | specific_handle = sess.run(iterator.string_handle()) 152 | else: 153 | generic_iter = None 154 | specific_handle = None 155 | 156 | return specific_handle, iterator, generic_iter 157 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Spotify AB 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | import argparse 15 | import glob 16 | import numpy as np 17 | import os 18 | import tensorflow as tf 19 | import time 20 | 21 | from cosernn.models import CoSeRNN 22 | from cosernn.helpers import make_dataset_generator_onerow as make_dataset_generator 23 | 24 | 25 | LASTTRACKS = [10, 25, 50, 100] 26 | LASTSESSIONS = [10-1, 20-1, 30-1, 50-1] 27 | RANK_L2 = False 28 | 29 | 30 | def as_matrix(config): 31 | return [[k, str(w)] for k, w in config.items()] 32 | 33 | 34 | def main(): 35 | global item_vects, id2item, user_vects, id2user, random_tracks, random_sessions 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("path_to_records") 38 | parser.add_argument("--rnnsize", default=400, type=int) 39 | parser.add_argument("--batch_size", default=256, type=int) 40 | parser.add_argument("--lr", default=0.0005, type=float) 41 | parser.add_argument("--decay_rate", default=0.99999, type=float) 42 | 43 | parser.add_argument("--context_var", default=0, type=int) 44 | parser.add_argument("--target_var", default=0, type=int) 45 | parser.add_argument('--dropout', default=0.0, type=float) 46 | parser.add_argument('--layersize', default=200, type=int) 47 | parser.add_argument('--add_train_noise', default=0, type=int) 48 | 49 | parser.add_argument('--imul', default=0.05, type=float) 50 | parser.add_argument('--MEMSIZE', default=2000, type=int) 51 | parser.add_argument('--mem_top_k', default=2, type=int) 52 | parser.add_argument('--eval_mem_switch', default=-1, type=int) 53 | 54 | parser.add_argument('--gpu', default=1, type=int) 55 | parser.add_argument('--POOLSIZE', default=10, type=int) 56 | 57 | parser.add_argument('--inputs_chosen', default='0,1,2,3,4', type=str) 58 | parser.add_argument('--session_rep', default='0,1', type=str) 59 | parser.add_argument('--pred_type', default=3, type=int) 60 | 61 | parser.add_argument('--inputs_chosen_options', default='[inputs, sampled_inputs_contexts, number_of_tracks_in_sessions, time_since_last_session, top_context, top_context_future]', type=str) 62 | parser.add_argument('--session_rep_options', default='[hist_rnn_both, hist_rnn_skip, hist_rnn_listen]', type=str) 63 | parser.add_argument('--pred_type_options', default='[rnn_pred, user_pred, rnn+user, softmax rnn+user]', type=str) 64 | 65 | parser.add_argument('--restorefile', default=None, type=str) 66 | parser.add_argument('--baselineEval', default=None, type=str) 67 | parser.add_argument('--max_length', default=400, type=int) 68 | args = parser.parse_args() 69 | 70 | args.context_var = args.context_var > 0.5 71 | args.target_var = args.target_var > 0.5 72 | args.add_train_noise = args.add_train_noise > 0.5 73 | args.gpu = args.gpu > 0.5 74 | 75 | args.inputs_chosen = [int(v) for v in args.inputs_chosen.split(',')] 76 | args.session_rep = [int(v) for v in args.session_rep.split(',')] 77 | args = vars(args) 78 | 79 | tfrecords = glob.glob(os.path.join(args["path_to_records"], "*")) 80 | num_items = 100 81 | num_users = 100 82 | 83 | tf.reset_default_graph() 84 | print('making session.....') 85 | with tf.Session() as sess: 86 | 87 | handle = tf.placeholder(tf.string, shape=[], name="handle_for_iterator") 88 | training_handle, train_iter, gen_iter = make_dataset_generator(sess, handle, args, tfrecords, 0) 89 | 90 | val_handle, val_iter, _ = make_dataset_generator(sess, handle, args, tfrecords, 1) 91 | test_handle, test_iter, _ = make_dataset_generator(sess, handle, args, tfrecords, 1) 92 | 93 | sample = gen_iter.get_next() 94 | mask = sample[11] 95 | 96 | model = CoSeRNN(sample, args['rnnsize'], args['batch_size']) 97 | 98 | item_emb, item_embedding_placeholder, item_embedding_init = model._make_embedding(num_items, 40, 'item_emb', trainable=False) 99 | user_emb, user_embedding_placeholder, user_embedding_init = model._make_embedding(num_users, 40, 'user_emb', trainable=False) 100 | user_f_emb, user_f_embedding_placeholder, user_f_embedding_init = model._make_embedding( num_users, 2, 'user_f_emb', trainable=False) 101 | 102 | mm_multiplier = tf.placeholder(tf.float32, name='mm_multiplier') 103 | is_training = tf.placeholder(tf.bool, name="is_training") 104 | is_single_sample = tf.placeholder(tf.bool,name="is_single_sample") # set to True when the dataset with a one-hot mask instead of multi-hot 105 | update_index = tf.placeholder(tf.int32, shape=[None, ], name='update_index') 106 | 107 | 108 | debugval, loss_single, loss_multi, single_pred, multi_pred, \ 109 | avgcos_single, avgcos_multi, targetmodel, histvalsmodel, context_now_org, summary_op, summary_op_val, summary_op_test, pVar_val, update_pVar_val, \ 110 | pVar_test, update_pVar_test, avgcos_nvm, session_item_scores, avgtargets, \ 111 | mem_updates, mem_key, mem_sess, mem_error, combined_alpha_rnn_mm, combined_alpha_spotify_other, error_neighbourhood, combined_single = model.make_network( 112 | item_emb, user_emb, user_f_emb, is_training, is_single_sample, update_index, args, mm_multiplier) 113 | 114 | 115 | step = tf.Variable(0, trainable=False) 116 | lr = tf.train.exponential_decay(args["lr"], 117 | step, 118 | 10000, 119 | args["decay_rate"], 120 | staircase=True, name="lr") 121 | optimizer = tf.train.AdamOptimizer(learning_rate=lr, name="Adam") 122 | train_step = optimizer.minimize(loss_multi, global_step=step) 123 | 124 | init = tf.global_variables_initializer() 125 | sess.run(init) 126 | sess.run(train_iter.initializer) 127 | 128 | # In the experiments, we initialize with specific embeddings. 129 | #sess.run(item_embedding_init, feed_dict={item_embedding_placeholder: item_matrix}) 130 | #sess.run(user_embedding_init, feed_dict={user_embedding_placeholder: user_matrix}) 131 | #sess.run(user_f_embedding_init, feed_dict={user_f_embedding_placeholder: user_feature_matrix}) 132 | 133 | loss = loss_multi 134 | avgcos = avgcos_multi 135 | 136 | eval_every = int(250 * args['imul']) 137 | patience_count = 0 138 | patience_max = 10 139 | iter_count = 0 140 | 141 | train_avgcos = [] 142 | times = [] 143 | losses_train = [] 144 | 145 | runer4ever = True 146 | single_sample_val = False 147 | mm_multiplier_val = 0 148 | memory_index_batch = np.arange(args['batch_size']) % args['MEMSIZE'] 149 | 150 | print('start training') 151 | while runer4ever: 152 | start = time.time() 153 | fd = {handle: training_handle, is_training: True, is_single_sample: single_sample_val, update_index: memory_index_batch.astype(np.float32), mm_multiplier: mm_multiplier_val} 154 | debugvalval, avgcosval,losses, _, maskval = sess.run([debugval, avgcos,loss, train_step, mask], feed_dict=fd) 155 | times.append(time.time() - start) 156 | 157 | train_avgcos += [np.mean(avgcosval[np.abs(avgcosval) > 0])] 158 | losses_train += [np.mean(losses)] 159 | iter_count += 1 160 | 161 | if iter_count % eval_every == 0: 162 | print("iteration", iter_count, 'avg time', np.mean(times)) 163 | print('train, avg_cosine:', np.mean(train_avgcos), "loss:", np.mean(losses_train)) 164 | 165 | train_avgcos = [] 166 | times = [] 167 | losses_train = [] 168 | 169 | if patience_count > patience_max: 170 | runer4ever = False 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /lib/cosernn/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Spotify AB 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | import tensorflow as tf 15 | 16 | from cosernn.helpers import ( 17 | day_count, hour_count, minute_count, device_count, top_context_count) 18 | 19 | 20 | class CoSeRNN(): 21 | def __init__(self, sample, rnnsize, batch_size): 22 | self.rnnsize = rnnsize 23 | self.batch_size = batch_size 24 | self.sample = sample 25 | 26 | def _make_rnn_gpu(self, input, num_layers, num_units, name, namespace="default", maxlen=None, init_state=None): 27 | ''' 28 | RNN implementation using a CudnnLSTM 29 | :param input: a time series sinput 30 | :param num_layers: number of layers 31 | :param num_units: number of hidden units 32 | :param name: component name 33 | :param namespace: namespace name 34 | :param maxlen: maximum length of the input 35 | :param init_state: specific initital state 36 | :return: Output of the LSTM 37 | ''' 38 | input = tf.transpose(input, [1, 0, 2]) 39 | with tf.variable_scope(namespace, reuse=tf.AUTO_REUSE) as scope: 40 | rnn = tf.contrib.cudnn_rnn.CudnnLSTM( 41 | num_layers=num_layers, 42 | num_units=num_units, 43 | dtype=tf.float32, 44 | name=name + "_rnn") 45 | 46 | if init_state is None: 47 | output, _ = rnn(input) 48 | output = tf.transpose(output, [1, 0, 2]) 49 | return output 50 | 51 | i = [] 52 | o = [] 53 | i2 = [] 54 | o2 = [] 55 | for j in range(num_layers): 56 | internal = tf.layers.dense(init_state, num_units, name=name + "_internal_hist" + str(j)) 57 | out = tf.layers.dense(init_state, num_units, name=name + "_out_hist" + str(j)) 58 | internal = tf.expand_dims(internal, axis=0) 59 | out = tf.expand_dims(out, axis=0) 60 | 61 | i.append(internal) 62 | o.append(out) 63 | 64 | internal2 = tf.layers.dense(init_state, num_units, name=name + "_internal2_hist" + str(j)) 65 | out2 = tf.layers.dense(init_state, num_units, name=name + "_out2_hist" + str(j)) 66 | internal2 = tf.expand_dims(internal2, axis=0) 67 | out2 = tf.expand_dims(out2, axis=0) 68 | 69 | i2.append(internal2) 70 | o2.append(out2) 71 | 72 | internal = tf.concat(i, axis=0) 73 | out = tf.concat(o, axis=0) 74 | 75 | internal2 = tf.concat(i2, axis=0) 76 | out2 = tf.concat(o2, axis=0) 77 | 78 | internal_final = internal 79 | out_final = out 80 | 81 | initial_state = (internal_final, out_final) 82 | 83 | output, _ = rnn(input, initial_state=initial_state) 84 | output = tf.transpose(output, [1, 0, 2]) 85 | 86 | return output 87 | 88 | def _extract_vals(self, item_emb, user_emb, user_f_emb): 89 | ''' 90 | Extract a list of feature tensors from the generator. 91 | :param item_emb: embedding of all tracks 92 | :param user_emb: embedding of all users 93 | :param user_f_emb: embedding of user features for all users 94 | :return: list of feature tensors 95 | ''' 96 | day_onehot, hour_onehot, minute_onehot, device_onehot, hist_avgs, mask, \ 97 | mask_split, user, hist_skipped, hist_listened, histavgs_feedback, mask_split_above, time_since_last_session, top_context, number_of_tracks_in_sessions = self.sample[0], self.sample[1],self.sample[2], self.sample[3],self.sample[4],\ 98 | self.sample[5],self.sample[6],self.sample[7],self.sample[8],self.sample[9],self.sample[10],self.sample[11], self.sample[12], self.sample[13], self.sample[14] 99 | 100 | mask_split_above = tf.cast(mask_split_above, tf.float32) 101 | 102 | day_onehot = tf.one_hot(day_onehot, day_count) 103 | hour_onehot = tf.one_hot(hour_onehot, hour_count) 104 | minute_onehot = tf.one_hot(minute_onehot, minute_count) 105 | device_onehot = tf.one_hot(device_onehot, device_count) 106 | 107 | EPS = 0.00001 108 | avoid_div_by_zero = tf.cast(tf.expand_dims(mask_split_above,-1) < 0.5, tf.float32) * EPS # 109 | avoid_div_by_zero_set_to_zero = tf.cast(tf.expand_dims(mask_split_above,-1) > 0.5, tf.float32) # 110 | 111 | hist_divider = tf.reduce_sum(tf.cast(hist_avgs>0, tf.float32), -1) 112 | hist_avgs = tf.nn.embedding_lookup(item_emb, hist_avgs) 113 | 114 | session_item_embs = hist_avgs 115 | session_item_feedback = histavgs_feedback 116 | 117 | hist_avgs = tf.reduce_sum(hist_avgs, axis=-2) / tf.expand_dims(hist_divider, -1) 118 | hist_avgs = hist_avgs / tf.expand_dims(tf.norm(hist_avgs + avoid_div_by_zero, ord=2, axis=-1), -1) * avoid_div_by_zero_set_to_zero 119 | 120 | hist_context = tf.concat([day_onehot, hour_onehot, device_onehot], axis=2) # 121 | 122 | hist_skipped_mask = tf.cast(hist_skipped > 0, tf.float32) 123 | hist_listened_mask = tf.cast(hist_listened > 0, tf.float32) 124 | 125 | hist_skipped = tf.nn.embedding_lookup(item_emb, hist_skipped) 126 | hist_listened = tf.nn.embedding_lookup(item_emb, hist_listened) 127 | 128 | hist_skip_divider = tf.reduce_sum(hist_skipped_mask, -1) 129 | skip_is0 = tf.cast(hist_skip_divider < 0.5, tf.float32) 130 | skip_is1 = tf.cast(hist_skip_divider > 0.5, tf.float32) 131 | hist_skipped = tf.reduce_sum(hist_skipped, axis=-2) / tf.expand_dims(hist_skip_divider + skip_is0*EPS, -1) # we add a small number since we can have sessions with no skips 132 | hist_skipped = hist_skipped / tf.expand_dims(tf.norm(hist_skipped + tf.expand_dims(skip_is0*EPS, -1), ord=2, axis=-1), -1) * tf.expand_dims(skip_is1, -1) 133 | 134 | hist_listened_divider = tf.reduce_sum(hist_listened_mask, -1) 135 | listen_is0 = tf.cast(hist_listened_divider < 0.5, tf.float32) 136 | listen_is1 = tf.cast(hist_listened_divider > 0.5, tf.float32) 137 | hist_listened = tf.reduce_sum(hist_listened, axis=-2) / tf.expand_dims(hist_listened_divider + listen_is0*EPS, -1) 138 | hist_listened = hist_listened / tf.expand_dims(tf.norm(hist_listened + tf.expand_dims(listen_is0*EPS, -1), ord=2, axis=-1), -1) * tf.expand_dims(listen_is1, -1) 139 | 140 | inputs = hist_avgs[:, :-1] 141 | inputs_contexts = hist_context[:, 1:] 142 | inputs_mask_loss = mask_split_above[:, 1:] 143 | inputs_mask = mask[:, 1:] 144 | targets = hist_avgs[:, 1:] 145 | 146 | inputs_skip = hist_skipped[:, :-1] 147 | inputs_listen = hist_listened[:, :-1] 148 | 149 | targets_skip = hist_skipped[:, 1:] 150 | targets_listen = hist_listened[:, 1:] 151 | 152 | session_item_embs = session_item_embs[:, 1:] 153 | session_item_feedback = session_item_feedback[:, 1:] 154 | 155 | time_since_last_session = time_since_last_session[:, 1:] # time since last session when predicting the current 156 | top_context_future = top_context[:, 1:] 157 | top_context = top_context[:, :-1] # top context belong to the 'previous' session, similar to inputs 158 | top_context = tf.one_hot(top_context, top_context_count) 159 | top_context_future = tf.one_hot(top_context_future, top_context_count) 160 | 161 | time_since_last_session = tf.expand_dims(time_since_last_session, -1) 162 | 163 | number_of_tracks_in_sessions = number_of_tracks_in_sessions[:, :-1] 164 | number_of_tracks_in_sessions = tf.expand_dims(number_of_tracks_in_sessions, -1) 165 | 166 | user_emb_vals = tf.nn.embedding_lookup(user_emb, user) 167 | user_f_emb_vals = tf.nn.embedding_lookup(user_f_emb, user) 168 | 169 | return inputs, inputs_contexts, inputs_mask, inputs_mask_loss, targets, \ 170 | inputs_skip, inputs_listen, targets_skip, targets_listen, session_item_embs, session_item_feedback, user_emb_vals, time_since_last_session, top_context, number_of_tracks_in_sessions, user_f_emb_vals, top_context_future 171 | 172 | def _make_loss(self, preds, target, mask): 173 | ''' 174 | Implements the cosine loss from the paper (assuming length normalized preds and targets) 175 | :param preds: predictions (normalized to unit length) 176 | :param target: targets (normalized to unit length) 177 | :param mask: potential mask if certain predictions are to be ignored 178 | :return: the cosine loss 179 | ''' 180 | loss = (1.0 - tf.reduce_sum(target * preds, -1)) * mask 181 | return loss 182 | 183 | def _make_2_layers(self, logits, units, name_add=''): 184 | ''' 185 | Implementation of 2 fully connected layers 186 | :param logits: the input to the layers 187 | :param units: number of hidden units 188 | :param name_add: unique name for the layers 189 | :return: Output of the last layer 190 | ''' 191 | output = tf.layers.dense(logits, units, name=name_add+"_l1", reuse=tf.AUTO_REUSE, activation='relu') 192 | output = tf.layers.dense(output, units, name=name_add+"_l2", reuse=tf.AUTO_REUSE, activation='relu') 193 | output = tf.nn.dropout(output, keep_prob=self.keep_prob) 194 | return output 195 | 196 | def _embed_context(self, context, name, size, variational, activation="linear"): 197 | ''' 198 | Embedding of the context features (either deterministic or sampled) 199 | :param context: context features 200 | :param name: unique name 201 | :param size: number of hidden units 202 | :param variational: True/False for using sampling or not 203 | :param activation: activation function 204 | :return: embedding of the context features 205 | ''' 206 | mu = tf.layers.dense(context, size, name="mu_"+name, reuse=tf.AUTO_REUSE) 207 | std = tf.layers.dense(context, size, name="std_"+name, reuse=tf.AUTO_REUSE, activation='sigmoid') 208 | eps_std = tf.cond(self.is_training, lambda:1.0, lambda: 0.0) 209 | eps = tf.random.normal(tf.shape(std), dtype=tf.float32, mean=0., stddev=eps_std, name='epsilon') 210 | z = mu + tf.exp(std / 2) * eps 211 | 212 | if not variational: 213 | z = tf.layers.dense(context, size, name="trans_init_"+name, activation=activation, reuse=tf.AUTO_REUSE) 214 | 215 | return z 216 | 217 | def _sample_target(self, input, name, variational, name_add=''): 218 | ''' 219 | Transforms the input using either a fully connected layer or through sampling. 220 | :param input: input tensor 221 | :param name: unique name 222 | :param variational: True/False depending on if output should be deterministic or sampling based 223 | :param name_add: unique name 224 | :return: Output of the transformation of the input, which is used as the predicted target in CoSeRNN 225 | ''' 226 | mu = tf.layers.dense(input, 40, name=name_add+"_mu_"+name, reuse=tf.AUTO_REUSE) 227 | std = tf.layers.dense(input, 40, name=name_add+"_std_"+name, reuse=tf.AUTO_REUSE, activation='sigmoid') 228 | eps_std = tf.cond(self.is_training, lambda:1.0, lambda: 0.0) 229 | eps = tf.random_normal(tf.shape(std), dtype=tf.float32, mean=0., stddev=eps_std, name= name_add+'_epsilon') 230 | z = mu + tf.exp(std / 2) * eps 231 | 232 | if not variational: 233 | z = tf.layers.dense(input, 40, name=name_add+"_target_output"+name, reuse=tf.AUTO_REUSE) 234 | 235 | return z 236 | 237 | def _make_embedding(self, vocab_size, embedding_size, name, trainable=False, inittype=tf.random.normal): 238 | ''' 239 | Initialize a embedding matrix 240 | :param vocab_size: number of unique elements to embed 241 | :param embedding_size: size of embedding 242 | :param name: unique name 243 | :param trainable: True/False to make trainable or not 244 | :return: embedding matrix 245 | ''' 246 | W = tf.Variable(inittype(shape=[vocab_size, embedding_size]), 247 | trainable=trainable, name=name) 248 | embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, embedding_size]) 249 | embedding_init = W.assign(embedding_placeholder) 250 | 251 | return W, embedding_placeholder, embedding_init 252 | 253 | def smooth_cosine_similarity(self, session_emb, sess_all_representations): 254 | ''' 255 | Computes a smooth cosine value for the memory part of the model (not used in the RecSys paper). 256 | :param session_emb: embedding of a session 257 | :param sess_all_representations: embedding matrix of all keys (i.e. session embeddings) in the memory module 258 | :return: the cosine similarity to the closest key in the memory module 259 | ''' 260 | sess_all_representations = tf.tile(tf.expand_dims(sess_all_representations, axis=0), multiples=[tf.shape(session_emb)[0], 1,1]) 261 | session_emb = tf.expand_dims(session_emb, axis=2) 262 | inner_product = tf.matmul(sess_all_representations, session_emb) 263 | k_norm = tf.sqrt(tf.reduce_sum(tf.square(session_emb), axis=1, keep_dims=True)) 264 | M_norm = tf.sqrt(tf.reduce_sum(tf.square(sess_all_representations), axis=2, keep_dims=True)) 265 | norm_product = M_norm * k_norm 266 | similarity = tf.squeeze(inner_product / (norm_product + 1e-8), axis=2) 267 | 268 | return similarity 269 | 270 | def make_network(self, item_emb, user_emb, user_f_emb, is_training, is_single_sample, update_index, args, mm_multiplier): 271 | ''' 272 | Creates the trainable network model 273 | :param item_emb: embedding matrix of items 274 | :param user_emb: embedding matrix of users 275 | :param user_f_emb: embedding matrix of user features 276 | :param is_training: True/False depending on if Train/Test 277 | :param is_single_sample: True/False if the batch has a single sample only 278 | :param update_index: variable for the current index to be updated in the memory module (not used in the RecSys paper) 279 | :param args: model configuration dictionary 280 | :return: relevant tensorflow variables for training and tensorboard visualization 281 | ''' 282 | self.dropout = args['dropout'] 283 | self.is_training = is_training 284 | self.keep_prob = tf.cond(self.is_training, lambda: 1 - self.dropout, lambda: 1.0) 285 | 286 | if args['add_train_noise']: 287 | self.add_train_noise = tf.cond(self.is_training, lambda: 1.0, lambda: 0.0) 288 | else: 289 | self.add_train_noise = tf.cond(self.is_training, lambda: 0.0, lambda: 0.0) 290 | 291 | inputs, inputs_contexts, mask, inputs_mask_loss, targets, \ 292 | inputs_skip, inputs_listen, targets_skip, targets_listen, session_item_embs, \ 293 | session_item_feedback, user_vals, time_since_last_session, top_context, number_of_tracks_in_sessions, user_f_emb_vals, top_context_future = self._extract_vals(item_emb, user_emb, user_f_emb) 294 | 295 | inputs_org = inputs # save the original inputs 296 | targets = targets_listen # we focus on predicting listened tracks during model training. 297 | 298 | user_f_emb_vals = tf.layers.dense(user_f_emb_vals, args['rnnsize'], activation='relu', name='init_features', reuse=tf.AUTO_REUSE) 299 | 300 | def context_rnn_run(inputs): 301 | choices = [inputs, inputs_contexts, number_of_tracks_in_sessions, time_since_last_session, top_context, top_context_future] 302 | chosen = [choices[cc] for cc in args['inputs_chosen']] # default = '0,1,2,3,4' 303 | inputs_combined = tf.concat(chosen, axis=-1) 304 | inputs_combined = self._embed_context(inputs_combined, 'context_sampler', args['rnnsize'], args['context_var']) 305 | 306 | if args['gpu']: 307 | hist_avgs_rnn = self._make_rnn_gpu(inputs_combined, 1, args['rnnsize'], 'histavg_rnn', 308 | maxlen=tf.reduce_sum(mask, 1) , init_state=user_f_emb_vals) 309 | else: 310 | exit(-1) 311 | 312 | return hist_avgs_rnn 313 | 314 | hist_rnn_listen = context_rnn_run(inputs_listen) 315 | hist_rnn_skip = context_rnn_run(inputs_skip) 316 | hist_rnn_both = context_rnn_run(inputs_org) 317 | 318 | rep_choices = [hist_rnn_both, hist_rnn_skip, hist_rnn_listen] 319 | rep_chosen = [rep_choices[cc] for cc in args['session_rep']] # default = '0,1' 320 | hist_avgs_rnn = tf.concat(rep_chosen, axis=-1) 321 | 322 | avoid_div_by_zero = tf.cast(tf.expand_dims(inputs_mask_loss, -1) < 0.5, tf.float32) * 0.0000001 323 | rnn_pred = self._make_2_layers(hist_avgs_rnn, args['layersize'], name_add='1') 324 | rnn_pred = self._sample_target(rnn_pred, 'target_sampler', args['target_var'], name_add='1') 325 | rnn_pred_org = rnn_pred 326 | 327 | debugval = rnn_pred_org 328 | # Memory Module code start - this is not used in the RecSys paper. 329 | # Here we have: 1) The Spotify embedding and the output from the RNN model. 330 | # Use hist_avgs_rnn, at the correct index (the only 1 on the inputs_mask_loss mask), as the key for similarity matching 331 | 332 | user_pred = user_vals / (tf.expand_dims(tf.norm(user_vals, ord=2, axis=1), -1)) 333 | key_emb, key_emb_ph, key_emb_init = self._make_embedding(args['MEMSIZE'], hist_avgs_rnn.shape[-1], 'key_emb', trainable=False, inittype=tf.zeros) 334 | session_emb, session_emb_ph, session_emb_init = self._make_embedding(args['MEMSIZE'], 40, 'session_emb', trainable=False, inittype=tf.zeros) 335 | error_emb, error_emb_ph, error_emb_init = self._make_embedding(args['MEMSIZE'], 1, 'error_emb', trainable=False, inittype=tf.zeros) 336 | 337 | if args['eval_mem_switch'] > -0.01: 338 | current_index = tf.cast(tf.argmax(inputs_mask_loss, -1), tf.int32) 339 | 340 | indexinto = tf.range(args['batch_size']) 341 | current_index = tf.concat((tf.expand_dims(indexinto, -1), tf.expand_dims(current_index, -1)), -1) 342 | 343 | current_target = tf.gather_nd(targets, current_index, name='current-target') 344 | current_key = tf.gather_nd(hist_avgs_rnn, current_index, name='current-key') 345 | 346 | # find most similar key compared to current_target 347 | cos_similarity = self.smooth_cosine_similarity(current_key, key_emb) # [batch, n_session] 348 | neigh_sim, neigh_num = tf.nn.top_k(cos_similarity, k=args['mem_top_k']) # [batch_size, memory_size] 349 | print(cos_similarity, neigh_sim, neigh_num) 350 | session_neighborhood = tf.nn.embedding_lookup(session_emb, neigh_num) # [batch_size, memory_size, memory_dim] 351 | key_neighbourhood = tf.nn.embedding_lookup(key_emb, neigh_num) 352 | error_neighbourhood = tf.squeeze(tf.nn.embedding_lookup(error_emb, neigh_num), -1) 353 | 354 | neigh_sim = tf.expand_dims(tf.nn.softmax(neigh_sim), axis=2) 355 | 356 | mm_pred = tf.reduce_sum(neigh_sim * session_neighborhood, axis=1) 357 | mm_attn_w_key = tf.reduce_sum(neigh_sim * key_neighbourhood, axis=1) 358 | 359 | # Memory Module code end 360 | mm_pred = mm_pred 361 | rnn_pred = tf.gather_nd(rnn_pred, current_index, name='rnn-pred') 362 | 363 | mm_pred = tf.cond(tf.reduce_sum(mm_pred) > 0.01, lambda : mm_pred / (tf.expand_dims(tf.norm(mm_pred, ord=2, axis=1), -1)), lambda : tf.zeros(mm_pred.shape)) 364 | rnn_pred = rnn_pred / (tf.expand_dims(tf.norm(rnn_pred, ord=2, axis=1), -1)) 365 | 366 | # combine mm_pred and rnn_pred, based on their "keys" 367 | combined_keys = tf.concat((error_neighbourhood, mm_attn_w_key, current_key, mm_attn_w_key - current_key, tf.expand_dims(tf.reduce_sum(mm_attn_w_key*current_key, -1), 1)), axis=-1) #), axis=-1) 368 | 369 | combined_single = tf.layers.dense(combined_keys, 3, name='MEMORYNETWORK_alpha_combine_rnn_mm') 370 | combined_multiplier_exp = tf.expand_dims([0, 0, (-100 + 100*mm_multiplier)], axis=0) 371 | combined_multiplier_not_exp = tf.expand_dims([1, 1, mm_multiplier], axis=0) 372 | combined_single = combined_single + combined_multiplier_exp 373 | combined_single = tf.nn.softmax(combined_single, axis=-1) 374 | 375 | combined_single_0 = tf.expand_dims(combined_single[:, 0], -1) 376 | combined_single_1 = tf.expand_dims(combined_single[:, 1], -1) 377 | combined_single_2 = tf.expand_dims(combined_single[:, 2], -1) 378 | 379 | combined_pred_total = combined_single_0*user_pred + combined_single_1*rnn_pred + combined_single_2*mm_pred 380 | combined_pred_total = combined_pred_total / (tf.expand_dims(tf.norm(combined_pred_total, ord=2, axis=1), -1)) 381 | 382 | combined_pred_total = tf.expand_dims(combined_pred_total, 1) 383 | 384 | multi_alpha = hist_avgs_rnn 385 | multi_alpha = tf.layers.dense(multi_alpha, 2, name='multi_alpha') 386 | multi_alpha = tf.nn.softmax(multi_alpha, axis=-1) 387 | 388 | multi_alpha_0 = tf.expand_dims(multi_alpha[:, :, 0], -1) 389 | multi_alpha_1 = tf.expand_dims(multi_alpha[:, :, 1], -1) 390 | 391 | # [rnn_pred, user_pred, rnn + user, softmax rnn + user] 392 | if args['pred_type'] == 0: 393 | multi_pred = rnn_pred_org 394 | elif args['pred_type'] == 1: 395 | multi_pred = tf.expand_dims(user_pred, 1) + 0*rnn_pred_org 396 | elif args['pred_type'] == 2: 397 | multi_pred = tf.expand_dims(user_pred, 1) + rnn_pred_org 398 | elif args['pred_type'] == 3: # default 399 | multi_pred = multi_alpha_0 * tf.expand_dims(user_pred, 1) + multi_alpha_1 * rnn_pred_org 400 | multi_pred = multi_pred / (tf.expand_dims(tf.norm(multi_pred + avoid_div_by_zero, ord=2, axis=2), -1)) 401 | 402 | if args['eval_mem_switch'] < -0.01: 403 | combined_pred_total = multi_pred 404 | combined_single = multi_alpha 405 | 406 | loss_multi = self._make_loss(multi_pred, targets, inputs_mask_loss) 407 | loss = self._make_loss(combined_pred_total, targets, inputs_mask_loss) 408 | 409 | print(loss, loss_multi) 410 | 411 | loss_multi = tf.reduce_sum(loss_multi) / tf.reduce_sum(inputs_mask_loss) 412 | loss = tf.reduce_sum(loss) / tf.reduce_sum(inputs_mask_loss) 413 | 414 | org_avgcos = tf.reduce_sum(targets * combined_pred_total, -1) * inputs_mask_loss 415 | avgcos = tf.reduce_sum(org_avgcos) / tf.reduce_sum(inputs_mask_loss) 416 | 417 | org_avgcos_multi = tf.reduce_sum(targets * multi_pred, -1) * inputs_mask_loss 418 | avgcos_multi = tf.reduce_sum(org_avgcos_multi) / tf.reduce_sum(inputs_mask_loss) 419 | 420 | # compute dot product against items in session 421 | tmp = tf.expand_dims(combined_pred_total, axis=-2) 422 | session_item_scores = tf.reduce_sum(tmp * session_item_embs, axis=-1) 423 | 424 | # After reading, insert the current (key, target) into the module, where again target needs to be indexed by the correct index 425 | if args['eval_mem_switch'] > -0.01: 426 | key_emb_update = tf.scatter_update(key_emb, update_index, current_key) 427 | session_emb_update = tf.scatter_update(session_emb, update_index, current_target) 428 | current_index_single = tf.cast(tf.zeros(args['batch_size']), tf.int32) 429 | indexinto_single = tf.range(args['batch_size']) 430 | current_index_single = tf.concat((tf.expand_dims(indexinto_single, -1), tf.expand_dims(current_index_single, -1)), -1) 431 | current_pred = tf.gather_nd(combined_pred_total, current_index_single, name='current_pred') 432 | current_cosine = tf.expand_dims(tf.reduce_sum(current_pred * current_target, -1), -1) 433 | error_emb_update = tf.scatter_update(error_emb, update_index, current_cosine) 434 | else: 435 | key_emb_update = 0 436 | session_emb_update = 0 437 | error_emb_update = 0 438 | combined_single_0 = multi_alpha_0 439 | combined_single_1 = multi_alpha_0 440 | combined_single_2 = multi_alpha_0 441 | error_neighbourhood = 0 442 | 443 | # tensorboard vars 444 | sum1 = tf.summary.scalar("train cosine sim", avgcos) # tf.reduce_sum(avgcos)) 445 | sum1a = tf.summary.scalar("train cosine rnn+user", avgcos_multi) # tf.reduce_sum(avgcos)) 446 | 447 | sum2 = tf.summary.scalar("train loss", tf.reduce_mean(loss)) 448 | sum2a = tf.summary.scalar("train loss rnn+user", tf.reduce_mean(loss_multi)) 449 | 450 | multi_alpha_sum = tf.summary.scalar('multi_alpha--0', tf.reduce_mean(multi_alpha_0)) 451 | 452 | combined_single_sum_0 = tf.summary.scalar('combined_single_sum--0', tf.reduce_mean(combined_single_0)) 453 | combined_single_sum_1 = tf.summary.scalar('combined_single_sum--1', tf.reduce_mean(combined_single_1)) 454 | combined_single_sum_2 = tf.summary.scalar('combined_single_sum--2', tf.reduce_mean(combined_single_2)) 455 | 456 | summary_op = tf.summary.merge([sum1, sum2, sum1a, sum2a, multi_alpha_sum, combined_single_sum_0, combined_single_sum_1, combined_single_sum_2]) 457 | 458 | pVar_val_loss = tf.placeholder(tf.float32, []) 459 | tmp = tf.Variable(0, dtype=tf.float32, name="tensorbardVar") 460 | update_pVar_val_loss = tmp.assign(pVar_val_loss) 461 | sum3 = tf.summary.scalar("val loss", tmp) 462 | 463 | pVar_val_avgcos = tf.placeholder(tf.float32, []) 464 | tmp = tf.Variable(0, dtype=tf.float32, name="tensorbardVar") 465 | update_pVar_val_avgcos = tmp.assign(pVar_val_avgcos) 466 | sum4 = tf.summary.scalar("val cosine sim", tmp) 467 | 468 | 469 | pVar_test_loss = tf.placeholder(tf.float32, []) 470 | tmp = tf.Variable(0, dtype=tf.float32, name="tensorbardVar") 471 | update_pVar_test_loss = tmp.assign(pVar_test_loss) 472 | sum5 = tf.summary.scalar("test loss", tmp) 473 | 474 | pVar_test_avgcos = tf.placeholder(tf.float32, []) 475 | tmp = tf.Variable(0, dtype=tf.float32, name="tensorbardVar") 476 | update_pVar_test_avgcos = tmp.assign(pVar_test_avgcos) 477 | sum6 = tf.summary.scalar("test cosine sim", tmp) 478 | 479 | # stuff for val mr and mrr 480 | pVar_val_mrr = tf.placeholder(tf.float32, []) 481 | tmp = tf.Variable(0, dtype=tf.float32, name="tensorbardVar") 482 | update_pVar_val_mrr = tmp.assign(pVar_val_mrr) 483 | sum7 = tf.summary.scalar("val mrr", tmp) 484 | 485 | pVar_val_mr = tf.placeholder(tf.float32, []) 486 | tmp = tf.Variable(0, dtype=tf.float32, name="tensorbardVar") 487 | update_pVar_val_mr = tmp.assign(pVar_val_mr) 488 | sum8 = tf.summary.scalar("val mr", tmp) 489 | 490 | # stuff for val mr and mrr 491 | pVar_test_mrr = tf.placeholder(tf.float32, []) 492 | tmp = tf.Variable(0, dtype=tf.float32, name="tensorbardVar") 493 | update_pVar_test_mrr = tmp.assign(pVar_test_mrr) 494 | sum9 = tf.summary.scalar("test mrr", tmp) 495 | 496 | pVar_test_mr = tf.placeholder(tf.float32, []) 497 | tmp = tf.Variable(0, dtype=tf.float32, name="tensorbardVar") 498 | update_pVar_test_mr = tmp.assign(pVar_test_mr) 499 | sum10 = tf.summary.scalar("test mr", tmp) 500 | 501 | summary_op_val = tf.summary.merge([sum3, sum4, sum7, sum8]) 502 | summary_op_test = tf.summary.merge([sum5, sum6, sum9, sum10]) 503 | 504 | return ( 505 | debugval, 506 | loss, 507 | loss_multi, 508 | combined_pred_total, 509 | multi_pred, 510 | org_avgcos, 511 | org_avgcos_multi, 512 | loss, 513 | avgcos, 514 | loss, 515 | summary_op, 516 | summary_op_val, 517 | summary_op_test, 518 | [ 519 | pVar_val_loss, 520 | pVar_val_avgcos, 521 | pVar_val_mrr, 522 | pVar_val_mr 523 | ], 524 | [ 525 | update_pVar_val_loss, 526 | update_pVar_val_avgcos, 527 | update_pVar_val_mrr, 528 | update_pVar_val_mr 529 | ], 530 | [ 531 | pVar_test_loss, 532 | pVar_test_avgcos, 533 | pVar_test_mrr, 534 | pVar_test_mr 535 | ], 536 | [ 537 | update_pVar_test_loss, 538 | update_pVar_test_avgcos, 539 | update_pVar_test_mrr, 540 | update_pVar_test_mr 541 | ], 542 | avgcos, 543 | session_item_scores, 544 | targets, 545 | [ 546 | key_emb_update, 547 | session_emb_update, 548 | error_emb_update 549 | ], 550 | [ 551 | key_emb, 552 | key_emb_ph, 553 | key_emb_init 554 | ], 555 | [ 556 | session_emb, 557 | session_emb_ph, 558 | session_emb_init 559 | ], 560 | [ 561 | error_emb, 562 | error_emb_ph, 563 | error_emb_init 564 | ], 565 | combined_single_0, 566 | combined_single_1, 567 | error_neighbourhood, 568 | combined_single, 569 | ) 570 | --------------------------------------------------------------------------------