├── README.md ├── dataset └── yelp2013 │ ├── download_yelp.sh │ └── process_data.py ├── predefined_vocab └── yelp2013 │ ├── 42939.vocab │ └── word_vectors.npy ├── requirements.txt └── src ├── main.py ├── model_src ├── CustomDataset │ ├── __pycache__ │ │ ├── aapr.cpython-36.pyc │ │ ├── polmed.cpython-36.pyc │ │ └── yelp2013.cpython-36.pyc │ ├── aapr.py │ ├── polmed.py │ └── yelp2013.py ├── SubModules │ ├── Attention.py │ ├── TextLSTM.py │ ├── TextUtils.py │ └── __pycache__ │ │ ├── Attention.cpython-36.pyc │ │ ├── TextLSTM.cpython-36.pyc │ │ └── TextUtils.cpython-36.pyc ├── layers.py └── model.py └── run.sh /README.md: -------------------------------------------------------------------------------- 1 | # BasisCustomize 2 | Categorical Metadata Representation for Customized Text Classification 3 | 4 | This PyTorch code was used in the experiments of the research paper 5 | 6 | Jihyeok Kim*, Reinald Kim Amplayo*, Kyungjae Lee, Sua Sung, Minji Seo, and Seung-won Hwang. 7 | [**Categorical-Metadata-Representation-for-Customized-Text-Classification**. _TACL_, 2019.](https://www.mitpressjournals.org/doi/pdf/10.1162/tacl_a_00263) 8 | (* equal contribution) 9 | 10 | ### Run the Code! 11 | 12 | #### Prerequisite 13 | - ```$ sudo apt-get install p7zip``` 14 | - PyTorch 1.0 15 | - Other requirements are listed in `requirements.txt`. 16 | 17 | #### 1. Preprocess Dataset 18 | 19 | We provided a shell script `dataset/yelp2013/download_yelp.sh` that downloads and preprocess the Yelp 2013 dataset. Preprocessing can be similarly done with other datasets as well (see below for download links). 20 | 21 | We also provided the vocabulary and word vectors used in our experiments (in the `predefined_vocab/yelp2013` directory) to better replicate the results reported in the paper. 22 | 23 | #### 2. Train and Test the Models 24 | 25 | The `src/main.py` trains the model using the given training and dev sets, and subsequently tests the model on the given test set. There are multiple arguments that need to be set, but the most important (and mandatory) ones are the following: 26 | 27 | - `model_type`: the type and method of customization, which can be assigned as either `BiLSTM` (no customization), or `[_basis]_cust`, where `` can be any of the following: word, encoder, attention, linear, bias. 28 | - `domain`: the dataset directory name (e.g. yelp2013) 29 | - `num_bases`: the number of bases (only required when basis customization is used) 30 | 31 | An example execution is: 32 | 33 | ~~~bash 34 | python3 -W ignore main.py \ 35 | --model_type linear_basis_cust \ 36 | --num_bases 4 \ 37 | --domain yelp2013 \ 38 | --vocab_dir ../predefined_vocab/yelp2013/42939.vocab \ 39 | --pretrained_word_em_dir ../predefined_vocab/yelp2013/word_vectors.npy \ 40 | --train_datadir ../dataset/yelp2013/processed_data/train.txt \ 41 | --dev_datadir ../dataset/yelp2013/processed_data/dev.txt \ 42 | --test_datadir ../dataset/yelp2013/processed_data/test.txt \ 43 | --meta_dim 64 \ 44 | --key_query_size 64 \ 45 | --word_dim 300 \ 46 | --state_size 256 \ 47 | --valid_step 1000 \ 48 | ~~~ 49 | 50 | ### Download the Datasets! 51 | 52 | There are three datasets used in the paper: Yelp 2013, AAPR, and PolMed. 53 | 54 | To download Yelp 2013, refer to the following link from the original authors. 55 | 56 | Although they were constructed by different authors (please refer to these links for AAPR and PolMed, we use specific data splits for the AAPR and PolMed datasets. 57 | Download our splits here. 58 | 59 | ### Cite the Paper! 60 | 61 | To cite the paper/code/data splits, please use this BibTex: 62 | 63 | ``` 64 | @article{kim2019categorical, 65 | Author = {Jihyeok Kim and Reinald Kim Amplayo and Kyungjae Lee and Sua Sung and Minji Seo and Seung-won Hwang}, 66 | Journal = {TACL}, 67 | Year = {2019}, 68 | Title = {Categorical Metadata Representation for Customized Text Classification} 69 | } 70 | ``` 71 | 72 | If using specific datasets, please also cite the original authors of the datasets: 73 | 74 | Yelp 2013 75 | ``` 76 | @inproceedings{tang2015learning, 77 | Author = {Duyu Tang and Bing Qin and Ting Liu}, 78 | Booktitle = {ACL}, 79 | Location = {Beijing, China}, 80 | Year = {2015}, 81 | Title = {Learning Semantic Representations of Users and Products for Document Level Sentiment Classification}, 82 | } 83 | ``` 84 | 85 | AAPR 86 | ``` 87 | @inproceedings{tang2015learning, 88 | Author = {Pengcheng Yang and Xu Sun and Wei Li and Shuming Ma}, 89 | Booktitle = {ACL: Short Papers}, 90 | Location = {Melbourne, Australia}, 91 | Year = {2018}, 92 | Title = {Automatic Academic Paper Rating Based on Modularized Hierarchical Convolutional Neural Network}, 93 | } 94 | ``` 95 | 96 | If there are any questions, please send Jihyeok Kim an email: zizi1532@yonsei.ac.kr 97 | 98 | -------------------------------------------------------------------------------- /dataset/yelp2013/download_yelp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | fileid="1uETI9hoZxbGchmB7D_o6WJaEfZLwx1nc" 3 | filename="dataset.7z" 4 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 5 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 6 | echo "Downloaded yelp2013 dataset" 7 | 7zr x dataset.7z -y 8 | echo "Unziped dataset.7z" 9 | mv dataset/yelp-2013-seg-20-20.dev.ss ./ 10 | mv dataset/yelp-2013-seg-20-20.train.ss ./ 11 | mv dataset/yelp-2013-seg-20-20.test.ss ./ 12 | rm -r dataset/ 13 | rm cookie 14 | rm dataset.7z 15 | echo "Preprocess dataset" 16 | python process_data.py 17 | echo "Completed" 18 | -------------------------------------------------------------------------------- /dataset/yelp2013/process_data.py: -------------------------------------------------------------------------------- 1 | import os, pickle, numpy as np, time 2 | from tqdm import tqdm 3 | 4 | if not os.path.exists("./processed_data"): 5 | os.mkdir("./processed_data") 6 | with open('../../predefined_vocab/yelp2013/42939.vocab', "r") as f: 7 | vocab = f.read().split("\n") 8 | word2idx = {w:i for i,w in enumerate(vocab)} 9 | 10 | with open('./yelp-2013-seg-20-20.train.ss') as f: 11 | data = f.read().split("\n")[:-1] 12 | user, product, rating, review = list(zip(*[x.split("\t\t") for x in data])) 13 | 14 | users = set(user); user2idx = {u:i for i,u in enumerate(users)} 15 | products = set(product); product2idx = {p:i for i,p in enumerate(products)} 16 | 17 | userid = [user2idx[x] for x in user] # convert str into integer index 18 | productid = [product2idx[x] for x in product] # convert str into integer index 19 | rating = [int(x)-1 for x in rating] # make rating start from 0 to 4 20 | i_unk = word2idx[''] 21 | review = [[word2idx.get(x, i_unk)for x in xs.split()]for xs in review] 22 | length = [len(x) for x in review] 23 | review = ["_".join([str(x) for x in xs]) for xs in review] 24 | 25 | with open('./processed_data/train.txt', 'w') as f: 26 | f.write("\n".join(["{},{},{},{},{}".format(u,p,r,l,x) for u,p,r,l,x in zip(userid, productid, rating, length, review)])) 27 | 28 | def process_data(in_path, out_path): 29 | with open(in_path, 'r') as f: 30 | data = f.read().split("\n")[:-1] 31 | user, product, rating, review = list(zip(*[x.split("\t\t") for x in data])) 32 | userid = [user2idx[x] for x in user] 33 | productid = [product2idx[x] for x in product] 34 | rating = [int(x)-1 for x in rating] 35 | i_unk = word2idx[''] 36 | review = [[word2idx.get(x, i_unk)for x in xs.split()]for xs in review] 37 | length = [len(x) for x in review] 38 | review = ["_".join([str(x) for x in xs]) for xs in review] 39 | with open(out_path, 'w') as f: 40 | f.write("\n".join(["{},{},{},{},{}".format(u,p,r,l,x) for u,p,r,l,x in zip(userid, productid, rating, length, review)])) 41 | 42 | process_data(in_path='./yelp-2013-seg-20-20.dev.ss', out_path="./processed_data/dev.txt") 43 | process_data(in_path='./yelp-2013-seg-20-20.test.ss', out_path="./processed_data/test.txt") 44 | -------------------------------------------------------------------------------- /predefined_vocab/yelp2013/word_vectors.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihyukkim-nlp/BasisCustomize/3a03e3d40c9ec64ddd364aad5c83d098256e9e25/predefined_vocab/yelp2013/word_vectors.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.29.1 2 | pytorch-ignite==0.1.2 3 | numpy==1.15.4 4 | argparse==1.1 5 | colorlog 6 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # utility packages 2 | import argparse, os, numpy as np, pickle, random 3 | import logging, colorlog 4 | from tqdm import tqdm 5 | 6 | # Pytorch packages 7 | import torch, torch.nn as nn, torch.nn.functional as F 8 | 9 | # Dataloader 10 | from model_src.CustomDataset.yelp2013 import yelp2013 11 | from model_src.CustomDataset.polmed import polmed 12 | from model_src.CustomDataset.aapr import aapr 13 | 14 | # Pytorch.Ignite Packages 15 | from ignite.engine import Events, Engine 16 | from ignite.contrib.handlers import ProgressBar 17 | 18 | # model 19 | from model_src.model import Classifier 20 | 21 | 22 | logging.disable(logging.DEBUG) 23 | colorlog.basicConfig( 24 | filename=None, 25 | level=logging.NOTSET, 26 | format="%(log_color)s[%(levelname)s:%(asctime)s]%(reset)s %(message)s", 27 | datefmt="%Y-%m-%d %H:%M:%S" 28 | ) 29 | parser = argparse.ArgumentParser() 30 | baseline_models = ['BiLSTM'] 31 | cust_models = ['word_cust', 'encoder_cust', 'attention_cust', 'linear_cust', 'bias_cust'] 32 | basis_cust_models = ['word_basis_cust', 'encoder_basis_cust', 'attention_basis_cust', 'linear_basis_cust', 'bias_basis_cust'] 33 | model_choices = baseline_models + cust_models + basis_cust_models 34 | parser.add_argument("--random_seed", type=int, default=33) 35 | parser.add_argument("--model_type", choices=model_choices, help="Give model type.") 36 | parser.add_argument("--domain", type=str, choices=['yelp2013', 'polmed', 'aapr'], default="yelp2013") 37 | parser.add_argument("--num_bases", type=int, default=0) 38 | parser.add_argument("--vocab_dir", type=str) 39 | parser.add_argument("--train_datadir", type=str, default="./processed_data/flat_data.p") 40 | parser.add_argument("--dev_datadir", type=str, default="./processed_data/flat_data.p") 41 | parser.add_argument("--test_datadir", type=str, default="./processed_data/flat_data.p") 42 | parser.add_argument("--word_dim", type=int, default=300, help="word vector dimension") 43 | parser.add_argument("--state_size", type=int, default=256, help="BiLSTM hidden dimension") 44 | parser.add_argument("--meta_dim", type=int, default=64, help="meta embedding latent vector dimension") 45 | parser.add_argument("--key_query_size", type=int, default=64, help="key and query dimension for meta context") 46 | parser.add_argument("--batch_size", type=int, default=32) 47 | parser.add_argument("--valid_step", type=int, default=1000, help="evaluation step using dev set") 48 | parser.add_argument("--epoch", type=int, default=10) 49 | parser.add_argument("--device", type=str, default="cuda") 50 | parser.add_argument("--pretrained_word_em_dir", type=str, default="") 51 | parser.add_argument("--max_grad_norm", type=float, default=3.0) 52 | 53 | args = parser.parse_args() 54 | if 'basis' in args.model_type: 55 | if args.num_bases==0: 56 | print(" must input number of bases (\"--num_bases\") for basis_cust model type") 57 | print(" e.g. python main.py word_basis_cust --num_bases 3") 58 | exit() 59 | # Manual Random Seed 60 | random.seed(args.random_seed) 61 | np.random.seed(args.random_seed) 62 | torch.manual_seed(args.random_seed) 63 | torch.cuda.manual_seed(args.random_seed) 64 | torch.backends.cudnn.deterministic=True 65 | 66 | class modelClassifier: 67 | def __init__(self): 68 | 69 | # Ignite engine 70 | self.engine = None 71 | self._engine_ready() 72 | 73 | # Dataloader 74 | domain_dataloader = { 75 | 'yelp2013':yelp2013(args), 76 | 'polmed':polmed(args), 77 | 'aapr':aapr(args), 78 | }[args.domain] 79 | self.train_dataloader = domain_dataloader.train_dataloader 80 | self.dev_dataloader = domain_dataloader.dev_dataloader 81 | self.test_dataloader = domain_dataloader.test_dataloader 82 | 83 | # MODEL DECLARATION 84 | self.model = Classifier(args).to(args.device) 85 | print("<< Model Configuration >>") 86 | print(self.model) 87 | print("*"*50) 88 | with open("./ModelDescription.txt", "w") as f: 89 | f.write(repr(self.model)) 90 | 91 | # OPTIMIZER DECLARATION 92 | parameters = filter(lambda p: p.requires_grad, self.model.parameters()) 93 | self.optimizer = torch.optim.Adadelta(parameters, lr=1.0, rho=0.9, eps=1e-6) 94 | self.criterion = { 95 | "yelp2013":F.cross_entropy, 96 | "polmed":F.cross_entropy, 97 | "aapr":F.binary_cross_entropy, 98 | }[args.domain] 99 | 100 | # PARAM SAVE DIRECTORY 101 | self.param_dir = './save_param/' 102 | if not os.path.exists(self.param_dir): os.mkdir(self.param_dir) # ./save_param 103 | self.param_dir = os.path.join(self.param_dir,args.domain) 104 | if not os.path.exists(self.param_dir): os.mkdir(self.param_dir) # ./save_param/{domain} 105 | self.param_dir = os.path.join(self.param_dir, args.model_type) 106 | if 'basis_cust' in args.model_type: 107 | self.param_dir += '({}).pth'.format(args.num_bases) 108 | else: 109 | self.param_dir += '.pth' 110 | 111 | def _init_param(self, model): 112 | colorlog.critical("[Init General Parameter] >> xavier_uniform_") 113 | for p in model.parameters(): 114 | if p.requires_grad: 115 | if len(p.shape)>1: 116 | nn.init.xavier_uniform_(p) 117 | else: 118 | nn.init.constant_(p, 0) 119 | if args.pretrained_word_em_dir: 120 | colorlog.critical("[Pretrained Word em loaded] from {}".format(args.pretrained_word_em_dir)) 121 | word_em = np.load(args.pretrained_word_em_dir) 122 | model.word_em_weight.data.copy_(torch.from_numpy(word_em)) 123 | 124 | def _init_meta_param(self, model): 125 | colorlog.critical("[Init Meta Parameter] >> uniform_ [-0.01, 0.01]") 126 | for name, param in model.meta_param_manager.state_dict().items(): 127 | colorlog.info("{} intialized".format(name)) 128 | nn.init.uniform_(param, -0.01, 0.01) 129 | 130 | def _engine_ready(self): 131 | colorlog.info("[Ignite Engine Ready]") 132 | self.engine = Engine(self._update) 133 | ProgressBar().attach(self.engine) # support tqdm progress bar 134 | self.engine.add_event_handler(Events.STARTED, self._started) 135 | self.engine.add_event_handler(Events.COMPLETED, self._completed) 136 | self.engine.add_event_handler(Events.EPOCH_STARTED, self._epoch_started) 137 | self.engine.add_event_handler(Events.EPOCH_COMPLETED, self._epoch_completed) 138 | self.engine.add_event_handler(Events.ITERATION_STARTED, self._iteration_started) 139 | self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._iteration_completed) 140 | 141 | def _update(self, engine, sample_batch): 142 | target, kwinputs = sample_batch 143 | # Inference 144 | predict = self.model(**kwinputs) 145 | # Loss & Update 146 | loss = self.criterion(input=predict, target=target) 147 | loss.backward() 148 | nn.utils.clip_grad_norm_(self.model.parameters(), args.max_grad_norm) 149 | # Loss logging 150 | self.engine.state.train_loss += loss.item()/args.valid_step 151 | return loss.item() # engine.state.output 152 | 153 | def _iteration_started(self, engine): 154 | self.model.zero_grad() 155 | self.optimizer.zero_grad() 156 | def _iteration_completed(self, engine): 157 | self.optimizer.step() 158 | # Evaluation 159 | if self.engine.state.iteration % args.valid_step == 0: 160 | dev_acc, dev_rmse = self.evaluation(self.dev_dataloader) 161 | if self.engine.state.best_dev_acc < dev_acc: 162 | self.engine.state.best_dev_acc = dev_acc 163 | self.engine.state.dev_rmse = dev_rmse 164 | torch.save(self.model.state_dict(), self.param_dir) 165 | colorlog.info(""" 166 | Model Type : {} 167 | EPOCH {} =====> TRAIN LOSS : {:.4f} 168 | VALIDATION ACCURACY : {:2.2f}% =====> BEST {:2.2f}% 169 | VALIDATION RMSE : {:2.4f}""".format( 170 | args.model_type, 171 | self.engine.state.epoch, self.engine.state.train_loss, 172 | dev_acc*100, self.engine.state.best_dev_acc*100, 173 | dev_rmse, 174 | )) 175 | self.engine.state.train_loss = 0 176 | def _started(self, engine): 177 | # Model Initialization 178 | self._init_param(self.model) 179 | if 'cust' in args.model_type: 180 | self._init_meta_param(self.model) 181 | self.model.train() 182 | self.model.zero_grad() 183 | self.optimizer.zero_grad() 184 | # ignite engine state intialization 185 | self.engine.state.best_dev_acc = -1 186 | self.engine.state.dev_rmse = -1 187 | self.engine.state.train_loss = 0 188 | def _completed(self, engine): 189 | colorlog.info("*"*20+" Training is DONE" + "*"*20) 190 | def _epoch_started(self, engine): 191 | colorlog.info('>' * 50) 192 | colorlog.info('EPOCH: {}'.format(self.engine.state.epoch)) 193 | def _epoch_completed(self, engine): 194 | pass 195 | 196 | def evaluation(self, dataloader): 197 | colorlog.info(" EVALUATION ... ") 198 | # HISTORY DECLARATION 199 | num_data = len(dataloader.dataset) 200 | predicted_label = np.empty(num_data).astype(np.int64) 201 | target_label = np.empty(num_data).astype(np.int64) 202 | self.model.eval() 203 | with torch.no_grad(): 204 | batch_size = dataloader.batch_size 205 | for i_batch, sample_batch in enumerate(dataloader): 206 | target_batch, kwinputs = sample_batch 207 | predict_batch = self.model(**kwinputs) 208 | 209 | # ACCURACY, RMSE 210 | _, predict_batch = torch.max(predict_batch, dim=1) 211 | predicted_label[i_batch*batch_size:(i_batch+1)*batch_size] = predict_batch.cpu().data.numpy() 212 | target_label[i_batch*batch_size:(i_batch+1)*batch_size] = target_batch.cpu().data.numpy() 213 | 214 | acc = (predicted_label==target_label).mean() 215 | rmse = ((predicted_label-target_label)**2).mean()**0.5 216 | self.model.train() 217 | return acc, rmse 218 | 219 | def train(self): 220 | self.engine.run(self.train_dataloader, max_epochs=args.epoch) 221 | def test(self): 222 | # LOAD PRETRAINED PARAMETERS 223 | state_dict = torch.load(self.param_dir) 224 | self.model.load_state_dict(state_dict) 225 | 226 | # EVALUATION 227 | test_acc, test_rmse = self.evaluation(self.test_dataloader) 228 | 229 | # SAVE FINAL EVALUATION PERFORMANCE 230 | colorlog.info(""" 231 | " Evaluation with test data set " 232 | << Model Type : {} >> 233 | TEST ACCURACY : {:2.2f}% 234 | TEST RMSE : {:2.4f} 235 | 236 | DEV ACCURACY : {:2.2f}% 237 | DEV RMSE : {:2.4f} 238 | """.format( 239 | args.model_type, 240 | test_acc*100, 241 | test_rmse, 242 | self.engine.state.best_dev_acc*100, 243 | self.engine.state.dev_rmse, 244 | )) 245 | return 246 | 247 | classifier = modelClassifier() 248 | try: 249 | classifier.train() 250 | except KeyboardInterrupt: 251 | print("KeyboardInterrupt occurs") 252 | print("Start Test Evaluation") 253 | 254 | classifier.test() 255 | -------------------------------------------------------------------------------- /src/model_src/CustomDataset/__pycache__/aapr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihyukkim-nlp/BasisCustomize/3a03e3d40c9ec64ddd364aad5c83d098256e9e25/src/model_src/CustomDataset/__pycache__/aapr.cpython-36.pyc -------------------------------------------------------------------------------- /src/model_src/CustomDataset/__pycache__/polmed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihyukkim-nlp/BasisCustomize/3a03e3d40c9ec64ddd364aad5c83d098256e9e25/src/model_src/CustomDataset/__pycache__/polmed.cpython-36.pyc -------------------------------------------------------------------------------- /src/model_src/CustomDataset/__pycache__/yelp2013.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihyukkim-nlp/BasisCustomize/3a03e3d40c9ec64ddd364aad5c83d098256e9e25/src/model_src/CustomDataset/__pycache__/yelp2013.cpython-36.pyc -------------------------------------------------------------------------------- /src/model_src/CustomDataset/aapr.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | 3 | class aapr(Dataset): 4 | def __init__(self, path): 5 | self.data = self.read_data(path) 6 | def __getitem__(self, index): return self.data[index] 7 | def __len__(self): return len(self.data) 8 | def custom_collate_fn(self, sample_batch): 9 | pass 10 | def read_data(self, path): 11 | pass -------------------------------------------------------------------------------- /src/model_src/CustomDataset/polmed.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | 3 | class polmed(Dataset): 4 | def __init__(self, path): 5 | self.data = self.read_data(path) 6 | def __getitem__(self, index): return self.data[index] 7 | def __len__(self): return len(self.data) 8 | def custom_collate_fn(self, sample_batch): 9 | pass 10 | def read_data(self, path): 11 | pass -------------------------------------------------------------------------------- /src/model_src/CustomDataset/yelp2013.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import pickle, numpy as np, os 4 | from model_src.SubModules.TextUtils import text_padding 5 | 6 | class CustomDataset(Dataset): 7 | def __init__(self, data): 8 | self.data = data 9 | def __getitem__(self, index): return self.data[index] 10 | def __len__(self): return len(self.data) 11 | class yelp2013(): 12 | def __init__(self, args): 13 | self.device = args.device 14 | 15 | # load vocabulary 16 | with open(args.vocab_dir, "r") as f: 17 | vocab = f.read().split('\n') 18 | self.word2idx = {w:i for i, w in enumerate(vocab)} 19 | self.idx2word = {i:w for i, w in enumerate(vocab)} 20 | self._ipad = self.word2idx[''] 21 | args._ipad = self._ipad 22 | args.vocab_size = len(self.word2idx) 23 | 24 | self.train_dataloader = DataLoader( 25 | dataset=CustomDataset( 26 | data=self.cust_read_data(args.train_datadir) if 'cust' in args.model_type \ 27 | else self.basic_read_data(args.train_datadir)), 28 | batch_size=args.batch_size, 29 | collate_fn=self.cust_collate_fn if 'cust' in args.model_type else self.basic_collate_fn, 30 | shuffle=True, 31 | ) 32 | self.dev_dataloader = DataLoader( 33 | dataset=CustomDataset( 34 | data=self.cust_read_data(args.dev_datadir) if 'cust' in args.model_type \ 35 | else self.basic_read_data(args.train_datadir)), 36 | batch_size=args.batch_size, 37 | collate_fn=self.cust_collate_fn if 'cust' in args.model_type else self.basic_collate_fn, 38 | shuffle=False, 39 | ) 40 | self.test_dataloader = DataLoader( 41 | dataset=CustomDataset( 42 | data=self.cust_read_data(args.test_datadir) if 'cust' in args.model_type \ 43 | else self.basic_read_data(args.train_datadir)), 44 | batch_size=args.batch_size, 45 | collate_fn=self.cust_collate_fn if 'cust' in args.model_type else self.basic_collate_fn, 46 | shuffle=False, 47 | ) 48 | 49 | args.num_label = 5 # rating from 0 to 4 50 | # (name of meta unit, number of meta unit) 51 | args.meta_units = [('user', 1631), ('product', 1633)] 52 | def cust_collate_fn(self, sample_batch): 53 | user, product, rating, length, review = list(zip(*sample_batch)) 54 | review = text_padding(text=review, length=length, padding_idx=self._ipad) 55 | mask = [[1]*l + [0]*(review.shape[1]-l) for l in length] 56 | user = torch.tensor(user).to(self.device) 57 | product = torch.tensor(product).to(self.device) 58 | rating = torch.tensor(rating).to(self.device) 59 | length = torch.tensor(length).to(self.device) 60 | review = torch.tensor(review).to(self.device) 61 | mask = torch.tensor(mask).to(self.device) 62 | return rating, {"review":review, "mask":mask, "length":length, "user":user, "product":product} 63 | def basic_collate_fn(self, sample_batch): 64 | rating, length, review = list(zip(*sample_batch)) 65 | review = text_padding(text=review, length=length, padding_idx=self._ipad) 66 | mask = [[1]*l + [0]*(review.shape[1]-l) for l in length] 67 | rating = torch.tensor(rating).to(self.device) 68 | length = torch.tensor(length).to(self.device) 69 | review = torch.tensor(review).to(self.device) 70 | mask = torch.tensor(mask).to(self.device) 71 | return rating, {"review":review, "mask":mask, "length":length} 72 | 73 | def cust_read_data(self, path): 74 | with open(path, 'r') as f: 75 | data = f.read().split("\n") 76 | user, product, rating, length, review = list(zip(*[x.split(",") for x in data])) 77 | user = [int(x) for x in user] 78 | product = [int(x) for x in product] 79 | rating = [int(x) for x in rating] 80 | length = [int(x) for x in length] 81 | review = [[int(x) for x in xs.split("_")] for xs in review] 82 | return list(zip(user, product, rating, length, review)) 83 | def basic_read_data(self, path): 84 | with open(path, 'r') as f: 85 | data = f.read().split("\n") 86 | _, _, rating, length, review = list(zip(*[x.split(",") for x in data])) 87 | rating = [int(x) for x in rating] 88 | length = [int(x) for x in length] 89 | review = [[int(x) for x in xs.split("_")] for xs in review] 90 | return list(zip(rating, length, review)) -------------------------------------------------------------------------------- /src/model_src/SubModules/Attention.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn, torch.nn.functional as F 2 | from torch.autograd import Variable 3 | def masked_softmax(logits, mask, dim=1, epsilon=1e-5): 4 | """ logits, mask has same size """ 5 | masked_logits = logits.masked_fill(mask == 0, -1e9) 6 | max_logits = torch.max(masked_logits, dim=dim, keepdim=True)[0] 7 | exps = torch.exp(masked_logits-max_logits) 8 | masked_exps = exps * mask.float() 9 | masked_sums = masked_exps.sum(dim, keepdim=True) + epsilon 10 | return masked_exps/masked_sums 11 | 12 | class AttentionWithoutQuery(nn.Module): 13 | def __init__(self, encoder_dim, device=torch.device('cpu')): 14 | super(AttentionWithoutQuery, self).__init__() 15 | self.encoder_dim=encoder_dim 16 | self.device=device 17 | def forward(self, encoder_dim, length=None): 18 | """ 19 | encoded_vecs: batch_size, max_length, encoder_hidden_dim 20 | (optional) length: list of lengths of encoded_vecs 21 | > if length is given then perform masked_softmax 22 | > None indicate fixed number of length (all same length in batch) 23 | """ 24 | pass 25 | class LinearAttentionWithoutQuery(AttentionWithoutQuery): 26 | def __init__(self, encoder_dim, device=torch.device('cpu')): 27 | super().__init__(encoder_dim, device) 28 | self.z = nn.Linear(self.encoder_dim, 1, bias=False) 29 | def forward(self, encoded_vecs, mask=None): 30 | logits = self.z(encoded_vecs).squeeze(dim=2) 31 | if (mask is not None): 32 | # batch_size, max_length 33 | attention = masked_softmax(logits=logits, mask=mask, dim=1) 34 | else: 35 | # batch_size, max_length 36 | attention = F.softmax(logits, dim=1) 37 | return ( 38 | torch.bmm(attention.unsqueeze(dim=1), encoded_vecs).squeeze(dim=1), 39 | attention 40 | ) 41 | class MLPAttentionWithoutQuery(AttentionWithoutQuery): 42 | def __init__(self, encoder_dim, device=torch.device('cpu')): 43 | """ 44 | ev_t: encoded_vecs 45 | u_t = tanh(W*(ev_t)+b) 46 | a_t = softmax(v^T u_t) 47 | """ 48 | super(MLPAttentionWithoutQuery, self).__init__(encoder_dim, device) 49 | self.W = nn.Sequential( 50 | nn.Linear( 51 | self.encoder_dim, 52 | self.encoder_dim 53 | ), 54 | nn.Tanh(), 55 | nn.Linear(self.encoder_dim, 1, bias=False) 56 | ) 57 | def forward(self, encoded_vecs, length=None): 58 | 59 | # batch_size, max_length 60 | logits = self.W(encoded_vecs).squeeze(dim=2) 61 | if (length is not None): 62 | N, L = logits.size() 63 | mask = [[1]*l + [0]*(L-l) for l in length] 64 | mask = torch.LongTensor(mask).to(self.device) 65 | # batch_size, max_length 66 | attention = masked_softmax(logits=logits, mask=mask, dim=1) 67 | else: 68 | # batch_size, max_length 69 | attention = F.softmax(logits, dim=1) 70 | 71 | # batch_size, encoder_dim 72 | return ( 73 | torch.bmm(attention.unsqueeze(dim=1), encoded_vecs).squeeze(dim=1), 74 | attention 75 | ) 76 | 77 | class AttentionWithQuery(nn.Module): 78 | """ AttentionWithQuery 79 | e.g. Language Translation, SA with meta information 80 | """ 81 | def __init__(self, encoder_dim, query_dim, device=torch.device('cpu')): 82 | super(AttentionWithQuery, self).__init__() 83 | self.encoder_dim=encoder_dim 84 | self.query_dim=query_dim 85 | self.device=device 86 | def forward(self, encoded_sequence, length=None): pass 87 | class LinearAttentionWithQuery(AttentionWithQuery): 88 | def __init__(self, encoder_dim, query_dim, device=torch.device('cpu')): 89 | super().__init__(encoder_dim, query_dim, device) 90 | def forward(self, encoded_vecs, query, mask=None): 91 | logits = (encoded_vecs*query).sum(dim=2) 92 | if (mask is not None): 93 | # batch_size, max_length 94 | attention = masked_softmax(logits=logits, mask=mask, dim=1) 95 | else: 96 | # batch_size, max_length 97 | attention = F.softmax(logits, dim=1) 98 | return ( 99 | torch.bmm(attention.unsqueeze(dim=1), encoded_vecs).squeeze(dim=1), 100 | attention 101 | ) 102 | class MLPAttentionWithQuery(AttentionWithQuery): 103 | def __init__(self, encoder_dim, query_dim, device=torch.device('cpu')): 104 | """ ev_t: encoded_vecs, q_t: query 105 | u_t = tanh(W*(ev_t, q_t)+b) 106 | a_t = softmax(v^T u_t) 107 | """ 108 | super(MLPAttentionWithQuery, self).__init__(encoder_dim, query_dim, device) 109 | self.W = nn.Sequential( 110 | nn.Linear( 111 | self.encoder_dim+self.query_dim, 112 | self.encoder_dim 113 | ), 114 | nn.Tanh(), 115 | nn.Linear(self.encoder_dim, 1, bias=False) 116 | ) 117 | for p in self.parameters(): 118 | if p.dim()>1: 119 | nn.init.xavier_normal_(p) 120 | def forward(self, encoded_vecs, query, length=None): 121 | """ 122 | encoded_vecs: batch_size, max_length, encoder_hidden_dim 123 | query: batch_size, max_length, query_dim 124 | (optional) length: list of lengths of encoded_vecs 125 | > if length is given then perform masked_softmax 126 | > None indicate fixed number of length (all same length in batch) 127 | """ 128 | # in case, query.size()=(batch_size, query_dim) 129 | # indicates query is length independent, e.g. attention for classification 130 | if query.dim()==2: 131 | query = query.unsqueeze(dim=1).repeat(1, encoded_vecs.size(1), 1) 132 | # batch_size, max_length 133 | logits = self.W(torch.cat([encoded_vecs, query], dim=2)).squeeze(dim=2) 134 | if (length is None): 135 | # batch_size, max_length 136 | attention = F.softmax(logits, dim=1) 137 | else: 138 | N, L = logits.size() 139 | mask = [[1]*l + [0]*(L-l) for l in length] 140 | mask = Variable(torch.LongTensor(mask)).to(self.device) 141 | # batch_size, max_length 142 | attention = masked_softmax(logits=logits, mask=mask, dim=1) 143 | 144 | # batch_size, encoder_dim 145 | return ( 146 | torch.bmm(attention.unsqueeze(dim=1), encoded_vecs).squeeze(dim=1), 147 | attention 148 | ) 149 | 150 | -------------------------------------------------------------------------------- /src/model_src/SubModules/TextLSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import itertools # for MultiDocumentTextLSTM 6 | class TextLSTM(nn.Module): 7 | def __init__(self, 8 | input_size, hidden_size, 9 | num_layers=1, bidirectional=False, dropout=0, bias=True, 10 | batch_first=True, 11 | device='cpu'): 12 | super(TextLSTM, self).__init__() 13 | self.batch_first=batch_first 14 | self.bidirectional=bidirectional 15 | self.num_layers = num_layers 16 | self.hidden_size = hidden_size 17 | self.device = device 18 | if self.num_layers==1: dropout=0 19 | self.rnn = nn.LSTM( 20 | input_size=input_size, 21 | hidden_size=hidden_size, 22 | num_layers=num_layers, 23 | bias=bias, 24 | batch_first=batch_first, 25 | bidirectional=bidirectional, 26 | dropout=dropout) 27 | def zero_init(self, batch_size): 28 | nd = 1 if not self.bidirectional else 2 29 | h0 = Variable(torch.zeros((self.num_layers*nd, batch_size, self.hidden_size))).to(self.device) 30 | c0 = Variable(torch.zeros((self.num_layers*nd, batch_size, self.hidden_size))).to(self.device) 31 | return (h0, c0) 32 | def forward(self, inputs, length, rnn_init=None, is_sorted=False): 33 | if rnn_init is None: 34 | rnn_init = self.zero_init(inputs.size(0)) 35 | if not is_sorted: 36 | sort_idx = torch.sort(-length)[1] 37 | inputs = inputs[sort_idx] 38 | length = length[sort_idx] 39 | # h0: size=(num_layers*bidriectional, batch_size, hidden_dim) 40 | # c0: size=(num_layers*bidriectional, batch_size, hidden_dim) 41 | h0, c0 = rnn_init 42 | rnn_init = (h0[:, sort_idx, :], c0[:, sort_idx, :]) 43 | unsort_idx = torch.sort(sort_idx)[1] 44 | x_pack = nn.utils.rnn.pack_padded_sequence(inputs, length, batch_first=self.batch_first) 45 | output, (hn, cn) = self.rnn(x_pack, rnn_init) 46 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) 47 | if not is_sorted: 48 | output = output[unsort_idx] 49 | hn = hn[:, unsort_idx, :] 50 | cn = cn[:, unsort_idx, :] 51 | # batch_size, length, hidden_size 52 | # batch_size, num_layers*bidirectional, hidden_size 53 | return output, (hn,cn) 54 | 55 | class MultiDocumentTextLSTM(TextLSTM): 56 | def __init__(self, *args, **kwargs): 57 | super(MultiDocumentTextLSTM, self).__init__(*args, **kwargs) 58 | def txt_pad(self, text, length, padidx=0): 59 | """ 60 | text: list of token indices 61 | length: list of length of tokens (same size with text) 62 | return: padded text tokens, ndarray, np.int64 63 | """ 64 | maxlen = max(length) 65 | padded_sentences = [] 66 | for l, x in zip(length, text): 67 | if l pack documents from (N, D, L, word_dim) to (N*D, L, word_dim) 78 | > pass LSTM 79 | > unpack documents from (N*D, L, hidden_dim) to (N,D,L,hidden_dim) 80 | """ 81 | batch_size = inputs.size(0) 82 | max_num_docs = inputs.size(1) 83 | max_length = inputs.size(2) 84 | word_dim = inputs.size(3) 85 | # 1. Get batch index for unpacking 86 | # batch_index = itertools.chain(*[[i]*d for i, d in enumerate(num_docs)]) 87 | 88 | # 2. Pack documents 89 | length_padded = self.txt_pad(length, num_docs) 90 | length_merged = itertools.chain(*length_padded) # N*D 91 | inputs_merged = itertools.chain(*inputs) # N*D, L, word_dim 92 | 93 | # 3. Pass LSTM 94 | # output: N*D, L, hidden_dim 95 | output, (hn,cn) = super(MultiDocumentTextLSTM, self)(self, inputs_merged, length_merged, rnn_init, is_sorted) 96 | 97 | # 4. Unpack documents 98 | output = output.view(batch_size, max_num_docs, max_length, -1) 99 | hn = hn.view(hn.size(0), batch_size, max_num_docs, hn.size(2), -1) 100 | cn = cn.view(cn.size(0), batch_size, max_num_docs, cn.size(2), -1) 101 | return output, (hn,cn) -------------------------------------------------------------------------------- /src/model_src/SubModules/TextUtils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.nn.utils.rnn import pack_padded_sequence as pack 3 | import torch, torch.nn as nn, torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | def text_padding(text, length, padding_idx, eos_idx=None, sos_idx=None): 7 | """ 8 | text: list of token indices 9 | length: list of length of tokens (same size with text) 10 | return: padded text tokens, ndarray, np.int64 11 | """ 12 | maxlen = max(length) 13 | num_data = len(text) 14 | _append_length = 0 15 | st = 0 16 | if eos_idx is not None: 17 | _append_length+=1 18 | if sos_idx is not None: 19 | _append_length+= 1 20 | st = 1 21 | if padding_idx: 22 | padded_sentences = np.zeros((num_data, maxlen+_append_length), dtype=np.int64) 23 | else: 24 | padded_sentences = np.zeros((num_data, maxlen+_append_length), dtype=np.int64)+padding_idx 25 | if sos_idx is not None: 26 | padded_sentences[:, 0] = sos_idx 27 | 28 | if eos_idx is not None: 29 | for i, (l, x) in enumerate(zip(length, text)): 30 | padded_sentences[i][st:st+l] = x 31 | padded_sentences[i][st+l] = eos_idx 32 | else: 33 | for i, (l, x) in enumerate(zip(length, text)): 34 | padded_sentences[i][st:st+l] = x 35 | 36 | return padded_sentences 37 | -------------------------------------------------------------------------------- /src/model_src/SubModules/__pycache__/Attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihyukkim-nlp/BasisCustomize/3a03e3d40c9ec64ddd364aad5c83d098256e9e25/src/model_src/SubModules/__pycache__/Attention.cpython-36.pyc -------------------------------------------------------------------------------- /src/model_src/SubModules/__pycache__/TextLSTM.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihyukkim-nlp/BasisCustomize/3a03e3d40c9ec64ddd364aad5c83d098256e9e25/src/model_src/SubModules/__pycache__/TextLSTM.cpython-36.pyc -------------------------------------------------------------------------------- /src/model_src/SubModules/__pycache__/TextUtils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihyukkim-nlp/BasisCustomize/3a03e3d40c9ec64ddd364aad5c83d098256e9e25/src/model_src/SubModules/__pycache__/TextUtils.cpython-36.pyc -------------------------------------------------------------------------------- /src/model_src/layers.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn, torch.nn.functional as F 2 | from .SubModules.Attention import LinearAttentionWithQuery, LinearAttentionWithoutQuery 3 | from .SubModules.TextLSTM import TextLSTM 4 | 5 | class BasicWordEmb(nn.Module): 6 | def __init__(self, args): 7 | super().__init__() 8 | self.word_em = nn.Embedding(args.vocab_size, args.word_dim, padding_idx=args._ipad) 9 | def forward(self, review, **kwargs): 10 | return self.word_em(review) 11 | class CustWordEmb(nn.Module): 12 | def __init__(self, args, meta_param_manager): 13 | super().__init__() 14 | self.word_dim = args.word_dim 15 | self.word_em = nn.Embedding(args.vocab_size, args.word_dim, padding_idx=args._ipad) 16 | for name, num_meta in args.meta_units: 17 | setattr(self, "num_"+name, num_meta) 18 | # word embedding transformation parameters 19 | setattr(self, name, nn.Embedding(num_meta, args.word_dim*args.word_dim)) 20 | meta_param_manager.register("CustWordEmb."+name, getattr(self, name).weight) 21 | def forward(self, review, **kwargs): 22 | x = self.word_em(review) 23 | r = None 24 | for name, idx in kwargs.items(): 25 | v=getattr(self, name)(idx).view(x.shape[0], self.word_dim, self.word_dim) 26 | rv = torch.bmm(x, v) 27 | if (r is not None): r += rv 28 | else: r = rv 29 | x = x + torch.tanh(r) # residual addition 30 | return x 31 | class BasisCustWordEmb(nn.Module): 32 | def __init__(self, args, meta_param_manager): 33 | super().__init__() 34 | self.word_dim = args.word_dim 35 | self.word_em = nn.Embedding(args.vocab_size, args.word_dim, padding_idx=args._ipad) 36 | for name, num_meta in args.meta_units: 37 | setattr(self, "num_"+name, num_meta) 38 | # word embedding transformation parameters 39 | setattr(self, name, nn.Embedding(num_meta, args.meta_dim)) 40 | meta_param_manager.register("BasisCustWordEmb."+name, getattr(self, name).weight) 41 | self.P = nn.Sequential( 42 | nn.Linear(args.meta_dim*len(args.meta_units), args.key_query_size), # From MetaData to Query 43 | nn.Tanh(), 44 | nn.Linear(args.key_query_size, args.num_bases, bias=False), # Calculate Weights of each Basis: Key & Query Inner-product 45 | nn.Softmax(dim=1), 46 | nn.Linear(args.num_bases, args.word_dim*args.word_dim), # Weighted Sum of Bases 47 | ) 48 | def forward(self, review, **kwargs): 49 | x = self.word_em(review) 50 | query = torch.cat( 51 | [getattr(self, name)(idx) 52 | for name, idx in kwargs.items()], dim=1) 53 | t = self.P(query).view(x.shape[0], self.word_dim, self.word_dim) 54 | r = torch.bmm(x, t) 55 | return x+torch.tanh(r) 56 | 57 | 58 | class BasicBiLSTM(nn.Module): 59 | def __init__(self, args): 60 | super().__init__() 61 | self.LSTM = TextLSTM( 62 | input_size=args.word_dim, 63 | hidden_size=args.state_size//2, # //2 for bidirectional 64 | bidirectional=True, 65 | device=args.device 66 | ) 67 | def forward(self, x, length, **kwargs): 68 | return self.LSTM(inputs=x, length=length)[0] 69 | class BasisCustBiLSTM(nn.Module): 70 | def __init__(self, args, meta_param_manager): 71 | super().__init__() 72 | self.device = args.device 73 | self.num_bases = args.num_bases 74 | self.each_state = args.state_size//2 75 | self.word_dim = args.word_dim 76 | for name, num_meta in args.meta_units: 77 | setattr(self, "num_"+name, num_meta) 78 | setattr(self, name, nn.Embedding(num_meta, args.meta_dim)) 79 | meta_param_manager.register("BasisCustBiLSTM."+name, getattr(self, name).weight) 80 | self.weight_ih_l0 = nn.Parameter(torch.zeros(args.num_bases, args.state_size*2, args.word_dim)) 81 | self.weight_hh_l0 = nn.Parameter(torch.zeros(args.num_bases, args.state_size*2, args.state_size//2)) 82 | self.bias_l0 = nn.Parameter(torch.zeros(args.num_bases, args.state_size*2)) 83 | self.weight_ih_l0_reverse = nn.Parameter(torch.zeros(args.num_bases, args.state_size*2, args.word_dim)) 84 | self.weight_hh_l0_reverse = nn.Parameter(torch.zeros(args.num_bases, args.state_size*2, args.state_size//2)) 85 | self.bias_l0_reverse = nn.Parameter(torch.zeros(args.num_bases, args.state_size*2)) 86 | self.P = nn.Sequential( 87 | nn.Linear(args.meta_dim*len(args.meta_units), args.key_query_size), # From MetaData to Query 88 | nn.Tanh(), 89 | nn.Linear(args.key_query_size, args.num_bases, bias=False), # Calculate Weights of each Basis: Key & Query Inner-product 90 | nn.Softmax(dim=1), 91 | ) 92 | def forward(self, x, length, **kwargs): 93 | # low-rank factorization 94 | # c_batch = self.encoder_coefficient(usr_batch, prd_batch) # batch_size, num_bases 95 | query = torch.cat([getattr(self, name)(idx) 96 | for name, idx in kwargs.items()], dim=1) 97 | c_batch = self.P(query) 98 | num_bases = self.num_bases 99 | cell_size = self.each_state 100 | input_size = self.word_dim 101 | 102 | batch_size = x.size(0) 103 | maxlength = torch.max(length).item() 104 | 105 | # make variable for backward path 106 | reverse_idx = torch.arange(maxlength-1, -1, -1).to(self.device) 107 | # reverse_idx = torch.from_numpy(reverse_idx) 108 | x_reverse = x[:, reverse_idx, :] 109 | 110 | weight_ih_l0 = torch.mm(c_batch , self.weight_ih_l0.view(num_bases, -1)).view(batch_size, cell_size*4, input_size) # batch_size, cell_size*4, input_size 111 | weight_hh_l0 = torch.mm(c_batch , self.weight_hh_l0.view(num_bases, -1)).view(batch_size, cell_size*4, cell_size) # batch_size, cell_size*4, cell_size 112 | bias_l0 = torch.mm(c_batch, self.bias_l0) # batch_size, cell_size*4 113 | weight_ih_l0_reverse = torch.mm(c_batch , self.weight_ih_l0_reverse.view(num_bases, -1)).view(batch_size, cell_size*4, input_size) # batch_size, cell_size*4, input_size 114 | weight_hh_l0_reverse = torch.mm(c_batch , self.weight_hh_l0_reverse.view(num_bases, -1)).view(batch_size, cell_size*4, cell_size) # batch_size, cell_size*4, cell_size 115 | bias_l0_reverse = torch.mm(c_batch, self.bias_l0_reverse) # batch_size, cell_size*4 116 | 117 | (h0, c0) = torch.zeros((2, batch_size, cell_size, 1)).to(self.device) # only for forward path 118 | (h0_reverse, c0_reverse) = torch.zeros((2, batch_size, cell_size, 1)).to(self.device) # only for forward path 119 | hidden = (h0, c0) 120 | hidden_reverse = (h0_reverse, c0_reverse) 121 | htops = None 122 | htops_reverse = None 123 | for i in range(maxlength): 124 | hx, cx = hidden # batch_size, cell_size, 1 125 | ix = x[:, i, :] # batch_size, input_size 126 | ix = ix.unsqueeze(dim=2) # batch_size, input_size, 1 127 | 128 | i2h = torch.bmm(weight_ih_l0, ix) 129 | i2h = i2h.squeeze(dim=2) # batch_size, cell_size*4 130 | h2h = torch.bmm(weight_hh_l0, hx) 131 | h2h = h2h.squeeze(dim=2) # batch_size, cell_size*4 132 | 133 | gates = i2h + h2h + bias_l0 # batch_size, cell_size*4 134 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 135 | 136 | ingate = torch.sigmoid(ingate) 137 | forgetgate = torch.sigmoid(forgetgate) 138 | cellgate = torch.tanh(cellgate) # o_t 139 | outgate = torch.sigmoid(outgate) 140 | 141 | cx = cx.squeeze(dim=2) # batch_size, cell_size 142 | cy = (forgetgate * cx) + (ingate * cellgate) 143 | hy = outgate * torch.tanh(cy) # batch_size, cell_size 144 | 145 | mask = (length-1) < i 146 | if mask.sum()>0: 147 | cy[mask] = torch.zeros(mask.sum(), cell_size).to(self.device) 148 | hy[mask] = torch.zeros(mask.sum(), cell_size).to(self.device) 149 | 150 | if (htops is None): htops = hy.unsqueeze(dim=1) 151 | else: htops = torch.cat((htops, hy.unsqueeze(dim=1)), dim=1) 152 | 153 | cx = cy.unsqueeze(dim=2) 154 | hx = hy.unsqueeze(dim=2) 155 | hidden = (hx, cx) 156 | 157 | ############################################################################### 158 | 159 | # reverse 160 | hx_reverse, cx_reverse = hidden_reverse # batch_size, cell_size, 1 161 | ix_reverse = x_reverse[:, i, :] # batch_size, input_size 162 | ix_reverse = ix_reverse.unsqueeze(dim=2) # batch_size, input_size, 1 163 | 164 | i2h = torch.bmm(weight_ih_l0_reverse, ix_reverse) 165 | i2h = i2h.squeeze(dim=2) # batch_size, cell_size*4 166 | h2h = torch.bmm(weight_hh_l0_reverse, hx_reverse) 167 | h2h = h2h.squeeze(dim=2) # batch_size, cell_size*4 168 | 169 | gates = i2h + h2h + bias_l0_reverse # batch_size, cell_size*4 170 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 171 | 172 | ingate = torch.sigmoid(ingate) 173 | forgetgate = torch.sigmoid(forgetgate) 174 | cellgate = torch.tanh(cellgate) # o_t 175 | outgate = torch.sigmoid(outgate) 176 | 177 | cx_reverse = cx_reverse.squeeze(dim=2) # batch_size, cell_size 178 | cy_reverse = (forgetgate * cx_reverse) + (ingate * cellgate) 179 | hy_reverse = outgate * torch.tanh(cy_reverse) # batch_size, cell_size 180 | 181 | # mask 182 | mask_reverse = (maxlength-i) > length 183 | # mask_reverse = np.nonzero(mask_reverse)[0] 184 | # mask_reverse = torch.from_numpy(mask_reverse).to(self.device) 185 | if mask_reverse.sum() > 0: 186 | cy_reverse[mask_reverse] = torch.zeros(mask_reverse.sum(), cell_size).to(self.device) 187 | hy_reverse[mask_reverse] = torch.zeros(mask_reverse.sum(), cell_size).to(self.device) 188 | 189 | if (htops_reverse is None): htops_reverse = hy_reverse.unsqueeze(dim=1) 190 | else: htops_reverse = torch.cat((htops_reverse, hy_reverse.unsqueeze(dim=1)), dim=1) 191 | 192 | cx_reverse = cy_reverse.unsqueeze(dim=2) 193 | hx_reverse = hy_reverse.unsqueeze(dim=2) 194 | hidden_reverse = (hx_reverse, cx_reverse) 195 | 196 | # reverse order of backward batch 197 | reverse_idx = torch.arange(maxlength-1, -1, -1).to(self.device) 198 | # reverse_idx = torch.from_numpy(reverse_idx).to(self.device) 199 | htops_reverse = htops_reverse[:, reverse_idx, :] 200 | 201 | # concatenate forward and backward path 202 | hiddens = torch.cat((htops, htops_reverse), dim=2) 203 | return hiddens 204 | 205 | class BasicAttention(nn.Module): 206 | def __init__(self, args): 207 | super().__init__() 208 | self.attention = LinearAttentionWithoutQuery( 209 | encoder_dim=args.state_size, 210 | device=args.device, 211 | ) 212 | def forward(self, x, mask, **kwargs): 213 | return self.attention(x, mask=mask)[0] 214 | class CustAttention(nn.Module): 215 | def __init__(self, args, meta_param_manager): 216 | super().__init__() 217 | for name, num_meta in args.meta_units: 218 | setattr(self, "num_"+name, num_meta) 219 | setattr(self, name, nn.Embedding(num_meta, args.meta_dim)) 220 | meta_param_manager.register("CustAttention."+name, getattr(self, name).weight) 221 | self.attention = LinearAttentionWithQuery(encoder_dim=args.state_size, query_dim=args.meta_dim*len(args.meta_units)) 222 | def forward(self, x, mask, **kwargs): 223 | return self.attention( 224 | x, 225 | query=torch.cat([ 226 | getattr(self, name)(idx) 227 | for name, idx in kwargs.items()], dim=1 228 | ).unsqueeze(dim=1).repeat(1, x.shape[1], 1), 229 | mask=mask)[0] 230 | class BasisCustAttention(nn.Module): 231 | def __init__(self, args, meta_param_manager): 232 | super().__init__() 233 | for name, num_meta in args.meta_units: 234 | setattr(self, "num_"+name, num_meta) 235 | setattr(self, name, nn.Embedding(num_meta, args.meta_dim)) 236 | meta_param_manager.register("BasisCustAttention."+name, getattr(self, name).weight) 237 | self.P = nn.Sequential( 238 | nn.Linear(args.meta_dim*len(args.meta_units), args.key_query_size), # From MetaData to Query 239 | nn.Tanh(), 240 | nn.Linear(args.key_query_size, args.num_bases, bias=False), # Calculate Weights of each Basis: Key & Query Inner-product 241 | nn.Softmax(dim=1), 242 | nn.Linear(args.num_bases, args.state_size), # Weighted Sum of Bases 243 | ) 244 | self.attention = LinearAttentionWithQuery(encoder_dim=args.state_size, query_dim=args.state_size) 245 | def forward(self, x, mask, **kwargs): 246 | return self.attention( 247 | x, 248 | query=self.P(torch.cat([ 249 | getattr(self, name)(idx) 250 | for name, idx in kwargs.items()], dim=1 251 | ).unsqueeze(dim=1).repeat(1, x.shape[1], 1)), 252 | mask=mask)[0] 253 | 254 | class BasicLinear(nn.Module): 255 | def __init__(self, args): 256 | super().__init__() 257 | self.W = nn.Linear(args.state_size, args.num_label, bias=False) 258 | def forward(self, x, **kwargs): 259 | return self.W(x) 260 | class CustLinear(nn.Module): 261 | def __init__(self, args, meta_param_manager): 262 | super().__init__() 263 | self.state_size = args.state_size 264 | self.num_label= args.num_label 265 | for name, num_meta in args.meta_units: 266 | setattr(self, "num_"+name, num_meta) 267 | setattr(self, name, nn.Embedding(num_meta, args.state_size*args.num_label)) 268 | meta_param_manager.register("CustLinear."+name, getattr(self, name).weight) 269 | def forward(self, x, **kwargs): 270 | W = torch.cat([ 271 | getattr(self, name)(idx).view(x.shape[0], self.state_size, self.num_label) 272 | for name, idx in kwargs.items()], dim=1) 273 | x = x.unsqueeze(dim=1).repeat(1,1,len(kwargs)) 274 | return torch.bmm(x, W).squeeze(dim=1) 275 | class BasisCustLinear(nn.Module): 276 | def __init__(self, args, meta_param_manager): 277 | super().__init__() 278 | self.state_size = args.state_size 279 | self.num_label = args.num_label 280 | for name, num_meta in args.meta_units: 281 | setattr(self, "num_"+name, num_meta) 282 | setattr(self, name, nn.Embedding(num_meta, args.meta_dim)) 283 | meta_param_manager.register("BasisCustLinear."+name, getattr(self, name).weight) 284 | self.P = nn.Sequential( 285 | nn.Linear(args.meta_dim*len(args.meta_units), args.key_query_size), # From MetaData to Query 286 | nn.Tanh(), 287 | nn.Linear(args.key_query_size, args.num_bases, bias=False), # Calculate Weights of each Basis: Key & Query Inner-product 288 | nn.Softmax(dim=1), 289 | nn.Linear(args.num_bases, args.state_size*args.num_label), # Weighted Sum of Bases 290 | ) 291 | def forward(self, x, **kwargs): 292 | W = self.P( 293 | torch.cat([getattr(self, name)(idx) for name, idx in kwargs.items()], dim=1) 294 | ).view(x.shape[0], self.state_size, self.num_label) 295 | return torch.bmm(x.unsqueeze(dim=1), W).squeeze(dim=1) 296 | 297 | class BasicBias(nn.Module): 298 | def __init__(self, args): 299 | super().__init__() 300 | self.b = nn.Parameter(torch.zeros((1, args.num_label))) 301 | def forward(self, **kwargs): 302 | return self.b 303 | class CustBias(nn.Module): 304 | def __init__(self, args, meta_param_manager): 305 | super().__init__() 306 | for name, num_meta in args.meta_units: 307 | setattr(self, "num_"+name, num_meta) 308 | setattr(self, name, nn.Embedding(num_meta, args.state_size)) 309 | meta_param_manager.register("CustBias."+name, getattr(self, name).weight) 310 | self.Y = nn.Linear(args.state_size*len(args.meta_units), args.num_label, bias=False) 311 | def forward(self, **kwargs): 312 | return self.Y(torch.cat([ 313 | getattr(self, name)(idx) 314 | for name, idx in kwargs.items()], dim=1)) 315 | class BasisCustBias(nn.Module): 316 | def __init__(self, args, meta_param_manager): 317 | super().__init__() 318 | for name, num_meta in args.meta_units: 319 | setattr(self, "num_"+name, num_meta) 320 | setattr(self, name, nn.Embedding(num_meta, args.meta_dim)) 321 | meta_param_manager.register("BasisCustBias."+name, getattr(self, name).weight) 322 | self.P = nn.Sequential( 323 | nn.Linear(args.meta_dim*len(args.meta_units), args.key_query_size), # From MetaData to Query 324 | nn.Tanh(), 325 | nn.Linear(args.key_query_size, args.num_bases, bias=False), # Calculate Weights of each Basis: Key & Query Inner-product 326 | nn.Softmax(dim=1), 327 | nn.Linear(args.num_bases, args.state_size), # Weighted Sum of Bases 328 | ) 329 | self.Y = nn.Linear(args.state_size, args.num_label, bias=False) 330 | def forward(self, **kwargs): 331 | return self.Y( 332 | self.P(torch.cat([getattr(self, name)(idx) for name, idx in kwargs.items()], dim=1)) 333 | ) 334 | 335 | class MetaParamManager: 336 | def __init__(self): 337 | self.meta_em = {} 338 | def state_dict(self): 339 | return self.meta_em 340 | def register(self, name, param): 341 | self.meta_em[name]=param 342 | -------------------------------------------------------------------------------- /src/model_src/model.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn, torch.nn.functional as F 2 | from .layers import * 3 | 4 | class Classifier(nn.Module): 5 | def __init__(self, args): 6 | super().__init__() 7 | self.meta_param_manager = MetaParamManager() 8 | if args.model_type == 'word_cust': 9 | self.word_em = CustWordEmb(args, self.meta_param_manager) 10 | elif args.model_type == 'word_basis_cust': 11 | self.word_em = BasisCustWordEmb(args, self.meta_param_manager) 12 | else: 13 | self.word_em = BasicWordEmb(args) 14 | 15 | if args.model_type == 'encoder_cust': 16 | raise Exception("Out-Of-Memory occurs ... ") 17 | elif args.model_type == 'encoder_basis_cust': 18 | self.encoder = BasisCustBiLSTM(args, self.meta_param_manager) 19 | else: 20 | self.encoder = BasicBiLSTM(args) 21 | 22 | if args.model_type == 'attention_cust': 23 | self.attention = CustAttention(args, self.meta_param_manager) 24 | elif args.model_type == 'attention_basis_cust': 25 | self.attention = BasisCustAttention(args, self.meta_param_manager) 26 | else: 27 | self.attention = BasicAttention(args) 28 | 29 | if args.model_type == 'linear_cust': 30 | self.W = CustLinear(args, self.meta_param_manager) 31 | elif args.model_type == 'linear_basis_cust': 32 | self.W = BasisCustLinear(args, self.meta_param_manager) 33 | else: 34 | self.W = BasicLinear(args) 35 | 36 | if args.model_type == 'bias_cust': 37 | self.b = CustBias(args, self.meta_param_manager) 38 | elif args.model_type == 'bias_basis_cust': 39 | self.b = BasisCustBias(args, self.meta_param_manager) 40 | else: 41 | self.b = BasicBias(args) 42 | self.word_em_weight = self.word_em.word_em.weight # for pretrained word em vector loading 43 | def forward(self, review, length, mask, **kwargs): 44 | x = self.word_em(review, **kwargs) 45 | # 2. BiLSTM 46 | x = self.encoder(x, length, **kwargs) 47 | # 3. Attention 48 | x = self.attention(x, mask, **kwargs) 49 | # 4. FC Weight Matrix 50 | x = self.W(x, **kwargs) 51 | # 5. FC bias 52 | x += self.b(**kwargs) 53 | return x -------------------------------------------------------------------------------- /src/run.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | echo "running bash script ... " 3 | python3 -W ignore main.py \ 4 | --model_type linear_basis_cust \ 5 | --num_bases 4 \ 6 | --domain yelp2013 \ 7 | --vocab_dir ../predefined_vocab/yelp2013/42939.vocab \ 8 | --pretrained_word_em_dir ../predefined_vocab/yelp2013/word_vectors.npy \ 9 | --train_datadir ../dataset/yelp2013/processed_data/train.txt \ 10 | --dev_datadir ../dataset/yelp2013/processed_data/dev.txt \ 11 | --test_datadir ../dataset/yelp2013/processed_data/test.txt \ 12 | --meta_dim 64 \ 13 | --key_query_size 64 \ 14 | --word_dim 300 \ 15 | --state_size 256 \ 16 | --valid_step 1000 \ 17 | --------------------------------------------------------------------------------