├── .gitignore ├── LICENSE ├── README.md └── code ├── data_prep ├── chn_hotel_dataset.py └── yelp_dataset.py ├── layers.py ├── models.py ├── options.py ├── train.py ├── utils.py └── vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Xilun Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Language-Adversarial Training for Cross-Lingual Text Classification 2 | 3 | This repo contains the source code for our TACL journal paper: 4 | 5 | [**Adversarial Deep Averaging Networks for Cross-Lingual Sentiment Classification**](https://arxiv.org/abs/1606.01614) 6 |
7 | [Xilun Chen](http://www.cs.cornell.edu/~xlchen/), 8 | Yu Sun, 9 | [Ben Athiwaratkun](http://www.benathiwaratkun.com/), 10 | [Claire Cardie](http://www.cs.cornell.edu/home/cardie/), 11 | [Kilian Weinberger](http://kilian.cs.cornell.edu/) 12 |
13 | Transactions of the Association for Computational Linguistics (TACL) 14 |
15 | [paper (arXiv)](https://arxiv.org/abs/1606.01614), 16 | [bibtex (arXiv)](http://www.cs.cornell.edu/~xlchen/resources/bibtex/adan.bib), 17 | [paper (TACL)](https://www.mitpressjournals.org/doi/abs/10.1162/tacl_a_00039), 18 | [bibtex](http://www.cs.cornell.edu/~xlchen/resources/bibtex/adan_tacl.bib), 19 | [talk@EMNLP2018](https://vimeo.com/306129914) 20 | 21 | ## Introduction 22 | 23 | 24 |

ADAN transfers the knowledge learned from labeled data on a resource-rich source language to low-resource languages where only unlabeled data exists. 25 | It achieves cross-lingual model transfer via learning language-invariant features extracted by Language-Adversarial Training.

26 | 27 | ## Requirements 28 | - Python 3.6 29 | - PyTorch 0.4 30 | - PyTorchNet (for confusion matrix) 31 | - scipy 32 | - tqdm (for progress bar) 33 | 34 | ## File Structure 35 | 36 | ``` 37 | . 38 | ├── README.md 39 | └── code 40 | ├── data_prep (data processing scripts) 41 | │   ├── chn_hotel_dataset.py (processing the Chinese Hotel Review dataset) 42 | │   └── yelp_dataset.py (processing the English Yelp Review dataset) 43 | ├── layers.py (lower-level helper modules) 44 | ├── models.py (higher-level modules) 45 | ├── options.py (hyper-parameters aka. all the knobs you may want to turn) 46 | ├── train.py (main file to train the model) 47 | ├── utils.py (helper functions) 48 | └── vocab.py (vocabulary) 49 | ``` 50 | 51 | ## Dataset 52 | 53 | The datasets can be downloaded separately [here](https://drive.google.com/drive/folders/1_JSr_VBVQ33hS0PuFjg68d3ePBr_eISF?usp=sharing). 54 | 55 | To support new datasets, simply write a new script under ```data_prep``` similar to the current ones and update ```train.py``` to correctly load it. 56 | 57 | ## Run Experiments 58 | 59 | ```bash 60 | python train.py --model_save_file {path_to_save_the_model} 61 | ``` 62 | 63 | By default, the code uses CNN as the feature extractor. 64 | To use the LSTM (with dot attention) feature extractor: 65 | 66 | ```bash 67 | python train.py --model lstm --F_layers 2 --model_save_file {path_to_save_the_model} 68 | ``` 69 | -------------------------------------------------------------------------------- /code/data_prep/chn_hotel_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class ChnHtlDataset(Dataset): 8 | def __init__(self, X, Y, num_train_lines, vocab, max_seq_len, update_vocab): 9 | # data is assumed to be pre-shuffled 10 | if num_train_lines > 0: 11 | X = X[:num_train_lines] 12 | Y = Y[:num_train_lines] 13 | if update_vocab: 14 | for x in X: 15 | for w in x: 16 | vocab.add_word(w) 17 | # save lengths 18 | self.X = [([vocab.lookup(w) for w in x], len(x)) for x in X] 19 | if max_seq_len > 0: 20 | self.set_max_seq_len(max_seq_len) 21 | self.Y = Y 22 | self.num_labels = 5 23 | assert len(self.X) == len(self.Y), 'X and Y have different lengths' 24 | print('Loaded Chinese Hotel dataset of {} samples'.format(len(self.X))) 25 | 26 | def __len__(self): 27 | return len(self.Y) 28 | 29 | def __getitem__(self, idx): 30 | return (self.X[idx], self.Y[idx]) 31 | 32 | def set_max_seq_len(self, max_seq_len): 33 | self.X = [(x[0][:max_seq_len], min(x[1], max_seq_len)) for x in self.X] 34 | self.max_seq_len = max_seq_len 35 | 36 | def get_max_seq_len(self): 37 | if not hasattr(self, 'max_seq_len'): 38 | self.max_seq_len = max([x[1] for x in self.X]) 39 | return self.max_seq_len 40 | 41 | def get_subset(self, num_lines): 42 | return ChnHtlDataset(self.X[:num_lines], self.Y[:num_lines], 43 | 0, self.max_seq_len) 44 | 45 | 46 | def get_chn_htl_datasets(vocab, X_filename, Y_filename, num_train_lines, max_seq_len): 47 | """ 48 | dataset is pre-shuffled 49 | split: 150k train + 10k valid + 10k test 50 | """ 51 | num_train = 150000 52 | num_valid = num_test = 10000 53 | num_total = num_train + num_valid + num_test 54 | raw_X = [] 55 | with open(X_filename) as inf: 56 | for line in inf: 57 | words = line.rstrip().split() 58 | if max_seq_len > 0: 59 | words = words[:max_seq_len] 60 | raw_X.append(words) 61 | Y = (torch.from_numpy(np.loadtxt(Y_filename)) - 1).long() 62 | assert num_total == len(raw_X) == len(Y), 'X and Y have different lengths' 63 | 64 | train_dataset = ChnHtlDataset(raw_X[:num_train], Y[:num_train], num_train_lines, 65 | vocab, max_seq_len, update_vocab=True) 66 | valid_dataset = ChnHtlDataset(raw_X[num_train:num_train+num_valid], 67 | Y[num_train:num_train+num_valid], 68 | 0, 69 | vocab, 70 | max_seq_len, 71 | update_vocab=False) 72 | test_dataset = ChnHtlDataset(raw_X[num_train+num_valid:], 73 | Y[num_train+num_valid:], 74 | 0, 75 | vocab, 76 | max_seq_len, 77 | update_vocab=False) 78 | return train_dataset, valid_dataset, test_dataset 79 | -------------------------------------------------------------------------------- /code/data_prep/yelp_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class YelpDataset(Dataset): 8 | def __init__(self, X_file, Y_file, num_train_lines, vocab, max_seq_len, update_vocab): 9 | self.raw_X = [] 10 | self.X = [] 11 | with open(X_file) as inf: 12 | cnt = 0 13 | for line in inf: 14 | words = line.rstrip().split() 15 | if max_seq_len > 0: 16 | words = words[:max_seq_len] 17 | self.raw_X.append(words) 18 | if update_vocab: 19 | for w in words: 20 | vocab.add_word(w) 21 | # save lengths 22 | self.X.append(([vocab.lookup(w) for w in words], len(words))) 23 | cnt += 1 24 | if num_train_lines > 0 and cnt >= num_train_lines: 25 | break 26 | 27 | self.max_seq_len = max_seq_len 28 | if isinstance(Y_file, str): 29 | self.Y = (torch.from_numpy(np.loadtxt(Y_file)) - 1).long() 30 | else: 31 | self.Y = Y_file 32 | if num_train_lines > 0: 33 | self.X = self.X[:num_train_lines] 34 | self.Y = self.Y[:num_train_lines] 35 | self.num_labels = 5 36 | # self.Y = self.Y.to(opt.device) 37 | assert len(self.X) == len(self.Y), 'X and Y have different lengths' 38 | print('Loaded Yelp dataset of {} samples'.format(len(self.X))) 39 | 40 | def __len__(self): 41 | return len(self.Y) 42 | 43 | def __getitem__(self, idx): 44 | return (self.X[idx], self.Y[idx]) 45 | 46 | def set_max_seq_len(self, max_seq_len): 47 | self.X = [(x[0][:max_seq_len], min(x[1], max_seq_len)) for x in self.X] 48 | self.max_seq_len = max_seq_len 49 | 50 | def get_max_seq_len(self): 51 | if not hasattr(self, 'max_seq_len'): 52 | self.max_seq_len = max([x[1] for x in self.X]) 53 | return self.max_seq_len 54 | 55 | def get_yelp_datasets(vocab, 56 | X_train_filename, 57 | Y_train_filename, 58 | num_train_lines, 59 | X_test_filename, 60 | Y_test_filename, 61 | max_seq_len): 62 | train_dataset = YelpDataset(X_train_filename, Y_train_filename, 63 | num_train_lines, vocab, max_seq_len, update_vocab=True) 64 | valid_dataset = YelpDataset(X_test_filename, Y_test_filename, 65 | 0, vocab, max_seq_len, update_vocab=True) 66 | return train_dataset, valid_dataset 67 | -------------------------------------------------------------------------------- /code/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import autograd, nn 3 | import torch.nn.functional as functional 4 | 5 | import utils 6 | 7 | class AveragingLayer(nn.Module): 8 | def __init__(self, word_emb): 9 | super(AveragingLayer, self).__init__() 10 | self.word_emb = word_emb 11 | 12 | def forward(self, input): 13 | """ 14 | input: (data, lengths): (IntTensor(batch_size, max_sent_len), IntTensor(batch_size)) 15 | """ 16 | data, lengths = input 17 | embeds = self.word_emb(data) 18 | X = embeds.sum(1).squeeze(1) 19 | lengths = lengths.view(-1, 1).expand_as(X) 20 | return X / lengths.float() 21 | 22 | 23 | class SummingLayer(nn.Module): 24 | def __init__(self, word_emb): 25 | super(SummingLayer, self).__init__() 26 | self.word_emb = word_emb 27 | 28 | def forward(self, input): 29 | """ 30 | input: (data, lengths): (IntTensor(batch_size, max_sent_len), IntTensor(batch_size)) 31 | """ 32 | data, _ = input 33 | embeds = self.word_emb(data) 34 | X = embeds.sum(1).squeeze() 35 | return X 36 | 37 | 38 | class DotAttentionLayer(nn.Module): 39 | def __init__(self, hidden_size): 40 | super(DotAttentionLayer, self).__init__() 41 | self.hidden_size = hidden_size 42 | self.W = nn.Linear(hidden_size, 1, bias=False) 43 | 44 | def forward(self, input): 45 | """ 46 | input: (unpacked_padded_output: batch_size x seq_len x hidden_size, lengths: batch_size) 47 | """ 48 | inputs, lengths = input 49 | batch_size, max_len, _ = inputs.size() 50 | flat_input = inputs.contiguous().view(-1, self.hidden_size) 51 | logits = self.W(flat_input).view(batch_size, max_len) 52 | alphas = functional.softmax(logits, dim=-1) 53 | 54 | # computing mask 55 | idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0).to(inputs.device) 56 | mask = (idxes= 0, 'Invalid layer numbers' 26 | self.fcnet = nn.Sequential() 27 | for i in range(num_layers): 28 | if dropout > 0: 29 | self.fcnet.add_module('f-dropout-{}'.format(i), nn.Dropout(p=dropout)) 30 | if i == 0: 31 | self.fcnet.add_module('f-linear-{}'.format(i), nn.Linear(vocab.emb_size, hidden_size)) 32 | else: 33 | self.fcnet.add_module('f-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size)) 34 | if batch_norm: 35 | self.fcnet.add_module('f-bn-{}'.format(i), nn.BatchNorm1d(hidden_size)) 36 | self.fcnet.add_module('f-relu-{}'.format(i), nn.ReLU()) 37 | 38 | def forward(self, input): 39 | return self.fcnet(self.avg(input)) 40 | 41 | 42 | class LSTMFeatureExtractor(nn.Module): 43 | def __init__(self, 44 | vocab, 45 | num_layers, 46 | hidden_size, 47 | dropout, 48 | bdrnn, 49 | attn_type): 50 | super(LSTMFeatureExtractor, self).__init__() 51 | self.num_layers = num_layers 52 | self.bdrnn = bdrnn 53 | self.attn_type = attn_type 54 | self.hidden_size = hidden_size//2 if bdrnn else hidden_size 55 | self.n_cells = self.num_layers*2 if bdrnn else self.num_layers 56 | 57 | self.word_emb = vocab.init_embed_layer() 58 | self.rnn = nn.LSTM(input_size=vocab.emb_size, hidden_size=self.hidden_size, 59 | num_layers=num_layers, dropout=dropout, bidirectional=bdrnn) 60 | if attn_type == 'dot': 61 | self.attn = DotAttentionLayer(hidden_size) 62 | 63 | def forward(self, input): 64 | data, lengths = input 65 | lengths_list = lengths.tolist() 66 | batch_size = len(data) 67 | embeds = self.word_emb(data) 68 | packed = pack_padded_sequence(embeds, lengths_list, batch_first=True) 69 | state_shape = self.n_cells, batch_size, self.hidden_size 70 | h0 = c0 = embeds.data.new(*state_shape) 71 | output, (ht, ct) = self.rnn(packed, (h0, c0)) 72 | 73 | if self.attn_type == 'last': 74 | return ht[-1] if not self.bdrnn \ 75 | else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1) 76 | elif self.attn_type == 'avg': 77 | unpacked_output = pad_packed_sequence(output, batch_first=True)[0] 78 | return torch.sum(unpacked_output, 1) / lengths.float().view(-1, 1) 79 | elif self.attn_type == 'dot': 80 | unpacked_output = pad_packed_sequence(output, batch_first=True)[0] 81 | return self.attn((unpacked_output, lengths)) 82 | else: 83 | raise Exception('Please specify valid attention (pooling) mechanism') 84 | 85 | 86 | class CNNFeatureExtractor(nn.Module): 87 | def __init__(self, 88 | vocab, 89 | num_layers, 90 | hidden_size, 91 | kernel_num, 92 | kernel_sizes, 93 | dropout): 94 | super(CNNFeatureExtractor, self).__init__() 95 | self.word_emb = vocab.init_embed_layer() 96 | self.kernel_num = kernel_num 97 | self.kernel_sizes = kernel_sizes 98 | 99 | self.convs = nn.ModuleList([nn.Conv2d(1, kernel_num, (K, vocab.emb_size)) for K in kernel_sizes]) 100 | 101 | assert num_layers >= 0, 'Invalid layer numbers' 102 | self.fcnet = nn.Sequential() 103 | for i in range(num_layers): 104 | if dropout > 0: 105 | self.fcnet.add_module('f-dropout-{}'.format(i), nn.Dropout(p=dropout)) 106 | if i == 0: 107 | self.fcnet.add_module('f-linear-{}'.format(i), 108 | nn.Linear(len(kernel_sizes)*kernel_num, hidden_size)) 109 | else: 110 | self.fcnet.add_module('f-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size)) 111 | self.fcnet.add_module('f-relu-{}'.format(i), nn.ReLU()) 112 | 113 | def forward(self, input): 114 | data, lengths = input 115 | batch_size = len(data) 116 | embeds = self.word_emb(data) 117 | # conv 118 | embeds = embeds.unsqueeze(1) # batch_size, 1, seq_len, emb_size 119 | x = [functional.relu(conv(embeds)).squeeze(3) for conv in self.convs] 120 | x = [functional.max_pool1d(i, i.size(2)).squeeze(2) for i in x] 121 | x = torch.cat(x, 1) 122 | # fcnet 123 | return self.fcnet(x) 124 | 125 | 126 | class SentimentClassifier(nn.Module): 127 | def __init__(self, 128 | num_layers, 129 | hidden_size, 130 | output_size, 131 | dropout, 132 | batch_norm=False): 133 | super(SentimentClassifier, self).__init__() 134 | assert num_layers >= 0, 'Invalid layer numbers' 135 | self.net = nn.Sequential() 136 | for i in range(num_layers): 137 | if dropout > 0: 138 | self.net.add_module('p-dropout-{}'.format(i), nn.Dropout(p=dropout)) 139 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size)) 140 | if batch_norm: 141 | self.net.add_module('p-bn-{}'.format(i), nn.BatchNorm1d(hidden_size)) 142 | self.net.add_module('p-relu-{}'.format(i), nn.ReLU()) 143 | 144 | self.net.add_module('p-linear-final', nn.Linear(hidden_size, output_size)) 145 | self.net.add_module('p-logsoftmax', nn.LogSoftmax(dim=-1)) 146 | 147 | def forward(self, input): 148 | return self.net(input) 149 | 150 | 151 | class LanguageDetector(nn.Module): 152 | def __init__(self, 153 | num_layers, 154 | hidden_size, 155 | dropout, 156 | batch_norm=False): 157 | super(LanguageDetector, self).__init__() 158 | assert num_layers >= 0, 'Invalid layer numbers' 159 | self.net = nn.Sequential() 160 | for i in range(num_layers): 161 | if dropout > 0: 162 | self.net.add_module('q-dropout-{}'.format(i), nn.Dropout(p=dropout)) 163 | self.net.add_module('q-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size)) 164 | if batch_norm: 165 | self.net.add_module('q-bn-{}'.format(i), nn.BatchNorm1d(hidden_size)) 166 | self.net.add_module('q-relu-{}'.format(i), nn.ReLU()) 167 | 168 | self.net.add_module('q-linear-final', nn.Linear(hidden_size, 1)) 169 | 170 | def forward(self, input): 171 | return self.net(input) 172 | -------------------------------------------------------------------------------- /code/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--max_epoch', type=int, default=30) 6 | # parser.add_argument('--dataset', default='yelp') # yelp, yelp-aren or amazon 7 | # path to the datasets 8 | parser.add_argument('--src_data_dir', default='../data/yelp-700k/') 9 | parser.add_argument('--tgt_data_dir', default='../data/hotel-170k/') 10 | parser.add_argument('--en_train_lines', type=int, default=0) # set to 0 to use all 11 | parser.add_argument('--ch_train_lines', type=int, default=0) # set to 0 to use all 12 | parser.add_argument('--max_seq_len', type=int, default=0) # set to 0 to not truncate 13 | parser.add_argument('--random_seed', type=int, default=1) 14 | parser.add_argument('--model_save_file', default='./save/adan') 15 | parser.add_argument('--batch_size', type=int, default=100) 16 | parser.add_argument('--learning_rate', type=float, default=0.0005) 17 | parser.add_argument('--Q_learning_rate', type=float, default=0.0005) 18 | # path to BWE 19 | parser.add_argument('--emb_filename', default='../data/bwe/Stanford_BWE.txt') 20 | parser.add_argument('--fix_emb', action='store_true') 21 | parser.add_argument('--random_emb', action='store_true') 22 | # use a fixed token for all words without pretrained embeddings when building vocab 23 | parser.add_argument('--fix_unk', action='store_true') 24 | parser.add_argument('--emb_size', type=int, default=50) 25 | parser.add_argument('--model', default='cnn') # dan or lstm or cnn 26 | # for LSTM model 27 | parser.add_argument('--attn', default='dot') # attention mechanism (for LSTM): avg, last, dot 28 | parser.add_argument('--bdrnn', dest='bdrnn', action='store_true', default=True) # bi-directional LSTM 29 | # use deep averaging network or deep summing network (for DAN model) 30 | parser.add_argument('--sum_pooling/', dest='sum_pooling', action='store_true') 31 | parser.add_argument('--avg_pooling/', dest='sum_pooling', action='store_false') 32 | # for CNN model 33 | parser.add_argument('--kernel_num', type=int, default=400) 34 | parser.add_argument('--kernel_sizes', type=int, nargs='+', default=[3,4,5]) 35 | parser.add_argument('--hidden_size', type=int, default=900) 36 | 37 | parser.add_argument('--F_layers', type=int, default=1) 38 | parser.add_argument('--P_layers', type=int, default=2) 39 | parser.add_argument('--Q_layers', type=int, default=2) 40 | parser.add_argument('--n_critic', type=int, default=5) 41 | parser.add_argument('--lambd', type=float, default=0.01) 42 | parser.add_argument('--F_bn/', dest='F_bn', action='store_true') 43 | parser.add_argument('--no_F_bn/', dest='F_bn', action='store_false') 44 | parser.add_argument('--P_bn/', dest='P_bn', action='store_true', default=True) 45 | parser.add_argument('--no_P_bn/', dest='P_bn', action='store_false') 46 | parser.add_argument('--Q_bn/', dest='Q_bn', action='store_true', default=True) 47 | parser.add_argument('--no_Q_bn/', dest='Q_bn', action='store_false') 48 | parser.add_argument('--dropout', type=float, default=0.2) 49 | parser.add_argument('--clip_lower', type=float, default=-0.01) 50 | parser.add_argument('--clip_upper', type=float, default=0.01) 51 | parser.add_argument('--device', type=str, default='cuda') 52 | parser.add_argument('--debug/', dest='debug', action='store_true') 53 | opt = parser.parse_args() 54 | 55 | if not torch.cuda.is_available(): 56 | opt.device = 'cpu' 57 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as functional 11 | import torch.optim as optim 12 | from torch.utils.data import DataLoader 13 | from torchnet.meter import ConfusionMeter 14 | 15 | from data_prep.yelp_dataset import get_yelp_datasets 16 | from data_prep.chn_hotel_dataset import get_chn_htl_datasets 17 | from models import * 18 | from options import opt 19 | from vocab import Vocab 20 | import utils 21 | 22 | random.seed(opt.random_seed) 23 | torch.manual_seed(opt.random_seed) 24 | 25 | # save logs 26 | if not os.path.exists(opt.model_save_file): 27 | os.makedirs(opt.model_save_file) 28 | logging.basicConfig(stream=sys.stderr, level=logging.DEBUG if opt.debug else logging.INFO) 29 | log = logging.getLogger(__name__) 30 | fh = logging.FileHandler(os.path.join(opt.model_save_file, 'log.txt')) 31 | log.addHandler(fh) 32 | 33 | # output options 34 | log.info('Training ADAN with options:') 35 | log.info(opt) 36 | 37 | 38 | def train(opt): 39 | # vocab 40 | log.info(f'Loading Embeddings...') 41 | vocab = Vocab(opt.emb_filename) 42 | # datasets 43 | log.info(f'Loading data...') 44 | yelp_X_train = os.path.join(opt.src_data_dir, 'X_train.txt.tok.shuf.lower') 45 | yelp_Y_train = os.path.join(opt.src_data_dir, 'Y_train.txt.shuf') 46 | yelp_X_test = os.path.join(opt.src_data_dir, 'X_test.txt.tok.lower') 47 | yelp_Y_test = os.path.join(opt.src_data_dir, 'Y_test.txt') 48 | yelp_train, yelp_valid = get_yelp_datasets(vocab, yelp_X_train, yelp_Y_train, 49 | opt.en_train_lines, yelp_X_test, yelp_Y_test, opt.max_seq_len) 50 | chn_X_file = os.path.join(opt.tgt_data_dir, 'X.sent.txt.shuf.lower') 51 | chn_Y_file = os.path.join(opt.tgt_data_dir, 'Y.txt.shuf') 52 | chn_train, chn_valid, chn_test = get_chn_htl_datasets(vocab, chn_X_file, chn_Y_file, 53 | opt.ch_train_lines, opt.max_seq_len) 54 | log.info('Done loading datasets.') 55 | opt.num_labels = yelp_train.num_labels 56 | 57 | if opt.max_seq_len <= 0: 58 | # set to true max_seq_len in the datasets 59 | opt.max_seq_len = max(yelp_train.get_max_seq_len(), 60 | chn_train.get_max_seq_len()) 61 | # dataset loaders 62 | my_collate = utils.sorted_collate if opt.model=='lstm' else utils.unsorted_collate 63 | yelp_train_loader = DataLoader(yelp_train, opt.batch_size, 64 | shuffle=True, collate_fn=my_collate) 65 | yelp_train_loader_Q = DataLoader(yelp_train, 66 | opt.batch_size, 67 | shuffle=True, collate_fn=my_collate) 68 | chn_train_loader = DataLoader(chn_train, opt.batch_size, 69 | shuffle=True, collate_fn=my_collate) 70 | chn_train_loader_Q = DataLoader(chn_train, 71 | opt.batch_size, 72 | shuffle=True, collate_fn=my_collate) 73 | yelp_train_iter_Q = iter(yelp_train_loader_Q) 74 | chn_train_iter = iter(chn_train_loader) 75 | chn_train_iter_Q = iter(chn_train_loader_Q) 76 | 77 | yelp_valid_loader = DataLoader(yelp_valid, opt.batch_size, 78 | shuffle=False, collate_fn=my_collate) 79 | chn_valid_loader = DataLoader(chn_valid, opt.batch_size, 80 | shuffle=False, collate_fn=my_collate) 81 | chn_test_loader = DataLoader(chn_test, opt.batch_size, 82 | shuffle=False, collate_fn=my_collate) 83 | 84 | # models 85 | if opt.model.lower() == 'dan': 86 | F = DANFeatureExtractor(vocab, opt.F_layers, opt.hidden_size, opt.dropout, opt.F_bn) 87 | elif opt.model.lower() == 'lstm': 88 | F = LSTMFeatureExtractor(vocab, opt.F_layers, opt.hidden_size, opt.dropout, 89 | opt.bdrnn, opt.attn) 90 | elif opt.model.lower() == 'cnn': 91 | F = CNNFeatureExtractor(vocab, opt.F_layers, 92 | opt.hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) 93 | else: 94 | raise Exception('Unknown model') 95 | P = SentimentClassifier(opt.P_layers, opt.hidden_size, opt.num_labels, 96 | opt.dropout, opt.P_bn) 97 | Q = LanguageDetector(opt.Q_layers, opt.hidden_size, opt.dropout, opt.Q_bn) 98 | F, P, Q = F.to(opt.device), P.to(opt.device), Q.to(opt.device) 99 | optimizer = optim.Adam(list(F.parameters()) + list(P.parameters()), 100 | lr=opt.learning_rate) 101 | optimizerQ = optim.Adam(Q.parameters(), lr=opt.Q_learning_rate) 102 | 103 | # training 104 | best_acc = 0.0 105 | for epoch in range(opt.max_epoch): 106 | F.train() 107 | P.train() 108 | Q.train() 109 | yelp_train_iter = iter(yelp_train_loader) 110 | # training accuracy 111 | correct, total = 0, 0 112 | sum_en_q, sum_ch_q = (0, 0.0), (0, 0.0) 113 | grad_norm_p, grad_norm_q = (0, 0.0), (0, 0.0) 114 | for i, (inputs_en, targets_en) in tqdm(enumerate(yelp_train_iter), 115 | total=len(yelp_train)//opt.batch_size): 116 | try: 117 | inputs_ch, _ = next(chn_train_iter) # Chinese labels are not used 118 | except: 119 | # check if Chinese data is exhausted 120 | chn_train_iter = iter(chn_train_loader) 121 | inputs_ch, _ = next(chn_train_iter) 122 | 123 | # Q iterations 124 | n_critic = opt.n_critic 125 | if n_critic>0 and ((epoch==0 and i<=25) or (i%500==0)): 126 | n_critic = 10 127 | utils.freeze_net(F) 128 | utils.freeze_net(P) 129 | utils.unfreeze_net(Q) 130 | for qiter in range(n_critic): 131 | # clip Q weights 132 | for p in Q.parameters(): 133 | p.data.clamp_(opt.clip_lower, opt.clip_upper) 134 | Q.zero_grad() 135 | # get a minibatch of data 136 | try: 137 | # labels are not used 138 | q_inputs_en, _ = next(yelp_train_iter_Q) 139 | except StopIteration: 140 | # check if dataloader is exhausted 141 | yelp_train_iter_Q = iter(yelp_train_loader_Q) 142 | q_inputs_en, _ = next(yelp_train_iter_Q) 143 | try: 144 | q_inputs_ch, _ = next(chn_train_iter_Q) 145 | except StopIteration: 146 | chn_train_iter_Q = iter(chn_train_loader_Q) 147 | q_inputs_ch, _ = next(chn_train_iter_Q) 148 | 149 | features_en = F(q_inputs_en) 150 | o_en_ad = Q(features_en) 151 | l_en_ad = torch.mean(o_en_ad) 152 | (-l_en_ad).backward() 153 | log.debug(f'Q grad norm: {Q.net[1].weight.grad.data.norm()}') 154 | sum_en_q = (sum_en_q[0] + 1, sum_en_q[1] + l_en_ad.item()) 155 | 156 | features_ch = F(q_inputs_ch) 157 | o_ch_ad = Q(features_ch) 158 | l_ch_ad = torch.mean(o_ch_ad) 159 | l_ch_ad.backward() 160 | log.debug(f'Q grad norm: {Q.net[1].weight.grad.data.norm()}') 161 | sum_ch_q = (sum_ch_q[0] + 1, sum_ch_q[1] + l_ch_ad.item()) 162 | 163 | optimizerQ.step() 164 | 165 | # F&P iteration 166 | utils.unfreeze_net(F) 167 | utils.unfreeze_net(P) 168 | utils.freeze_net(Q) 169 | if opt.fix_emb: 170 | utils.freeze_net(F.word_emb) 171 | # clip Q weights 172 | for p in Q.parameters(): 173 | p.data.clamp_(opt.clip_lower, opt.clip_upper) 174 | F.zero_grad() 175 | P.zero_grad() 176 | 177 | features_en = F(inputs_en) 178 | o_en_sent = P(features_en) 179 | l_en_sent = functional.nll_loss(o_en_sent, targets_en) 180 | l_en_sent.backward(retain_graph=True) 181 | o_en_ad = Q(features_en) 182 | l_en_ad = torch.mean(o_en_ad) 183 | (opt.lambd*l_en_ad).backward(retain_graph=True) 184 | # training accuracy 185 | _, pred = torch.max(o_en_sent, 1) 186 | total += targets_en.size(0) 187 | correct += (pred == targets_en).sum().item() 188 | 189 | features_ch = F(inputs_ch) 190 | o_ch_ad = Q(features_ch) 191 | l_ch_ad = torch.mean(o_ch_ad) 192 | (-opt.lambd*l_ch_ad).backward() 193 | 194 | optimizer.step() 195 | 196 | # end of epoch 197 | log.info('Ending epoch {}'.format(epoch+1)) 198 | # logs 199 | if sum_en_q[0] > 0: 200 | log.info(f'Average English Q output: {sum_en_q[1]/sum_en_q[0]}') 201 | log.info(f'Average Foreign Q output: {sum_ch_q[1]/sum_ch_q[0]}') 202 | # evaluate 203 | log.info('Training Accuracy: {}%'.format(100.0*correct/total)) 204 | log.info('Evaluating English Validation set:') 205 | evaluate(opt, yelp_valid_loader, F, P) 206 | log.info('Evaluating Foreign validation set:') 207 | acc = evaluate(opt, chn_valid_loader, F, P) 208 | if acc > best_acc: 209 | log.info(f'New Best Foreign validation accuracy: {acc}') 210 | best_acc = acc 211 | torch.save(F.state_dict(), 212 | '{}/netF_epoch_{}.pth'.format(opt.model_save_file, epoch)) 213 | torch.save(P.state_dict(), 214 | '{}/netP_epoch_{}.pth'.format(opt.model_save_file, epoch)) 215 | torch.save(Q.state_dict(), 216 | '{}/netQ_epoch_{}.pth'.format(opt.model_save_file, epoch)) 217 | log.info('Evaluating Foreign test set:') 218 | evaluate(opt, chn_test_loader, F, P) 219 | log.info(f'Best Foreign validation accuracy: {best_acc}') 220 | 221 | 222 | def evaluate(opt, loader, F, P): 223 | F.eval() 224 | P.eval() 225 | it = iter(loader) 226 | correct = 0 227 | total = 0 228 | confusion = ConfusionMeter(opt.num_labels) 229 | with torch.no_grad(): 230 | for inputs, targets in tqdm(it): 231 | outputs = P(F(inputs)) 232 | _, pred = torch.max(outputs, 1) 233 | confusion.add(pred.data, targets.data) 234 | total += targets.size(0) 235 | correct += (pred == targets).sum().item() 236 | accuracy = correct / total 237 | log.info('Accuracy on {} samples: {}%'.format(total, 100.0*accuracy)) 238 | log.debug(confusion.conf) 239 | return accuracy 240 | 241 | 242 | if __name__ == '__main__': 243 | train(opt) 244 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import numpy as np 3 | import torch 4 | from torch.utils.serialization import load_lua 5 | from options import opt 6 | 7 | def freeze_net(net): 8 | for p in net.parameters(): 9 | p.requires_grad = False 10 | 11 | 12 | def unfreeze_net(net): 13 | for p in net.parameters(): 14 | p.requires_grad = True 15 | 16 | 17 | def sorted_collate(batch): 18 | return my_collate(batch, sort=True) 19 | 20 | 21 | def unsorted_collate(batch): 22 | return my_collate(batch, sort=False) 23 | 24 | 25 | def my_collate(batch, sort): 26 | x, y = zip(*batch) 27 | x, y = pad(x, y, opt.eos_idx, sort) 28 | x = (x[0].to(opt.device), x[1].to(opt.device)) 29 | y = y.to(opt.device) 30 | return (x, y) 31 | 32 | 33 | def pad(x, y, eos_idx, sort): 34 | inputs, lengths = zip(*x) 35 | max_len = max(lengths) 36 | # pad sequences 37 | padded_inputs = torch.full((len(inputs), max_len), eos_idx, dtype=torch.long) 38 | for i, row in enumerate(inputs): 39 | assert eos_idx not in row, f'EOS in sequence {row}' 40 | padded_inputs[i][:len(row)] = torch.tensor(row, dtype=torch.long) 41 | lengths = torch.tensor(lengths, dtype=torch.long) 42 | y = torch.tensor(y, dtype=torch.long).view(-1) 43 | if sort: 44 | # sort by length 45 | sort_len, sort_idx = lengths.sort(0, descending=True) 46 | padded_inputs = padded_inputs.index_select(0, sort_idx) 47 | y = y.index_select(0, sort_idx) 48 | return (padded_inputs, sort_len), y 49 | else: 50 | return (padded_inputs, lengths), y 51 | 52 | 53 | def zero_eos(emb, eos_idx): 54 | emb.weight.data[eos_idx].zero_() 55 | -------------------------------------------------------------------------------- /code/vocab.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from options import opt 6 | 7 | 8 | class Vocab: 9 | def __init__(self, txt_file): 10 | self.vocab_size, self.emb_size = 0, opt.emb_size 11 | self.embeddings = [] 12 | self.w2vvocab = {} 13 | self.v2wvocab = [] 14 | # load pretrained embedings 15 | with open(txt_file, 'r') as inf: 16 | parts = inf.readline().split() 17 | assert len(parts) == 2 18 | vs, es = int(parts[0]), int(parts[1]) 19 | assert es == self.emb_size 20 | # add an UNK token 21 | self.pretrained = np.empty((vs, es), dtype=np.float) 22 | self.pt_v2wvocab = [] 23 | self.pt_w2vvocab = {} 24 | cnt = 0 25 | for line in inf: 26 | parts = line.rstrip().split(' ') 27 | word = parts[0] 28 | # add to vocab 29 | self.pt_v2wvocab.append(word) 30 | self.pt_w2vvocab[word] = cnt 31 | # load vector 32 | if len(parts) == 2: # comma separated 33 | vecs = parts[-1] 34 | vector = [float(x) for x in vecs.split(',')] 35 | else: 36 | vector = [float(x) for x in parts[-self.emb_size:]] 37 | self.pretrained[cnt] = vector 38 | cnt += 1 39 | # add 40 | self.unk_tok = '' 41 | self.add_word(self.unk_tok) 42 | self.unk_idx = self.w2vvocab[self.unk_tok] 43 | # add EOS token 44 | self.eos_tok = '' 45 | self.add_word(self.eos_tok) 46 | opt.eos_idx = self.eos_idx = self.w2vvocab[self.eos_tok] 47 | self.embeddings[self.eos_idx][:] = 0 48 | 49 | def base_form(word): 50 | return word.strip().lower() 51 | 52 | def new_rand_emb(self): 53 | vec = np.random.normal(0, 1, size=self.emb_size) 54 | vec /= sum(x*x for x in vec) ** .5 55 | return vec 56 | 57 | def init_embed_layer(self): 58 | # free some memory 59 | self.clear_pretrained_vectors() 60 | emb = nn.Embedding.from_pretrained(torch.tensor(self.embeddings, dtype=torch.float), 61 | freeze=opt.fix_emb) 62 | assert len(emb.weight) == self.vocab_size 63 | return emb 64 | 65 | def add_word(self, word): 66 | word = Vocab.base_form(word) 67 | if word not in self.w2vvocab: 68 | if not opt.random_emb and hasattr(self, 'pt_w2vvocab'): 69 | if opt.fix_unk and word not in self.pt_w2vvocab: 70 | # use fixed unk token, do not update vocab 71 | return 72 | if word in self.pt_w2vvocab: 73 | vector = self.pretrained[self.pt_w2vvocab[word]].copy() 74 | else: 75 | vector = self.new_rand_emb() 76 | else: 77 | vector = self.new_rand_emb() 78 | self.v2wvocab.append(word) 79 | self.w2vvocab[word] = self.vocab_size 80 | self.embeddings.append(vector) 81 | self.vocab_size += 1 82 | 83 | def clear_pretrained_vectors(self): 84 | del self.pretrained 85 | del self.pt_w2vvocab 86 | del self.pt_v2wvocab 87 | 88 | def lookup(self, word): 89 | word = Vocab.base_form(word) 90 | if word in self.w2vvocab: 91 | return self.w2vvocab[word] 92 | return self.unk_idx 93 | 94 | def get_word(self, i): 95 | return self.v2wvocab[i] 96 | --------------------------------------------------------------------------------