├── README.md ├── data ├── ml-100k │ ├── README │ ├── u.data │ ├── u.item │ └── u.user └── ml-1m │ ├── movies.dat │ ├── ratings.dat │ └── users.dat ├── data_loader ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── avazu.cpython-36.pyc │ ├── criteo.cpython-36.pyc │ ├── data_loader.cpython-36.pyc │ └── movielens.cpython-36.pyc ├── avazu.py ├── criteo.py ├── data_loader.py └── movielens.py ├── engine.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── factorizer.cpython-36.pyc │ ├── modules.cpython-36.pyc │ └── pep_embedding.cpython-36.pyc ├── factorizer.py ├── modules.py └── pep_embedding.py ├── train_avazu.py ├── train_avazu_retrain.py ├── train_criteo.py ├── train_criteo_retrain.py ├── train_ml-1m.py ├── train_ml-1m_retrain.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-35.pyc ├── __init__.cpython-36.pyc ├── evaluate.cpython-35.pyc ├── evaluate.cpython-36.pyc ├── performance_optimization.cpython-35.pyc ├── train.cpython-35.pyc └── train.cpython-36.pyc ├── evaluate.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Learnable Embedding Sizes for Recommender Systems 2 | This repository contains PyTorch Implementation of ICLR 2021 paper: [*Learnable Embedding Sizes for Recommender Systems*.](https://arxiv.org/abs/2101.07577) 3 | Please check our paper for more details about our work if you are interested. 4 | 5 | ## Usage 6 | Following the steps below to run our codes: 7 | 8 | ### 1. Install torchfm 9 | `pip install torchfm` 10 | 11 | For more information about torchfm, please see: 12 | 13 | 14 | 15 | ### 2. Download datasets 16 | We provide MovieLens-1M dataset in `data/ml-1m`. If you want to run PEP on Criteo and Avazu datasets, 17 | you need to download the dataset at [Criteo](https://www.kaggle.com/c/criteo-display-ad-challenge) and [Avazu](https://www.kaggle.com/c/avazu-ctr-prediction). 18 | 19 | ### 3. Put the data in `data/criteo` or `data/avazu` 20 | Raw data should be stored with the following file directory: 21 | 22 | `data/criteo/train.txt` 23 | 24 | `data/avazu/train` 25 | 26 | ### 4. Specify the hyper-parameters 27 | 28 | For learning embedding sizes, the hyper-parameters are in `train_[dataset].py` 29 | 30 | For retraining learned embedding sizes, the hyper-parameters are in `train_[dataset]_retrain.py` 31 | 32 | ### 5. Learning embedding sizes 33 | 34 | Run `train_[dataset].py` to learn embedding sizes. Learned embedding will be saved in 35 | `tmp/embedding/fm/[alias]/`, named as number of parameters. 36 | 37 | ### 6. Retrain the pruned embedding 38 | 39 | Run `train_[dataset]_retrain.py` to retrain the pruned embedding table. You need to specify what embedding table need to be retrain by hyper-parameter `retrain_emb_param`. 40 | 41 | ## Requirements 42 | + Python 3 43 | + PyTorch 1.1.0 44 | 45 | ## Citation 46 | If you find this repo is useful for you, please kindly cite our paper. 47 | ``` 48 | @inproceedings{liu2021learnable, 49 | title={Learnable Embedding Sizes for Recommender Systems}, 50 | author={Siyi Liu and Chen Gao and Yihong Chen and Depeng Jin and Yong Li}, 51 | booktitle={International Conference on Learning Representations}, 52 | year={2021}, 53 | url={https://openreview.net/forum?id=vQzcqQWIS0q} 54 | } 55 | ``` 56 | 57 | ## Acknowledgment 58 | The structure of this code is largely based on [lambda-opt](https://github.com/yihong-chen/lambda-opt). 59 | -------------------------------------------------------------------------------- /data/ml-100k/README: -------------------------------------------------------------------------------- 1 | SUMMARY & USAGE LICENSE 2 | ============================================= 3 | 4 | MovieLens data sets were collected by the GroupLens Research Project 5 | at the University of Minnesota. 6 | 7 | This data set consists of: 8 | * 100,000 ratings (1-5) from 943 users on 1682 movies. 9 | * Each user has rated at least 20 movies. 10 | * Simple demographic info for the users (age, gender, occupation, zip) 11 | 12 | The data was collected through the MovieLens web site 13 | (movielens.umn.edu) during the seven-month period from September 19th, 14 | 1997 through April 22nd, 1998. This data has been cleaned up - users 15 | who had less than 20 ratings or did not have complete demographic 16 | information were removed from this data set. Detailed descriptions of 17 | the data file can be found at the end of this file. 18 | 19 | Neither the University of Minnesota nor any of the researchers 20 | involved can guarantee the correctness of the data, its suitability 21 | for any particular purpose, or the validity of results based on the 22 | use of the data set. The data set may be used for any research 23 | purposes under the following conditions: 24 | 25 | * The user may not state or imply any endorsement from the 26 | University of Minnesota or the GroupLens Research Group. 27 | 28 | * The user must acknowledge the use of the data set in 29 | publications resulting from the use of the data set 30 | (see below for citation information). 31 | 32 | * The user may not redistribute the data without separate 33 | permission. 34 | 35 | * The user may not use this information for any commercial or 36 | revenue-bearing purposes without first obtaining permission 37 | from a faculty member of the GroupLens Research Project at the 38 | University of Minnesota. 39 | 40 | If you have any further questions or comments, please contact GroupLens 41 | . 42 | 43 | CITATION 44 | ============================================== 45 | 46 | To acknowledge use of the dataset in publications, please cite the 47 | following paper: 48 | 49 | F. Maxwell Harper and Joseph A. Konstan. 2015. The MovieLens Datasets: 50 | History and Context. ACM Transactions on Interactive Intelligent 51 | Systems (TiiS) 5, 4, Article 19 (December 2015), 19 pages. 52 | DOI=http://dx.doi.org/10.1145/2827872 53 | 54 | 55 | ACKNOWLEDGEMENTS 56 | ============================================== 57 | 58 | Thanks to Al Borchers for cleaning up this data and writing the 59 | accompanying scripts. 60 | 61 | PUBLISHED WORK THAT HAS USED THIS DATASET 62 | ============================================== 63 | 64 | Herlocker, J., Konstan, J., Borchers, A., Riedl, J.. An Algorithmic 65 | Framework for Performing Collaborative Filtering. Proceedings of the 66 | 1999 Conference on Research and Development in Information 67 | Retrieval. Aug. 1999. 68 | 69 | FURTHER INFORMATION ABOUT THE GROUPLENS RESEARCH PROJECT 70 | ============================================== 71 | 72 | The GroupLens Research Project is a research group in the Department 73 | of Computer Science and Engineering at the University of Minnesota. 74 | Members of the GroupLens Research Project are involved in many 75 | research projects related to the fields of information filtering, 76 | collaborative filtering, and recommender systems. The project is lead 77 | by professors John Riedl and Joseph Konstan. The project began to 78 | explore automated collaborative filtering in 1992, but is most well 79 | known for its world wide trial of an automated collaborative filtering 80 | system for Usenet news in 1996. The technology developed in the 81 | Usenet trial formed the base for the formation of Net Perceptions, 82 | Inc., which was founded by members of GroupLens Research. Since then 83 | the project has expanded its scope to research overall information 84 | filtering solutions, integrating in content-based methods as well as 85 | improving current collaborative filtering technology. 86 | 87 | Further information on the GroupLens Research project, including 88 | research publications, can be found at the following web site: 89 | 90 | http://www.grouplens.org/ 91 | 92 | GroupLens Research currently operates a movie recommender based on 93 | collaborative filtering: 94 | 95 | http://www.movielens.org/ 96 | 97 | DETAILED DESCRIPTIONS OF DATA FILES 98 | ============================================== 99 | 100 | Here are brief descriptions of the data. 101 | 102 | ml-data.tar.gz -- Compressed tar file. To rebuild the u data files do this: 103 | gunzip ml-data.tar.gz 104 | tar xvf ml-data.tar 105 | mku.sh 106 | 107 | u.data -- The full u data set, 100000 ratings by 943 users on 1682 items. 108 | Each user has rated at least 20 movies. Users and items are 109 | numbered consecutively from 1. The data is randomly 110 | ordered. This is a tab separated list of 111 | user id | item id | rating | timestamp. 112 | The time stamps are unix seconds since 1/1/1970 UTC 113 | 114 | u.info -- The number of users, items, and ratings in the u data set. 115 | 116 | u.item -- Information about the items (movies); this is a tab separated 117 | list of 118 | movie id | movie title | release date | video release date | 119 | IMDb URL | unknown | Action | Adventure | Animation | 120 | Children's | Comedy | Crime | Documentary | Drama | Fantasy | 121 | Film-Noir | Horror | Musical | Mystery | Romance | Sci-Fi | 122 | Thriller | War | Western | 123 | The last 19 fields are the genres, a 1 indicates the movie 124 | is of that genre, a 0 indicates it is not; movies can be in 125 | several genres at once. 126 | The movie ids are the ones used in the u.data data set. 127 | 128 | u.genre -- A list of the genres. 129 | 130 | u.user -- Demographic information about the users; this is a tab 131 | separated list of 132 | user id | age | gender | occupation | zip code 133 | The user ids are the ones used in the u.data data set. 134 | 135 | u.occupation -- A list of the occupations. 136 | 137 | u1.base -- The data sets u1.base and u1.test through u5.base and u5.test 138 | u1.test are 80%/20% splits of the u data into training and test data. 139 | u2.base Each of u1, ..., u5 have disjoint test sets; this if for 140 | u2.test 5 fold cross validation (where you repeat your experiment 141 | u3.base with each training and test set and average the results). 142 | u3.test These data sets can be generated from u.data by mku.sh. 143 | u4.base 144 | u4.test 145 | u5.base 146 | u5.test 147 | 148 | ua.base -- The data sets ua.base, ua.test, ub.base, and ub.test 149 | ua.test split the u data into a training set and a test set with 150 | ub.base exactly 10 ratings per user in the test set. The sets 151 | ub.test ua.test and ub.test are disjoint. These data sets can 152 | be generated from u.data by mku.sh. 153 | 154 | allbut.pl -- The script that generates training and test sets where 155 | all but n of a users ratings are in the training data. 156 | 157 | mku.sh -- A shell script to generate all the u data sets from u.data. 158 | -------------------------------------------------------------------------------- /data/ml-100k/u.item: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/data/ml-100k/u.item -------------------------------------------------------------------------------- /data/ml-100k/u.user: -------------------------------------------------------------------------------- 1 | 1|24|M|technician|85711 2 | 2|53|F|other|94043 3 | 3|23|M|writer|32067 4 | 4|24|M|technician|43537 5 | 5|33|F|other|15213 6 | 6|42|M|executive|98101 7 | 7|57|M|administrator|91344 8 | 8|36|M|administrator|05201 9 | 9|29|M|student|01002 10 | 10|53|M|lawyer|90703 11 | 11|39|F|other|30329 12 | 12|28|F|other|06405 13 | 13|47|M|educator|29206 14 | 14|45|M|scientist|55106 15 | 15|49|F|educator|97301 16 | 16|21|M|entertainment|10309 17 | 17|30|M|programmer|06355 18 | 18|35|F|other|37212 19 | 19|40|M|librarian|02138 20 | 20|42|F|homemaker|95660 21 | 21|26|M|writer|30068 22 | 22|25|M|writer|40206 23 | 23|30|F|artist|48197 24 | 24|21|F|artist|94533 25 | 25|39|M|engineer|55107 26 | 26|49|M|engineer|21044 27 | 27|40|F|librarian|30030 28 | 28|32|M|writer|55369 29 | 29|41|M|programmer|94043 30 | 30|7|M|student|55436 31 | 31|24|M|artist|10003 32 | 32|28|F|student|78741 33 | 33|23|M|student|27510 34 | 34|38|F|administrator|42141 35 | 35|20|F|homemaker|42459 36 | 36|19|F|student|93117 37 | 37|23|M|student|55105 38 | 38|28|F|other|54467 39 | 39|41|M|entertainment|01040 40 | 40|38|M|scientist|27514 41 | 41|33|M|engineer|80525 42 | 42|30|M|administrator|17870 43 | 43|29|F|librarian|20854 44 | 44|26|M|technician|46260 45 | 45|29|M|programmer|50233 46 | 46|27|F|marketing|46538 47 | 47|53|M|marketing|07102 48 | 48|45|M|administrator|12550 49 | 49|23|F|student|76111 50 | 50|21|M|writer|52245 51 | 51|28|M|educator|16509 52 | 52|18|F|student|55105 53 | 53|26|M|programmer|55414 54 | 54|22|M|executive|66315 55 | 55|37|M|programmer|01331 56 | 56|25|M|librarian|46260 57 | 57|16|M|none|84010 58 | 58|27|M|programmer|52246 59 | 59|49|M|educator|08403 60 | 60|50|M|healthcare|06472 61 | 61|36|M|engineer|30040 62 | 62|27|F|administrator|97214 63 | 63|31|M|marketing|75240 64 | 64|32|M|educator|43202 65 | 65|51|F|educator|48118 66 | 66|23|M|student|80521 67 | 67|17|M|student|60402 68 | 68|19|M|student|22904 69 | 69|24|M|engineer|55337 70 | 70|27|M|engineer|60067 71 | 71|39|M|scientist|98034 72 | 72|48|F|administrator|73034 73 | 73|24|M|student|41850 74 | 74|39|M|scientist|T8H1N 75 | 75|24|M|entertainment|08816 76 | 76|20|M|student|02215 77 | 77|30|M|technician|29379 78 | 78|26|M|administrator|61801 79 | 79|39|F|administrator|03755 80 | 80|34|F|administrator|52241 81 | 81|21|M|student|21218 82 | 82|50|M|programmer|22902 83 | 83|40|M|other|44133 84 | 84|32|M|executive|55369 85 | 85|51|M|educator|20003 86 | 86|26|M|administrator|46005 87 | 87|47|M|administrator|89503 88 | 88|49|F|librarian|11701 89 | 89|43|F|administrator|68106 90 | 90|60|M|educator|78155 91 | 91|55|M|marketing|01913 92 | 92|32|M|entertainment|80525 93 | 93|48|M|executive|23112 94 | 94|26|M|student|71457 95 | 95|31|M|administrator|10707 96 | 96|25|F|artist|75206 97 | 97|43|M|artist|98006 98 | 98|49|F|executive|90291 99 | 99|20|M|student|63129 100 | 100|36|M|executive|90254 101 | 101|15|M|student|05146 102 | 102|38|M|programmer|30220 103 | 103|26|M|student|55108 104 | 104|27|M|student|55108 105 | 105|24|M|engineer|94043 106 | 106|61|M|retired|55125 107 | 107|39|M|scientist|60466 108 | 108|44|M|educator|63130 109 | 109|29|M|other|55423 110 | 110|19|M|student|77840 111 | 111|57|M|engineer|90630 112 | 112|30|M|salesman|60613 113 | 113|47|M|executive|95032 114 | 114|27|M|programmer|75013 115 | 115|31|M|engineer|17110 116 | 116|40|M|healthcare|97232 117 | 117|20|M|student|16125 118 | 118|21|M|administrator|90210 119 | 119|32|M|programmer|67401 120 | 120|47|F|other|06260 121 | 121|54|M|librarian|99603 122 | 122|32|F|writer|22206 123 | 123|48|F|artist|20008 124 | 124|34|M|student|60615 125 | 125|30|M|lawyer|22202 126 | 126|28|F|lawyer|20015 127 | 127|33|M|none|73439 128 | 128|24|F|marketing|20009 129 | 129|36|F|marketing|07039 130 | 130|20|M|none|60115 131 | 131|59|F|administrator|15237 132 | 132|24|M|other|94612 133 | 133|53|M|engineer|78602 134 | 134|31|M|programmer|80236 135 | 135|23|M|student|38401 136 | 136|51|M|other|97365 137 | 137|50|M|educator|84408 138 | 138|46|M|doctor|53211 139 | 139|20|M|student|08904 140 | 140|30|F|student|32250 141 | 141|49|M|programmer|36117 142 | 142|13|M|other|48118 143 | 143|42|M|technician|08832 144 | 144|53|M|programmer|20910 145 | 145|31|M|entertainment|V3N4P 146 | 146|45|M|artist|83814 147 | 147|40|F|librarian|02143 148 | 148|33|M|engineer|97006 149 | 149|35|F|marketing|17325 150 | 150|20|F|artist|02139 151 | 151|38|F|administrator|48103 152 | 152|33|F|educator|68767 153 | 153|25|M|student|60641 154 | 154|25|M|student|53703 155 | 155|32|F|other|11217 156 | 156|25|M|educator|08360 157 | 157|57|M|engineer|70808 158 | 158|50|M|educator|27606 159 | 159|23|F|student|55346 160 | 160|27|M|programmer|66215 161 | 161|50|M|lawyer|55104 162 | 162|25|M|artist|15610 163 | 163|49|M|administrator|97212 164 | 164|47|M|healthcare|80123 165 | 165|20|F|other|53715 166 | 166|47|M|educator|55113 167 | 167|37|M|other|L9G2B 168 | 168|48|M|other|80127 169 | 169|52|F|other|53705 170 | 170|53|F|healthcare|30067 171 | 171|48|F|educator|78750 172 | 172|55|M|marketing|22207 173 | 173|56|M|other|22306 174 | 174|30|F|administrator|52302 175 | 175|26|F|scientist|21911 176 | 176|28|M|scientist|07030 177 | 177|20|M|programmer|19104 178 | 178|26|M|other|49512 179 | 179|15|M|entertainment|20755 180 | 180|22|F|administrator|60202 181 | 181|26|M|executive|21218 182 | 182|36|M|programmer|33884 183 | 183|33|M|scientist|27708 184 | 184|37|M|librarian|76013 185 | 185|53|F|librarian|97403 186 | 186|39|F|executive|00000 187 | 187|26|M|educator|16801 188 | 188|42|M|student|29440 189 | 189|32|M|artist|95014 190 | 190|30|M|administrator|95938 191 | 191|33|M|administrator|95161 192 | 192|42|M|educator|90840 193 | 193|29|M|student|49931 194 | 194|38|M|administrator|02154 195 | 195|42|M|scientist|93555 196 | 196|49|M|writer|55105 197 | 197|55|M|technician|75094 198 | 198|21|F|student|55414 199 | 199|30|M|writer|17604 200 | 200|40|M|programmer|93402 201 | 201|27|M|writer|E2A4H 202 | 202|41|F|educator|60201 203 | 203|25|F|student|32301 204 | 204|52|F|librarian|10960 205 | 205|47|M|lawyer|06371 206 | 206|14|F|student|53115 207 | 207|39|M|marketing|92037 208 | 208|43|M|engineer|01720 209 | 209|33|F|educator|85710 210 | 210|39|M|engineer|03060 211 | 211|66|M|salesman|32605 212 | 212|49|F|educator|61401 213 | 213|33|M|executive|55345 214 | 214|26|F|librarian|11231 215 | 215|35|M|programmer|63033 216 | 216|22|M|engineer|02215 217 | 217|22|M|other|11727 218 | 218|37|M|administrator|06513 219 | 219|32|M|programmer|43212 220 | 220|30|M|librarian|78205 221 | 221|19|M|student|20685 222 | 222|29|M|programmer|27502 223 | 223|19|F|student|47906 224 | 224|31|F|educator|43512 225 | 225|51|F|administrator|58202 226 | 226|28|M|student|92103 227 | 227|46|M|executive|60659 228 | 228|21|F|student|22003 229 | 229|29|F|librarian|22903 230 | 230|28|F|student|14476 231 | 231|48|M|librarian|01080 232 | 232|45|M|scientist|99709 233 | 233|38|M|engineer|98682 234 | 234|60|M|retired|94702 235 | 235|37|M|educator|22973 236 | 236|44|F|writer|53214 237 | 237|49|M|administrator|63146 238 | 238|42|F|administrator|44124 239 | 239|39|M|artist|95628 240 | 240|23|F|educator|20784 241 | 241|26|F|student|20001 242 | 242|33|M|educator|31404 243 | 243|33|M|educator|60201 244 | 244|28|M|technician|80525 245 | 245|22|M|student|55109 246 | 246|19|M|student|28734 247 | 247|28|M|engineer|20770 248 | 248|25|M|student|37235 249 | 249|25|M|student|84103 250 | 250|29|M|executive|95110 251 | 251|28|M|doctor|85032 252 | 252|42|M|engineer|07733 253 | 253|26|F|librarian|22903 254 | 254|44|M|educator|42647 255 | 255|23|M|entertainment|07029 256 | 256|35|F|none|39042 257 | 257|17|M|student|77005 258 | 258|19|F|student|77801 259 | 259|21|M|student|48823 260 | 260|40|F|artist|89801 261 | 261|28|M|administrator|85202 262 | 262|19|F|student|78264 263 | 263|41|M|programmer|55346 264 | 264|36|F|writer|90064 265 | 265|26|M|executive|84601 266 | 266|62|F|administrator|78756 267 | 267|23|M|engineer|83716 268 | 268|24|M|engineer|19422 269 | 269|31|F|librarian|43201 270 | 270|18|F|student|63119 271 | 271|51|M|engineer|22932 272 | 272|33|M|scientist|53706 273 | 273|50|F|other|10016 274 | 274|20|F|student|55414 275 | 275|38|M|engineer|92064 276 | 276|21|M|student|95064 277 | 277|35|F|administrator|55406 278 | 278|37|F|librarian|30033 279 | 279|33|M|programmer|85251 280 | 280|30|F|librarian|22903 281 | 281|15|F|student|06059 282 | 282|22|M|administrator|20057 283 | 283|28|M|programmer|55305 284 | 284|40|M|executive|92629 285 | 285|25|M|programmer|53713 286 | 286|27|M|student|15217 287 | 287|21|M|salesman|31211 288 | 288|34|M|marketing|23226 289 | 289|11|M|none|94619 290 | 290|40|M|engineer|93550 291 | 291|19|M|student|44106 292 | 292|35|F|programmer|94703 293 | 293|24|M|writer|60804 294 | 294|34|M|technician|92110 295 | 295|31|M|educator|50325 296 | 296|43|F|administrator|16803 297 | 297|29|F|educator|98103 298 | 298|44|M|executive|01581 299 | 299|29|M|doctor|63108 300 | 300|26|F|programmer|55106 301 | 301|24|M|student|55439 302 | 302|42|M|educator|77904 303 | 303|19|M|student|14853 304 | 304|22|F|student|71701 305 | 305|23|M|programmer|94086 306 | 306|45|M|other|73132 307 | 307|25|M|student|55454 308 | 308|60|M|retired|95076 309 | 309|40|M|scientist|70802 310 | 310|37|M|educator|91711 311 | 311|32|M|technician|73071 312 | 312|48|M|other|02110 313 | 313|41|M|marketing|60035 314 | 314|20|F|student|08043 315 | 315|31|M|educator|18301 316 | 316|43|F|other|77009 317 | 317|22|M|administrator|13210 318 | 318|65|M|retired|06518 319 | 319|38|M|programmer|22030 320 | 320|19|M|student|24060 321 | 321|49|F|educator|55413 322 | 322|20|M|student|50613 323 | 323|21|M|student|19149 324 | 324|21|F|student|02176 325 | 325|48|M|technician|02139 326 | 326|41|M|administrator|15235 327 | 327|22|M|student|11101 328 | 328|51|M|administrator|06779 329 | 329|48|M|educator|01720 330 | 330|35|F|educator|33884 331 | 331|33|M|entertainment|91344 332 | 332|20|M|student|40504 333 | 333|47|M|other|V0R2M 334 | 334|32|M|librarian|30002 335 | 335|45|M|executive|33775 336 | 336|23|M|salesman|42101 337 | 337|37|M|scientist|10522 338 | 338|39|F|librarian|59717 339 | 339|35|M|lawyer|37901 340 | 340|46|M|engineer|80123 341 | 341|17|F|student|44405 342 | 342|25|F|other|98006 343 | 343|43|M|engineer|30093 344 | 344|30|F|librarian|94117 345 | 345|28|F|librarian|94143 346 | 346|34|M|other|76059 347 | 347|18|M|student|90210 348 | 348|24|F|student|45660 349 | 349|68|M|retired|61455 350 | 350|32|M|student|97301 351 | 351|61|M|educator|49938 352 | 352|37|F|programmer|55105 353 | 353|25|M|scientist|28480 354 | 354|29|F|librarian|48197 355 | 355|25|M|student|60135 356 | 356|32|F|homemaker|92688 357 | 357|26|M|executive|98133 358 | 358|40|M|educator|10022 359 | 359|22|M|student|61801 360 | 360|51|M|other|98027 361 | 361|22|M|student|44074 362 | 362|35|F|homemaker|85233 363 | 363|20|M|student|87501 364 | 364|63|M|engineer|01810 365 | 365|29|M|lawyer|20009 366 | 366|20|F|student|50670 367 | 367|17|M|student|37411 368 | 368|18|M|student|92113 369 | 369|24|M|student|91335 370 | 370|52|M|writer|08534 371 | 371|36|M|engineer|99206 372 | 372|25|F|student|66046 373 | 373|24|F|other|55116 374 | 374|36|M|executive|78746 375 | 375|17|M|entertainment|37777 376 | 376|28|F|other|10010 377 | 377|22|M|student|18015 378 | 378|35|M|student|02859 379 | 379|44|M|programmer|98117 380 | 380|32|M|engineer|55117 381 | 381|33|M|artist|94608 382 | 382|45|M|engineer|01824 383 | 383|42|M|administrator|75204 384 | 384|52|M|programmer|45218 385 | 385|36|M|writer|10003 386 | 386|36|M|salesman|43221 387 | 387|33|M|entertainment|37412 388 | 388|31|M|other|36106 389 | 389|44|F|writer|83702 390 | 390|42|F|writer|85016 391 | 391|23|M|student|84604 392 | 392|52|M|writer|59801 393 | 393|19|M|student|83686 394 | 394|25|M|administrator|96819 395 | 395|43|M|other|44092 396 | 396|57|M|engineer|94551 397 | 397|17|M|student|27514 398 | 398|40|M|other|60008 399 | 399|25|M|other|92374 400 | 400|33|F|administrator|78213 401 | 401|46|F|healthcare|84107 402 | 402|30|M|engineer|95129 403 | 403|37|M|other|06811 404 | 404|29|F|programmer|55108 405 | 405|22|F|healthcare|10019 406 | 406|52|M|educator|93109 407 | 407|29|M|engineer|03261 408 | 408|23|M|student|61755 409 | 409|48|M|administrator|98225 410 | 410|30|F|artist|94025 411 | 411|34|M|educator|44691 412 | 412|25|M|educator|15222 413 | 413|55|M|educator|78212 414 | 414|24|M|programmer|38115 415 | 415|39|M|educator|85711 416 | 416|20|F|student|92626 417 | 417|27|F|other|48103 418 | 418|55|F|none|21206 419 | 419|37|M|lawyer|43215 420 | 420|53|M|educator|02140 421 | 421|38|F|programmer|55105 422 | 422|26|M|entertainment|94533 423 | 423|64|M|other|91606 424 | 424|36|F|marketing|55422 425 | 425|19|M|student|58644 426 | 426|55|M|educator|01602 427 | 427|51|M|doctor|85258 428 | 428|28|M|student|55414 429 | 429|27|M|student|29205 430 | 430|38|M|scientist|98199 431 | 431|24|M|marketing|92629 432 | 432|22|M|entertainment|50311 433 | 433|27|M|artist|11211 434 | 434|16|F|student|49705 435 | 435|24|M|engineer|60007 436 | 436|30|F|administrator|17345 437 | 437|27|F|other|20009 438 | 438|51|F|administrator|43204 439 | 439|23|F|administrator|20817 440 | 440|30|M|other|48076 441 | 441|50|M|technician|55013 442 | 442|22|M|student|85282 443 | 443|35|M|salesman|33308 444 | 444|51|F|lawyer|53202 445 | 445|21|M|writer|92653 446 | 446|57|M|educator|60201 447 | 447|30|M|administrator|55113 448 | 448|23|M|entertainment|10021 449 | 449|23|M|librarian|55021 450 | 450|35|F|educator|11758 451 | 451|16|M|student|48446 452 | 452|35|M|administrator|28018 453 | 453|18|M|student|06333 454 | 454|57|M|other|97330 455 | 455|48|M|administrator|83709 456 | 456|24|M|technician|31820 457 | 457|33|F|salesman|30011 458 | 458|47|M|technician|Y1A6B 459 | 459|22|M|student|29201 460 | 460|44|F|other|60630 461 | 461|15|M|student|98102 462 | 462|19|F|student|02918 463 | 463|48|F|healthcare|75218 464 | 464|60|M|writer|94583 465 | 465|32|M|other|05001 466 | 466|22|M|student|90804 467 | 467|29|M|engineer|91201 468 | 468|28|M|engineer|02341 469 | 469|60|M|educator|78628 470 | 470|24|M|programmer|10021 471 | 471|10|M|student|77459 472 | 472|24|M|student|87544 473 | 473|29|M|student|94708 474 | 474|51|M|executive|93711 475 | 475|30|M|programmer|75230 476 | 476|28|M|student|60440 477 | 477|23|F|student|02125 478 | 478|29|M|other|10019 479 | 479|30|M|educator|55409 480 | 480|57|M|retired|98257 481 | 481|73|M|retired|37771 482 | 482|18|F|student|40256 483 | 483|29|M|scientist|43212 484 | 484|27|M|student|21208 485 | 485|44|F|educator|95821 486 | 486|39|M|educator|93101 487 | 487|22|M|engineer|92121 488 | 488|48|M|technician|21012 489 | 489|55|M|other|45218 490 | 490|29|F|artist|V5A2B 491 | 491|43|F|writer|53711 492 | 492|57|M|educator|94618 493 | 493|22|M|engineer|60090 494 | 494|38|F|administrator|49428 495 | 495|29|M|engineer|03052 496 | 496|21|F|student|55414 497 | 497|20|M|student|50112 498 | 498|26|M|writer|55408 499 | 499|42|M|programmer|75006 500 | 500|28|M|administrator|94305 501 | 501|22|M|student|10025 502 | 502|22|M|student|23092 503 | 503|50|F|writer|27514 504 | 504|40|F|writer|92115 505 | 505|27|F|other|20657 506 | 506|46|M|programmer|03869 507 | 507|18|F|writer|28450 508 | 508|27|M|marketing|19382 509 | 509|23|M|administrator|10011 510 | 510|34|M|other|98038 511 | 511|22|M|student|21250 512 | 512|29|M|other|20090 513 | 513|43|M|administrator|26241 514 | 514|27|M|programmer|20707 515 | 515|53|M|marketing|49508 516 | 516|53|F|librarian|10021 517 | 517|24|M|student|55454 518 | 518|49|F|writer|99709 519 | 519|22|M|other|55320 520 | 520|62|M|healthcare|12603 521 | 521|19|M|student|02146 522 | 522|36|M|engineer|55443 523 | 523|50|F|administrator|04102 524 | 524|56|M|educator|02159 525 | 525|27|F|administrator|19711 526 | 526|30|M|marketing|97124 527 | 527|33|M|librarian|12180 528 | 528|18|M|student|55104 529 | 529|47|F|administrator|44224 530 | 530|29|M|engineer|94040 531 | 531|30|F|salesman|97408 532 | 532|20|M|student|92705 533 | 533|43|M|librarian|02324 534 | 534|20|M|student|05464 535 | 535|45|F|educator|80302 536 | 536|38|M|engineer|30078 537 | 537|36|M|engineer|22902 538 | 538|31|M|scientist|21010 539 | 539|53|F|administrator|80303 540 | 540|28|M|engineer|91201 541 | 541|19|F|student|84302 542 | 542|21|M|student|60515 543 | 543|33|M|scientist|95123 544 | 544|44|F|other|29464 545 | 545|27|M|technician|08052 546 | 546|36|M|executive|22911 547 | 547|50|M|educator|14534 548 | 548|51|M|writer|95468 549 | 549|42|M|scientist|45680 550 | 550|16|F|student|95453 551 | 551|25|M|programmer|55414 552 | 552|45|M|other|68147 553 | 553|58|M|educator|62901 554 | 554|32|M|scientist|62901 555 | 555|29|F|educator|23227 556 | 556|35|F|educator|30606 557 | 557|30|F|writer|11217 558 | 558|56|F|writer|63132 559 | 559|69|M|executive|10022 560 | 560|32|M|student|10003 561 | 561|23|M|engineer|60005 562 | 562|54|F|administrator|20879 563 | 563|39|F|librarian|32707 564 | 564|65|M|retired|94591 565 | 565|40|M|student|55422 566 | 566|20|M|student|14627 567 | 567|24|M|entertainment|10003 568 | 568|39|M|educator|01915 569 | 569|34|M|educator|91903 570 | 570|26|M|educator|14627 571 | 571|34|M|artist|01945 572 | 572|51|M|educator|20003 573 | 573|68|M|retired|48911 574 | 574|56|M|educator|53188 575 | 575|33|M|marketing|46032 576 | 576|48|M|executive|98281 577 | 577|36|F|student|77845 578 | 578|31|M|administrator|M7A1A 579 | 579|32|M|educator|48103 580 | 580|16|M|student|17961 581 | 581|37|M|other|94131 582 | 582|17|M|student|93003 583 | 583|44|M|engineer|29631 584 | 584|25|M|student|27511 585 | 585|69|M|librarian|98501 586 | 586|20|M|student|79508 587 | 587|26|M|other|14216 588 | 588|18|F|student|93063 589 | 589|21|M|lawyer|90034 590 | 590|50|M|educator|82435 591 | 591|57|F|librarian|92093 592 | 592|18|M|student|97520 593 | 593|31|F|educator|68767 594 | 594|46|M|educator|M4J2K 595 | 595|25|M|programmer|31909 596 | 596|20|M|artist|77073 597 | 597|23|M|other|84116 598 | 598|40|F|marketing|43085 599 | 599|22|F|student|R3T5K 600 | 600|34|M|programmer|02320 601 | 601|19|F|artist|99687 602 | 602|47|F|other|34656 603 | 603|21|M|programmer|47905 604 | 604|39|M|educator|11787 605 | 605|33|M|engineer|33716 606 | 606|28|M|programmer|63044 607 | 607|49|F|healthcare|02154 608 | 608|22|M|other|10003 609 | 609|13|F|student|55106 610 | 610|22|M|student|21227 611 | 611|46|M|librarian|77008 612 | 612|36|M|educator|79070 613 | 613|37|F|marketing|29678 614 | 614|54|M|educator|80227 615 | 615|38|M|educator|27705 616 | 616|55|M|scientist|50613 617 | 617|27|F|writer|11201 618 | 618|15|F|student|44212 619 | 619|17|M|student|44134 620 | 620|18|F|writer|81648 621 | 621|17|M|student|60402 622 | 622|25|M|programmer|14850 623 | 623|50|F|educator|60187 624 | 624|19|M|student|30067 625 | 625|27|M|programmer|20723 626 | 626|23|M|scientist|19807 627 | 627|24|M|engineer|08034 628 | 628|13|M|none|94306 629 | 629|46|F|other|44224 630 | 630|26|F|healthcare|55408 631 | 631|18|F|student|38866 632 | 632|18|M|student|55454 633 | 633|35|M|programmer|55414 634 | 634|39|M|engineer|T8H1N 635 | 635|22|M|other|23237 636 | 636|47|M|educator|48043 637 | 637|30|M|other|74101 638 | 638|45|M|engineer|01940 639 | 639|42|F|librarian|12065 640 | 640|20|M|student|61801 641 | 641|24|M|student|60626 642 | 642|18|F|student|95521 643 | 643|39|M|scientist|55122 644 | 644|51|M|retired|63645 645 | 645|27|M|programmer|53211 646 | 646|17|F|student|51250 647 | 647|40|M|educator|45810 648 | 648|43|M|engineer|91351 649 | 649|20|M|student|39762 650 | 650|42|M|engineer|83814 651 | 651|65|M|retired|02903 652 | 652|35|M|other|22911 653 | 653|31|M|executive|55105 654 | 654|27|F|student|78739 655 | 655|50|F|healthcare|60657 656 | 656|48|M|educator|10314 657 | 657|26|F|none|78704 658 | 658|33|M|programmer|92626 659 | 659|31|M|educator|54248 660 | 660|26|M|student|77380 661 | 661|28|M|programmer|98121 662 | 662|55|M|librarian|19102 663 | 663|26|M|other|19341 664 | 664|30|M|engineer|94115 665 | 665|25|M|administrator|55412 666 | 666|44|M|administrator|61820 667 | 667|35|M|librarian|01970 668 | 668|29|F|writer|10016 669 | 669|37|M|other|20009 670 | 670|30|M|technician|21114 671 | 671|21|M|programmer|91919 672 | 672|54|F|administrator|90095 673 | 673|51|M|educator|22906 674 | 674|13|F|student|55337 675 | 675|34|M|other|28814 676 | 676|30|M|programmer|32712 677 | 677|20|M|other|99835 678 | 678|50|M|educator|61462 679 | 679|20|F|student|54302 680 | 680|33|M|lawyer|90405 681 | 681|44|F|marketing|97208 682 | 682|23|M|programmer|55128 683 | 683|42|M|librarian|23509 684 | 684|28|M|student|55414 685 | 685|32|F|librarian|55409 686 | 686|32|M|educator|26506 687 | 687|31|F|healthcare|27713 688 | 688|37|F|administrator|60476 689 | 689|25|M|other|45439 690 | 690|35|M|salesman|63304 691 | 691|34|M|educator|60089 692 | 692|34|M|engineer|18053 693 | 693|43|F|healthcare|85210 694 | 694|60|M|programmer|06365 695 | 695|26|M|writer|38115 696 | 696|55|M|other|94920 697 | 697|25|M|other|77042 698 | 698|28|F|programmer|06906 699 | 699|44|M|other|96754 700 | 700|17|M|student|76309 701 | 701|51|F|librarian|56321 702 | 702|37|M|other|89104 703 | 703|26|M|educator|49512 704 | 704|51|F|librarian|91105 705 | 705|21|F|student|54494 706 | 706|23|M|student|55454 707 | 707|56|F|librarian|19146 708 | 708|26|F|homemaker|96349 709 | 709|21|M|other|N4T1A 710 | 710|19|M|student|92020 711 | 711|22|F|student|15203 712 | 712|22|F|student|54901 713 | 713|42|F|other|07204 714 | 714|26|M|engineer|55343 715 | 715|21|M|technician|91206 716 | 716|36|F|administrator|44265 717 | 717|24|M|technician|84105 718 | 718|42|M|technician|64118 719 | 719|37|F|other|V0R2H 720 | 720|49|F|administrator|16506 721 | 721|24|F|entertainment|11238 722 | 722|50|F|homemaker|17331 723 | 723|26|M|executive|94403 724 | 724|31|M|executive|40243 725 | 725|21|M|student|91711 726 | 726|25|F|administrator|80538 727 | 727|25|M|student|78741 728 | 728|58|M|executive|94306 729 | 729|19|M|student|56567 730 | 730|31|F|scientist|32114 731 | 731|41|F|educator|70403 732 | 732|28|F|other|98405 733 | 733|44|F|other|60630 734 | 734|25|F|other|63108 735 | 735|29|F|healthcare|85719 736 | 736|48|F|writer|94618 737 | 737|30|M|programmer|98072 738 | 738|35|M|technician|95403 739 | 739|35|M|technician|73162 740 | 740|25|F|educator|22206 741 | 741|25|M|writer|63108 742 | 742|35|M|student|29210 743 | 743|31|M|programmer|92660 744 | 744|35|M|marketing|47024 745 | 745|42|M|writer|55113 746 | 746|25|M|engineer|19047 747 | 747|19|M|other|93612 748 | 748|28|M|administrator|94720 749 | 749|33|M|other|80919 750 | 750|28|M|administrator|32303 751 | 751|24|F|other|90034 752 | 752|60|M|retired|21201 753 | 753|56|M|salesman|91206 754 | 754|59|F|librarian|62901 755 | 755|44|F|educator|97007 756 | 756|30|F|none|90247 757 | 757|26|M|student|55104 758 | 758|27|M|student|53706 759 | 759|20|F|student|68503 760 | 760|35|F|other|14211 761 | 761|17|M|student|97302 762 | 762|32|M|administrator|95050 763 | 763|27|M|scientist|02113 764 | 764|27|F|educator|62903 765 | 765|31|M|student|33066 766 | 766|42|M|other|10960 767 | 767|70|M|engineer|00000 768 | 768|29|M|administrator|12866 769 | 769|39|M|executive|06927 770 | 770|28|M|student|14216 771 | 771|26|M|student|15232 772 | 772|50|M|writer|27105 773 | 773|20|M|student|55414 774 | 774|30|M|student|80027 775 | 775|46|M|executive|90036 776 | 776|30|M|librarian|51157 777 | 777|63|M|programmer|01810 778 | 778|34|M|student|01960 779 | 779|31|M|student|K7L5J 780 | 780|49|M|programmer|94560 781 | 781|20|M|student|48825 782 | 782|21|F|artist|33205 783 | 783|30|M|marketing|77081 784 | 784|47|M|administrator|91040 785 | 785|32|M|engineer|23322 786 | 786|36|F|engineer|01754 787 | 787|18|F|student|98620 788 | 788|51|M|administrator|05779 789 | 789|29|M|other|55420 790 | 790|27|M|technician|80913 791 | 791|31|M|educator|20064 792 | 792|40|M|programmer|12205 793 | 793|22|M|student|85281 794 | 794|32|M|educator|57197 795 | 795|30|M|programmer|08610 796 | 796|32|F|writer|33755 797 | 797|44|F|other|62522 798 | 798|40|F|writer|64131 799 | 799|49|F|administrator|19716 800 | 800|25|M|programmer|55337 801 | 801|22|M|writer|92154 802 | 802|35|M|administrator|34105 803 | 803|70|M|administrator|78212 804 | 804|39|M|educator|61820 805 | 805|27|F|other|20009 806 | 806|27|M|marketing|11217 807 | 807|41|F|healthcare|93555 808 | 808|45|M|salesman|90016 809 | 809|50|F|marketing|30803 810 | 810|55|F|other|80526 811 | 811|40|F|educator|73013 812 | 812|22|M|technician|76234 813 | 813|14|F|student|02136 814 | 814|30|M|other|12345 815 | 815|32|M|other|28806 816 | 816|34|M|other|20755 817 | 817|19|M|student|60152 818 | 818|28|M|librarian|27514 819 | 819|59|M|administrator|40205 820 | 820|22|M|student|37725 821 | 821|37|M|engineer|77845 822 | 822|29|F|librarian|53144 823 | 823|27|M|artist|50322 824 | 824|31|M|other|15017 825 | 825|44|M|engineer|05452 826 | 826|28|M|artist|77048 827 | 827|23|F|engineer|80228 828 | 828|28|M|librarian|85282 829 | 829|48|M|writer|80209 830 | 830|46|M|programmer|53066 831 | 831|21|M|other|33765 832 | 832|24|M|technician|77042 833 | 833|34|M|writer|90019 834 | 834|26|M|other|64153 835 | 835|44|F|executive|11577 836 | 836|44|M|artist|10018 837 | 837|36|F|artist|55409 838 | 838|23|M|student|01375 839 | 839|38|F|entertainment|90814 840 | 840|39|M|artist|55406 841 | 841|45|M|doctor|47401 842 | 842|40|M|writer|93055 843 | 843|35|M|librarian|44212 844 | 844|22|M|engineer|95662 845 | 845|64|M|doctor|97405 846 | 846|27|M|lawyer|47130 847 | 847|29|M|student|55417 848 | 848|46|M|engineer|02146 849 | 849|15|F|student|25652 850 | 850|34|M|technician|78390 851 | 851|18|M|other|29646 852 | 852|46|M|administrator|94086 853 | 853|49|M|writer|40515 854 | 854|29|F|student|55408 855 | 855|53|M|librarian|04988 856 | 856|43|F|marketing|97215 857 | 857|35|F|administrator|V1G4L 858 | 858|63|M|educator|09645 859 | 859|18|F|other|06492 860 | 860|70|F|retired|48322 861 | 861|38|F|student|14085 862 | 862|25|M|executive|13820 863 | 863|17|M|student|60089 864 | 864|27|M|programmer|63021 865 | 865|25|M|artist|11231 866 | 866|45|M|other|60302 867 | 867|24|M|scientist|92507 868 | 868|21|M|programmer|55303 869 | 869|30|M|student|10025 870 | 870|22|M|student|65203 871 | 871|31|M|executive|44648 872 | 872|19|F|student|74078 873 | 873|48|F|administrator|33763 874 | 874|36|M|scientist|37076 875 | 875|24|F|student|35802 876 | 876|41|M|other|20902 877 | 877|30|M|other|77504 878 | 878|50|F|educator|98027 879 | 879|33|F|administrator|55337 880 | 880|13|M|student|83702 881 | 881|39|M|marketing|43017 882 | 882|35|M|engineer|40503 883 | 883|49|M|librarian|50266 884 | 884|44|M|engineer|55337 885 | 885|30|F|other|95316 886 | 886|20|M|student|61820 887 | 887|14|F|student|27249 888 | 888|41|M|scientist|17036 889 | 889|24|M|technician|78704 890 | 890|32|M|student|97301 891 | 891|51|F|administrator|03062 892 | 892|36|M|other|45243 893 | 893|25|M|student|95823 894 | 894|47|M|educator|74075 895 | 895|31|F|librarian|32301 896 | 896|28|M|writer|91505 897 | 897|30|M|other|33484 898 | 898|23|M|homemaker|61755 899 | 899|32|M|other|55116 900 | 900|60|M|retired|18505 901 | 901|38|M|executive|L1V3W 902 | 902|45|F|artist|97203 903 | 903|28|M|educator|20850 904 | 904|17|F|student|61073 905 | 905|27|M|other|30350 906 | 906|45|M|librarian|70124 907 | 907|25|F|other|80526 908 | 908|44|F|librarian|68504 909 | 909|50|F|educator|53171 910 | 910|28|M|healthcare|29301 911 | 911|37|F|writer|53210 912 | 912|51|M|other|06512 913 | 913|27|M|student|76201 914 | 914|44|F|other|08105 915 | 915|50|M|entertainment|60614 916 | 916|27|M|engineer|N2L5N 917 | 917|22|F|student|20006 918 | 918|40|M|scientist|70116 919 | 919|25|M|other|14216 920 | 920|30|F|artist|90008 921 | 921|20|F|student|98801 922 | 922|29|F|administrator|21114 923 | 923|21|M|student|E2E3R 924 | 924|29|M|other|11753 925 | 925|18|F|salesman|49036 926 | 926|49|M|entertainment|01701 927 | 927|23|M|programmer|55428 928 | 928|21|M|student|55408 929 | 929|44|M|scientist|53711 930 | 930|28|F|scientist|07310 931 | 931|60|M|educator|33556 932 | 932|58|M|educator|06437 933 | 933|28|M|student|48105 934 | 934|61|M|engineer|22902 935 | 935|42|M|doctor|66221 936 | 936|24|M|other|32789 937 | 937|48|M|educator|98072 938 | 938|38|F|technician|55038 939 | 939|26|F|student|33319 940 | 940|32|M|administrator|02215 941 | 941|20|M|student|97229 942 | 942|48|F|librarian|78209 943 | 943|22|M|student|77841 944 | -------------------------------------------------------------------------------- /data/ml-1m/movies.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/data/ml-1m/movies.dat -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/data_loader/__init__.py -------------------------------------------------------------------------------- /data_loader/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/data_loader/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/avazu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/data_loader/__pycache__/avazu.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/criteo.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/data_loader/__pycache__/criteo.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/data_loader/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/movielens.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/data_loader/__pycache__/movielens.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/avazu.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import struct 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import lmdb 7 | import numpy as np 8 | import torch.utils.data 9 | from tqdm import tqdm 10 | import pandas as pd 11 | 12 | 13 | class AvazuDataset(torch.utils.data.Dataset): 14 | """ 15 | Avazu Click-Through Rate Prediction Dataset 16 | 17 | Dataset preparation 18 | Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature 19 | 20 | :param dataset_path: avazu train path 21 | :param cache_path: lmdb cache path 22 | :param rebuild_cache: If True, lmdb cache is refreshed 23 | :param min_threshold: infrequent feature threshold 24 | 25 | Reference 26 | https://www.kaggle.com/c/avazu-ctr-prediction 27 | """ 28 | 29 | def __init__(self, dataset_path=None, cache_path='.avazu', rebuild_cache=False, min_threshold=10, num_blocks=8): 30 | self.NUM_FEATS = 22 31 | self.min_threshold = min_threshold 32 | self.num_blocks = num_blocks 33 | if rebuild_cache or not Path(cache_path).exists(): 34 | shutil.rmtree(cache_path, ignore_errors=True) 35 | if dataset_path is None: 36 | raise ValueError('create cache: failed: dataset_path is None') 37 | self.__build_cache(dataset_path, cache_path) 38 | self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True) 39 | with self.env.begin(write=False) as txn: 40 | self.length = txn.stat()['entries'] - 1 41 | self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32) 42 | 43 | def __getitem__(self, index): 44 | with self.env.begin(write=False) as txn: 45 | np_array = np.frombuffer( 46 | txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.long) 47 | return np_array[1:], np_array[0] 48 | 49 | def __len__(self): 50 | return self.length 51 | 52 | def __build_cache(self, path, cache_path): 53 | feat_mapper, defaults, field_dims = self.__get_feat_mapper(path) 54 | with lmdb.open(cache_path, map_size=int(1e11)) as env: 55 | with env.begin(write=True) as txn: 56 | txn.put(b'field_dims', field_dims.tobytes()) 57 | for buffer in self.__yield_buffer(path, feat_mapper, defaults): 58 | with env.begin(write=True) as txn: 59 | for key, value in buffer: 60 | txn.put(key, value) 61 | 62 | def __get_feat_mapper(self, path): 63 | feat_cnts = defaultdict(lambda: defaultdict(int)) 64 | new_feat_cnts = defaultdict(lambda: defaultdict(int)) 65 | with open(path) as f: 66 | f.readline() 67 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 68 | pbar.set_description('Create avazu dataset cache: counting features') 69 | for line in pbar: 70 | values = line.rstrip('\n').split(',') 71 | if len(values) != self.NUM_FEATS + 2: 72 | continue 73 | for i in range(1, self.NUM_FEATS + 1): 74 | feat_cnts[i][values[i + 1]] += 1 75 | feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()} 76 | feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()} 77 | defaults = {i: len(cnt) for i, cnt in feat_mapper.items()} 78 | 79 | for field, sub_dict in feat_cnts.items(): 80 | for key in list(sub_dict.keys()): 81 | if sub_dict[key] < self.min_threshold: 82 | sub_dict['default'] += 1 83 | else: 84 | new_feat_cnts[field][feat_mapper[field][key]] = sub_dict[key] 85 | 86 | if sub_dict['default'] != 0: 87 | new_feat_cnts[field][len(feat_mapper[field])] = sub_dict['default'] 88 | field_dims = self.__get_field_dims(new_feat_cnts) 89 | return feat_mapper, defaults, field_dims 90 | 91 | def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)): 92 | item_idx = 0 93 | buffer = list() 94 | with open(path) as f: 95 | f.readline() 96 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 97 | pbar.set_description('Create avazu dataset cache: setup lmdb') 98 | for line in pbar: 99 | values = line.rstrip('\n').split(',') 100 | if len(values) != self.NUM_FEATS + 2: 101 | continue 102 | np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32) 103 | np_array[0] = int(values[1]) 104 | for i in range(1, self.NUM_FEATS + 1): 105 | np_array[i] = feat_mapper[i].get(values[i+1], defaults[i]) 106 | buffer.append((struct.pack('>I', item_idx), np_array.tobytes())) 107 | item_idx += 1 108 | if item_idx % buffer_size == 0: 109 | yield buffer 110 | buffer.clear() 111 | yield buffer 112 | 113 | def __get_field_dims(self, data): 114 | all_freq = None 115 | index_offset = 0 116 | field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32) 117 | for i, col in enumerate(data.keys()): 118 | freq = pd.Series(data[col]).sort_values(ascending=False) 119 | freq.index = freq.index + index_offset 120 | if all_freq is None: 121 | all_freq = freq 122 | else: 123 | all_freq = pd.concat([all_freq, freq], axis=0) 124 | index_offset += len(freq) 125 | field_dims[i] = len(freq) 126 | 127 | return field_dims 128 | -------------------------------------------------------------------------------- /data_loader/criteo.py: -------------------------------------------------------------------------------- 1 | import math 2 | import shutil 3 | import struct 4 | from collections import defaultdict 5 | from functools import lru_cache 6 | from pathlib import Path 7 | 8 | import lmdb 9 | import numpy as np 10 | import torch.utils.data 11 | from tqdm import tqdm 12 | import pandas as pd 13 | 14 | 15 | class CriteoDataset(torch.utils.data.Dataset): 16 | """ 17 | Criteo Display Advertising Challenge Dataset 18 | 19 | Data prepration: 20 | * Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature 21 | * Discretize numerical values by log2 transformation which is proposed by the winner of Criteo Competition 22 | 23 | :param dataset_path: criteo train.txt path. 24 | :param cache_path: lmdb cache path. 25 | :param rebuild_cache: If True, lmdb cache is refreshed. 26 | :param min_threshold: infrequent feature threshold. 27 | 28 | Reference: 29 | https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset 30 | https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf 31 | """ 32 | 33 | def __init__(self, dataset_path=None, cache_path='.criteo', rebuild_cache=False, min_threshold=10, 34 | category_only=False): # category 35 | self.NUM_FEATS = 39 36 | self.NUM_INT_FEATS = 13 37 | self.min_threshold = min_threshold 38 | self.category_only = category_only 39 | self.item_idx = 0 40 | if rebuild_cache or not Path(cache_path).exists(): 41 | shutil.rmtree(cache_path, ignore_errors=True) 42 | if dataset_path is None: 43 | raise ValueError('create cache: failed: dataset_path is None') 44 | self.__build_cache(dataset_path, cache_path) 45 | self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True) 46 | print(self.item_idx) 47 | with self.env.begin(write=False) as txn: 48 | stat = txn.stat() 49 | self.length = stat['entries'] - 1 50 | self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32) 51 | 52 | def __getitem__(self, index): 53 | with self.env.begin(write=False) as txn: 54 | name = struct.pack('>I', index) 55 | stream = txn.get(name) 56 | if stream is None: 57 | print("None") 58 | print(index) 59 | np_array = np.frombuffer(stream, dtype=np.uint32).astype(dtype=np.long) 60 | if self.category_only: 61 | return np_array[1 + self.NUM_INT_FEATS:], np_array[0] 62 | else: 63 | return np_array[1:], np_array[0] 64 | 65 | def __len__(self): 66 | return self.length 67 | 68 | def __build_cache(self, path, cache_path): 69 | feat_mapper, defaults, field_dims = self.__get_feat_mapper(path) 70 | with lmdb.open(cache_path, map_size=int(1e11)) as env: 71 | with env.begin(write=True) as txn: 72 | txn.put(b'field_dims', field_dims.tobytes()) 73 | for buffer in self.__yield_buffer(path, feat_mapper, defaults): 74 | with env.begin(write=True) as txn: 75 | for key, value in buffer: 76 | txn.put(key, value) 77 | 78 | def __get_feat_mapper(self, path): 79 | feat_cnts = defaultdict(lambda: defaultdict(int)) 80 | new_feat_cnts = defaultdict(lambda: defaultdict(int)) 81 | with open(path) as f: 82 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 83 | pbar.set_description('Create criteo dataset cache: counting features') 84 | for line in pbar: 85 | values = line.rstrip('\n').split('\t') 86 | if len(values) != self.NUM_FEATS + 1: 87 | continue 88 | for i in range(1, self.NUM_INT_FEATS + 1): 89 | feat_cnts[i][convert_numeric_feature(values[i])] += 1 90 | for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1): 91 | feat_cnts[i][values[i]] += 1 92 | 93 | feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()} 94 | feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()} 95 | defaults = {i: len(cnt) for i, cnt in feat_mapper.items()} 96 | 97 | for field, sub_dict in feat_cnts.items(): 98 | for key in list(sub_dict.keys()): 99 | if sub_dict[key] < self.min_threshold: 100 | sub_dict['default'] += 1 101 | else: 102 | new_feat_cnts[field][feat_mapper[field][key]] = sub_dict[key] 103 | if sub_dict['default'] != 0: 104 | new_feat_cnts[field][len(feat_mapper[field])] = sub_dict['default'] 105 | field_dims = self.__get_field_dims(new_feat_cnts) 106 | return feat_mapper, defaults, field_dims 107 | 108 | def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)): 109 | item_idx = 0 110 | buffer = list() 111 | with open(path) as f: 112 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 113 | pbar.set_description('Create criteo dataset cache: setup lmdb') 114 | for line in pbar: 115 | values = line.rstrip('\n').split('\t') 116 | if len(values) != self.NUM_FEATS + 1: 117 | continue 118 | np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32) 119 | np_array[0] = int(values[0]) 120 | for i in range(1, self.NUM_INT_FEATS + 1): 121 | np_array[i] = feat_mapper[i].get(convert_numeric_feature(values[i]), defaults[i]) 122 | for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1): 123 | np_array[i] = feat_mapper[i].get(values[i], defaults[i]) 124 | name = struct.pack('>I', item_idx) 125 | if name is None: 126 | print("None") 127 | buffer.append((name, np_array.tobytes())) 128 | item_idx += 1 129 | if item_idx % buffer_size == 0: 130 | yield buffer 131 | buffer.clear() 132 | 133 | self.item_idx = item_idx 134 | yield buffer 135 | 136 | def __get_field_dims(self, data): 137 | all_freq = None 138 | index_offset = 0 139 | field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32) 140 | for i, col in enumerate(data.keys()): 141 | freq = pd.Series(data[col]).sort_values(ascending=False) 142 | freq.index = freq.index + index_offset 143 | if all_freq is None: 144 | all_freq = freq 145 | else: 146 | all_freq = pd.concat([all_freq, freq], axis=0) 147 | index_offset += len(freq) 148 | field_dims[i] = len(freq) 149 | 150 | return field_dims 151 | 152 | 153 | @lru_cache(maxsize=None) 154 | def convert_numeric_feature(val: str): 155 | if val == '': 156 | return 'NULL' 157 | v = int(val) 158 | if v > 2: 159 | return str(int(math.log(v) ** 2)) 160 | else: 161 | return str(v - 2) 162 | -------------------------------------------------------------------------------- /data_loader/data_loader.py: -------------------------------------------------------------------------------- 1 | import math 2 | from datetime import datetime 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | from data_loader.movielens import MovieLensDataset 9 | from data_loader.avazu import AvazuDataset 10 | from data_loader.criteo import CriteoDataset 11 | 12 | 13 | def setup_generator(opt): 14 | """Choose different type of sampler for MF & FM""" 15 | if opt['factorizer'] == 'fm': 16 | return FMGenerator(opt) 17 | else: 18 | raise NotImplementedError 19 | 20 | 21 | class FMGenerator(object): 22 | def __init__(self, opt): 23 | data_path = opt['data_path'] 24 | data_type = opt['data_type'] 25 | category_only = opt['category_only'] 26 | rebuild_cache = opt['rebuild_cache'] 27 | self.batch_size_train = opt.get('batch_size_train') 28 | self.batch_size_valid = opt.get('batch_size_valid') 29 | self.batch_size_test = opt.get('batch_size_test') 30 | 31 | self.opt = opt 32 | 33 | if data_type == 'criteo': 34 | dataset = CriteoDataset(data_path+'train.txt', data_path+'cache', rebuild_cache=rebuild_cache) 35 | elif data_type == 'avazu': 36 | dataset = AvazuDataset(data_path+'train', data_path+'cache', rebuild_cache=rebuild_cache) 37 | elif data_type == 'ml-1m': 38 | dataset = MovieLensDataset(data_path, data_type) 39 | else: 40 | raise RuntimeError("Invalid data type: {}".format(data_type)) 41 | 42 | train_length = int(len(dataset) * 0.8) 43 | valid_length = int(len(dataset) * 0.1) 44 | test_length = len(dataset) - train_length - valid_length 45 | self.train_data, self.valid_data, self.test_data = torch.utils.data.random_split( 46 | dataset, (train_length, valid_length, test_length)) 47 | 48 | self._train_epoch = iter([]) 49 | self._valid_epoch = iter([]) 50 | self._test_epoch = iter([]) 51 | 52 | self.num_batches_train = math.ceil(len(self.train_data) / self.batch_size_train) 53 | self.num_batches_valid = math.ceil(len(self.valid_data) / self.batch_size_valid) 54 | self.num_batches_test = math.ceil(len(self.test_data) / self.batch_size_test) 55 | if data_type == 'criteo' and category_only: 56 | self.field_dims = dataset.field_dims[13:] 57 | else: 58 | self.field_dims = dataset.field_dims 59 | 60 | print('\tNum of train records: {}'.format(len(self.train_data))) 61 | print('\tNum of valid records: {}'.format(len(self.valid_data))) 62 | print('\tNum of test records: {}'.format(len(self.test_data))) 63 | print('\tNum of fields: {}'.format(len(self.field_dims))) 64 | print('\tNum of features: {}'.format(sum(self.field_dims))) 65 | 66 | @property 67 | def train_epoch(self): 68 | """list of training batches""" 69 | return self._train_epoch 70 | 71 | @train_epoch.setter 72 | def train_epoch(self, new_epoch): 73 | self._train_epoch = new_epoch 74 | 75 | @property 76 | def valid_epoch(self): 77 | """list of validation batches""" 78 | return self._valid_epoch 79 | 80 | @valid_epoch.setter 81 | def valid_epoch(self, new_epoch): 82 | self._valid_epoch = new_epoch 83 | 84 | @property 85 | def test_epoch(self): 86 | """list of test batches""" 87 | return self._test_epoch 88 | 89 | @test_epoch.setter 90 | def test_epoch(self, new_epoch): 91 | self._test_epoch = new_epoch 92 | 93 | def get_epoch(self, type): 94 | """ 95 | return: 96 | list, an epoch of batchified samples of type=['train', 'valid', 'test'] 97 | """ 98 | if type == 'train': 99 | return self.train_epoch 100 | 101 | if type == 'valid': 102 | return self.valid_epoch 103 | 104 | if type == 'test': 105 | return self.test_epoch 106 | 107 | def get_sample(self, type): 108 | """get training sample or validation sample""" 109 | epoch = self.get_epoch(type) 110 | 111 | try: 112 | sample = next(epoch) 113 | except StopIteration: 114 | self.set_epoch(type) 115 | epoch = self.get_epoch(type) 116 | sample = next(epoch) 117 | if self.opt['load_in_queue']: 118 | # continue to queue 119 | self.cont_queue(type) 120 | 121 | return sample 122 | 123 | def set_epoch(self, type): 124 | """setup batches of type = [training, validation, testing]""" 125 | # print('\tSetting epoch {}'.format(type)) 126 | start = datetime.now() 127 | if type == 'train': 128 | loader = DataLoader(self.train_data, 129 | batch_size=self.batch_size_train, 130 | shuffle=True, pin_memory=False) 131 | self.train_epoch = iter(loader) 132 | num_batches = len(self.train_epoch) 133 | elif type == 'valid': 134 | 135 | loader = DataLoader(self.valid_data, 136 | batch_size=self.batch_size_valid, 137 | shuffle=True, pin_memory=False) 138 | self.valid_epoch = iter(loader) 139 | num_batches = len(self.valid_epoch) 140 | elif type == 'test': 141 | 142 | loader = DataLoader(self.test_data, 143 | batch_size=self.batch_size_test, 144 | shuffle=False, pin_memory=False) 145 | self.test_epoch = iter(loader) 146 | num_batches = len(self.test_epoch) 147 | end = datetime.now() 148 | # print('\tFinish setting epoch {}, num_batches {}, time {} mins'.format(type, 149 | # num_batches, 150 | # (end - start).total_seconds() / 60)) 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /data_loader/movielens.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from sklearn.preprocessing import LabelEncoder 5 | 6 | import math 7 | 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class MovieLensDataset(Dataset): 12 | def __init__(self, data_path, data_type): 13 | self.label_encoder = LabelEncoder() 14 | self.data, self.labels, self.field_dims = self.load_fm_dataset(data_path, data_type) 15 | 16 | def load_fm_dataset(self, data_path, data_type): 17 | print('Reconstructing {} data from {}'.format(data_type, data_path)) 18 | if data_type == 'ml-100k': 19 | header = ['user_id', 'age', 'gender', 'occupation', 'zip_code'] 20 | df_user = pd.read_csv(data_path + 'u.user', sep='|', names=header) 21 | df_user['age'] = pd.cut(df_user['age'], [0, 17, 24, 34, 44, 49, 55, 100], 22 | labels=['under 18', '18-24', '25-34', '35-44', '45-49', '50-55', 'Age-56+']) 23 | 24 | for col in df_user.columns[1:]: 25 | df_user[col] = self.label_encoder.fit_transform(df_user[col]) 26 | 27 | header = ['item_id', 'title', 'release_date', 'video_release_date', 'IMDb_URL', 28 | 'unknown', 'Action', 'Adventure', 'Animation', 'Children', 'Comedy', 29 | 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 30 | 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western'] 31 | df_item = pd.read_csv(data_path + 'u.item', sep='|', names=header, encoding="ISO-8859-1") 32 | df_item_genre = df_item[['unknown', 'Action', 'Adventure', 'Animation', 'Children', 'Comedy', 33 | 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 34 | 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western']] 35 | 36 | df_item.drop(['unknown', 'Action', 'Adventure', 'Animation', 'Children', 'Comedy', 37 | 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 38 | 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western'], 39 | axis=1, inplace=True) 40 | df_item['release_date'] = df_item['release_date'].str.split('-').str[-1] 41 | df_item = df_item.fillna(0) 42 | df_item['release_date'] = df_item['release_date'].apply(lambda x: int(x)) 43 | 44 | df_item = df_item.drop(columns=['title', 'video_release_date', 'IMDb_URL']) 45 | df_item = pd.concat([df_item, df_item_genre], axis=1) 46 | for col in df_item.columns[1:]: 47 | df_item[col] = self.label_encoder.fit_transform(df_item[col]) 48 | 49 | header = ['user_id', 'item_id', 'rating', 'timestamp'] 50 | df_data = pd.read_csv(data_path + 'u.data', sep='\t', names=header) 51 | else: 52 | header = ['user_id', 'gender', 'age', 'occupation', 'zip_code'] 53 | df_user = pd.read_csv(data_path + 'users.dat', sep='::', names=header, engine='python') 54 | for col in df_user.columns[1:]: 55 | df_user[col] = self.label_encoder.fit_transform(df_user[col]) 56 | 57 | header = ['item_id', 'title', 'genres'] 58 | df_item = pd.read_csv(data_path + 'movies.dat', sep='::', 59 | names=header, encoding="ISO-8859-1", engine='python') 60 | 61 | year = df_item.title.str[-5:-1].apply(lambda x: int(x)) 62 | df_item['years'] = self.label_encoder.fit_transform(year) 63 | df_genres = df_item.genres.str.split('|').str.join('|').str.get_dummies() 64 | df_genres['genre'] = '' 65 | for col in df_genres.columns: 66 | df_genres['genre'] += df_genres[col].map(str) 67 | if col != 'genre': 68 | del df_genres[col] 69 | df_genres['genre'] = self.label_encoder.fit_transform(df_genres['genre']) 70 | df_item.drop(['title', 'genres'], axis=1, inplace=True) 71 | df_item = pd.concat([df_item, df_genres], axis=1) 72 | 73 | header = ['user_id', 'item_id', 'rating', 'timestamp'] 74 | 75 | df_data = pd.read_csv(data_path + 'ratings.dat', sep='::', names=header, engine='python') 76 | df_data['timestamp'] = df_data['timestamp'] - df_data['timestamp'].min() 77 | df_data['timestamp'] = df_data['timestamp'].apply(self.convert_numeric_feature) 78 | df_data['timestamp'] = self.label_encoder.fit_transform(df_data['timestamp']) 79 | 80 | df_data = df_data[(~df_data['rating'].isin([3]))].reset_index(drop=True) 81 | df_data['rating'] = df_data.rating.apply(lambda x: 1 if int(x) > 3 else 0) 82 | 83 | df_data = df_data.merge(df_user, on='user_id', how='left') 84 | df_data = df_data.merge(df_item, on='item_id', how='left') 85 | 86 | data, labels = df_data.iloc[:, 3:], df_data['rating'] 87 | 88 | field_dims = data.nunique() 89 | 90 | return data.values, labels.values, field_dims.values 91 | 92 | def convert_numeric_feature(self, val): 93 | if val == '': 94 | return 'NULL' 95 | v = int(val) 96 | if v > 2: 97 | return str(int(math.log(v) ** 2)) 98 | else: 99 | return str(v - 2) 100 | 101 | def __getitem__(self, index): 102 | return self.data[index], self.labels[index] 103 | 104 | def __len__(self): 105 | return len(self.data) 106 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from datetime import datetime 5 | from argparse import ArgumentParser 6 | from tensorboardX import SummaryWriter 7 | 8 | from models.factorizer import setup_factorizer 9 | from data_loader.data_loader import setup_generator 10 | from utils.evaluate import evaluate_fm 11 | 12 | 13 | def setup_args(parser=None): 14 | """ Set up arguments for the Engine 15 | 16 | return: 17 | python dictionary 18 | """ 19 | if parser is None: 20 | parser = ArgumentParser() 21 | data = parser.add_argument_group('Data') 22 | engine = parser.add_argument_group('Engine Arguments') 23 | factorize = parser.add_argument_group('Factorizer Arguments') 24 | matrix_factorize = parser.add_argument_group('MF Arguments') 25 | regularize = parser.add_argument_group('Regularizer Arguments') 26 | log = parser.add_argument_group('Tensorboard Arguments') 27 | 28 | engine.add_argument('--alias', default='experiment', 29 | help='Name for the experiment') 30 | engine.add_argument('--seed', default='42') 31 | 32 | data.add_argument('--data-type', default='ml1m', help='type of the dataset') 33 | data.add_argument('--data-path', default='./data/{data_type}/') 34 | data.add_argument('--train_test-freq-bd', help='split the data freq-wise, bound of the user freq') 35 | data.add_argument('--train-valid-freq-bd', help='split the data freq-wise, bound of the user freq') 36 | data.add_argument('--batch-size-train', default=1) 37 | data.add_argument('--batch-size-valid', default=1) 38 | data.add_argument('--batch-size-test', default=1) 39 | data.add_argument('--device-ids-test', default=[0], help='devices used for multi-processing evaluate') 40 | 41 | regularize.add_argument('--max-steps', default=1e8) 42 | regularize.add_argument('--use-cuda', default=True) 43 | regularize.add_argument('--device-id', default=0, help='Training Devices') 44 | 45 | factorize.add_argument('--factorizer', default='fm', help='Type of the Factorization Model') 46 | factorize.add_argument('--latent-dim', default=8) 47 | 48 | type_opt = 'fm' 49 | matrix_factorize.add_argument('--{}-optimizer'.format(type_opt), default='sgd') 50 | matrix_factorize.add_argument('--{}-lr'.format(type_opt), default=1e-3) 51 | matrix_factorize.add_argument('--{}-grad-clip'.format(type_opt), default=1) 52 | 53 | log.add_argument('--log-interval', default=1) 54 | log.add_argument('--tensorboard', default='./tmp/runs') 55 | log.add_argument('--early_stop', default=None) 56 | log.add_argument('--display_interval', default=100) 57 | return parser 58 | 59 | 60 | class Engine(object): 61 | """Engine wrapping the training & evaluation 62 | of adpative regularized maxtirx factorization 63 | """ 64 | 65 | def __init__(self, opt): 66 | self._opt = opt 67 | self._opt['data_path'] = self._opt['data_path'].format(data_type=self._opt['data_type']) 68 | self._sampler = setup_generator(opt) 69 | 70 | self._opt['field_dims'] = self._sampler.field_dims 71 | 72 | self._opt['emb_save_path'] = self._opt['emb_save_path'].format( 73 | factorizer=self._opt['factorizer'], 74 | data_type=self._opt['data_type'], 75 | alias=self._opt['alias'], 76 | num_parameter='{num_parameter}' 77 | ) 78 | if 'retrain_emb_param' in opt: 79 | self.retrain = True 80 | if opt['re_init']: 81 | self._opt['alias'] += '_reinitTrue' 82 | else: 83 | self._opt['alias'] += '_reinitFalse' 84 | self._opt['alias'] += '_retrain_emb_param{}'.format(opt['retrain_emb_param']) 85 | else: 86 | self.retrain = False 87 | self.candidate_p = self._opt.get('candidate_p') 88 | self._opt['eval_res_path'] = self._opt['eval_res_path'].format( 89 | factorizer=self._opt['factorizer'], 90 | data_type=self._opt['data_type'], 91 | alias=self._opt['alias'], 92 | epoch_idx='{epoch_idx}' 93 | ) 94 | self._factorizer = setup_factorizer(opt) 95 | self._opt['tensorboard'] = self._opt['tensorboard'].format( 96 | factorizer=self._opt['factorizer'], 97 | data_type=self._opt['data_type'], 98 | ) 99 | self._writer = SummaryWriter(log_dir='{}/{}'.format(self._opt['tensorboard'], opt['alias'])) 100 | self._writer.add_text('option', str(opt), 0) 101 | self._mode = None 102 | self.early_stop = self._opt.get('early_stop') 103 | 104 | 105 | @property 106 | def mode(self): 107 | return self._mode 108 | 109 | @mode.setter 110 | def mode(self, new_mode): 111 | assert new_mode in ['complete', 'partial', None] # training a complete trajectory or a partial trajctory 112 | self._mode = new_mode 113 | 114 | def save_pruned_embedding(self, param, step_idx): 115 | max_candidate_p = max(self.candidate_p) 116 | if max_candidate_p == 0: 117 | print("Minimal target parameters achieved, stop pruning.") 118 | exit(0) 119 | else: 120 | if param <= max_candidate_p: 121 | embedding = self._factorizer.model.get_embedding() 122 | emb_save_path = self._opt['emb_save_path'].format(num_parameter=param) 123 | emb_save_dir, _ = os.path.split(emb_save_path) 124 | if not os.path.exists(emb_save_dir): 125 | os.makedirs(emb_save_dir) 126 | np.save(emb_save_path, embedding) 127 | max_idx = self.candidate_p.index(max(self.candidate_p)) 128 | self.candidate_p[max_idx] = 0 129 | print("*" * 80) 130 | print("Reach the target parameter: {}, save embedding with size: {}".format(max_candidate_p, param)) 131 | print("*" * 80) 132 | elif step_idx == 0: 133 | embedding = self._factorizer.model.get_embedding() 134 | emb_save_path = self._opt['emb_save_path'].format(num_parameter='initial_embedding') 135 | emb_save_dir, _ = os.path.split(emb_save_path) 136 | if not os.path.exists(emb_save_dir): 137 | os.makedirs(emb_save_dir) 138 | np.save(emb_save_path, embedding) 139 | print("*" * 80) 140 | print("Save the initial embedding table") 141 | print("*" * 80) 142 | 143 | def train_an_episode(self, max_steps, episode_idx=''): 144 | """Train a feature_based recommendation model""" 145 | assert self.mode in ['partial', 'complete'] 146 | 147 | print('-' * 80) 148 | print('[{} episode {} starts!]'.format(self.mode, episode_idx)) 149 | print('Initializing ...') 150 | self._factorizer.init_episode() 151 | 152 | log_interval = self._opt.get('log_interval') 153 | eval_interval = self._opt.get('eval_interval') 154 | display_interval = self._opt.get('display_interval') 155 | 156 | status = dict() 157 | flag, test_flag, valid_flag = 0, 0, 0 158 | valid_mf_loss, train_mf_loss = np.inf, np.inf 159 | best_valid_result = {"AUC": [0, 0], "LogLoss": [np.inf, 0]} 160 | best_test_result = {"AUC": [0, 0], "LogLoss": [np.inf, 0]} 161 | epoch_start = datetime.now() 162 | for step_idx in range(int(max_steps)): 163 | # Prepare status for current step 164 | status['done'] = False 165 | status['sampler'] = self._sampler 166 | train_mf_loss = self._factorizer.update(self._sampler) 167 | status['train_mf_loss'] = train_mf_loss 168 | 169 | # Logging & Evaluate on the Evaluate Set 170 | if self.mode == 'complete' and step_idx % log_interval == 0: 171 | epoch_idx = int(step_idx / self._sampler.num_batches_train) 172 | sparsity, params = self._factorizer.model.calc_sparsity() 173 | if not self.retrain: 174 | self.save_pruned_embedding(params, step_idx) 175 | self._writer.add_scalar('train/step_wise/mf_loss', train_mf_loss, step_idx) 176 | self._writer.add_scalar('train/step_wise/sparsity', sparsity, step_idx) 177 | 178 | if step_idx % display_interval == 0: 179 | print('[Epoch {}|Step {}|Flag {}|Sparsity {:.4f}|Params {}]'.format(epoch_idx, 180 | step_idx % self._sampler.num_batches_train, 181 | flag, sparsity, params)) 182 | 183 | if step_idx % self._sampler.num_batches_train == 0: 184 | threshold = self._factorizer.model.get_threshold() 185 | 186 | self._writer.add_histogram('threshold/epoch_wise/threshold', threshold, epoch_idx) 187 | self._writer.add_scalar('train/epoch_wise/sparsity', sparsity, epoch_idx) 188 | self._writer.add_scalar('train/epoch_wise/params', params, epoch_idx) 189 | 190 | if (step_idx % self._sampler.num_batches_train == 0) and (epoch_idx % eval_interval == 0) and self.retrain: 191 | print('Evaluate on test ...') 192 | start = datetime.now() 193 | eval_res_path = self._opt['eval_res_path'].format(epoch_idx=epoch_idx) 194 | eval_res_dir, _ = os.path.split(eval_res_path) 195 | if not os.path.exists(eval_res_dir): 196 | os.makedirs(eval_res_dir) 197 | 198 | use_cuda = self._opt['use_cuda'] 199 | logloss, auc = evaluate_fm(self._factorizer, self._sampler, use_cuda) 200 | self._writer.add_scalar('test/epoch_wise/metron_auc', auc, epoch_idx) 201 | self._writer.add_scalar('test/epoch_wise/metron_logloss', logloss, epoch_idx) 202 | if logloss < best_test_result['LogLoss'][0]: 203 | best_test_result['LogLoss'][0] = logloss 204 | best_test_result['LogLoss'][1] = epoch_idx 205 | if auc > best_test_result['AUC'][0]: 206 | best_test_result['AUC'][0] = auc 207 | best_test_result['AUC'][1] = epoch_idx 208 | test_flag = 0 209 | else: 210 | test_flag += 1 211 | pd.Series(best_test_result).to_csv(eval_res_path) 212 | print("*" * 80) 213 | print("Test AUC: {:4f} | Logloss: {:4f}".format(auc, logloss)) 214 | end = datetime.now() 215 | print('Evaluate Time {} minutes'.format((end - start).total_seconds() / 60)) 216 | epoch_end = datetime.now() 217 | dur = (epoch_end - epoch_start).total_seconds() / 60 218 | epoch_start = datetime.now() 219 | print('[Epoch {:4d}] train MF loss: {:04.8f}, ' 220 | 'valid loss: {:04.8f}, time {:04.8f} minutes'.format(epoch_idx, 221 | train_mf_loss, 222 | valid_mf_loss, 223 | dur)) 224 | print("*"*80) 225 | 226 | flag = test_flag 227 | if self.early_stop is not None and flag >= self.early_stop: 228 | print("Early stop training process") 229 | print("Best performance on test data: ", best_test_result) 230 | print("Best performance on valid data: ", best_valid_result) 231 | self._writer.add_text('best_valid_result', str(best_valid_result), 0) 232 | self._writer.add_text('best_test_result', str(best_test_result), 0) 233 | exit() 234 | 235 | def train(self): 236 | self.mode = 'complete' 237 | self.train_an_episode(self._opt['max_steps']) 238 | 239 | 240 | if __name__ == '__main__': 241 | opt = setup_args() 242 | engine = Engine(opt) 243 | engine.train() 244 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/factorizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/models/__pycache__/factorizer.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/models/__pycache__/modules.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/pep_embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/models/__pycache__/pep_embedding.cpython-36.pyc -------------------------------------------------------------------------------- /models/factorizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import BCEWithLogitsLoss 3 | from torch.optim.lr_scheduler import ExponentialLR 4 | from copy import deepcopy 5 | 6 | from models.modules import LR, FM, DeepFM, AutoInt 7 | from utils.train import use_cuda, use_optimizer, get_grad_norm 8 | 9 | 10 | def setup_factorizer(opt): 11 | new_opt = deepcopy(opt) 12 | for k, v in opt.items(): 13 | if k.startswith('fm_'): 14 | new_opt[k[3:]] = v 15 | return FMFactorizer(new_opt) 16 | 17 | 18 | class Factorizer(object): 19 | def __init__(self, opt): 20 | self.opt = opt 21 | self.clip = opt.get('grad_clip') 22 | self.use_cuda = opt.get('use_cuda') 23 | self.batch_size_test = opt.get('batch_size_test') 24 | self.l2_penalty = opt['l2_penalty'] 25 | 26 | self.criterion = BCEWithLogitsLoss(size_average=False) 27 | 28 | self.model = None 29 | self.optimizer = None 30 | self.scheduler = None 31 | 32 | self.param_grad = None 33 | self.optim_status = None 34 | 35 | self.prev_param = None 36 | self.param = None 37 | 38 | self._train_step_idx = None 39 | self._train_episode_idx = None 40 | 41 | @property 42 | def train_step_idx(self): 43 | return self._train_step_idx 44 | 45 | @train_step_idx.setter 46 | def train_step_idx(self, new_step_idx): 47 | self._train_step_idx = new_step_idx 48 | 49 | @property 50 | def train_episode_idx(self): 51 | return self._train_episode_idx 52 | 53 | @train_episode_idx.setter 54 | def train_episode_idx(self, new_episode_idx): 55 | self._train_episode_idx = new_episode_idx 56 | 57 | def get_grad_norm(self): 58 | assert hasattr(self, 'model') 59 | return get_grad_norm(self.model) 60 | 61 | def get_emb_dims(self): 62 | return self.model.get_emb_dims() 63 | 64 | def update(self, sampler): 65 | if (self.train_step_idx > 0) and (self.train_step_idx % sampler.num_batches_train == 0): 66 | self.scheduler.step() 67 | 68 | self.train_step_idx += 1 69 | 70 | self.model.train() 71 | self.optimizer.zero_grad() 72 | 73 | 74 | class FMFactorizer(Factorizer): 75 | def __init__(self, opt): 76 | super(FMFactorizer, self).__init__(opt) 77 | self.opt = opt 78 | if opt['model'] == 'linear': 79 | self.model = LR(opt) 80 | elif opt['model'] == 'fm': 81 | self.model = FM(opt) 82 | elif opt['model'] == 'deepfm': 83 | self.model = DeepFM(opt) 84 | elif opt['model'] == 'autoint': 85 | self.model = AutoInt(opt) 86 | else: 87 | raise ValueError("Invalid FM model type: {}".format(opt['model'])) 88 | 89 | if self.use_cuda: 90 | use_cuda(True, opt['device_id']) 91 | self.model.cuda() 92 | 93 | self.optimizer = use_optimizer(self.model, opt) 94 | self.scheduler = ExponentialLR(self.optimizer, gamma=opt['lr_exp_decay']) 95 | 96 | def init_episode(self): 97 | opt = self.opt 98 | if opt['model'] == 'linear': 99 | self.model = LR(opt) 100 | elif opt['model'] == 'fm': 101 | self.model = FM(opt) 102 | elif opt['model'] == 'deepfm': 103 | self.model = DeepFM(opt) 104 | elif opt['model'] == 'autoint': 105 | self.model = AutoInt(opt) 106 | else: 107 | raise ValueError("Invalid FM model type: {}".format(opt['model'])) 108 | 109 | self._train_step_idx = 0 110 | if self.use_cuda: 111 | use_cuda(True, opt['device_id']) 112 | self.model.cuda() 113 | self.optimizer = use_optimizer(self.model, opt) 114 | self.scheduler = ExponentialLR(self.optimizer, gamma=opt['lr_exp_decay']) 115 | 116 | def update(self, sampler): 117 | """ 118 | update FM model parameters 119 | """ 120 | super(FMFactorizer, self).update(sampler) 121 | data, labels = sampler.get_sample('train') 122 | if self.use_cuda: 123 | data, labels = data.cuda(), labels.cuda() 124 | prob_preference = self.model.forward(data) 125 | non_reg_loss = self.criterion(prob_preference, labels.float()) / (data.size()[0]) 126 | l2_loss = self.model.l2_penalty(data, self.l2_penalty) / (data.size()[0]) 127 | loss = non_reg_loss + l2_loss 128 | loss.backward() 129 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) 130 | self.optimizer.step() 131 | return loss.item() 132 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchfm.layer import FactorizationMachine, FeaturesLinear, MultiLayerPerceptron 5 | import numpy as np 6 | 7 | from models.pep_embedding import PEPEmbedding 8 | 9 | 10 | class LR(torch.nn.Module): 11 | def __init__(self, opt): 12 | super(LR, self).__init__() 13 | self.use_cuda = opt.get('use_cuda') 14 | self.field_dims = opt['field_dims'] 15 | self.linear = FeaturesLinear(self.field_dims) # linear part 16 | 17 | def forward(self, x): 18 | """Compute Score""" 19 | score = self.linear.forward(x) 20 | return score.squeeze(1) 21 | 22 | def l2_penalty(self, x, lamb): 23 | return 0 24 | 25 | def calc_sparsity(self): 26 | return 0, 0 27 | 28 | def get_threshold(self): 29 | return 0 30 | 31 | def get_embedding(self): 32 | return np.zeros(1) 33 | 34 | 35 | class FM(torch.nn.Module): 36 | """Factorization Machines""" 37 | 38 | def __init__(self, opt): 39 | super(FM, self).__init__() 40 | self.use_cuda = opt.get('use_cuda') 41 | self.latent_dim = opt['latent_dim'] 42 | self.field_dims = opt['field_dims'] 43 | 44 | self.feature_num = sum(self.field_dims) 45 | self.embedding = PEPEmbedding(opt) 46 | self.linear = FeaturesLinear(self.field_dims) # linear part 47 | self.fm = FactorizationMachine(reduce_sum=True) 48 | print("BackBone Embedding Parameters: ", self.feature_num * self.latent_dim) 49 | 50 | def forward(self, x): 51 | linear_score = self.linear.forward(x) 52 | xv = self.embedding(x) 53 | fm_score = self.fm.forward(xv) 54 | score = linear_score + fm_score 55 | return score.squeeze(1) 56 | 57 | def l2_penalty(self, x, lamb): 58 | xv = self.embedding(x) 59 | xv_sq = xv.pow(2) 60 | xv_penalty = xv_sq * lamb 61 | xv_penalty = xv_penalty.sum() 62 | return xv_penalty 63 | 64 | def calc_sparsity(self): 65 | base = self.feature_num * self.latent_dim 66 | non_zero_values = torch.nonzero(self.embedding.sparse_v).size(0) 67 | percentage = 1 - (non_zero_values / base) 68 | return percentage, non_zero_values 69 | 70 | def get_threshold(self): 71 | return self.embedding.g(self.embedding.s) 72 | 73 | def get_embedding(self): 74 | return self.embedding.sparse_v.detach().cpu().numpy() 75 | 76 | 77 | class DeepFM(FM): 78 | def __init__(self, opt): 79 | super(DeepFM, self).__init__(opt) 80 | self.embed_output_dim = len(self.field_dims) * self.latent_dim 81 | self.mlp_dims = opt['mlp_dims'] 82 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, self.mlp_dims, dropout=0.2) 83 | 84 | def forward(self, x): 85 | linear_score = self.linear.forward(x) 86 | xv = self.embedding(x) 87 | fm_score = self.fm.forward(xv) 88 | dnn_score = self.mlp.forward(xv.view(-1, self.embed_output_dim)) 89 | score = linear_score + fm_score + dnn_score 90 | return score.squeeze(1) 91 | 92 | 93 | class AutoInt(DeepFM): 94 | def __init__(self, opt): 95 | super(AutoInt, self).__init__(opt) 96 | self.has_residual = opt['has_residual'] 97 | self.full_part = opt['full_part'] 98 | self.atten_embed_dim = opt['atten_embed_dim'] 99 | self.num_heads = opt['num_heads'] 100 | self.num_layers = opt['num_layers'] 101 | self.att_dropout = opt['att_dropout'] 102 | 103 | self.atten_output_dim = len(self.field_dims) * self.atten_embed_dim 104 | self.dnn_input_dim = len(self.field_dims) * self.latent_dim 105 | 106 | self.atten_embedding = torch.nn.Linear(self.latent_dim, self.atten_embed_dim) 107 | self.self_attns = torch.nn.ModuleList([ 108 | torch.nn.MultiheadAttention(self.atten_embed_dim, self.num_heads, dropout=self.att_dropout) for _ in range(self.num_layers) 109 | ]) 110 | self.attn_fc = torch.nn.Linear(self.atten_output_dim, 1) 111 | if self.has_residual: 112 | self.V_res_embedding = torch.nn.Linear(self.latent_dim, self.atten_embed_dim) 113 | 114 | def forward(self, x): 115 | xv = self.embedding(x) 116 | score = self.autoint_layer(xv) 117 | if self.full_part: 118 | dnn_score = self.mlp.forward(xv.view(-1, self.embed_output_dim)) 119 | score = dnn_score + score 120 | 121 | return score.squeeze(1) 122 | 123 | def autoint_layer(self, xv): 124 | """Multi-head self-attention layer""" 125 | atten_x = self.atten_embedding(xv) # bs, field_num, atten_dim 126 | cross_term = atten_x.transpose(0, 1) # field_num, bs, atten_dim 127 | for self_attn in self.self_attns: 128 | cross_term, _ = self_attn(cross_term, cross_term, cross_term) 129 | cross_term = cross_term.transpose(0, 1) # bs, field_num, atten_dim 130 | if self.has_residual: 131 | V_res = self.V_res_embedding(xv) 132 | cross_term += V_res 133 | cross_term = F.relu(cross_term).contiguous().view(-1, self.atten_output_dim) # bs, field_num * atten_dim 134 | output = self.attn_fc(cross_term) 135 | return output 136 | 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /models/pep_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class PEPEmbedding(nn.Module): 8 | def __init__(self, opt): 9 | super(PEPEmbedding, self).__init__() 10 | self.use_cuda = opt.get('use_cuda') 11 | self.threshold_type = opt['threshold_type'] 12 | self.latent_dim = opt['latent_dim'] 13 | self.field_dims = opt['field_dims'] 14 | self.feature_num = sum(opt['field_dims']) 15 | self.field_num = len(opt['field_dims']) 16 | self.g_type = opt['g_type'] 17 | self.gk = opt['gk'] 18 | init = opt['threshold_init'] 19 | self.retrain = False 20 | self.mask = None 21 | 22 | self.g = torch.sigmoid 23 | self.s = self.init_threshold(init) 24 | self.offsets = np.array((0, *np.cumsum(self.field_dims)[:-1]), dtype=np.long) 25 | 26 | self.v = torch.nn.Parameter(torch.rand(self.feature_num, self.latent_dim)) 27 | torch.nn.init.xavier_uniform_(self.v) 28 | 29 | if 'retrain_emb_param' in opt: 30 | self.retrain = True 31 | self.init_retrain(opt) 32 | print("Retrain epoch {}".format(opt['retrain_emb_param'])) 33 | 34 | self.sparse_v = self.v.data 35 | 36 | def init_retrain(self, opt): 37 | retrain_emb_param = opt['retrain_emb_param'] 38 | sparse_emb = np.load(opt['emb_save_path'].format(num_parameter=retrain_emb_param)+'.npy') 39 | sparse_emb = torch.from_numpy(sparse_emb) 40 | mask = torch.abs(torch.sign(sparse_emb)) 41 | if opt['re_init']: 42 | init_emb = torch.nn.Parameter(torch.rand(self.feature_num, self.latent_dim)) 43 | torch.nn.init.xavier_uniform_(init_emb) 44 | else: 45 | init_emb = np.load(opt['emb_save_path'].format(num_parameter='initial_embedding') + '.npy') 46 | init_emb = torch.from_numpy(init_emb) 47 | 48 | init_emb = init_emb * mask 49 | self.v = torch.nn.Parameter(init_emb) 50 | self.mask = mask 51 | self.gk = 0 52 | if self.use_cuda: 53 | self.mask = self.mask.cuda() 54 | 55 | def init_threshold(self, init): 56 | if self.threshold_type == 'global': 57 | s = nn.Parameter(init * torch.ones(1)) 58 | elif self.threshold_type == 'dimension': 59 | s = nn.Parameter(init * torch.ones([self.latent_dim])) 60 | elif self.threshold_type == 'feature': 61 | s = nn.Parameter(init * torch.ones([self.feature_num, 1])) 62 | elif self.threshold_type == 'field': 63 | s = nn.Parameter(init * torch.ones([self.field_num, 1])) 64 | elif self.threshold_type == 'feature_dim': 65 | s = nn.Parameter(init * torch.ones([self.feature_num, self.latent_dim])) 66 | elif self.threshold_type == 'field_dim': 67 | s = nn.Parameter(init * torch.ones([self.field_num, self.latent_dim])) 68 | else: 69 | raise ValueError('Invalid threshold_type: {}'.format(self.threshold_type)) 70 | return s 71 | 72 | def soft_threshold(self, v, s): 73 | if s.size(0) == self.field_num: # field-wise lambda 74 | field_v = torch.split(v, tuple(self.field_dims)) 75 | concat_v = [] 76 | for i, v in enumerate(field_v): 77 | v = torch.sign(v) * torch.relu(torch.abs(v) - (self.g(s[i]) * self.gk)) 78 | concat_v.append(v) 79 | 80 | concat_v = torch.cat(concat_v, dim=0) 81 | return concat_v 82 | else: 83 | return torch.sign(v) * torch.relu(torch.abs(v) - (self.g(s) * self.gk)) 84 | 85 | def forward(self, x): 86 | x = x + x.new_tensor(self.offsets).unsqueeze(0) 87 | self.sparse_v = self.soft_threshold(self.v, self.s) 88 | if self.retrain: 89 | self.sparse_v = self.sparse_v * self.mask 90 | xv = F.embedding(x, self.sparse_v) 91 | 92 | return xv 93 | -------------------------------------------------------------------------------- /train_avazu.py: -------------------------------------------------------------------------------- 1 | from engine import setup_args, Engine 2 | import torch 3 | import os 4 | import random 5 | import numpy as np 6 | 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 8 | torch.backends.cudnn.enabled = True 9 | 10 | if __name__ == '__main__': 11 | parser = setup_args() 12 | parser.set_defaults( 13 | alias='test', 14 | tensorboard='./tmp/runs/{factorizer}/{data_type}', 15 | ########## 16 | ## data ## 17 | ########## 18 | data_type='avazu', 19 | data_path='./data/{data_type}/', 20 | load_in_queue=False, 21 | category_only=False, 22 | rebuild_cache=False, 23 | eval_res_path='./tmp/res/{factorizer}/{data_type}/{alias}/{epoch_idx}.csv', 24 | emb_save_path='./tmp/embedding/{factorizer}/{data_type}/{alias}/{num_parameter}', 25 | ###################### 26 | ## train/test split ## 27 | ###################### 28 | train_test_split='lro', 29 | test_ratio=0.1, 30 | valid_ratio=1/9, 31 | ########################## 32 | ## Devices & Efficiency ## 33 | ########################## 34 | use_cuda=True, 35 | early_stop=5, 36 | log_interval=1, # 816 37 | eval_interval=1, 38 | display_interval=2000, # 10 epochs between 2 evaluations 39 | device_ids_test=[0], 40 | device_id=0, 41 | batch_size_train=1024, 42 | batch_size_valid=1024, 43 | batch_size_test=1024, 44 | ########### 45 | ## Model ## 46 | ########### 47 | factorizer='fm', 48 | model='fm', 49 | fm_lr=1e-3, 50 | # Deep 51 | mlp_dims=[400, 400, 400], 52 | # AutoInt 53 | has_residual=True, 54 | full_part=True, 55 | num_heads=2, 56 | num_layers=3, 57 | att_dropout=0, 58 | atten_embed_dim=64, 59 | # optimizer setting 60 | fm_optimizer='adam', 61 | fm_amsgrad=False, 62 | fm_eps=1e-8, 63 | fm_l2_regularization=1e-5, 64 | fm_betas=(0.9, 0.999), 65 | fm_grad_clip=100, # 0.1 66 | fm_lr_exp_decay=1, 67 | l2_penalty=0, 68 | ######### 69 | ## PEP ## 70 | ######### 71 | latent_dim=24, 72 | threshold_type='feature_dim', 73 | g_type='sigmoid', 74 | gk=1, 75 | threshold_init=-150, 76 | candidate_p=[50000, 30000, 20000], 77 | ) 78 | 79 | opt = parser.parse_args(args=[]) 80 | opt = vars(opt) 81 | 82 | # rename alias 83 | # rename alias 84 | 85 | opt['alias'] = '{}_{}_BaseDim{}_bsz{}_lr_{}_optim_{}_thresholdType{}_thres_init{}_{}-{}_l2_penalty{}'.format( 86 | opt['model'].upper(), 87 | opt['alias'], 88 | opt['latent_dim'], 89 | opt['batch_size_train'], 90 | opt['fm_lr'], 91 | opt['fm_optimizer'], 92 | opt['threshold_type'].upper(), 93 | opt['threshold_init'], 94 | opt['g_type'], 95 | opt['gk'], 96 | opt['l2_penalty'] 97 | ) 98 | 99 | print(opt['alias']) 100 | random.seed(opt['seed']) 101 | # np.random.seed(opt['seed']) 102 | torch.manual_seed(opt['seed']) 103 | torch.cuda.manual_seed_all(opt['seed']) 104 | engine = Engine(opt) 105 | engine.train() 106 | -------------------------------------------------------------------------------- /train_avazu_retrain.py: -------------------------------------------------------------------------------- 1 | from engine import setup_args, Engine 2 | import torch 3 | import os 4 | import random 5 | import numpy as np 6 | 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 8 | torch.backends.cudnn.enabled = True 9 | 10 | if __name__ == '__main__': 11 | parser = setup_args() 12 | parser.set_defaults( 13 | alias='test', 14 | tensorboard='./tmp/runs/{factorizer}/{data_type}', 15 | ########## 16 | ## data ## 17 | ########## 18 | data_type='avazu', 19 | data_path='./data/{data_type}/', 20 | load_in_queue=False, 21 | category_only=False, 22 | rebuild_cache=False, 23 | eval_res_path='./tmp/res/{factorizer}/{data_type}/{alias}/{epoch_idx}.csv', 24 | emb_save_path='./tmp/embedding/{factorizer}/{data_type}/{alias}/{num_parameter}', 25 | ###################### 26 | ## train/test split ## 27 | ###################### 28 | train_test_split='lro', 29 | test_ratio=0.1, 30 | valid_ratio=1/9, 31 | ########################## 32 | ## Devices & Efficiency ## 33 | ########################## 34 | use_cuda=True, 35 | early_stop=5, 36 | log_interval=1, # 816 37 | eval_interval=1, 38 | display_interval=2000, # 10 epochs between 2 evaluations 39 | device_ids_test=[0], 40 | device_id=0, 41 | batch_size_train=1024, 42 | batch_size_valid=1024, 43 | batch_size_test=1024, 44 | ########### 45 | ## Model ## 46 | ########### 47 | factorizer='fm', 48 | model='fm', 49 | fm_lr=1e-3, 50 | # Deep 51 | mlp_dims=[400, 400, 400], 52 | # AutoInt 53 | has_residual=True, 54 | full_part=True, 55 | num_heads=2, 56 | num_layers=3, 57 | att_dropout=0, 58 | atten_embed_dim=64, 59 | # optimizer setting 60 | fm_optimizer='adam', 61 | fm_amsgrad=False, 62 | fm_eps=1e-8, 63 | fm_l2_regularization=1e-5, 64 | fm_betas=(0.9, 0.999), 65 | fm_grad_clip=100, # 0.1 66 | fm_lr_exp_decay=1, 67 | l2_penalty=0, 68 | ######### 69 | ## PEP ## 70 | ######### 71 | latent_dim=24, 72 | threshold_type='feature_dim', 73 | g_type='sigmoid', 74 | gk=1, 75 | threshold_init=-150, 76 | retrain_emb_param=29994, 77 | re_init=False, 78 | ) 79 | 80 | opt = parser.parse_args(args=[]) 81 | opt = vars(opt) 82 | 83 | # rename alias 84 | # rename alias 85 | 86 | opt['alias'] = '{}_{}_BaseDim{}_bsz{}_lr_{}_optim_{}_thresholdType{}_thres_init{}_{}-{}_l2_penalty{}'.format( 87 | opt['model'].upper(), 88 | opt['alias'], 89 | opt['latent_dim'], 90 | opt['batch_size_train'], 91 | opt['fm_lr'], 92 | opt['fm_optimizer'], 93 | opt['threshold_type'].upper(), 94 | opt['threshold_init'], 95 | opt['g_type'], 96 | opt['gk'], 97 | opt['l2_penalty'] 98 | ) 99 | 100 | print(opt['alias']) 101 | random.seed(opt['seed']) 102 | # np.random.seed(opt['seed']) 103 | torch.manual_seed(opt['seed']) 104 | torch.cuda.manual_seed_all(opt['seed']) 105 | engine = Engine(opt) 106 | engine.train() 107 | -------------------------------------------------------------------------------- /train_criteo.py: -------------------------------------------------------------------------------- 1 | from engine import setup_args, Engine 2 | import torch 3 | import os 4 | import random 5 | 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 7 | torch.backends.cudnn.enabled = True 8 | 9 | if __name__ == '__main__': 10 | parser = setup_args() 11 | parser.set_defaults( 12 | alias='test', 13 | tensorboard='./tmp/runs/{factorizer}/{data_type}', 14 | ########## 15 | ## data ## 16 | ########## 17 | data_type='criteo', 18 | data_path='./data/{data_type}/', 19 | load_in_queue=False, 20 | category_only=False, 21 | rebuild_cache=False, 22 | eval_res_path='./tmp/res/{factorizer}/{data_type}/{alias}/{epoch_idx}.csv', 23 | emb_save_path='./tmp/embedding/{factorizer}/{data_type}/{alias}/{num_parameter}', 24 | ###################### 25 | ## train/test split ## 26 | ###################### 27 | train_test_split='lro', 28 | test_ratio=0.1, 29 | valid_ratio=1/9, 30 | ########################## 31 | ## Devices & Efficiency ## 32 | ########################## 33 | use_cuda=True, 34 | early_stop=5, 35 | log_interval=1, 36 | eval_interval=1, 37 | display_interval=2000, 38 | device_ids_test=[0], 39 | device_id=0, 40 | batch_size_train=1024, 41 | batch_size_valid=1024, 42 | batch_size_test=1024, 43 | ########### 44 | ## Model ## 45 | ########### 46 | factorizer='fm', 47 | model='fm', 48 | fm_lr=1e-3, 49 | # Deep 50 | mlp_dims=[400, 400, 400], 51 | # AutoInt 52 | has_residual=True, 53 | full_part=True, 54 | num_heads=2, 55 | num_layers=3, 56 | att_dropout=0, 57 | atten_embed_dim=64, 58 | # 59 | fm_optimizer='adam', 60 | fm_amsgrad=False, 61 | fm_eps=1e-8, 62 | fm_l2_regularization=1e-5, 63 | fm_betas=(0.9, 0.999), 64 | fm_grad_clip=100, # 0.1 65 | fm_lr_exp_decay=1, 66 | l2_penalty=0, 67 | ######### 68 | ## PEP ## 69 | ######### 70 | latent_dim=24, 71 | threshold_type='feature_dim', 72 | g_type='sigmoid', 73 | gk=1, 74 | threshold_init=-150, 75 | candidate_p=[50000, 30000, 20000], 76 | ) 77 | opt = parser.parse_args(args=[]) 78 | opt = vars(opt) 79 | 80 | # rename alias 81 | # rename alias 82 | 83 | opt['alias'] = '{}_{}_BaseDim{}_bsz{}_lr_{}_optim_{}_thresholdType{}_thres_init{}_{}-{}_l2_penalty{}'.format( 84 | opt['model'].upper(), 85 | opt['alias'], 86 | opt['latent_dim'], 87 | opt['batch_size_train'], 88 | opt['fm_lr'], 89 | opt['fm_optimizer'], 90 | opt['threshold_type'].upper(), 91 | opt['threshold_init'], 92 | opt['g_type'], 93 | opt['gk'], 94 | opt['l2_penalty'] 95 | ) 96 | 97 | print(opt['alias']) 98 | random.seed(opt['seed']) 99 | # np.random.seed(opt['seed']) 100 | torch.manual_seed(opt['seed']) 101 | torch.cuda.manual_seed_all(opt['seed']) 102 | engine = Engine(opt) 103 | engine.train() 104 | -------------------------------------------------------------------------------- /train_criteo_retrain.py: -------------------------------------------------------------------------------- 1 | from engine import setup_args, Engine 2 | import torch 3 | import os 4 | import random 5 | 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 7 | torch.backends.cudnn.enabled = True 8 | 9 | if __name__ == '__main__': 10 | parser = setup_args() 11 | parser.set_defaults( 12 | alias='test', 13 | tensorboard='./tmp/runs/{factorizer}/{data_type}', 14 | ########## 15 | ## data ## 16 | ########## 17 | data_type='criteo', 18 | data_path='./data/{data_type}/', 19 | load_in_queue=False, 20 | category_only=False, 21 | rebuild_cache=False, 22 | eval_res_path='./tmp/res/{factorizer}/{data_type}/{alias}/{epoch_idx}.csv', 23 | emb_save_path='./tmp/embedding/{factorizer}/{data_type}/{alias}/{num_parameter}', 24 | ###################### 25 | ## train/test split ## 26 | ###################### 27 | train_test_split='lro', 28 | test_ratio=0.1, 29 | valid_ratio=1/9, 30 | ########################## 31 | ## Devices & Efficiency ## 32 | ########################## 33 | use_cuda=True, 34 | early_stop=5, 35 | log_interval=1, 36 | eval_interval=1, 37 | display_interval=2000, 38 | device_ids_test=[0], 39 | device_id=0, 40 | batch_size_train=1024, 41 | batch_size_valid=1024, 42 | batch_size_test=1024, 43 | ########### 44 | ## Model ## 45 | ########### 46 | factorizer='fm', 47 | model='fm', 48 | fm_lr=1e-3, 49 | # Deep 50 | mlp_dims=[400, 400, 400], 51 | # AutoInt 52 | has_residual=True, 53 | full_part=True, 54 | num_heads=2, 55 | num_layers=3, 56 | att_dropout=0, 57 | atten_embed_dim=64, 58 | # 59 | fm_optimizer='adam', 60 | fm_amsgrad=False, 61 | fm_eps=1e-8, 62 | fm_l2_regularization=1e-5, 63 | fm_betas=(0.9, 0.999), 64 | fm_grad_clip=100, # 0.1 65 | fm_lr_exp_decay=1, 66 | l2_penalty=0, 67 | ######### 68 | ## PEP ## 69 | ######### 70 | latent_dim=24, 71 | threshold_type='feature_dim', 72 | g_type='sigmoid', 73 | gk=1, 74 | threshold_init=-150, 75 | retrain_emb_param=29994, 76 | re_init=False, 77 | ) 78 | opt = parser.parse_args(args=[]) 79 | opt = vars(opt) 80 | 81 | # rename alias 82 | # rename alias 83 | 84 | opt['alias'] = '{}_{}_BaseDim{}_bsz{}_lr_{}_optim_{}_thresholdType{}_thres_init{}_{}-{}_l2_penalty{}'.format( 85 | opt['model'].upper(), 86 | opt['alias'], 87 | opt['latent_dim'], 88 | opt['batch_size_train'], 89 | opt['fm_lr'], 90 | opt['fm_optimizer'], 91 | opt['threshold_type'].upper(), 92 | opt['threshold_init'], 93 | opt['g_type'], 94 | opt['gk'], 95 | opt['l2_penalty'] 96 | ) 97 | 98 | print(opt['alias']) 99 | random.seed(opt['seed']) 100 | # np.random.seed(opt['seed']) 101 | torch.manual_seed(opt['seed']) 102 | torch.cuda.manual_seed_all(opt['seed']) 103 | engine = Engine(opt) 104 | engine.train() 105 | -------------------------------------------------------------------------------- /train_ml-1m.py: -------------------------------------------------------------------------------- 1 | from engine import setup_args, Engine 2 | import torch 3 | import os 4 | import random 5 | import numpy as np 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 7 | torch.backends.cudnn.enabled = True 8 | 9 | if __name__ == '__main__': 10 | parser = setup_args() 11 | parser.set_defaults( 12 | alias='test', 13 | tensorboard='./tmp/runs/{factorizer}/{data_type}', 14 | ########## 15 | ## data ## 16 | ########## 17 | data_type='ml-1m', 18 | data_path='./data/{data_type}/', 19 | load_in_queue=False, 20 | category_only=False, 21 | rebuild_cache=False, 22 | eval_res_path='./tmp/res/{factorizer}/{data_type}/{alias}/{epoch_idx}.csv', 23 | emb_save_path='./tmp/embedding/{factorizer}/{data_type}/{alias}/{num_parameter}', 24 | ###################### 25 | ## train/test split ## 26 | ###################### 27 | test_ratio=0.1, 28 | valid_ratio=1/9, 29 | ########################## 30 | ## Devices & Efficiency ## 31 | ########################## 32 | use_cuda=True, 33 | early_stop=40, 34 | log_interval=1, 35 | display_interval=500, 36 | eval_interval=5, # 10 epochs between 2 evaluations 37 | device_ids_test=[0], 38 | device_id=0, 39 | batch_size_train=1024, 40 | batch_size_valid=1024, 41 | batch_size_test=1024, 42 | ########### 43 | ## Model ## 44 | ########### 45 | factorizer='fm', 46 | model='fm', 47 | fm_lr=1e-3, 48 | # Deep 49 | mlp_dims=[100, 100], 50 | # AutoInt 51 | has_residual=True, 52 | full_part=True, 53 | num_heads=2, 54 | num_layers=3, 55 | att_dropout=0.4, 56 | atten_embed_dim=64, 57 | # optimizer setting 58 | fm_optimizer='adam', 59 | fm_amsgrad=False, 60 | fm_eps=1e-8, 61 | fm_l2_regularization=1e-5, 62 | fm_betas=(0.9, 0.999), 63 | fm_grad_clip=100, # 0.1 64 | fm_lr_exp_decay=1, 65 | l2_penalty=0, 66 | ######### 67 | ## PEP ## 68 | ######### 69 | latent_dim=32, 70 | threshold_type='feature_dim', 71 | g_type='sigmoid', 72 | gk=1, 73 | threshold_init=-15, 74 | candidate_p=[50000, 30000, 20000], 75 | ) 76 | 77 | opt = parser.parse_args(args=[]) 78 | opt = vars(opt) 79 | 80 | # rename alias 81 | # rename alias 82 | 83 | opt['alias'] = '{}_{}_BaseDim{}_bsz{}_lr_{}_optim_{}_thresholdType{}_thres_init{}_{}-{}_l2_penalty{}'.format( 84 | opt['model'].upper(), 85 | opt['alias'], 86 | opt['latent_dim'], 87 | opt['batch_size_train'], 88 | opt['fm_lr'], 89 | opt['fm_optimizer'], 90 | opt['threshold_type'].upper(), 91 | opt['threshold_init'], 92 | opt['g_type'], 93 | opt['gk'], 94 | opt['l2_penalty'] 95 | ) 96 | print(opt['alias']) 97 | random.seed(opt['seed']) 98 | # np.random.seed(opt['seed']) 99 | torch.manual_seed(opt['seed']) 100 | torch.cuda.manual_seed_all(opt['seed']) 101 | engine = Engine(opt) 102 | engine.train() 103 | -------------------------------------------------------------------------------- /train_ml-1m_retrain.py: -------------------------------------------------------------------------------- 1 | from engine import setup_args, Engine 2 | import torch 3 | import os 4 | import random 5 | import numpy as np 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 7 | torch.backends.cudnn.enabled = True 8 | 9 | if __name__ == '__main__': 10 | parser = setup_args() 11 | parser.set_defaults( 12 | alias='test', 13 | tensorboard='./tmp/runs/{factorizer}/{data_type}', 14 | ########## 15 | ## data ## 16 | ########## 17 | data_type='ml-1m', 18 | data_path='./data/{data_type}/', 19 | load_in_queue=False, 20 | category_only=False, 21 | rebuild_cache=False, 22 | eval_res_path='./tmp/res/{factorizer}/{data_type}/{alias}/{epoch_idx}.csv', 23 | emb_save_path='./tmp/embedding/{factorizer}/{data_type}/{alias}/{num_parameter}', 24 | ###################### 25 | ## train/test split ## 26 | ###################### 27 | test_ratio=0.1, 28 | valid_ratio=1/9, 29 | ########################## 30 | ## Devices & Efficiency ## 31 | ########################## 32 | use_cuda=True, 33 | early_stop=40, 34 | log_interval=1, 35 | display_interval=500, 36 | eval_interval=5, # 10 epochs between 2 evaluations 37 | device_ids_test=[0], 38 | device_id=0, 39 | batch_size_train=1024, 40 | batch_size_valid=1024, 41 | batch_size_test=1024, 42 | ########### 43 | ## Model ## 44 | ########### 45 | factorizer='fm', 46 | model='fm', 47 | fm_lr=1e-3, 48 | # Deep 49 | mlp_dims=[100, 100], 50 | # AutoInt 51 | has_residual=True, 52 | full_part=True, 53 | num_heads=2, 54 | num_layers=3, 55 | att_dropout=0.4, 56 | atten_embed_dim=64, 57 | # optimizer setting 58 | fm_optimizer='adam', 59 | fm_amsgrad=False, 60 | fm_eps=1e-8, 61 | fm_l2_regularization=1e-5, 62 | fm_betas=(0.9, 0.999), 63 | fm_grad_clip=100, # 0.1 64 | fm_lr_exp_decay=1, 65 | l2_penalty=0, 66 | ######### 67 | ## PEP ## 68 | ######### 69 | latent_dim=32, 70 | threshold_type='feature_dim', 71 | g_type='sigmoid', 72 | gk=1, 73 | threshold_init=-15, 74 | retrain_emb_param=29994, 75 | re_init=False, 76 | ) 77 | 78 | opt = parser.parse_args(args=[]) 79 | opt = vars(opt) 80 | 81 | # rename alias 82 | # rename alias 83 | 84 | opt['alias'] = '{}_{}_BaseDim{}_bsz{}_lr_{}_optim_{}_thresholdType{}_thres_init{}_{}-{}_l2_penalty{}'.format( 85 | opt['model'].upper(), 86 | opt['alias'], 87 | opt['latent_dim'], 88 | opt['batch_size_train'], 89 | opt['fm_lr'], 90 | opt['fm_optimizer'], 91 | opt['threshold_type'].upper(), 92 | opt['threshold_init'], 93 | opt['g_type'], 94 | opt['gk'], 95 | opt['l2_penalty'] 96 | ) 97 | print(opt['alias']) 98 | random.seed(opt['seed']) 99 | # np.random.seed(opt['seed']) 100 | torch.manual_seed(opt['seed']) 101 | torch.cuda.manual_seed_all(opt['seed']) 102 | engine = Engine(opt) 103 | engine.train() 104 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluate.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/utils/__pycache__/evaluate.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/utils/__pycache__/evaluate.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/performance_optimization.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/utils/__pycache__/performance_optimization.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/utils/__pycache__/train.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssui-liu/learnable-embed-sizes-for-RecSys/e4d07d2c5e4aaefa95b939b83600adfaf47d7a6f/utils/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /utils/evaluate.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | import torch 3 | from torch import multiprocessing as mp 4 | import pandas as pd 5 | import numpy as np 6 | from datetime import datetime 7 | 8 | 9 | def evaluate_fm(factorizer, sampler, use_cuda, on='test'): 10 | all_logloss, all_auc = [], [] 11 | model = factorizer.model 12 | model.eval() 13 | for i in range(sampler.num_batches_test): 14 | data, labels = sampler.get_sample(on) 15 | 16 | if use_cuda: 17 | data, labels = data.cuda(), labels.cuda() 18 | prob_preference = model.forward(data) 19 | logloss = factorizer.criterion(prob_preference, labels.float()) / (data.size()[0]) 20 | all_logloss.append(logloss.detach().cpu().numpy()) 21 | 22 | prob_preference = torch.sigmoid(prob_preference).detach().cpu().numpy() 23 | labels = labels.detach().cpu().numpy() 24 | auc = metrics.roc_auc_score(labels, prob_preference) 25 | all_auc.append(auc) 26 | 27 | return np.mean(all_logloss), np.mean(all_auc) 28 | 29 | 30 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some handy functions for pytroch model training ... 3 | """ 4 | import torch 5 | 6 | 7 | def get_grad_norm(model): 8 | grads = [] 9 | for p in model.parameters(): 10 | if p.grad is not None: 11 | grads.append(p.grad.data.view(-1, 1)) 12 | if len(grads) == 0: 13 | grads.append(torch.FloatTensor([0])) 14 | grad_norm = torch.norm(torch.cat(grads)) 15 | if grad_norm.is_cuda: 16 | grad_norm = grad_norm.cpu() 17 | return grad_norm.item() 18 | 19 | 20 | # Checkpoints 21 | def save_checkpoint(model, model_dir): 22 | torch.save(model.state_dict(), model_dir) 23 | 24 | 25 | def resume_checkpoint(model, model_dir, device_id): 26 | state_dict = torch.load(model_dir, 27 | map_location=lambda storage, loc: storage.cuda(device=device_id)) # ensure all storage are on gpu 28 | model.load_state_dict(state_dict) 29 | 30 | 31 | # Hyper params 32 | def use_cuda(enabled, device_id=0): 33 | if enabled: 34 | assert torch.cuda.is_available(), 'CUDA is not available' 35 | torch.cuda.set_device(device_id) 36 | 37 | 38 | def use_optimizer(network, params): 39 | if params['optimizer'] == 'sgd': 40 | optimizer = torch.optim.SGD(network.parameters(), 41 | lr=params['lr'], 42 | weight_decay=params['l2_regularization']) 43 | elif params['optimizer'] == 'adam': 44 | optimizer = torch.optim.Adam(network.parameters(), 45 | lr=params['lr'], 46 | betas=params['betas'], 47 | weight_decay=params['l2_regularization'], 48 | amsgrad=params['amsgrad']) 49 | elif params['optimizer'] == 'rmsprop': 50 | optimizer = torch.optim.RMSprop(network.parameters(), 51 | lr=params['lr'], 52 | alpha=params['alpha'], 53 | momentum=params['momentum'], 54 | weight_decay=params['l2_regularization']) 55 | return optimizer 56 | --------------------------------------------------------------------------------