├── README.md ├── config.py ├── data ├── preprd │ └── mvp │ │ ├── new_user │ │ ├── dict │ │ ├── test │ │ └── train │ │ └── partial │ │ ├── dict.json │ │ ├── partial10 │ │ ├── partial20 │ │ ├── partial30 │ │ ├── partial40 │ │ ├── partial50 │ │ ├── partial60 │ │ ├── partial70 │ │ ├── partial80 │ │ ├── partial90 │ │ ├── sep.py │ │ └── users.json └── raw │ ├── company_info.csv │ ├── purchase_history.json │ └── user_info.csv ├── dataset.py ├── exp.py ├── main.py └── model ├── etna.py └── functions.py /README.md: -------------------------------------------------------------------------------- 1 | # ETNA : Embedding Transformation Network with Attention 2 | 3 | This is our Pytorch implementation for the paper: 4 | 5 | Raehyun Kim and Hyunjae Kim (2019). *[Predicting Multiple Demographic Attributes with Task Specific Embedding Transformation and Attention Network.](https://arxiv.org/abs/1903.10144)* In Proceedings of SIAM International Conference on Data Mining (SDM'19) 6 | 7 | The code is tested under a Linux desktop (w/ TiTan X - Pascal) with Pytorch 1.0.0. and Python 3. 8 | 9 | 10 | ## MVP (Multi-Vendor loyalty Program) Dataset 11 | We provide dataset for demographic prediction. You can find our raw dataset in (`data/raw`). 12 | 13 | MVP dataset consists of three files. `[Company_info, User_info, Purchase_history]` 14 | 15 | * `company_info.csv` : Company's industrial categories are included in company info. 16 | * `user_info.csv` : User's demographic information (processed as class). 17 | * `purchase_history.json` : Each user's purchasing history. 18 | 19 | ## Model Training and Evaluation 20 | We have two type of task settings. (New user and partial prediction) 21 | 22 | And user should specify observation ratio for partial prediction task. 23 | 24 | To train our model on `partial task with 50% of observation ratio` (with default hyper-parameters): 25 | 26 | ``` 27 | python main.py --model_type ETNA --task_type partial50 28 | ``` 29 | 30 | Experiments on other observed ratios are also available as follows: 31 | ``` 32 | python main.py --model_type ETNA --task_type partial10 33 | ``` 34 | 35 | If you want to test our model on validation set for searching your own hyper-parameters, use '--do-validation' argument. 36 | 37 | Note that we do not provide validation set, so you should use some portion of the training set as validation set. 38 | 39 | 40 | ## Reference 41 | Please cite our paper if you use the code or datasets. 42 | ``` 43 | @inproceedings{kim2019predicting, 44 | title={Predicting multiple demographic attributes with task specific embedding transformation and attention network}, 45 | author={Kim, Raehyun and Kim, Hyunjae and Lee, Janghyuk and Kang, Jaewoo}, 46 | booktitle={Proceedings of the 2019 SIAM International Conference on Data Mining}, 47 | pages={765--773}, 48 | year={2019}, 49 | organization={SIAM} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # data 8 | parser.add_argument('--dataset', type=str, default='mvp') 9 | parser.add_argument('--data-path', type=str, default="./data/preprd", 10 | help="") 11 | parser.add_argument('--rand-seed', type=int, default=1) 12 | parser.add_argument('--data-shuffle', type=int, default=0) 13 | parser.add_argument('--num-workers', type=int, default=2) 14 | 15 | # task settings 16 | parser.add_argument('--task-type', type=str, default='partial50', required=True, 17 | help="[partial50, new_user]") 18 | 19 | # optimizations 20 | parser.add_argument('--opt', type=str, default='Adam', 21 | help="Adam / RMSprop / SGD / Adagrad / Adadelta / Adamax") 22 | parser.add_argument('--learning-rate', type=float, default=1e-3) 23 | parser.add_argument('--momentum', type=float, default=0.9) 24 | 25 | # embeddings 26 | parser.add_argument('--item-emb-size', type=int, default=100) 27 | 28 | # training parameters 29 | parser.add_argument('--batch-size', type=int, default=64) 30 | parser.add_argument('--user-emb-dim', type=int, default=100) 31 | parser.add_argument('--num-negs', type=int, default=1) 32 | parser.add_argument('--max-epoch', type=int, default=100) 33 | parser.add_argument('--grad-max-norm', type=float, default=5) 34 | parser.add_argument('--num-batches', type=int, default=20) 35 | 36 | # model selection 37 | parser.add_argument('--model-type', type=str, default='ETNA', required=True, 38 | help="[POP, ETN, ETNA]") 39 | 40 | # debugging and analysis 41 | parser.add_argument('--do-validation', action='store_true', default=False) 42 | parser.add_argument('--no-cuda', action='store_true', default=False, 43 | help="Whether not to use CUDA when available") 44 | parser.add_argument('--save-log', type=int, default=0) 45 | parser.add_argument('--print-per-step', type=int, default=9999) 46 | 47 | # regularization 48 | parser.add_argument('--no-early-stop', action='store_true', default=False) 49 | parser.add_argument('--weight-decay', type=float, default=1e-5) 50 | parser.add_argument('--lr-decay', type=float, default=0.9) 51 | 52 | args = parser.parse_args() 53 | return args 54 | 55 | -------------------------------------------------------------------------------- /data/preprd/mvp/new_user/dict: -------------------------------------------------------------------------------- 1 | {"attr_len": [2, 4, 2], "dict": ["null", "UNK", "O11712", "K10309", "K10326", "K00002", "O11682", "I60003", "I10174", "O11219", "K10246", "K10552", "I10141", "O11776", "O30025", "O11057", "O11652", "K10621", "O11570", "O11426", "I10060", "O11396", "K10435", "K10413", "K10563", "I10090", "O11075", "O11051", "I10097", "K10491", "O11004", "O11045", "O11711", "O11415", "O11632", "K10220", "K10536", "O11654", "O11298", "O11609", "O11106", "K10678", "O11430", "O11202", "K10569", "O11270", "O11722", "I10133", "I10170", "O11121", "O10033", "O11428", "O11517", "O11364", "O11459", "O11414", "K00016", "O11565", "O11003", "I60001", "O11165", "I50003", "K10035", "O11563", "O11541", "K10002", "K10516", "O11559", "O11489", "O11660", "O11686", "O11512", "K10314", "O11560", "O30021", "K10617", "O11347", "K10192", "K10500", "O11619", "O11792", "K10672", "K10656", "K00012", "I10059", "K10620", "I10105", "O11515", "O11710", "O11727", "O11222", "O11680", "I10185", "O40002", "O11684", "I10093", "I10190", "O20020", "K10530", "O11758", "K10117", "O11606", "O11635", "O11745", "O11763", "K10529", "O11599", "O11647", "O20008", "K10462", "O11673", "O11697", "O11360", "I10143", "K10084", "O11579", "I10123", "K10266", " ", "O11133", "O11368", "I10079", "K00007", "I10007", "O11779", "O11694", "K10519", "O11136", "O11555", "O11636", "O11617", "O11468", "O11159", "O30017", "O11464", "I50001", "I10063", "O11329", "I10118", "O11575", "K10460", "O11558", "O11037", "K10437", "K10334", "O11183", "O11408", "O11664", "O11346", "O11491", "O11486", "I10181", "O11508", "K10219", "O11510", "I10213", "K10555", "I10089", "K10571", "O11277", "O11421", "K10576", "O10029", "K10036", "O11131", "K10009", "O11383", "K10231", "O11532", "K10282", "O10013", "O11371", "I10110", "O11128", "O11576", "K10661", "O11530", "O11243", "O11582", "K10329", "O11753", "O11523", "I10137", "K10175", "O11225", "O11234", "K00025", "O11240", "O11174", "O11769", "K10587", "O11248", "O11411", "U00004", "K10520", "K10586", "O11457", "O11422", "O11752", "O11701", "I10108", "I10165", "O11519", "O11197", "O11465", "I20001", "K10616", "O11623", "O11081", "K00010", "K10609", "I10160", "O20023", "K10421", "O11323", "O11042", "O11765", "O11718", "O30002", "O11591", "K10505", "O11492", "O11438", "O11275", "K10381", "O11634", "I60004", "I10115", "O11705", "I10102", "I10109", "K10281", "O11087", "O11461", "O11552", "O11375", "O11011", "O11600", "O10008", "G10001", "K10273", "O11578", "O11546", "U00002", "K10365", "O11685", "O11666", "I10113", "K10188", "O11649", "K10303", "I10095", "O11480", "O11613", "K10575", "O11379", "O11646", "O11259", "O11640", "O11611", "O11307", "O11285", "O11268", "O11678", "K10251", "I10011", "I10111", "O11659", "O11380", "K10394", "O11717", "O11500", "I10179", "O11303", "O11198", "O11412", "O11518", "O11413", "O11389", "I10119", "O11495", "I10104", "O11076", "I10175", "K10202", "K10005", "O11672", "O20011", "O11378", "O11699", "O11695", "O11653", "O11638", "O11603", "I10107", "O11164", "O11742", "O11392", "O11665", "O11580", "O11524", "I10122", "I10062", "O11488", "O11627", "O11050", "O11592", "O11644", "K10169", "K10675", "O90006", "O11146", "I10083", "I10112", "I10138", "O11674", "I10150", "O11237", "O11052", "O11467", "O11551", "I10129", "I10154", "I10070", "O11436", "O11733", "O11336", "I10101", "K10503", "O11481", "O11425", "O11484", "I10134", "I10125", "K10538", "O11156", "O11048", "I80001", "A00001", "O11735", "I10092", "O10003", "K10016", "O11677", "O11245", "O11482", "K10470", "O11584", "O11564", "K00003", "I10194", "O11676", "O11618", "O11044", "I10078", "I10005", "O11341", "U00006", "O11692", "K10388", "I10205", "I10164", "O11320", "O11487", "O11557", "I10103", "I10001", "O20016", "O11356", "K10297", "O11474", "O11593", "I10132", "K10615", "O11255", "K10004", "I10086", "O11271", "O11526", "O11641", "O11648", "K10180", "O11639", "I10068", "O11504", "O11483", "O11708", "O11410", "O11274", "U00001", "O11335", "K10424", "I10106", "I10148", "O11566", "I10100", "I10114", "K10582", "O11669", "O10028", "K10223", "O11511", "O11299", "O11454", "O11621", "O11055", "O11598", "I10120", "K10032", "I10084", "O11397", "K10204", "O11605", "O11625", "I10193", "I10096", "I10200", "K10614", "I10082", "K10540", "K10019", "O11707", "O11631", "O11191", "O11521", "O11671", "I60005", "O11597", "I10155", "I10080", "K10488", "O11748", "O11429", "K10471", "O11444", "O11516", "K10433", "O11331", "O11154", "K10434", "O11628", "O11538", "U00003", "O11447", "O11614", "I10130", "K10671", "I90003", "O11005", "O20019", "I10206", "O11657", "O11455", "K10574", "O11141", "I10191", "K10271", "I10151", "K10665", "K10517", "K10331", "K10406", "O11725", "K00004", "K00014", "O11007", "O11505", "O11583", "O11645", "K10353", "O11569", "K10624", "O11620", "O11496", "I10142", "O11667", "O11406", "O11466", "K10634", "O11252", "K10199", "O11616", "K10397", "I10171", "I10204", "O11437", "O20001", "I10075", "O11350", "O11478", "O11503", "I10002", "K10548", "O11544", "K10566", "O11476", "O10004"]} -------------------------------------------------------------------------------- /data/preprd/mvp/partial/dict.json: -------------------------------------------------------------------------------- 1 | {"dict": ["null", "UNK", "O11712", "K10309", "K10326", "K00002", "O11682", "I60003", "I10174", "O11219", "K10246", "K10552", "I10141", "O11776", "O30025", "O11057", "O11652", "K10621", "O11570", "O11426", "I10060", "O11396", "K10435", "K10413", "K10563", "I10090", "O11075", "O11051", "I10097", "K10491", "O11004", "O11045", "O11711", "O11415", "O11632", "K10220", "K10536", "O11654", "O11298", "O11609", "O11106", "K10678", "O11430", "O11202", "K10569", "O11270", "O11722", "I10133", "I10170", "O11121", "O10033", "O11428", "O11517", "O11364", "O11459", "O11414", "K00016", "O11565", "O11003", "I60001", "O11165", "I50003", "K10035", "O11563", "O11541", "K10002", "K10516", "O11559", "O11489", "O11660", "O11686", "O11512", "K10314", "O11560", "O30021", "K10617", "O11347", "K10192", "K10500", "O11619", "O11792", "K10672", "K10656", "K00012", "I10059", "K10620", "I10105", "O11515", "O11710", "O11727", "O11222", "O11680", "I10185", "O40002", "O11684", "I10093", "I10190", "O20020", "K10530", "O11758", "K10117", "O11606", "O11635", "O11745", "O11763", "K10529", "O11599", "O11647", "O20008", "K10462", "O11673", "O11697", "O11360", "I10143", "K10084", "O11579", "I10123", "K10266", " ", "O11133", "O11368", "I10079", "K00007", "I10007", "O11779", "O11694", "K10519", "O11136", "O11555", "O11636", "O11617", "O11468", "O11159", "O30017", "O11464", "I50001", "I10063", "O11329", "I10118", "O11575", "K10460", "O11558", "O11037", "K10437", "K10334", "O11183", "O11408", "O11664", "O11346", "O11491", "O11486", "I10181", "O11508", "K10219", "O11510", "I10213", "K10555", "I10089", "K10571", "O11277", "O11421", "K10576", "O10029", "K10036", "O11131", "K10009", "O11383", "K10231", "O11532", "K10282", "O10013", "O11371", "I10110", "O11128", "O11576", "K10661", "O11530", "O11243", "O11582", "K10329", "O11753", "O11523", "I10137", "K10175", "O11225", "O11234", "K00025", "O11240", "O11174", "O11769", "K10587", "O11248", "O11411", "U00004", "K10520", "K10586", "O11457", "O11422", "O11752", "O11701", "I10108", "I10165", "O11519", "O11197", "O11465", "I20001", "K10616", "O11623", "O11081", "K00010", "K10609", "I10160", "O20023", "K10421", "O11323", "O11042", "O11765", "O11718", "O30002", "O11591", "K10505", "O11492", "O11438", "O11275", "K10381", "O11634", "I60004", "I10115", "O11705", "I10102", "I10109", "K10281", "O11087", "O11461", "O11552", "O11375", "O11011", "O11600", "O10008", "G10001", "K10273", "O11578", "O11546", "U00002", "K10365", "O11685", "O11666", "I10113", "K10188", "O11649", "K10303", "I10095", "O11480", "O11613", "K10575", "O11379", "O11646", "O11259", "O11640", "O11611", "O11307", "O11285", "O11268", "O11678", "K10251", "I10011", "I10111", "O11659", "O11380", "K10394", "O11717", "O11500", "I10179", "O11303", "O11198", "O11412", "O11518", "O11413", "O11389", "I10119", "O11495", "I10104", "O11076", "I10175", "K10202", "K10005", "O11672", "O20011", "O11378", "O11699", "O11695", "O11653", "O11638", "O11603", "I10107", "O11164", "O11742", "O11392", "O11665", "O11580", "O11524", "I10122", "I10062", "O11488", "O11627", "O11050", "O11592", "O11644", "K10169", "K10675", "O90006", "O11146", "I10083", "I10112", "I10138", "O11674", "I10150", "O11237", "O11052", "O11467", "O11551", "I10129", "I10154", "I10070", "O11436", "O11733", "O11336", "I10101", "K10503", "O11481", "O11425", "O11484", "I10134", "I10125", "K10538", "O11156", "O11048", "I80001", "A00001", "O11735", "I10092", "O10003", "K10016", "O11677", "O11245", "O11482", "K10470", "O11584", "O11564", "K00003", "I10194", "O11676", "O11618", "O11044", "I10078", "I10005", "O11341", "U00006", "O11692", "K10388", "I10205", "I10164", "O11320", "O11487", "O11557", "I10103", "I10001", "O20016", "O11356", "K10297", "O11474", "O11593", "I10132", "K10615", "O11255", "K10004", "I10086", "O11271", "O11526", "O11641", "O11648", "K10180", "O11639", "I10068", "O11504", "O11483", "O11708", "O11410", "O11274", "U00001", "O11335", "K10424", "I10106", "I10148", "O11566", "I10100", "I10114", "K10582", "O11669", "O10028", "K10223", "O11511", "O11299", "O11454", "O11621", "O11055", "O11598", "I10120", "K10032", "I10084", "O11397", "K10204", "O11605", "O11625", "I10193", "I10096", "I10200", "K10614", "I10082", "K10540", "K10019", "O11707", "O11631", "O11191", "O11521", "O11671", "I60005", "O11597", "I10155", "I10080", "K10488", "O11748", "O11429", "K10471", "O11444", "O11516", "K10433", "O11331", "O11154", "K10434", "O11628", "O11538", "U00003", "O11447", "O11614", "I10130", "K10671", "I90003", "O11005", "O20019", "I10206", "O11657", "O11455", "K10574", "O11141", "I10191", "K10271", "I10151", "K10665", "K10517", "K10331", "K10406", "O11725", "K00004", "K00014", "O11007", "O11505", "O11583", "O11645", "K10353", "O11569", "K10624", "O11620", "O11496", "I10142", "O11667", "O11406", "O11466", "K10634", "O11252", "K10199", "O11616", "K10397", "I10171", "I10204", "O11437", "O20001", "I10075", "O11350", "O11478", "O11503", "I10002", "K10548", "O11544", "K10566", "O11476", "O10004"], "attr_len": [2, 4, 2]} -------------------------------------------------------------------------------- /data/preprd/mvp/partial/partial10: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/demographic-prediction/ef9608ec3abe1594786cf0b1576a958ba0a55adf/data/preprd/mvp/partial/partial10 -------------------------------------------------------------------------------- /data/preprd/mvp/partial/sep.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | 5 | for r in [10,20,30,40,50,60,70,80,90]: 6 | data = json.load(open('partial'+str(r))) 7 | json.dump({'history':data['history'], 'label':data['label']}, open('partial'+str(r))) 8 | 9 | -------------------------------------------------------------------------------- /data/raw/company_info.csv: -------------------------------------------------------------------------------- 1 | company_idx,company_cd,cate_lg,cate_sm 2 | 1,UNK,, 3 | 2,O11712,food,Restaurant 4 | 3,K10309,food,Groceries 5 | 4,K10326,unk,Books 6 | 5,K00002,life,Sanitary pad 7 | 6,O11682,event,E-commerce 8 | 7,I60003,event,Event 9 | 8,I10174,communication,Communications 10 | 9,O11219,necessities,E-commerce 11 | 10,K10246,unk,Books 12 | 11,K10552,unk,Books 13 | 12,I10141,e-commerce,Restaurant 14 | 13,O11776,necessities,Communications 15 | 14,O30025,life,Hospital 16 | 15,O11057,finance,Credit card 17 | 16,O11652,finance,Loan 18 | 17,K10621,necessities,Retailer 19 | 18,O11570,e-commerce,E-commerce 20 | 19,O11426,food,Cafe 21 | 20,I10060,e-commerce,E-commerce 22 | 21,O11396,necessities,Necessities 23 | 22,K10435,life,Office supplies 24 | 23,K10413,unk,Books 25 | 24,K10563,necessities,Supermarket 26 | 25,I10090,food,Food 27 | 26,O11075,life,Hotel 28 | 27,O11051,necessities,Baby clothes 29 | 28,I10097,communication,Communications 30 | 29,K10491,mileage,Mileage 31 | 30,O11004,life,Automobile repair 32 | 31,O11045,necessities,Sports clothes 33 | 32,O11711,food,Restaurant 34 | 33,O11415,life,Resort 35 | 34,O11632,life,Sauna 36 | 35,K10220,necessities,Farm product 37 | 36,K10536,unk,Food 38 | 37,O11654,life,Beauty 39 | 38,O11298,event,Event 40 | 39,O11609,food,Cafe 41 | 40,O11106,necessities,Record shop 42 | 41,K10678,e-commerce,E-commerce 43 | 42,O11430,entertainment,Theater 44 | 43,O11202,finance,Credit card 45 | 44,K10569,life,Automobile repair 46 | 45,O11270,necessities,Supermarket 47 | 46,O11722,life,Maternity clinic 48 | 47,I10133,necessities,Gas 49 | 48,I10170,e-commerce,Entertainment 50 | 49,O11121,necessities,Automobile accessaries 51 | 50,O10033,food,Cafe 52 | 51,O11428,entertainment,Theater 53 | 52,O11517,necessities,Health food 54 | 53,O11364,event,Event 55 | 54,O11459,life,Sauna 56 | 55,O11414,life,Resort 57 | 56,K00016,e-commerce,E-commerce 58 | 57,O11565,finance,Credit card 59 | 58,O11003,necessities,Apparel 60 | 59,I60001,communication,Communications 61 | 60,O11165,event,Event 62 | 61,I50003,e-commerce,Entertainment 63 | 62,K10035,food,Food 64 | 63,O11563,e-commerce,E-commerce 65 | 64,O11541,life,Tour 66 | 65,K10002,food,Food 67 | 66,K10516,unk,Gift 68 | 67,O11559,e-commerce,E-commerce 69 | 68,O11489,life,Life etc 70 | 69,O11660,finance,Loan 71 | 70,O11686,event,Cosmetics 72 | 71,O11512,necessities,Apparel 73 | 72,K10314,food,Food 74 | 73,O11560,necessities,Tour 75 | 74,O30021,finance,Financial 76 | 75,K10617,unk,Educational institute 77 | 76,O11347,finance,Insuarance 78 | 77,K10192,life,Cosmetics 79 | 78,K10500,unk,E-commerce 80 | 79,O11619,event,Event 81 | 80,O11792,life,Hairdresser 82 | 81,K10672,food,Food 83 | 82,K10656,event,Retailer 84 | 83,K00012,unk,Life etc 85 | 84,I10059,necessities,Mobile device 86 | 85,K10620,necessities,Bath supplies 87 | 86,I10105,e-commerce,Entertainment 88 | 87,O11515,life,Automobile repair 89 | 88,O11710,food,Restaurant 90 | 89,O11727,life,Food 91 | 90,O11222,food,Cafe 92 | 91,O11680,e-commerce,E-commerce 93 | 92,I10185,life,Event 94 | 93,O40002,life,Hotel 95 | 94,O11684,food,Restaurant 96 | 95,I10093,entertainment,Sports event 97 | 96,I10190,e-commerce,E-commerce 98 | 97,O20020,necessities,Books 99 | 98,K10530,life,Contraceptive equipment 100 | 99,O11758,necessities,Apparel 101 | 100,K10117,necessities,Groceries 102 | 101,O11606,food,Cafe 103 | 102,O11635,e-commerce,Food 104 | 103,O11745,life,Transportation 105 | 104,O11763,life,Parcel 106 | 105,K10529,food,Restaurant 107 | 106,O11599,life,Resort 108 | 107,O11647,e-commerce,E-commerce 109 | 108,O20008,necessities,Supermarket 110 | 109,K10462,food,Food 111 | 110,O11673,food,Café 112 | 111,O11697,e-commerce,Apparel 113 | 112,O11360,mileage,Mileage 114 | 113,I10143,event,Event 115 | 114,K10084,life,Laundry soap 116 | 115,O11579,necessities,Gift 117 | 116,I10123,necessities,Apparel 118 | 117,K10266,unk,Books 119 | 118, ,, 120 | 119,O11133,e-commerce,Online Game 121 | 120,O11368,e-commerce,E-commerce 122 | 121,I10079,e-commerce,Duty free 123 | 122,K00007,life,Cosmetics 124 | 123,I10007,life,Car rent 125 | 124,O11779,food,Restaurant 126 | 125,O11694,entertainment,Performance 127 | 126,K10519,unk,Books 128 | 127,O11136,e-commerce,Communications 129 | 128,O11555,e-commerce,Communications 130 | 129,O11636,entertainment,Entertainment 131 | 130,O11617,life,Automobile repair 132 | 131,O11468,finance,Credit card 133 | 132,O11159,necessities,Home appliance 134 | 133,O30017,life,Hairdresser 135 | 134,O11464,food,Bakery 136 | 135,I50001,fuel,Gas 137 | 136,I10063,life,Automobile repair 138 | 137,O11329,necessities,Stationery store 139 | 138,I10118,event,Event 140 | 139,O11575,event,Event 141 | 140,K10460,life,Home electronic appliances 142 | 141,O11558,necessities,E-commerce 143 | 142,O11037,event,Event 144 | 143,K10437,unk,Books 145 | 144,K10334,food,Food 146 | 145,O11183,necessities,Supermarket 147 | 146,O11408,necessities,Books 148 | 147,O11664,necessities,Retailer 149 | 148,O11346,necessities,Hotel 150 | 149,O11491,necessities,Apparel 151 | 150,O11486,entertainment,Video 152 | 151,I10181,necessities,Retailer 153 | 152,O11508,food,Restaurant 154 | 153,K10219,necessities,Health food 155 | 154,O11510,finance,Insuarance 156 | 155,I10213,finance,Financial 157 | 156,K10555,unk,Books 158 | 157,I10089,necessities,Gas 159 | 158,K10571,event,Gas 160 | 159,O11277,food,Bar 161 | 160,O11421,necessities,Pet 162 | 161,K10576,unk,Books 163 | 162,O10029,food,Bakery 164 | 163,K10036,fuel,Gas 165 | 164,O11131,necessities,Sports clothes 166 | 165,K10009,necessities,Medicine 167 | 166,O11383,entertainment,Theater 168 | 167,K10231,unk,Convenience store 169 | 168,O11532,life,Life etc 170 | 169,K10282,unk,Books 171 | 170,O10013,life,Educational institute 172 | 171,O11371,life,Sauna 173 | 172,I10110,finance,Credit card 174 | 173,O11128,finance,Credit card 175 | 174,O11576,necessities,Financial 176 | 175,K10661,necessities,Stationery store 177 | 176,O11530,necessities,Communications 178 | 177,O11243,e-commerce,E-commerce 179 | 178,O11582,entertainment,Theater 180 | 179,K10329,unk,Books 181 | 180,O11753,life,Beauty 182 | 181,O11523,necessities,Stationery store 183 | 182,I10137,life,Transportation 184 | 183,K10175,unk,Books 185 | 184,O11225,e-commerce,E-commerce 186 | 185,O11234,necessities,Supermarket 187 | 186,K00025,food,Groceries 188 | 187,O11240,e-commerce,E-commerce 189 | 188,O11174,food,Bakery 190 | 189,O11769,necessities,Food 191 | 190,K10587,unk,Books 192 | 191,O11248,e-commerce,Automobile repair 193 | 192,O11411,life,Life etc 194 | 193,U00004,necessities,Supermarket 195 | 194,K10520,unk,Books 196 | 195,K10586,unk,Books 197 | 196,O11457,life,Airline 198 | 197,O11422,e-commerce,E-commerce 199 | 198,O11752,finance,Insuarance 200 | 199,O11701,food,Fast food 201 | 200,I10108,e-commerce,Entertainment 202 | 201,I10165,necessities,Health food 203 | 202,O11519,entertainment,Sports activity 204 | 203,O11197,necessities,Apparel 205 | 204,O11465,mileage,Mileage 206 | 205,I20001,life,Photo Studio 207 | 206,K10616,unk,Books 208 | 207,O11623,entertainment,Video 209 | 208,O11081,life,Hairdresser 210 | 209,K00010,unk,Food 211 | 210,K10609,unk,Event 212 | 211,I10160,e-commerce,E-commerce 213 | 212,O20023,necessities,Supermarket 214 | 213,K10421,life,Parcel 215 | 214,O11323,e-commerce,Theater 216 | 215,O11042,necessities,Bedding 217 | 216,O11765,event,Cosmetics 218 | 217,O11718,life,Event 219 | 218,O30002,finance,Credit card 220 | 219,O11591,food,Restaurant 221 | 220,K10505,food,Food 222 | 221,O11492,entertainment,Theater 223 | 222,O11438,life,Resort 224 | 223,O11275,necessities,Convenience store 225 | 224,K10381,unk,Event 226 | 225,O11634,e-commerce,Educational institute 227 | 226,I60004,event,Event 228 | 227,I10115,necessities,Gas 229 | 228,O11705,life,Medical etc 230 | 229,I10102,e-commerce,Entertainment 231 | 230,I10109,e-commerce,Entertainment 232 | 231,K10281,life,Laundry soap 233 | 232,O11087,e-commerce,Books 234 | 233,O11461,e-commerce,Books 235 | 234,O11552,life,Educational institute 236 | 235,O11375,food,Fast food 237 | 236,O11011,necessities,Apparel 238 | 237,O11600,necessities,Flower 239 | 238,O10008,necessities,Optician 240 | 239,G10001,finance,Credit card 241 | 240,K10273,unk,Necessities 242 | 241,O11578,necessities,Stationery store 243 | 242,O11546,entertainment,Theater 244 | 243,U00002,necessities,Supermarket 245 | 244,K10365,life,Necessities 246 | 245,O11685,food,Restaurant 247 | 246,O11666,necessities,Duty free 248 | 247,I10113,life,Automobile repair 249 | 248,K10188,unk,Kitchen supplies 250 | 249,O11649,necessities,Home electronic appliances 251 | 250,K10303,life,Office supplies 252 | 251,I10095,event,Gas 253 | 252,O11480,e-commerce,Tour 254 | 253,O11613,necessities,Home electronic appliances 255 | 254,K10575,food,Bar 256 | 255,O11379,e-commerce,Duty free 257 | 256,O11646,necessities,Cosmetics 258 | 257,O11259,necessities,Duty free 259 | 258,O11640,food,Restaurant 260 | 259,O11611,necessities,Office supplies 261 | 260,O11307,necessities,Interior design 262 | 261,O11285,e-commerce,Automobile repair 263 | 262,O11268,necessities,Supermarket 264 | 263,O11678,life,Hairdresser 265 | 264,K10251,unk,Home appliance 266 | 265,I10011,life,Hairdresser 267 | 266,I10111,e-commerce,Event 268 | 267,O11659,life,Chauffeur service 269 | 268,O11380,finance,Insuarance 270 | 269,K10394,unk,Books 271 | 270,O11717,necessities,Baby carriage 272 | 271,O11500,necessities,Retailer 273 | 272,I10179,e-commerce,Self-Defense 274 | 273,O11303,necessities,Optician 275 | 274,O11198,necessities,Supermarket 276 | 275,O11412,life,Resort 277 | 276,O11518,life,Resort 278 | 277,O11413,life,Resort 279 | 278,O11389,life,Market research 280 | 279,I10119,communication,Communications 281 | 280,O11495,e-commerce,Educational institute 282 | 281,I10104,e-commerce,Entertainment 283 | 282,O11076,food,Fast food 284 | 283,I10175,life,Automobile repair 285 | 284,K10202,unk,Books 286 | 285,K10005,life,Kitchen supplies 287 | 286,O11672,food,Bakery 288 | 287,O20011,necessities,Optician 289 | 288,O11378,event,Performance 290 | 289,O11699,life,Museum 291 | 290,O11695,entertainment,Performance 292 | 291,O11653,finance,Financial 293 | 292,O11638,life,Resort 294 | 293,O11603,food,Restaurant 295 | 294,I10107,e-commerce,Entertainment 296 | 295,O11164,entertainment,Theater 297 | 296,O11742,entertainment,Video 298 | 297,O11392,e-commerce,Performance 299 | 298,O11665,life,Fast food 300 | 299,O11580,necessities,Flower 301 | 300,O11524,food,Bakery 302 | 301,I10122,event,Event 303 | 302,I10062,life,Automobile repair 304 | 303,O11488,necessities,Retailer 305 | 304,O11627,food,Restaurant 306 | 305,O11050,entertainment,Sports event 307 | 306,O11592,e-commerce,E-commerce 308 | 307,O11644,event,Cosmetics 309 | 308,K10169,unk,Books 310 | 309,K10675,e-commerce,E-commerce 311 | 310,O90006,finance,Credit card 312 | 311,O11146,life,Hairdresser 313 | 312,I10083,necessities,Necessities 314 | 313,I10112,life,Automobile repair 315 | 314,I10138,food,Gas 316 | 315,O11674,necessities,Apparel 317 | 316,I10150,life,Event 318 | 317,O11237,mileage,Mileage 319 | 318,O11052,necessities,Baby clothes 320 | 319,O11467,necessities,Apparel 321 | 320,O11551,life,Educational institute 322 | 321,I10129,e-commerce,Automobile accessaries 323 | 322,I10154,food,Restaurant 324 | 323,I10070,e-commerce,Entertainment 325 | 324,O11436,necessities,E-commerce 326 | 325,O11733,e-commerce,Books 327 | 326,O11336,necessities,Stationery store 328 | 327,I10101,e-commerce,Entertainment 329 | 328,K10503,life,Retailer 330 | 329,O11481,e-commerce,Theater 331 | 330,O11425,e-commerce,Fast food 332 | 331,O11484,finance,Credit card 333 | 332,I10134,communication,Mileage 334 | 333,I10125,event,Event 335 | 334,K10538,life,Home electronic appliances 336 | 335,O11156,food,Bar 337 | 336,O11048,food,Fast-food 338 | 337,I80001,event,Event 339 | 338,A00001,event,Event 340 | 339,O11735,necessities,Stationery store 341 | 340,I10092,event,Event 342 | 341,O10003,food,Bakery 343 | 342,K10016,food,Food 344 | 343,O11677,life,Hairdresser 345 | 344,O11245,necessities,Department store 346 | 345,O11482,necessities,Dairy product 347 | 346,K10470,life,Necessities 348 | 347,O11584,life,Educational institute 349 | 348,O11564,life,Chauffeur service 350 | 349,K00003,food,Dairy product 351 | 350,I10194,life,Photo 352 | 351,O11676,food,Restaurant 353 | 352,O11618,event,Event 354 | 353,O11044,necessities,Necessities 355 | 354,I10078,necessities,Apparel 356 | 355,I10005,e-commerce,Entertainment 357 | 356,O11341,necessities,Pet 358 | 357,U00006,necessities,Supermarket 359 | 358,O11692,necessities,Convenience store 360 | 359,K10388,unk,Event 361 | 360,I10205,life,Event 362 | 361,I10164,e-commerce,Performance 363 | 362,O11320,necessities,Mobile device 364 | 363,O11487,necessities,Health food 365 | 364,O11557,necessities,Stationery store 366 | 365,I10103,e-commerce,Entertainment 367 | 366,I10001,necessities,Supermarket 368 | 367,O20016,necessities,Pharmacy 369 | 368,O11356,e-commerce,E-commerce 370 | 369,K10297,life,E-commerce 371 | 370,O11474,necessities,Home appliance 372 | 371,O11593,necessities,Optician 373 | 372,I10132,necessities,Duty free 374 | 373,K10615,necessities,Necessities 375 | 374,O11255,necessities,Retailer 376 | 375,K10004,food,Food 377 | 376,I10086,e-commerce,Entertainment 378 | 377,O11271,food,Bakery 379 | 378,O11526,life,Chauffeur service 380 | 379,O11641,food,Fast food 381 | 380,O11648,food,Fast food 382 | 381,K10180,unk,Books 383 | 382,O11639,food,Restaurant 384 | 383,I10068,e-commerce,E-commerce 385 | 384,O11504,finance,Credit card 386 | 385,O11483,food,Restaurant 387 | 386,O11708,life,Life etc 388 | 387,O11410,food,Bakery 389 | 388,O11274,entertainment,Theater 390 | 389,U00001,necessities,Supermarket 391 | 390,O11335,finance,Financial 392 | 391,K10424,unk,Mileage 393 | 392,I10106,e-commerce,Entertainment 394 | 393,I10148,communication,Mileage 395 | 394,O11566,finance,Credit card 396 | 395,I10100,communication,Communications 397 | 396,I10114,necessities,Event 398 | 397,K10582,e-commerce,Entertainment 399 | 398,O11669,necessities,Convenience store 400 | 399,O10028,food,Fast food 401 | 400,K10223,food,Bakery 402 | 401,O11511,food,Restaurant 403 | 402,O11299,entertainment,Theater 404 | 403,O11454,necessities,Books 405 | 404,O11621,life,Airline 406 | 405,O11055,necessities,Department store 407 | 406,O11598,life,Educational institute 408 | 407,I10120,e-commerce,Entertainment 409 | 408,K10032,life,Necessities 410 | 409,I10084,life,Gas 411 | 410,O11397,necessities,Cosmetics 412 | 411,K10204,unk,Books 413 | 412,O11605,life,Market research 414 | 413,O11625,food,Bakery 415 | 414,I10193,mileage,Mileage 416 | 415,I10096,communication,Communications 417 | 416,I10200,food,Fast food 418 | 417,K10614,necessities,Financial 419 | 418,I10082,e-commerce,Entertainment 420 | 419,K10540,unk,Books 421 | 420,K10019,life,Cosmetics 422 | 421,O11707,life,Car wash 423 | 422,O11631,necessities,Cosmetics 424 | 423,O11191,entertainment,Leisure 425 | 424,O11521,life,Tour 426 | 425,O11671,life,Beauty 427 | 426,I60005,e-commerce,E-commerce 428 | 427,O11597,necessities,Cosmetics 429 | 428,I10155,life,Chauffeur service 430 | 429,I10080,necessities,Duty free 431 | 430,K10488,e-commerce,E-commerce 432 | 431,O11748,life,Event 433 | 432,O11429,mileage,Mileage 434 | 433,K10471,life,Kitchen supplies 435 | 434,O11444,necessities,Apparel 436 | 435,O11516,life,Real estate 437 | 436,K10433,unk,Cosmetics 438 | 437,O11331,food,Cafe 439 | 438,O11154,necessities,Supermarket 440 | 439,K10434,life,Insuarance 441 | 440,O11628,food,Fast food 442 | 441,O11538,life,Sports 443 | 442,U00003,necessities,Supermarket 444 | 443,O11447,life,Market research 445 | 444,O11614,food,Bakery 446 | 445,I10130,necessities,E-commerce 447 | 446,K10671,e-commerce,E-commerce 448 | 447,I90003,e-commerce,Life etc 449 | 448,O11005,entertainment,Amusement park 450 | 449,O20019,necessities,Jewelry 451 | 450,I10206,e-commerce,Video 452 | 451,O11657,necessities,Office supplies 453 | 452,O11455,event,Event 454 | 453,K10574,necessities,Retailer 455 | 454,O11141,life,Hairdresser 456 | 455,I10191,e-commerce,Gift 457 | 456,K10271,unk,Books 458 | 457,I10151,e-commerce,Event 459 | 458,K10665,food,Food 460 | 459,K10517,life,Real estate 461 | 460,K10331,life,Office supplies 462 | 461,K10406,necessities,Necessities 463 | 462,O11725,life,Hospital 464 | 463,K00004,life,Department store 465 | 464,K00014,unk,Bakery 466 | 465,O11007,food,Fast-food 467 | 466,O11505,food,Restaurant 468 | 467,O11583,life,Educational institute 469 | 468,O11645,necessities,Retailer 470 | 469,K10353,food,Restaurant 471 | 470,O11569,necessities,Department store 472 | 471,K10624,life,Retailer 473 | 472,O11620,necessities,Financial 474 | 473,O11496,entertainment,Gym 475 | 474,I10142,event,Gas 476 | 475,O11667,food,Cafe 477 | 476,O11406,e-commerce,Entertainment 478 | 477,O11466,necessities,Cosmetics 479 | 478,K10634,unk,Books 480 | 479,O11252,e-commerce,E-commerce 481 | 480,K10199,unk,Books 482 | 481,O11616,event,Retailer 483 | 482,K10397,unk,Necessities 484 | 483,I10171,e-commerce,E-commerce 485 | 484,I10204,life,Market research 486 | 485,O11437,life,Laundry 487 | 486,O20001,necessities,Department store 488 | 487,I10075,event,Event 489 | 488,O11350,necessities,Supermarket 490 | 489,O11478,e-commerce,Books 491 | 490,O11503,food,Restaurant 492 | 491,I10002,necessities,Supermarket 493 | 492,K10548,life,Cosmetics 494 | 493,O11544,e-commerce,E-commerce 495 | 494,K10566,unk,Books 496 | 495,O11476,necessities,Pharmacy 497 | 496,O10004,food,Restaurant 498 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import random 4 | import os 5 | import sys 6 | import torch 7 | import re 8 | from torch.utils.data import Dataset 9 | from config import get_args 10 | 11 | args = get_args() 12 | 13 | # set random seeds 14 | np.random.seed(1) 15 | random.seed(1) 16 | torch.manual_seed(1) 17 | 18 | class Dictionary(object): 19 | NULL = '' 20 | UNK = '' 21 | 22 | def __init__(self, data_path, task_type): 23 | if task_type == 'new_user': 24 | data_path = os.path.join(data_path, task_type, 'dict') 25 | else: 26 | data_path = os.path.join(data_path, 'partial', 'dict.json') 27 | load_file = json.load(open(data_path)) 28 | self.dict = load_file['dict'] 29 | self.attr_len = load_file['attr_len'] 30 | 31 | def __len__(self): 32 | return len(self.dict) 33 | 34 | def __iter__(self): 35 | return iter(self.dict) 36 | 37 | def add(self, item): 38 | if item not in self.dict: 39 | self.dict.append(item) 40 | 41 | class DemoAttrDataset(Dataset): 42 | def __init__(self, logger, data_type, data_path, task_type, model_type): 43 | self.data_type = data_type 44 | self.history = self.label = self.observed = None 45 | 46 | if task_type == 'new_user': 47 | data = json.load(open(os.path.join(data_path, task_type, data_type))) 48 | history = data['history'] 49 | label = data['label'] 50 | observed = data['observed'] 51 | else: 52 | users = json.load(open(os.path.join(data_path, 'partial', 'users.json'))) 53 | history = users['history'] 54 | label = users['attribute'] 55 | observed = json.load(open(os.path.join(data_path, 'partial', task_type)))['observed'] 56 | 57 | shuffled_idx = list(range(len(history))) 58 | self.history = np.asarray(history)[shuffled_idx].tolist() 59 | self.label = np.asarray(label)[shuffled_idx].tolist() 60 | self.observed = np.asarray(observed)[shuffled_idx].tolist() 61 | 62 | if 'partial' in task_type and any([True if t in self.data_type else False for t in ['valid', 'test']]): 63 | self.observed = np.invert(np.asarray(self.observed).astype(bool)).astype(int).tolist() 64 | 65 | logger.info("{} {} samples are loaded".format(self.__len__(), self.data_type)) 66 | 67 | 68 | def __len__(self): 69 | return len(self.label) 70 | 71 | def __getitem__(self, index): 72 | return self.history[index], self.label[index], self.observed[index] 73 | 74 | 75 | def batchify(batch): 76 | history, label, observed = [],[],[] 77 | 78 | for ex in batch: 79 | history.append(ex[0]) 80 | label.append(ex[1]) 81 | observed.append(ex[2]) 82 | 83 | maxlen_history = max([len(h) for h in history]) 84 | x = torch.LongTensor(len(history), maxlen_history).zero_() 85 | x_mask = torch.ByteTensor(len(history), maxlen_history).zero_() 86 | for i, h in enumerate(history): 87 | x[i, :len(h)].copy_(torch.from_numpy(np.asarray(h))) 88 | x_mask[i, :len(h)].fill_(1) 89 | y = np.asarray(label) 90 | ob = np.asarray(observed) 91 | 92 | return x, x_mask, y, ob 93 | 94 | 95 | -------------------------------------------------------------------------------- /exp.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import copy 4 | from collections import Counter 5 | from functools import reduce 6 | import logging 7 | import numpy as np 8 | from operator import mul 9 | import os 10 | from sklearn.metrics import hamming_loss 11 | import sys 12 | import time 13 | 14 | import torch 15 | from torch.autograd import Variable 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | from tensorboardX import SummaryWriter 19 | 20 | from dataset import Dictionary 21 | from model.etna import ETNADemoPredictor 22 | 23 | 24 | class Experiment: 25 | def __init__(self, args, logger): 26 | 27 | self.args = args 28 | self.logger = logger 29 | Dict = Dictionary( 30 | data_path=os.path.join(args.data_path, args.dataset), 31 | task_type=args.task_type) 32 | self.dict = Dict.dict 33 | self.attr_len = Dict.attr_len 34 | self.all_the_poss = reduce(mul, Dict.attr_len, 1) 35 | self.logger.info("Experiment initializing . . . ") 36 | 37 | # build models 38 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 39 | if args.model_type == 'POP': 40 | self.model = 'POP' 41 | elif any([True if args.model_type == m else False for m in ['ETN', 'ETNA']]): 42 | self.model = ETNADemoPredictor(logger, args.model_type, self.dict.__len__(), 43 | args.item_emb_size, Dict.attr_len, args.no_cuda).to(device) 44 | else: 45 | sys.exit() 46 | 47 | if args.model_type != 'POP': 48 | self.select_optimizer(self.model) 49 | self.logger.info(self.model) 50 | self.step_count = 0 51 | 52 | 53 | def select_optimizer(self, model): 54 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 55 | if(self.args.opt == 'Adam'): 56 | model.optimizer = optim.Adam(parameters, lr=self.args.learning_rate, 57 | weight_decay=self.args.weight_decay) 58 | elif(self.args.opt == 'RMSprop'): 59 | model.optimizer = optim.RMSprop(parameters, lr=self.args.learning_rate, 60 | weight_decay=self.args.weight_decay, 61 | momentum=self.args.momentum) 62 | elif(self.args.opt == 'SGD'): 63 | model.optimizer = optim.SGD(parameters, lr=self.args.learning_rate, 64 | weight_decay=self.args.weight_decay, 65 | momentum=self.args.momentum) 66 | elif(self.args.opt == 'Adagrad'): 67 | model.optimizer = optim.Adagrad(parameters, lr=self.args.learning_rate) 68 | elif(self.args.opt == 'Adadelta'): 69 | model.optimizer = optim.Adadelta(parameters, lr=self.args.learning_rate) 70 | 71 | def adjust_lr(self): 72 | for param_group in self.model.optimizer.param_groups: 73 | param_group['lr'] *= self.args.lr_decay 74 | 75 | def run_epoch(self, epoch, data_loader, dataset, trainable=False): 76 | num_samples = data_loader.dataset.__len__() 77 | num_steps = (num_samples // self.args.batch_size) + 1 78 | self.num_steps = num_steps 79 | 80 | 81 | self.y_em_counter, self.yp_counter, self.yt_counter = Counter(), Counter(), Counter() 82 | 83 | self.hm_acc = self.num_users = 0 84 | loss_sum = 0 85 | for i, (x, x_mask, y, ob) in enumerate(data_loader): 86 | t0 = time.clock() 87 | self.step = i+1 88 | self.step_count += 1 89 | 90 | # change the mode 91 | if self.args.model_type != 'POP': 92 | if trainable: 93 | self.model.train() 94 | self.model.optimizer.zero_grad() 95 | else: 96 | self.model.eval() 97 | 98 | prob, loss = self.model(x, x_mask, y, ob, trainable) 99 | 100 | if trainable: 101 | loss.backward() 102 | nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_max_norm) 103 | self.model.optimizer.step() 104 | 105 | ls = loss.item() 106 | loss_sum += ls 107 | else: 108 | prob = None 109 | 110 | self.accumulate_score(prob, y, ob) 111 | 112 | if (i+1) % self.args.print_per_step == 0: 113 | hm, macP, macR, macF1, wP, wR, wF1 = self.get_score() 114 | t1 = time.clock() 115 | self.logger.info("< {} : step {} > Loss={:5.3f}, time:{:5.2}, Hamming={:2.3f}" 116 | .format(data_loader.dataset.data_type, self.step, loss_sum/self.step, t1-t0, hm)) 117 | self.logger.info("macro - macP:{:2.3f}, macR:{:2.3f}, macF1:{:2.3f}" 118 | .format(macP, macR, macF1)) 119 | self.logger.info("weighted - wP:{:2.3f}, wR:{:2.3f}, wF1:{:2.3f}" 120 | .format(wP, wR, wF1)) 121 | 122 | hm, macP, macR, macF1, wP, wR, wF1 = self.get_score() 123 | return loss_sum / num_steps, hm, macP, macR, macF1, wP, wR, wF1 124 | 125 | 126 | def accumulate_score(self, prob, label, observed): 127 | y_numbering = np.asarray([[j if l else 0 for j, l in enumerate(ll)] \ 128 | for i, ll in enumerate(label)]) 129 | 130 | if self.args.model_type == 'POP': 131 | popular = [[0, 1, 0, 1, 0, 0, 0, 1] \ 132 | for _ in range(y_numbering.shape[0])] 133 | prob = popular 134 | 135 | for b_idx, ob in enumerate(observed): 136 | pred, true = [],[] 137 | start = 0 138 | for a_idx, al in enumerate(self.attr_len): 139 | end = start + al 140 | if sum(ob[start:end]): 141 | p = np.argmax(prob[b_idx][start:end], 0) + start 142 | t = sum(y_numbering[b_idx][start:end]) 143 | pred.append(p) 144 | true.append(t) 145 | start += al 146 | 147 | if pred and true: 148 | self.yp_counter[str(pred)] += 1 149 | self.yt_counter[str(true)] += 1 150 | if np.array_equal(pred, true): 151 | self.y_em_counter[str(true)] += 1 152 | 153 | # calculate and accumulate hamming loss 154 | self.hm_acc += hamming_loss(true, pred) 155 | 156 | self.num_users += 1 157 | 158 | def get_score(self): 159 | # for divide-by-zero exception 160 | if not self.num_users: num_users = 1 161 | else: num_users = self.num_users 162 | 163 | hm_loss = self.hm_acc / num_users 164 | 165 | macP = macR = macF1 = wP = wR = wF1 = 0 166 | 167 | # macro and weighted Precision 168 | for y, cnt in self.yp_counter.items(): 169 | if y in self.y_em_counter.keys(): 170 | macP += (self.y_em_counter[y] / cnt) 171 | if y in self.yt_counter.keys(): 172 | wP += (self.y_em_counter[y] / cnt) * self.yt_counter[y] 173 | macP /= len(self.yt_counter) 174 | wP /= num_users 175 | 176 | # macro and weighted Recall 177 | for y, cnt in self.yt_counter.items(): 178 | if y in self.y_em_counter.keys(): 179 | wR += self.y_em_counter[y] 180 | macR += (self.y_em_counter[y] / cnt) 181 | macR /= len(self.yt_counter) 182 | wR /= num_users 183 | 184 | # calculate F1 using computed precision and recall. 185 | # this code includes exception. 186 | if macP == 0 and macR == 0: 187 | macF1 = 0 188 | else: 189 | macF1 = (2 * macP * macR) / (macP + macR) 190 | if wP == 0 and wR == 0: 191 | wF1 = 0 192 | else: 193 | wF1 = (2 * wP * wR) / (wP + wR) 194 | return hm_loss, macP, macR, macF1, wP, wR, wF1 195 | 196 | 197 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import numpy as np 4 | import random 5 | import os 6 | import sys 7 | import time 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import uuid 11 | 12 | from dataset import DemoAttrDataset, batchify 13 | from exp import Experiment 14 | from config import get_args 15 | 16 | def run_experiment(args, logger): 17 | 18 | # generate a data loader for validation set 19 | if args.do_validation: 20 | eval_loader = DataLoader( 21 | dataset=DemoAttrDataset( 22 | logger=logger, 23 | data_type='valid', 24 | data_path=os.path.join(args.data_path, args.dataset), 25 | task_type=args.task_type, 26 | model_type=args.model_type), 27 | batch_size=args.batch_size, 28 | shuffle=False, 29 | num_workers=args.num_workers, 30 | collate_fn=batchify) 31 | else: 32 | eval_loader = DataLoader( 33 | dataset=DemoAttrDataset( 34 | logger=logger, 35 | data_type='test', 36 | data_path=os.path.join(args.data_path, args.dataset), 37 | task_type=args.task_type, 38 | model_type=args.model_type), 39 | batch_size=args.batch_size, 40 | shuffle=False, 41 | num_workers=args.num_workers, 42 | collate_fn=batchify) 43 | 44 | train_dataset = DemoAttrDataset( 45 | logger=logger, 46 | data_type='train', 47 | data_path=os.path.join(args.data_path, args.dataset), 48 | task_type=args.task_type, 49 | model_type=args.model_type) 50 | train_loader = DataLoader( 51 | dataset=train_dataset, 52 | batch_size=args.batch_size, 53 | shuffle=True, 54 | num_workers=args.num_workers, 55 | collate_fn=batchify) 56 | 57 | exp = Experiment(args, logger) 58 | 59 | max_score = max_loss = stop_cnt = 0 60 | max_macP = max_macR = max_macF1 = max_wP = max_wR = max_wF1 = 0 61 | pre_wR = 0 62 | for epoch in range(args.max_epoch): 63 | logger.info("++ Epoch : {} ++ \n".format(epoch+1)) 64 | 65 | tr_t0 = time.clock() 66 | tr_loss, tr_hm, \ 67 | tr_macP, tr_macR, tr_macF1, tr_wP, tr_wR, tr_wF1 = \ 68 | exp.run_epoch(epoch, train_loader, args.dataset, trainable=True) 69 | tr_t1 = time.clock() 70 | 71 | eval_t0 = time.clock() 72 | eval_loss, eval_hm, \ 73 | eval_macP, eval_macR, eval_macF1, eval_wP, eval_wR, eval_wF1 = \ 74 | exp.run_epoch(epoch, eval_loader, args.dataset, trainable=False) 75 | eval_t1 = time.clock() 76 | 77 | # print training scores 78 | logger.info("### Training # Loss={:5.3f}, time:{:5.2}, Hamming={:2.3f}" 79 | .format(tr_loss, tr_t1-tr_t0, tr_hm)) 80 | logger.info("# macro - macP:{:2.3f}, macR:{:2.3f}, macF1:{:2.3f}" 81 | .format(tr_macP, tr_macR, tr_macF1)) 82 | logger.info("# weighted - wP:{:2.3f}, wR:{:2.3f}, wF1:{:2.3f} \n" 83 | .format(tr_wP, tr_wR, tr_wF1)) 84 | 85 | # print val/test scores 86 | logger.info("%%% Evaluation % Loss={:5.3f}, time:{:5.2}, Hamming={:2.3f}" 87 | .format(eval_loss, eval_t1-eval_t0, eval_hm)) 88 | logger.info("% macro - macP:{:2.3f}, macR:{:2.3f}, macF1:{:2.3f}" 89 | .format(eval_macP, eval_macR, eval_macF1)) 90 | logger.info("% weighted - wP:{:2.3f}, wR:{:2.3f}, wF1:{:2.3f} \n" 91 | .format(eval_wP, eval_wR, eval_wF1)) 92 | 93 | # early stop 94 | if max_score < eval_wF1: 95 | max_epoch = epoch+1 96 | max_score = eval_wF1 97 | max_loss = eval_loss 98 | max_hm = eval_hm 99 | max_macP = eval_macP 100 | max_macR = eval_macR 101 | max_macF1 = eval_macF1 102 | max_wP = eval_wP 103 | max_wR = eval_wR 104 | max_wF1 = eval_wF1 105 | #model_params = exp.model.item_emb.weight 106 | stop_cnt = 0 107 | else: 108 | # lr decay 109 | exp.adjust_lr() 110 | stop_cnt += 1 111 | if args.model_type == 'POP': break 112 | 113 | if stop_cnt >= 5 and not args.no_early_stop: 114 | return max_epoch, max_loss, max_hm, \ 115 | max_macP, max_macR, max_macF1, \ 116 | max_wP, max_wR, max_wF1 117 | return max_epoch, max_loss, max_hm, \ 118 | max_macP, max_macR, max_macF1, \ 119 | max_wP, max_wR, max_wF1 120 | 121 | 122 | def main(): 123 | # get all arguments 124 | args = get_args() 125 | 126 | # set random seeds 127 | #np.random.seed(args.rand_seed) 128 | #random.seed(args.rand_seed) 129 | #torch.manual_seed(args.rand_seed) 130 | 131 | # set a logger 132 | model_id = time.strftime("%Y%m%d-") + str(uuid.uuid4())[:8] 133 | formatter = logging.Formatter('%(asctime)s: %(message)s ', '%m/%d/%Y %I:%M:%S %p') 134 | logger = logging.getLogger(model_id) 135 | logger.setLevel(logging.INFO) 136 | streamHandler = logging.StreamHandler() 137 | streamHandler.setFormatter(formatter) 138 | logger.addHandler(streamHandler) 139 | if args.save_log: 140 | fileHandler = logging.FileHandler('./save/log/'+model_id+'.log') 141 | fileHandler.setFormatter(formatter) 142 | logger.addHandler(fileHandler) 143 | logger.info('log file : ./save/log/'+model_id+'.log') 144 | logger.info(args) 145 | 146 | ep, loss, hm, macP, macR, macF1, wP, wR, wF1 = run_experiment(args, logger) 147 | logger.info("[Final score - ep:{}] Loss={:5.3f}, Hamming={:2.3f}" 148 | .format(ep, loss, hm)) 149 | logger.info("[ macro ] macP:{:2.3f}, macR:{:2.3f}, macF1:{:2.3f}" 150 | .format(macP, macR, macF1)) 151 | logger.info("[ weighted ] wP:{:2.3f}, wR:{:2.3f}, wF1:{:2.3f}" 152 | .format(wP, wR, wF1)) 153 | if args.save_log: 154 | logger.info('log file : ./save/log/'+model_id+'.log') 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | 160 | -------------------------------------------------------------------------------- /model/etna.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from functools import reduce 6 | import numpy as np 7 | import sys 8 | import time 9 | 10 | from .functions import compute_loss 11 | 12 | np.set_printoptions(threshold=np.inf) 13 | torch.set_printoptions(threshold=5000) 14 | 15 | class ETNADemoPredictor(nn.Module): 16 | def __init__(self, logger, model_type, len_dict, item_emb_size, attr_len, no_cuda): 17 | super(ETNADemoPredictor, self).__init__() 18 | 19 | self.logger = logger 20 | self.model_type = model_type 21 | self.attr_len = attr_len 22 | self.item_emb_size = item_emb_size 23 | self.no_cuda = no_cuda 24 | self.optimizer = None 25 | 26 | # item embedding matrix 27 | self.item_emb = nn.Embedding(len_dict, item_emb_size, padding_idx=0) 28 | # item transformation matrix 29 | self.emb_tran = nn.ModuleList([nn.Linear(item_emb_size, item_emb_size, bias=False) for i in range(len(attr_len))]) 30 | 31 | if model_type == 'ETNA': 32 | # item attention matrix 33 | self.item_att_W = nn.ModuleList([nn.Linear(item_emb_size, 1) for i in range(len(attr_len))]) 34 | 35 | # prediction matrix for each attribute 36 | self.W_all = nn.ModuleList() 37 | for i, _ in enumerate(attr_len): 38 | self.W_all.append(nn.Linear(item_emb_size, attr_len[i], bias=False)) 39 | 40 | 41 | def forward(self, x, x_mask, y, ob, trainable=False): 42 | 43 | def get_attention(w, embed, len): 44 | # Attention score with non linear 45 | att_u = F.relu(w(embed).squeeze(2)) 46 | 47 | att_score = torch.zeros(att_u.size()) 48 | if torch.cuda.is_available() and not self.no_cuda: 49 | att_score = att_score.cuda() 50 | 51 | for i, l in enumerate(len): 52 | candi = att_u[i][:l] 53 | a = F.softmax(candi, 0) 54 | att_score[i][:l] = a 55 | 56 | attnd_emb = embed * att_score.unsqueeze(2) 57 | 58 | rep = torch.tanh(torch.sum(attnd_emb, 1)) 59 | return rep, att_score 60 | 61 | def item_attention(embed, share_emb, len): 62 | # embed : [B,K,emb] --> att_u [B,K,1] for each attribute 63 | batch = embed[0].size(0) 64 | attr_rep = [] 65 | att_scores = [] 66 | 67 | for i, attr_w in enumerate(self.item_att_W): 68 | rep, att = get_attention(attr_w, embed[i], len) 69 | attr_rep.append(rep.unsqueeze(2)) 70 | att_scores.append(att) 71 | # user_rep : [B, 3(num attr), emb] 72 | user_rep = torch.cat(attr_rep, 2).view(batch,-1) 73 | 74 | return user_rep, att_scores 75 | 76 | 77 | def attr_attention(embed, attr_att_W, len): 78 | batch = embed.size(0) 79 | if self.learning_form=='separated': 80 | user_rep = [] 81 | for attr_w in attr_att_W: 82 | rep, att_score = get_attention(attr_w, embed, len) 83 | user_rep.append(rep.unsqueeze(2)) 84 | user_rep = torch.cat(user_rep, 2).view(batch,-1) 85 | else: 86 | user_rep = get_attention(attr_att_W, embed, len) 87 | return user_rep 88 | 89 | 90 | y = torch.from_numpy(y).float() 91 | ob = torch.from_numpy(ob).float() 92 | x_len = torch.sum(x_mask.long(), 1) 93 | if torch.cuda.is_available() and not self.no_cuda: 94 | x = x.cuda() 95 | x_mask = x_mask.cuda() 96 | y = y.cuda() 97 | ob = ob.cuda() 98 | x_len = x_len.cuda() 99 | 100 | # Shared Embedding Layer 101 | embed = self.item_emb(x) 102 | 103 | embeds = [] 104 | for tran_w in self.emb_tran: 105 | # Embedding Transformation Layer 106 | attr_embed = F.relu(tran_w(embed)) 107 | embeds.append(attr_embed) 108 | 109 | if self.model_type == 'ETNA': 110 | # Task-Specific Attention Layer 111 | user_rep, att_scores = item_attention(embeds, False, x_len) 112 | else: 113 | # In ETN, user representations are computed by averaging item embedding vectors. 114 | user_rep = torch.stack(embeds) 115 | x_mask_ = x_mask.unsqueeze(0).unsqueeze(3).expand(user_rep.size()) 116 | user_rep = user_rep*x_mask_.float() 117 | user_rep = user_rep.sum(2).transpose(1,0).contiguous().view(y.size(0), -1) 118 | user_rep = user_rep / x_len.unsqueeze(1).float() 119 | # add a non-linear 120 | user_rep = torch.sigmoid(user_rep) 121 | 122 | # Prediction Layer 123 | for i, W in enumerate(self.W_all): 124 | if i == 0: 125 | W_user = W(user_rep[:,:self.item_emb_size]) 126 | else: 127 | W_user = torch.cat((W_user, W(user_rep[:,i*self.item_emb_size:(i+1)*self.item_emb_size])), 1) 128 | 129 | # all attr are observed in new-user prediction 130 | loss = 0 131 | s = e = 0 132 | for i, t in enumerate(self.attr_len): 133 | e += t 134 | lg, ls = compute_loss(W_user, y, ob, s, e, self.no_cuda) 135 | loss += ls 136 | if i == 0: 137 | logit = lg 138 | else: 139 | logit = np.concatenate((logit, lg), 1) 140 | s = e 141 | return logit, loss 142 | 143 | 144 | -------------------------------------------------------------------------------- /model/functions.py: -------------------------------------------------------------------------------- 1 | # some operations common in different models 2 | 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | 10 | def draw_neg_sample(batch_size, attr_len, label, observed): 11 | # weight [batch, all_posible] 12 | # find label index 13 | val_label = label*observed 14 | neg_samples = [] 15 | for val_l in val_label: 16 | neg_idx = [] 17 | val_y = val_l.nonzero() 18 | for attr_y in val_y: 19 | start = end = 0 20 | for n in attr_len: 21 | end = start + n 22 | if start <= attr_y < end: 23 | candidate = [i for i in range(start,end) if i!=attr_y] 24 | neg = random.sample(candidate, 1) 25 | neg_idx.append(neg) 26 | start += n 27 | neg_sample = torch.zeros(label.size(1)) 28 | for idx in neg_idx: 29 | neg_sample[idx] = 1 30 | neg_samples.append(neg_sample) 31 | return torch.stack(neg_samples) 32 | 33 | def compute_loss(W_user, label, ob, start, end, no_cuda, weight=None): 34 | W_c = W_user * ob 35 | W_c = W_c.transpose(1,0)[start:end].transpose(1,0) 36 | y = label.transpose(1,0)[start:end].transpose(1,0) 37 | 38 | prob = F.softmax(W_c, dim=1).cpu().detach().numpy() 39 | 40 | c_idx = [i for i, s in enumerate(W_c.sum(1).cpu().detach().numpy()) if s] 41 | 42 | if c_idx: 43 | c_idx = (torch.from_numpy(np.asarray(c_idx))).long() 44 | if torch.cuda.is_available() and not no_cuda: 45 | c_idx = c_idx.cuda() 46 | W_c = torch.index_select(W_c, 0, c_idx) 47 | y_c = torch.index_select(y, 0, c_idx) 48 | 49 | all_possible = [[1 if i==j else 0 for j in range(end-start)] \ 50 | for i in range(end-start)] 51 | all_possible = (torch.from_numpy(np.asarray( 52 | all_possible))).float() 53 | if torch.cuda.is_available() and not no_cuda: 54 | all_possible = all_possible.cuda() 55 | 56 | denom = 0 57 | for case in all_possible: 58 | denom += torch.sum(W_c*case, 1).exp() 59 | obj = torch.sum(W_c*y_c, 1).exp() / denom 60 | 61 | if weight is not None: 62 | weighted = torch.sum(y_c * weight, 1) 63 | loss = -torch.sum(obj.log()*weighted) 64 | else: 65 | loss = -torch.sum(obj.log()) 66 | batch_size = y_c.size(0) 67 | else: 68 | loss = torch.tensor(0., requires_grad=True).float() 69 | if torch.cuda.is_available() and not no_cuda: 70 | loss = loss.cuda() 71 | batch_size = 1 72 | 73 | return prob, loss / batch_size 74 | 75 | 76 | 77 | --------------------------------------------------------------------------------