├── README.md ├── config.py ├── discrete_encoders.py ├── evaluate.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # BinarySentEmb 2 | Code for the ACL 2019 paper: Learning Compressed Sentence Representations for On-Device Text Processing. 3 | 4 | 5 | This repository contains source code necessary to reproduce the results presented in the following paper: 6 | * [*Learning Compressed Sentence Representations for On-Device Text Processing*](https://arxiv.org/pdf/1906.08340.pdf) (ACL 2019) 7 | 8 | This project is maintained by [Pengyu Cheng](https://linear95.github.io/). Feel free to contact pengyu.cheng@duke.edu for any relevant issues. 9 | 10 | ## Dependencies 11 | This code is written in python. The dependencies are: 12 | * Python 3.6 13 | * Pytorch>=0.4 (0.4.1 is recommended) 14 | * NLTK>=3 15 | 16 | 17 | ## Download pretrained models 18 | 19 | First, download [GloVe](https://nlp.stanford.edu/projects/glove/) pretrained word embeddings: 20 | 21 | ```bash 22 | mkdir dataset/GloVe 23 | curl -Lo dataset/GloVe/glove.840B.300d.zip http://nlp.stanford.edu/data/glove.840B.300d.zip 24 | unzip dataset/GloVe/glove.840B.300d.zip -d dataset/GloVe/ 25 | ``` 26 | Then, follow the instruction of [InferSent](https://github.com/facebookresearch/InferSent) to download pretrain universal sentence encoder: 27 | 28 | ```bash 29 | mkdir encoder 30 | curl -Lo encoder/infersent1.pkl https://dl.fbaipublicfiles.com/infersent/infersent1.pkl 31 | ``` 32 | 33 | Futhermore, download our pretrained binary sentence encoder from [here](https://drive.google.com/open?id=12lzqtxQwktywXRc1HsQ36ptHGfGOTcIJ). Make sure the binary encoder is also in the `./encoder/` folder. 34 | 35 | ## Train a binary encoder 36 | To train a binary sentence encoder, first download `data.py`, `mutils.py`, and `models.py` from [InferSent](https://github.com/facebookresearch/InferSent). 37 | 38 | Then, run the command: 39 | 40 | ```bash 41 | python train.py 42 | ``` 43 | 44 | ## Evaluate the binary encoder on transfer tasks 45 | Following the instruction of [SentEval](https://github.com/facebookresearch/SentEval) to download the sentence embeddings evaluation toolkit and datasets. 46 | 47 | Download the original InferSent encoder model from [here](https://github.com/facebookresearch/InferSent). 48 | 49 | To reproduce results of our pretrained binary sentence encoder, run the command: 50 | ```bash 51 | python evaluate.py 52 | ``` 53 | 54 | ## Citation 55 | Please cite our ACL 2019 paper if you found the code useful. 56 | 57 | ```latex 58 | @article{shen2019learning, 59 | title={Learning Compressed Sentence Representations for On-Device Text Processing}, 60 | author={Shen, Dinghan and Cheng, Pengyu and Sundararaman, Dhanasekar and Zhang, Xinyuan and Yang, Qian and Tang, Meng and Celikyilmaz, Asli and Carin, Lawrence}, 61 | journal={arXiv preprint arXiv:1906.08340}, 62 | year={2019} 63 | } 64 | ``` 65 | 66 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | lty = 'linear' 2 | fc_dim = 2048 3 | random_seed = 2019 4 | RAN_LOAD_PATH = '../nonlinear_ae/' 5 | PCA_LOAD_PATH = RAN_LOAD_PATH 6 | 7 | INFERSENT_VERSION = 1 # version of InferSent 8 | dim = 2048 9 | model_name = 'bEncoder2048.pkl' 10 | encoder_type = 'AE'#'AE' #'PCA','Random' 'Id' 'HT' 11 | sim_type = 'cosine'#'cosine' #'hamming' 12 | 13 | -------------------------------------------------------------------------------- /discrete_encoders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class IdEncoder(nn.Module): 10 | def __init__(self): 11 | super(IdEncoder,self).__init__() 12 | 13 | def encode(self,x): 14 | return x 15 | 16 | 17 | class HTEncoder(nn.Module): 18 | def __init__(self, load_dir): 19 | super(HTEncoder,self).__init__() 20 | self.emb_mean = torch.from_numpy(np.load(load_dir + 'emb_mean.npy')).cuda() 21 | self.emb_mean = self.emb_mean.view(1,4096) 22 | 23 | def encode(self,x): 24 | # return (x>self.emb_mean).float().cuda() 25 | return (x>0.07).float().cuda() 26 | 27 | 28 | 29 | class RandomEncoder(nn.Module): 30 | def __init__(self,dim,load_dir): 31 | super(RandomEncoder,self).__init__() 32 | self.project_mat = torch.randn(4096,dim).cuda() 33 | self.emb_mean = torch.from_numpy(np.load(load_dir + 'emb_mean.npy')).cuda() 34 | self.emb_mean = self.emb_mean.view(1,4096) 35 | 36 | def encode(self, x): 37 | random_project_emb = torch.matmul(x - self.emb_mean, self.project_mat) 38 | discrete_emb = (random_project_emb > 0.).float().cuda() 39 | return random_project_emb + (discrete_emb-random_project_emb).detach() 40 | 41 | 42 | 43 | class PCAEncoder(nn.Module): 44 | def __init__(self,dim,load_dir): 45 | super(PCAEncoder,self).__init__() 46 | np_project_mat = np.load(load_dir + 'trans_mat.npy') 47 | self.project_mat = torch.from_numpy(np_project_mat[:,:dim]).cuda() 48 | self.emb_mean = torch.from_numpy(np.load(load_dir + 'emb_mean.npy')).cuda() 49 | self.emb_mean = self.emb_mean.view(1,4096) 50 | 51 | def encode(self, x): 52 | pca_emb = torch.matmul(x - self.emb_mean, self.project_mat) 53 | discrete_emb = (pca_emb > 0.).float().cuda() 54 | return pca_emb + (discrete_emb-pca_emb).detach() 55 | 56 | 57 | 58 | class LinearAutoEncoder(nn.Module): 59 | def __init__(self, dim): 60 | super(LinearAutoEncoder, self).__init__() 61 | 62 | self.encoder = nn.Sequential( 63 | nn.Linear(2* 2048, dim), 64 | ) 65 | 66 | self.decoder = nn.Sequential( 67 | nn.Linear(dim, 2*2048) 68 | ) 69 | 70 | def forward(self, x): 71 | logits = self.encoder(x) 72 | latent_code = (logits>0.).float().cuda() 73 | to_decoder = logits+(latent_code-logits).detach() 74 | predict = self.decoder(2.*to_decoder-1.) 75 | return predict 76 | 77 | def encode(self, x): 78 | logits = self.encoder(x) 79 | latent_code = (logits>0.).float().cuda() 80 | to_decoder = logits+(latent_code-logits).detach() 81 | return to_decoder 82 | 83 | 84 | 85 | class NonlinearAutoEncoder(nn.Module): 86 | def __init__(self, dim, fc_dim=2048): #fc_dim for dimension of fully-connect layers 87 | super(NonlinearAutoEncoder, self).__init__() 88 | 89 | self.encoder = nn.Sequential( 90 | nn.Linear(2* 2048, fc_dim), 91 | nn.Tanh(), 92 | nn.Linear(fc_dim, dim) 93 | ) 94 | self.decoder = nn.Sequential( 95 | nn.Linear(dim, 2*2048) 96 | ) 97 | 98 | def forward(self, x): 99 | logits = self.encoder(x) 100 | latent_code = (logits>0.).float().cuda() 101 | to_decoder = logits+(latent_code-logits).detach() 102 | predict = self.decoder(2.*to_decoder-1.) 103 | return predict 104 | 105 | def encode(self, x): 106 | logits = self.encoder(x) 107 | latent_code = (logits>0.).float().cuda() 108 | to_decoder = logits+(latent_code-logits).detach() 109 | return to_decoder 110 | 111 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | InferSent models. See https://github.com/facebookresearch/InferSent. 4 | """ 5 | 6 | #from __future__ import absolute_import, division, unicode_literals 7 | 8 | import sys 9 | import os 10 | import numpy as np 11 | import torch 12 | import logging 13 | 14 | import discrete_encoders as DisEnc 15 | import config 16 | 17 | # Set PATHs 18 | PATH_TO_INFERSENT = '../../InferSent-master/' # path to Infersent 19 | PATH_TO_SENTEVAL = '../../SentEval-master/' # path to SentEval 20 | PATH_TO_DATA = PATH_TO_SENTEVAL + 'data' # path to transfer task datasets 21 | PATH_TO_W2V = './dataset/GloVe/glove.840B.300d.txt' # path to GloVe word embedding 22 | PATH_TO_CONT_ENCODER = './encoder/infersent1.pkl' 23 | PATH_TO_B_ENCODER = './encoder/bEncoder2048.pkl' 24 | 25 | 26 | #assert os.path.isfile(INFERSENT_PATH) and os.path.isfile(PATH_TO_W2V), \ 'Set MODEL and GloVe PATHs' 27 | 28 | # import senteval 29 | sys.path.insert(0, PATH_TO_SENTEVAL) 30 | import senteval.engine as engine_cosine 31 | #import senteval.engine_hamming as engine_hamming 32 | 33 | sys.path.insert(0, PATH_TO_INFERSENT) 34 | from models import InferSent 35 | 36 | def prepare(params, samples): 37 | params.infersent.build_vocab([' '.join(s) for s in samples], tokenize=False) 38 | 39 | 40 | def batcher(params, batch): 41 | sentences = [' '.join(s) for s in batch] 42 | embeddings = params.infersent.encode(sentences, bsize=params.batch_size, tokenize=False) 43 | embeddings = torch.from_numpy(embeddings).float().cuda() 44 | return params.autoencoder.encode(embeddings).data.cpu().numpy() 45 | 46 | 47 | def hamming_similarity(s1,s2): 48 | return -np.sum(np.abs(s1-s2),axis = -1) 49 | 50 | 51 | # define senteval params 52 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5, 53 | 'classifier' :{'nhid': 0, 'optim': 'adam', 'batch_size': 64, 54 | 'tenacity': 5, 'epoch_size': 4} 55 | } 56 | # Set up logger 57 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 58 | 59 | if __name__ == "__main__": 60 | # Load InferSent model 61 | params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048, 62 | 'pool_type': 'max', 'dpout_model': 0.0, 'version': 1} 63 | model = InferSent(params_model) 64 | model.load_state_dict(torch.load(PATH_TO_CONT_ENCODER)) 65 | model.set_w2v_path(PATH_TO_W2V) 66 | 67 | model_name = config.encoder_type 68 | if config.encoder_type == 'AE': 69 | dis_encoder = DisEnc.LinearAutoEncoder(config.dim) 70 | dis_encoder.load_state_dict(torch.load(PATH_TO_B_ENCODER)) 71 | model_name = model_name + '_' + config.model_name #+'V'+str(config.INFERSENT_VERSION) 72 | elif config.encoder_type == 'PCA': 73 | dis_encoder = DisEnc.PCAEncoder(config.dim,config.PCA_LOAD_PATH) 74 | elif config.encoder_type == 'Random': 75 | dis_encoder = DisEnc.RandomEncoder(config.dim,config.RAN_LOAD_PATH) 76 | elif config.encoder_type == 'Id': 77 | dis_encoder = DisEnc.IdEncoder() 78 | elif config.encoder_type == 'HT': 79 | dis_encoder = DisEnc.HTEncoder(config.RAN_LOAD_PATH) 80 | 81 | 82 | print('testing '+model_name) 83 | 84 | params_senteval['infersent'] = model.cuda() 85 | params_senteval['autoencoder'] = dis_encoder.cuda() 86 | params_senteval['similarity'] = hamming_similarity 87 | 88 | if config.sim_type == 'cosine': 89 | se = engine_cosine.SE(params_senteval, batcher, prepare) 90 | elif config.sim_type == 'hamming': 91 | se = engine_hamming.SE(params_senteval, batcher, prepare) 92 | 93 | results = se.eval(['MR', 'CR','STS12', 'STS13', 'STS14', 'STS15', 'STS16','MRPC','SICKRelatedness','STSBenchmark','SICKEntailment','SICKRelatedness', 'MPQA', 'SUBJ', 'SST2', 'SST5']#, 'MRPC', 94 | 95 | print(results) 96 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | import argparse 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from data import get_nli, get_batch, build_vocab 15 | from mutils import get_optimizer 16 | from models import InferSent 17 | 18 | #from models import autoencoder 19 | import discrete_encoders as DisEnc 20 | 21 | 22 | model_name = "bEncoder" 23 | 24 | parser = argparse.ArgumentParser(description='NLI training') 25 | # paths 26 | parser.add_argument("--nlipath", type=str, default='./dataset/SNLI/', help="NLI data path (SNLI or MultiNLI)") 27 | parser.add_argument("--outputdir", type=str, default='savedir/', help="Output directory") 28 | parser.add_argument("--word_emb_path", type=str, default="./dataset/GloVe/glove.840B.300d.txt", help="word embedding file path") 29 | 30 | # training 31 | parser.add_argument("--n_epochs", type=int, default = 20) 32 | parser.add_argument("--batch_size", type=int, default=64) 33 | parser.add_argument("--dpout_model", type=float, default=0., help="encoder dropout") 34 | parser.add_argument("--dpout_fc", type=float, default=0., help="classifier dropout") 35 | parser.add_argument("--nonlinear_fc", type=float, default=0, help="use nonlinearity in fc") 36 | parser.add_argument("--optimizer", type=str, default="adam,lr=0.00005", help="adam or sgd,lr=0.1") 37 | parser.add_argument("--lr", type=float, default=5e-5, help="learning rate") 38 | parser.add_argument("--lrshrink", type=float, default=5, help="shrink factor for sgd") 39 | parser.add_argument("--decay", type=float, default=0.99, help="lr decay") 40 | parser.add_argument("--minlr", type=float, default=1e-5, help="minimum lr") 41 | parser.add_argument("--struct_coef", type=float, default=1e-5, help="coefficient for structural loss") 42 | 43 | 44 | parser.add_argument("--max_norm", type=float, default=5., help="max norm (grad clipping)") 45 | 46 | # model 47 | parser.add_argument("--encoder_type", type=str, default='InferSent', help="see list of encoders") 48 | parser.add_argument("--enc_lstm_dim", type=int, default=2048, help="encoder nhid dimension") 49 | #parser.add_argument("--sent_emb_dim", type=int, default=512, help="encoder output dimension") 50 | parser.add_argument("--n_enc_layers", type=int, default=1, help="encoder num layers") 51 | parser.add_argument("--fc_dim", type=int, default=512, help="nhid of fc layers") 52 | parser.add_argument("--n_classes", type=int, default=3, help="entailment/neutral/contradiction") 53 | parser.add_argument("--pool_type", type=str, default='max', help="max or mean") 54 | 55 | # gpu 56 | parser.add_argument("--gpu_id", type=int, default= 0, help="GPU ID") 57 | parser.add_argument("--seed", type=int, default= 2019 , help="seed") 58 | 59 | # data 60 | parser.add_argument("--word_emb_dim", type=int, default=300, help="word embedding dimension") 61 | 62 | parser.add_argument("--dis_emb_dim", type=int, default = 1024, help = 'discrete_embedding_dim') 63 | 64 | params, _ = parser.parse_known_args() 65 | 66 | ae_lr = params.lr 67 | # set gpu device 68 | torch.cuda.set_device(params.gpu_id) 69 | 70 | # print parameters passed, and all parameters 71 | print('\ntogrep : {0}\n'.format(sys.argv[1:])) 72 | print(params) 73 | 74 | 75 | """ 76 | SEED 77 | """ 78 | np.random.seed(params.seed) 79 | torch.manual_seed(params.seed) 80 | torch.cuda.manual_seed(params.seed) 81 | 82 | """ 83 | DATA 84 | """ 85 | train, valid, test = get_nli(params.nlipath) 86 | word_vec = build_vocab(train['s1'] + train['s2'] + 87 | valid['s1'] + valid['s2'] + 88 | test['s1'] + test['s2'], params.word_emb_path) 89 | 90 | for split in ['s1', 's2']: 91 | for data_type in ['train', 'valid', 'test']: 92 | eval(data_type)[split] = np.array([[''] + 93 | [word for word in sent.split() if word in word_vec] + 94 | [''] for sent in eval(data_type)[split]]) 95 | 96 | 97 | """ 98 | MODEL 99 | """ 100 | # model config 101 | config_nli_model = { 102 | 'n_words' : len(word_vec) , 103 | 'word_emb_dim' : params.word_emb_dim , 104 | 'enc_lstm_dim' : params.enc_lstm_dim , 105 | 'n_enc_layers' : params.n_enc_layers , 106 | 'dpout_model' : params.dpout_model , 107 | 'dpout_fc' : params.dpout_fc , 108 | 'fc_dim' : params.fc_dim , 109 | 'bsize' : params.batch_size , 110 | 'n_classes' : params.n_classes , 111 | 'pool_type' : params.pool_type , 112 | 'nonlinear_fc' : params.nonlinear_fc , 113 | 'encoder_type' : params.encoder_type , 114 | 'use_cuda' : True , 115 | } 116 | 117 | # model 118 | encoder_types = ['InferSent', 'BLSTMprojEncoder', 'BGRUlastEncoder', 119 | 'InnerAttentionMILAEncoder', 'InnerAttentionYANGEncoder', 120 | 'InnerAttentionNAACLEncoder', 'ConvNetEncoder', 'LSTMEncoder'] 121 | assert params.encoder_type in encoder_types, "encoder_type must be in " + \ 122 | str(encoder_types) 123 | 124 | infersent_net = InferSent(config_nli_model) 125 | print(infersent_net) 126 | 127 | infersent_net.load_state_dict(torch.load('./encoder/infersent1.pkl')) 128 | infersent_net.cuda() 129 | 130 | for parameters_infer in infersent_net.parameters(): 131 | parameters_infer.requires_grad =False 132 | 133 | 134 | ae_model = DisEnc.LinearAutoEncoder(params.dis_emb_dim).cuda() 135 | 136 | print(ae_model) 137 | 138 | def cos_distance(a,b): 139 | return (1.-torch.nn.functional.cosine_similarity(a,b)) 140 | 141 | def hamming_distance(a,b): 142 | #return (a-b).abs().sum() 143 | return torch.nn.functional.pairwise_distance(a,b) 144 | 145 | def mse(a,b): 146 | return ((a-b)*(a-b)).mean() 147 | 148 | 149 | optimizer = torch.optim.Adam(ae_model.parameters(), lr=ae_lr, weight_decay=1e-5) 150 | 151 | """ 152 | TRAIN 153 | """ 154 | val_mse_best = 100000 155 | adam_stop = False 156 | stop_training = False 157 | 158 | 159 | 160 | def trainepoch(epoch): 161 | print('\nTRAINING : Epoch ' + str(epoch)) 162 | ae_model.train() 163 | all_costs = [] 164 | logs = [] 165 | words_count = 0 166 | 167 | last_time = time.time() 168 | correct = 0. 169 | # shuffle the data 170 | permutation = np.random.permutation(len(train['s1'])) 171 | 172 | s1 = train['s1'][permutation] 173 | s2 = train['s2'][permutation] 174 | 175 | # optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * params.decay if epoch>1\ 176 | # and 'sgd' in params.optimizer else optimizer.param_groups[0]['lr'] 177 | # print('Learning rate : {0}'.format(optimizer.param_groups[0]['lr'])) 178 | 179 | for stidx in range(0, len(s1), params.batch_size): 180 | if params.batch_size > len(s1)-stidx: 181 | sub_batch_size = (len(s1)-stidx)//2 182 | else: 183 | sub_batch_size = params.batch_size // 2 184 | 185 | # prepare batch 186 | s1_batch, s1_len = get_batch(s1[stidx:stidx + sub_batch_size], 187 | word_vec, params.word_emb_dim) 188 | s2_batch, s2_len = get_batch(s2[stidx:stidx + sub_batch_size], 189 | word_vec, params.word_emb_dim) 190 | s3_batch, s3_len = get_batch(s1[stidx + sub_batch_size :stidx + 2 * sub_batch_size], 191 | word_vec, params.word_emb_dim) 192 | s4_batch, s4_len = get_batch(s2[stidx + sub_batch_size :stidx + 2 * sub_batch_size], 193 | word_vec, params.word_emb_dim) 194 | 195 | s1_batch, s2_batch, s3_batch, s4_batch = Variable(s1_batch.cuda()), Variable(s2_batch.cuda()), Variable(s3_batch.cuda()), Variable(s4_batch.cuda()) 196 | 197 | # model forward 198 | cont_code1 = infersent_net((s1_batch, s1_len)) 199 | cont_code2 = infersent_net((s2_batch, s2_len)) 200 | 201 | cont_code3 = infersent_net((s3_batch, s3_len)) 202 | cont_code4 = infersent_net((s4_batch, s4_len)) 203 | 204 | output1 = ae_model(cont_code1) 205 | output2 = ae_model(cont_code2) 206 | output3 = ae_model(cont_code3) 207 | output4 = ae_model(cont_code4) 208 | 209 | discrete_code1 = ae_model.encode(cont_code1) 210 | discrete_code2 = ae_model.encode(cont_code2) 211 | discrete_code3 = ae_model.encode(cont_code3) 212 | discrete_code4 = ae_model.encode(cont_code4) 213 | 214 | cont_code_dist_1 = cos_distance(cont_code1,cont_code2) 215 | cont_code_dist_2 = cos_distance(cont_code3,cont_code4) 216 | discrete_code_dist1 = hamming_distance(discrete_code1,discrete_code2) # or cos 217 | discrete_code_dist2 = hamming_distance(discrete_code3,discrete_code4) # or cos 218 | 219 | struct_sign = ((cont_code_dist_1 -cont_code_dist_2) > 0.).float().cuda() 220 | struct_loss = (F.relu( (discrete_code_dist2-discrete_code_dist1)*(2*struct_sign-1.).detach()) ** 2).mean() 221 | recons_loss = (mse(cont_code1,output1) +mse(cont_code2,output2) +mse(cont_code3,output3) +mse(cont_code4,output4))/4. 222 | 223 | loss = params.struct_coef * struct_loss + recons_loss 224 | 225 | all_costs.append([struct_loss.sqrt().data.cpu().detach().numpy(),recons_loss.sqrt().data.cpu().detach().numpy()]) 226 | 227 | # backward 228 | optimizer.zero_grad() 229 | loss.backward() 230 | 231 | 232 | # optimizer step 233 | optimizer.step() 234 | #optimizer.param_groups[0]['lr'] = current_lr 235 | 236 | if len(all_costs) % 100 == 0: 237 | losses = np.mean(all_costs, axis = 0) 238 | print('{0} at epoch {1} ; struct_loss {2}; recons_loss {3}'.format(stidx,epoch,losses[0],losses[1] )) 239 | all_costs = [] 240 | 241 | train_acc = (np.mean(all_costs)) 242 | print('results : epoch {0} ; mean accuracy train : {1}' 243 | .format(epoch,(train_acc))) 244 | 245 | evaluate(epoch) 246 | return 0 247 | 248 | 249 | 250 | def evaluate(epoch, eval_type='valid', final_eval=False): 251 | infersent_net.eval() 252 | ae_model.eval() 253 | correct = 0. 254 | global stop_training, adam_stop 255 | 256 | if eval_type == 'valid': 257 | print('\nVALIDATION : Epoch {0}'.format(epoch)) 258 | 259 | s1 = test['s1'] if eval_type == 'valid' else test['s1'] 260 | s2 = test['s2'] if eval_type == 'valid' else test['s2'] 261 | target = test['label'] if eval_type == 'valid' else test['label'] 262 | 263 | all_costs = [] 264 | 265 | 266 | for stidx in range(0, len(s1), params.batch_size): 267 | if params.batch_size > len(s1)-stidx: 268 | sub_batch_size = (len(s1)-stidx)//2 269 | else: 270 | sub_batch_size = params.batch_size // 2 271 | 272 | # prepare batch 273 | s1_batch, s1_len = get_batch(s1[stidx:stidx + sub_batch_size], 274 | word_vec, params.word_emb_dim) 275 | s2_batch, s2_len = get_batch(s2[stidx:stidx + sub_batch_size], 276 | word_vec, params.word_emb_dim) 277 | s3_batch, s3_len = get_batch(s1[stidx + sub_batch_size :stidx + 2 * sub_batch_size], 278 | word_vec, params.word_emb_dim) 279 | s4_batch, s4_len = get_batch(s2[stidx + sub_batch_size :stidx + 2 * sub_batch_size], 280 | word_vec, params.word_emb_dim) 281 | 282 | s1_batch, s2_batch, s3_batch, s4_batch = Variable(s1_batch.cuda()), Variable(s2_batch.cuda()), Variable(s3_batch.cuda()), Variable(s4_batch.cuda()) 283 | # tgt_batch = Variable(torch.LongTensor(target[stidx:stidx + params.batch_size])).cuda() 284 | k = s1_batch.size(1) # actual batch size 285 | 286 | # model forward 287 | cont_code1 = infersent_net((s1_batch, s1_len)) 288 | cont_code2 = infersent_net((s2_batch, s2_len)) 289 | 290 | cont_code3 = infersent_net((s3_batch, s3_len)) 291 | cont_code4 = infersent_net((s4_batch, s4_len)) 292 | 293 | output1 = ae_model(cont_code1) 294 | output2 = ae_model(cont_code2) 295 | output3 = ae_model(cont_code3) 296 | output4 = ae_model(cont_code4) 297 | 298 | discrete_code1 = ae_model.encode(cont_code1) 299 | discrete_code2 = ae_model.encode(cont_code2) 300 | discrete_code3 = ae_model.encode(cont_code3) 301 | discrete_code4 = ae_model.encode(cont_code4) 302 | 303 | cont_code_dist_1 = cos_distance(cont_code1,cont_code2) 304 | cont_code_dist_2 = cos_distance(cont_code3,cont_code4) 305 | discrete_code_dist1 = hamming_distance(discrete_code1,discrete_code2) # or cos 306 | discrete_code_dist2 = hamming_distance(discrete_code3,discrete_code4) # or cos 307 | 308 | struct_sign = ((cont_code_dist_1 -cont_code_dist_2) > 0.).float().cuda() 309 | struct_loss = F.relu( (discrete_code_dist2-discrete_code_dist1)*(2*struct_sign-1.).detach()).mean() 310 | recons_loss = (mse(cont_code1,output1) +mse(cont_code2,output2) +mse(cont_code3,output3) +mse(cont_code4,output4))/4. 311 | 312 | # loss 313 | all_costs.append([struct_loss.data.cpu().detach().numpy(),recons_loss.sqrt().data.cpu().detach().numpy()]) 314 | 315 | losses = np.mean(all_costs, axis = 0) 316 | print('valid set at epoch{0} ; struct_loss {1}; recons_loss {2}'.format(epoch,losses[0],losses[1] )) 317 | 318 | return 0 319 | 320 | 321 | """ 322 | Train model on Natural Language Inference task 323 | """ 324 | epoch = 1 325 | 326 | 327 | while epoch <= params.n_epochs:# and (not stop_training): 328 | train_acc = trainepoch(epoch) 329 | # evaluate(epoch, 'valid') 330 | epoch += 1 331 | 332 | --------------------------------------------------------------------------------