├── .gitignore
├── LICENSE
├── README.md
└── code
├── data_prep
├── chn_hotel_dataset.py
└── yelp_dataset.py
├── layers.py
├── models.py
├── options.py
├── train.py
├── utils.py
└── vocab.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Xilun Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Language-Adversarial Training for Cross-Lingual Text Classification
2 |
3 | This repo contains the source code for our TACL journal paper:
4 |
5 | [**Adversarial Deep Averaging Networks for Cross-Lingual Sentiment Classification**](https://arxiv.org/abs/1606.01614)
6 |
7 | [Xilun Chen](http://www.cs.cornell.edu/~xlchen/),
8 | Yu Sun,
9 | [Ben Athiwaratkun](http://www.benathiwaratkun.com/),
10 | [Claire Cardie](http://www.cs.cornell.edu/home/cardie/),
11 | [Kilian Weinberger](http://kilian.cs.cornell.edu/)
12 |
13 | Transactions of the Association for Computational Linguistics (TACL)
14 |
15 | [paper (arXiv)](https://arxiv.org/abs/1606.01614),
16 | [bibtex (arXiv)](http://www.cs.cornell.edu/~xlchen/resources/bibtex/adan.bib),
17 | [paper (TACL)](https://www.mitpressjournals.org/doi/abs/10.1162/tacl_a_00039),
18 | [bibtex](http://www.cs.cornell.edu/~xlchen/resources/bibtex/adan_tacl.bib),
19 | [talk@EMNLP2018](https://vimeo.com/306129914)
20 |
21 | ## Introduction
22 |
23 |
24 |
ADAN transfers the knowledge learned from labeled data on a resource-rich source language to low-resource languages where only unlabeled data exists.
25 | It achieves cross-lingual model transfer via learning language-invariant features extracted by Language-Adversarial Training.
26 |
27 | ## Requirements
28 | - Python 3.6
29 | - PyTorch 0.4
30 | - PyTorchNet (for confusion matrix)
31 | - scipy
32 | - tqdm (for progress bar)
33 |
34 | ## File Structure
35 |
36 | ```
37 | .
38 | ├── README.md
39 | └── code
40 | ├── data_prep (data processing scripts)
41 | │ ├── chn_hotel_dataset.py (processing the Chinese Hotel Review dataset)
42 | │ └── yelp_dataset.py (processing the English Yelp Review dataset)
43 | ├── layers.py (lower-level helper modules)
44 | ├── models.py (higher-level modules)
45 | ├── options.py (hyper-parameters aka. all the knobs you may want to turn)
46 | ├── train.py (main file to train the model)
47 | ├── utils.py (helper functions)
48 | └── vocab.py (vocabulary)
49 | ```
50 |
51 | ## Dataset
52 |
53 | The datasets can be downloaded separately [here](https://drive.google.com/drive/folders/1_JSr_VBVQ33hS0PuFjg68d3ePBr_eISF?usp=sharing).
54 |
55 | To support new datasets, simply write a new script under ```data_prep``` similar to the current ones and update ```train.py``` to correctly load it.
56 |
57 | ## Run Experiments
58 |
59 | ```bash
60 | python train.py --model_save_file {path_to_save_the_model}
61 | ```
62 |
63 | By default, the code uses CNN as the feature extractor.
64 | To use the LSTM (with dot attention) feature extractor:
65 |
66 | ```bash
67 | python train.py --model lstm --F_layers 2 --model_save_file {path_to_save_the_model}
68 | ```
69 |
--------------------------------------------------------------------------------
/code/data_prep/chn_hotel_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import torch
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class ChnHtlDataset(Dataset):
8 | def __init__(self, X, Y, num_train_lines, vocab, max_seq_len, update_vocab):
9 | # data is assumed to be pre-shuffled
10 | if num_train_lines > 0:
11 | X = X[:num_train_lines]
12 | Y = Y[:num_train_lines]
13 | if update_vocab:
14 | for x in X:
15 | for w in x:
16 | vocab.add_word(w)
17 | # save lengths
18 | self.X = [([vocab.lookup(w) for w in x], len(x)) for x in X]
19 | if max_seq_len > 0:
20 | self.set_max_seq_len(max_seq_len)
21 | self.Y = Y
22 | self.num_labels = 5
23 | assert len(self.X) == len(self.Y), 'X and Y have different lengths'
24 | print('Loaded Chinese Hotel dataset of {} samples'.format(len(self.X)))
25 |
26 | def __len__(self):
27 | return len(self.Y)
28 |
29 | def __getitem__(self, idx):
30 | return (self.X[idx], self.Y[idx])
31 |
32 | def set_max_seq_len(self, max_seq_len):
33 | self.X = [(x[0][:max_seq_len], min(x[1], max_seq_len)) for x in self.X]
34 | self.max_seq_len = max_seq_len
35 |
36 | def get_max_seq_len(self):
37 | if not hasattr(self, 'max_seq_len'):
38 | self.max_seq_len = max([x[1] for x in self.X])
39 | return self.max_seq_len
40 |
41 | def get_subset(self, num_lines):
42 | return ChnHtlDataset(self.X[:num_lines], self.Y[:num_lines],
43 | 0, self.max_seq_len)
44 |
45 |
46 | def get_chn_htl_datasets(vocab, X_filename, Y_filename, num_train_lines, max_seq_len):
47 | """
48 | dataset is pre-shuffled
49 | split: 150k train + 10k valid + 10k test
50 | """
51 | num_train = 150000
52 | num_valid = num_test = 10000
53 | num_total = num_train + num_valid + num_test
54 | raw_X = []
55 | with open(X_filename) as inf:
56 | for line in inf:
57 | words = line.rstrip().split()
58 | if max_seq_len > 0:
59 | words = words[:max_seq_len]
60 | raw_X.append(words)
61 | Y = (torch.from_numpy(np.loadtxt(Y_filename)) - 1).long()
62 | assert num_total == len(raw_X) == len(Y), 'X and Y have different lengths'
63 |
64 | train_dataset = ChnHtlDataset(raw_X[:num_train], Y[:num_train], num_train_lines,
65 | vocab, max_seq_len, update_vocab=True)
66 | valid_dataset = ChnHtlDataset(raw_X[num_train:num_train+num_valid],
67 | Y[num_train:num_train+num_valid],
68 | 0,
69 | vocab,
70 | max_seq_len,
71 | update_vocab=False)
72 | test_dataset = ChnHtlDataset(raw_X[num_train+num_valid:],
73 | Y[num_train+num_valid:],
74 | 0,
75 | vocab,
76 | max_seq_len,
77 | update_vocab=False)
78 | return train_dataset, valid_dataset, test_dataset
79 |
--------------------------------------------------------------------------------
/code/data_prep/yelp_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import torch
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class YelpDataset(Dataset):
8 | def __init__(self, X_file, Y_file, num_train_lines, vocab, max_seq_len, update_vocab):
9 | self.raw_X = []
10 | self.X = []
11 | with open(X_file) as inf:
12 | cnt = 0
13 | for line in inf:
14 | words = line.rstrip().split()
15 | if max_seq_len > 0:
16 | words = words[:max_seq_len]
17 | self.raw_X.append(words)
18 | if update_vocab:
19 | for w in words:
20 | vocab.add_word(w)
21 | # save lengths
22 | self.X.append(([vocab.lookup(w) for w in words], len(words)))
23 | cnt += 1
24 | if num_train_lines > 0 and cnt >= num_train_lines:
25 | break
26 |
27 | self.max_seq_len = max_seq_len
28 | if isinstance(Y_file, str):
29 | self.Y = (torch.from_numpy(np.loadtxt(Y_file)) - 1).long()
30 | else:
31 | self.Y = Y_file
32 | if num_train_lines > 0:
33 | self.X = self.X[:num_train_lines]
34 | self.Y = self.Y[:num_train_lines]
35 | self.num_labels = 5
36 | # self.Y = self.Y.to(opt.device)
37 | assert len(self.X) == len(self.Y), 'X and Y have different lengths'
38 | print('Loaded Yelp dataset of {} samples'.format(len(self.X)))
39 |
40 | def __len__(self):
41 | return len(self.Y)
42 |
43 | def __getitem__(self, idx):
44 | return (self.X[idx], self.Y[idx])
45 |
46 | def set_max_seq_len(self, max_seq_len):
47 | self.X = [(x[0][:max_seq_len], min(x[1], max_seq_len)) for x in self.X]
48 | self.max_seq_len = max_seq_len
49 |
50 | def get_max_seq_len(self):
51 | if not hasattr(self, 'max_seq_len'):
52 | self.max_seq_len = max([x[1] for x in self.X])
53 | return self.max_seq_len
54 |
55 | def get_yelp_datasets(vocab,
56 | X_train_filename,
57 | Y_train_filename,
58 | num_train_lines,
59 | X_test_filename,
60 | Y_test_filename,
61 | max_seq_len):
62 | train_dataset = YelpDataset(X_train_filename, Y_train_filename,
63 | num_train_lines, vocab, max_seq_len, update_vocab=True)
64 | valid_dataset = YelpDataset(X_test_filename, Y_test_filename,
65 | 0, vocab, max_seq_len, update_vocab=True)
66 | return train_dataset, valid_dataset
67 |
--------------------------------------------------------------------------------
/code/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import autograd, nn
3 | import torch.nn.functional as functional
4 |
5 | import utils
6 |
7 | class AveragingLayer(nn.Module):
8 | def __init__(self, word_emb):
9 | super(AveragingLayer, self).__init__()
10 | self.word_emb = word_emb
11 |
12 | def forward(self, input):
13 | """
14 | input: (data, lengths): (IntTensor(batch_size, max_sent_len), IntTensor(batch_size))
15 | """
16 | data, lengths = input
17 | embeds = self.word_emb(data)
18 | X = embeds.sum(1).squeeze(1)
19 | lengths = lengths.view(-1, 1).expand_as(X)
20 | return X / lengths.float()
21 |
22 |
23 | class SummingLayer(nn.Module):
24 | def __init__(self, word_emb):
25 | super(SummingLayer, self).__init__()
26 | self.word_emb = word_emb
27 |
28 | def forward(self, input):
29 | """
30 | input: (data, lengths): (IntTensor(batch_size, max_sent_len), IntTensor(batch_size))
31 | """
32 | data, _ = input
33 | embeds = self.word_emb(data)
34 | X = embeds.sum(1).squeeze()
35 | return X
36 |
37 |
38 | class DotAttentionLayer(nn.Module):
39 | def __init__(self, hidden_size):
40 | super(DotAttentionLayer, self).__init__()
41 | self.hidden_size = hidden_size
42 | self.W = nn.Linear(hidden_size, 1, bias=False)
43 |
44 | def forward(self, input):
45 | """
46 | input: (unpacked_padded_output: batch_size x seq_len x hidden_size, lengths: batch_size)
47 | """
48 | inputs, lengths = input
49 | batch_size, max_len, _ = inputs.size()
50 | flat_input = inputs.contiguous().view(-1, self.hidden_size)
51 | logits = self.W(flat_input).view(batch_size, max_len)
52 | alphas = functional.softmax(logits, dim=-1)
53 |
54 | # computing mask
55 | idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0).to(inputs.device)
56 | mask = (idxes= 0, 'Invalid layer numbers'
26 | self.fcnet = nn.Sequential()
27 | for i in range(num_layers):
28 | if dropout > 0:
29 | self.fcnet.add_module('f-dropout-{}'.format(i), nn.Dropout(p=dropout))
30 | if i == 0:
31 | self.fcnet.add_module('f-linear-{}'.format(i), nn.Linear(vocab.emb_size, hidden_size))
32 | else:
33 | self.fcnet.add_module('f-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
34 | if batch_norm:
35 | self.fcnet.add_module('f-bn-{}'.format(i), nn.BatchNorm1d(hidden_size))
36 | self.fcnet.add_module('f-relu-{}'.format(i), nn.ReLU())
37 |
38 | def forward(self, input):
39 | return self.fcnet(self.avg(input))
40 |
41 |
42 | class LSTMFeatureExtractor(nn.Module):
43 | def __init__(self,
44 | vocab,
45 | num_layers,
46 | hidden_size,
47 | dropout,
48 | bdrnn,
49 | attn_type):
50 | super(LSTMFeatureExtractor, self).__init__()
51 | self.num_layers = num_layers
52 | self.bdrnn = bdrnn
53 | self.attn_type = attn_type
54 | self.hidden_size = hidden_size//2 if bdrnn else hidden_size
55 | self.n_cells = self.num_layers*2 if bdrnn else self.num_layers
56 |
57 | self.word_emb = vocab.init_embed_layer()
58 | self.rnn = nn.LSTM(input_size=vocab.emb_size, hidden_size=self.hidden_size,
59 | num_layers=num_layers, dropout=dropout, bidirectional=bdrnn)
60 | if attn_type == 'dot':
61 | self.attn = DotAttentionLayer(hidden_size)
62 |
63 | def forward(self, input):
64 | data, lengths = input
65 | lengths_list = lengths.tolist()
66 | batch_size = len(data)
67 | embeds = self.word_emb(data)
68 | packed = pack_padded_sequence(embeds, lengths_list, batch_first=True)
69 | state_shape = self.n_cells, batch_size, self.hidden_size
70 | h0 = c0 = embeds.data.new(*state_shape)
71 | output, (ht, ct) = self.rnn(packed, (h0, c0))
72 |
73 | if self.attn_type == 'last':
74 | return ht[-1] if not self.bdrnn \
75 | else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
76 | elif self.attn_type == 'avg':
77 | unpacked_output = pad_packed_sequence(output, batch_first=True)[0]
78 | return torch.sum(unpacked_output, 1) / lengths.float().view(-1, 1)
79 | elif self.attn_type == 'dot':
80 | unpacked_output = pad_packed_sequence(output, batch_first=True)[0]
81 | return self.attn((unpacked_output, lengths))
82 | else:
83 | raise Exception('Please specify valid attention (pooling) mechanism')
84 |
85 |
86 | class CNNFeatureExtractor(nn.Module):
87 | def __init__(self,
88 | vocab,
89 | num_layers,
90 | hidden_size,
91 | kernel_num,
92 | kernel_sizes,
93 | dropout):
94 | super(CNNFeatureExtractor, self).__init__()
95 | self.word_emb = vocab.init_embed_layer()
96 | self.kernel_num = kernel_num
97 | self.kernel_sizes = kernel_sizes
98 |
99 | self.convs = nn.ModuleList([nn.Conv2d(1, kernel_num, (K, vocab.emb_size)) for K in kernel_sizes])
100 |
101 | assert num_layers >= 0, 'Invalid layer numbers'
102 | self.fcnet = nn.Sequential()
103 | for i in range(num_layers):
104 | if dropout > 0:
105 | self.fcnet.add_module('f-dropout-{}'.format(i), nn.Dropout(p=dropout))
106 | if i == 0:
107 | self.fcnet.add_module('f-linear-{}'.format(i),
108 | nn.Linear(len(kernel_sizes)*kernel_num, hidden_size))
109 | else:
110 | self.fcnet.add_module('f-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
111 | self.fcnet.add_module('f-relu-{}'.format(i), nn.ReLU())
112 |
113 | def forward(self, input):
114 | data, lengths = input
115 | batch_size = len(data)
116 | embeds = self.word_emb(data)
117 | # conv
118 | embeds = embeds.unsqueeze(1) # batch_size, 1, seq_len, emb_size
119 | x = [functional.relu(conv(embeds)).squeeze(3) for conv in self.convs]
120 | x = [functional.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
121 | x = torch.cat(x, 1)
122 | # fcnet
123 | return self.fcnet(x)
124 |
125 |
126 | class SentimentClassifier(nn.Module):
127 | def __init__(self,
128 | num_layers,
129 | hidden_size,
130 | output_size,
131 | dropout,
132 | batch_norm=False):
133 | super(SentimentClassifier, self).__init__()
134 | assert num_layers >= 0, 'Invalid layer numbers'
135 | self.net = nn.Sequential()
136 | for i in range(num_layers):
137 | if dropout > 0:
138 | self.net.add_module('p-dropout-{}'.format(i), nn.Dropout(p=dropout))
139 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
140 | if batch_norm:
141 | self.net.add_module('p-bn-{}'.format(i), nn.BatchNorm1d(hidden_size))
142 | self.net.add_module('p-relu-{}'.format(i), nn.ReLU())
143 |
144 | self.net.add_module('p-linear-final', nn.Linear(hidden_size, output_size))
145 | self.net.add_module('p-logsoftmax', nn.LogSoftmax(dim=-1))
146 |
147 | def forward(self, input):
148 | return self.net(input)
149 |
150 |
151 | class LanguageDetector(nn.Module):
152 | def __init__(self,
153 | num_layers,
154 | hidden_size,
155 | dropout,
156 | batch_norm=False):
157 | super(LanguageDetector, self).__init__()
158 | assert num_layers >= 0, 'Invalid layer numbers'
159 | self.net = nn.Sequential()
160 | for i in range(num_layers):
161 | if dropout > 0:
162 | self.net.add_module('q-dropout-{}'.format(i), nn.Dropout(p=dropout))
163 | self.net.add_module('q-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
164 | if batch_norm:
165 | self.net.add_module('q-bn-{}'.format(i), nn.BatchNorm1d(hidden_size))
166 | self.net.add_module('q-relu-{}'.format(i), nn.ReLU())
167 |
168 | self.net.add_module('q-linear-final', nn.Linear(hidden_size, 1))
169 |
170 | def forward(self, input):
171 | return self.net(input)
172 |
--------------------------------------------------------------------------------
/code/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--max_epoch', type=int, default=30)
6 | # parser.add_argument('--dataset', default='yelp') # yelp, yelp-aren or amazon
7 | # path to the datasets
8 | parser.add_argument('--src_data_dir', default='../data/yelp-700k/')
9 | parser.add_argument('--tgt_data_dir', default='../data/hotel-170k/')
10 | parser.add_argument('--en_train_lines', type=int, default=0) # set to 0 to use all
11 | parser.add_argument('--ch_train_lines', type=int, default=0) # set to 0 to use all
12 | parser.add_argument('--max_seq_len', type=int, default=0) # set to 0 to not truncate
13 | parser.add_argument('--random_seed', type=int, default=1)
14 | parser.add_argument('--model_save_file', default='./save/adan')
15 | parser.add_argument('--batch_size', type=int, default=100)
16 | parser.add_argument('--learning_rate', type=float, default=0.0005)
17 | parser.add_argument('--Q_learning_rate', type=float, default=0.0005)
18 | # path to BWE
19 | parser.add_argument('--emb_filename', default='../data/bwe/Stanford_BWE.txt')
20 | parser.add_argument('--fix_emb', action='store_true')
21 | parser.add_argument('--random_emb', action='store_true')
22 | # use a fixed token for all words without pretrained embeddings when building vocab
23 | parser.add_argument('--fix_unk', action='store_true')
24 | parser.add_argument('--emb_size', type=int, default=50)
25 | parser.add_argument('--model', default='cnn') # dan or lstm or cnn
26 | # for LSTM model
27 | parser.add_argument('--attn', default='dot') # attention mechanism (for LSTM): avg, last, dot
28 | parser.add_argument('--bdrnn', dest='bdrnn', action='store_true', default=True) # bi-directional LSTM
29 | # use deep averaging network or deep summing network (for DAN model)
30 | parser.add_argument('--sum_pooling/', dest='sum_pooling', action='store_true')
31 | parser.add_argument('--avg_pooling/', dest='sum_pooling', action='store_false')
32 | # for CNN model
33 | parser.add_argument('--kernel_num', type=int, default=400)
34 | parser.add_argument('--kernel_sizes', type=int, nargs='+', default=[3,4,5])
35 | parser.add_argument('--hidden_size', type=int, default=900)
36 |
37 | parser.add_argument('--F_layers', type=int, default=1)
38 | parser.add_argument('--P_layers', type=int, default=2)
39 | parser.add_argument('--Q_layers', type=int, default=2)
40 | parser.add_argument('--n_critic', type=int, default=5)
41 | parser.add_argument('--lambd', type=float, default=0.01)
42 | parser.add_argument('--F_bn/', dest='F_bn', action='store_true')
43 | parser.add_argument('--no_F_bn/', dest='F_bn', action='store_false')
44 | parser.add_argument('--P_bn/', dest='P_bn', action='store_true', default=True)
45 | parser.add_argument('--no_P_bn/', dest='P_bn', action='store_false')
46 | parser.add_argument('--Q_bn/', dest='Q_bn', action='store_true', default=True)
47 | parser.add_argument('--no_Q_bn/', dest='Q_bn', action='store_false')
48 | parser.add_argument('--dropout', type=float, default=0.2)
49 | parser.add_argument('--clip_lower', type=float, default=-0.01)
50 | parser.add_argument('--clip_upper', type=float, default=0.01)
51 | parser.add_argument('--device', type=str, default='cuda')
52 | parser.add_argument('--debug/', dest='debug', action='store_true')
53 | opt = parser.parse_args()
54 |
55 | if not torch.cuda.is_available():
56 | opt.device = 'cpu'
57 |
--------------------------------------------------------------------------------
/code/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import random
5 | import sys
6 | from tqdm import tqdm
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as functional
11 | import torch.optim as optim
12 | from torch.utils.data import DataLoader
13 | from torchnet.meter import ConfusionMeter
14 |
15 | from data_prep.yelp_dataset import get_yelp_datasets
16 | from data_prep.chn_hotel_dataset import get_chn_htl_datasets
17 | from models import *
18 | from options import opt
19 | from vocab import Vocab
20 | import utils
21 |
22 | random.seed(opt.random_seed)
23 | torch.manual_seed(opt.random_seed)
24 |
25 | # save logs
26 | if not os.path.exists(opt.model_save_file):
27 | os.makedirs(opt.model_save_file)
28 | logging.basicConfig(stream=sys.stderr, level=logging.DEBUG if opt.debug else logging.INFO)
29 | log = logging.getLogger(__name__)
30 | fh = logging.FileHandler(os.path.join(opt.model_save_file, 'log.txt'))
31 | log.addHandler(fh)
32 |
33 | # output options
34 | log.info('Training ADAN with options:')
35 | log.info(opt)
36 |
37 |
38 | def train(opt):
39 | # vocab
40 | log.info(f'Loading Embeddings...')
41 | vocab = Vocab(opt.emb_filename)
42 | # datasets
43 | log.info(f'Loading data...')
44 | yelp_X_train = os.path.join(opt.src_data_dir, 'X_train.txt.tok.shuf.lower')
45 | yelp_Y_train = os.path.join(opt.src_data_dir, 'Y_train.txt.shuf')
46 | yelp_X_test = os.path.join(opt.src_data_dir, 'X_test.txt.tok.lower')
47 | yelp_Y_test = os.path.join(opt.src_data_dir, 'Y_test.txt')
48 | yelp_train, yelp_valid = get_yelp_datasets(vocab, yelp_X_train, yelp_Y_train,
49 | opt.en_train_lines, yelp_X_test, yelp_Y_test, opt.max_seq_len)
50 | chn_X_file = os.path.join(opt.tgt_data_dir, 'X.sent.txt.shuf.lower')
51 | chn_Y_file = os.path.join(opt.tgt_data_dir, 'Y.txt.shuf')
52 | chn_train, chn_valid, chn_test = get_chn_htl_datasets(vocab, chn_X_file, chn_Y_file,
53 | opt.ch_train_lines, opt.max_seq_len)
54 | log.info('Done loading datasets.')
55 | opt.num_labels = yelp_train.num_labels
56 |
57 | if opt.max_seq_len <= 0:
58 | # set to true max_seq_len in the datasets
59 | opt.max_seq_len = max(yelp_train.get_max_seq_len(),
60 | chn_train.get_max_seq_len())
61 | # dataset loaders
62 | my_collate = utils.sorted_collate if opt.model=='lstm' else utils.unsorted_collate
63 | yelp_train_loader = DataLoader(yelp_train, opt.batch_size,
64 | shuffle=True, collate_fn=my_collate)
65 | yelp_train_loader_Q = DataLoader(yelp_train,
66 | opt.batch_size,
67 | shuffle=True, collate_fn=my_collate)
68 | chn_train_loader = DataLoader(chn_train, opt.batch_size,
69 | shuffle=True, collate_fn=my_collate)
70 | chn_train_loader_Q = DataLoader(chn_train,
71 | opt.batch_size,
72 | shuffle=True, collate_fn=my_collate)
73 | yelp_train_iter_Q = iter(yelp_train_loader_Q)
74 | chn_train_iter = iter(chn_train_loader)
75 | chn_train_iter_Q = iter(chn_train_loader_Q)
76 |
77 | yelp_valid_loader = DataLoader(yelp_valid, opt.batch_size,
78 | shuffle=False, collate_fn=my_collate)
79 | chn_valid_loader = DataLoader(chn_valid, opt.batch_size,
80 | shuffle=False, collate_fn=my_collate)
81 | chn_test_loader = DataLoader(chn_test, opt.batch_size,
82 | shuffle=False, collate_fn=my_collate)
83 |
84 | # models
85 | if opt.model.lower() == 'dan':
86 | F = DANFeatureExtractor(vocab, opt.F_layers, opt.hidden_size, opt.dropout, opt.F_bn)
87 | elif opt.model.lower() == 'lstm':
88 | F = LSTMFeatureExtractor(vocab, opt.F_layers, opt.hidden_size, opt.dropout,
89 | opt.bdrnn, opt.attn)
90 | elif opt.model.lower() == 'cnn':
91 | F = CNNFeatureExtractor(vocab, opt.F_layers,
92 | opt.hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout)
93 | else:
94 | raise Exception('Unknown model')
95 | P = SentimentClassifier(opt.P_layers, opt.hidden_size, opt.num_labels,
96 | opt.dropout, opt.P_bn)
97 | Q = LanguageDetector(opt.Q_layers, opt.hidden_size, opt.dropout, opt.Q_bn)
98 | F, P, Q = F.to(opt.device), P.to(opt.device), Q.to(opt.device)
99 | optimizer = optim.Adam(list(F.parameters()) + list(P.parameters()),
100 | lr=opt.learning_rate)
101 | optimizerQ = optim.Adam(Q.parameters(), lr=opt.Q_learning_rate)
102 |
103 | # training
104 | best_acc = 0.0
105 | for epoch in range(opt.max_epoch):
106 | F.train()
107 | P.train()
108 | Q.train()
109 | yelp_train_iter = iter(yelp_train_loader)
110 | # training accuracy
111 | correct, total = 0, 0
112 | sum_en_q, sum_ch_q = (0, 0.0), (0, 0.0)
113 | grad_norm_p, grad_norm_q = (0, 0.0), (0, 0.0)
114 | for i, (inputs_en, targets_en) in tqdm(enumerate(yelp_train_iter),
115 | total=len(yelp_train)//opt.batch_size):
116 | try:
117 | inputs_ch, _ = next(chn_train_iter) # Chinese labels are not used
118 | except:
119 | # check if Chinese data is exhausted
120 | chn_train_iter = iter(chn_train_loader)
121 | inputs_ch, _ = next(chn_train_iter)
122 |
123 | # Q iterations
124 | n_critic = opt.n_critic
125 | if n_critic>0 and ((epoch==0 and i<=25) or (i%500==0)):
126 | n_critic = 10
127 | utils.freeze_net(F)
128 | utils.freeze_net(P)
129 | utils.unfreeze_net(Q)
130 | for qiter in range(n_critic):
131 | # clip Q weights
132 | for p in Q.parameters():
133 | p.data.clamp_(opt.clip_lower, opt.clip_upper)
134 | Q.zero_grad()
135 | # get a minibatch of data
136 | try:
137 | # labels are not used
138 | q_inputs_en, _ = next(yelp_train_iter_Q)
139 | except StopIteration:
140 | # check if dataloader is exhausted
141 | yelp_train_iter_Q = iter(yelp_train_loader_Q)
142 | q_inputs_en, _ = next(yelp_train_iter_Q)
143 | try:
144 | q_inputs_ch, _ = next(chn_train_iter_Q)
145 | except StopIteration:
146 | chn_train_iter_Q = iter(chn_train_loader_Q)
147 | q_inputs_ch, _ = next(chn_train_iter_Q)
148 |
149 | features_en = F(q_inputs_en)
150 | o_en_ad = Q(features_en)
151 | l_en_ad = torch.mean(o_en_ad)
152 | (-l_en_ad).backward()
153 | log.debug(f'Q grad norm: {Q.net[1].weight.grad.data.norm()}')
154 | sum_en_q = (sum_en_q[0] + 1, sum_en_q[1] + l_en_ad.item())
155 |
156 | features_ch = F(q_inputs_ch)
157 | o_ch_ad = Q(features_ch)
158 | l_ch_ad = torch.mean(o_ch_ad)
159 | l_ch_ad.backward()
160 | log.debug(f'Q grad norm: {Q.net[1].weight.grad.data.norm()}')
161 | sum_ch_q = (sum_ch_q[0] + 1, sum_ch_q[1] + l_ch_ad.item())
162 |
163 | optimizerQ.step()
164 |
165 | # F&P iteration
166 | utils.unfreeze_net(F)
167 | utils.unfreeze_net(P)
168 | utils.freeze_net(Q)
169 | if opt.fix_emb:
170 | utils.freeze_net(F.word_emb)
171 | # clip Q weights
172 | for p in Q.parameters():
173 | p.data.clamp_(opt.clip_lower, opt.clip_upper)
174 | F.zero_grad()
175 | P.zero_grad()
176 |
177 | features_en = F(inputs_en)
178 | o_en_sent = P(features_en)
179 | l_en_sent = functional.nll_loss(o_en_sent, targets_en)
180 | l_en_sent.backward(retain_graph=True)
181 | o_en_ad = Q(features_en)
182 | l_en_ad = torch.mean(o_en_ad)
183 | (opt.lambd*l_en_ad).backward(retain_graph=True)
184 | # training accuracy
185 | _, pred = torch.max(o_en_sent, 1)
186 | total += targets_en.size(0)
187 | correct += (pred == targets_en).sum().item()
188 |
189 | features_ch = F(inputs_ch)
190 | o_ch_ad = Q(features_ch)
191 | l_ch_ad = torch.mean(o_ch_ad)
192 | (-opt.lambd*l_ch_ad).backward()
193 |
194 | optimizer.step()
195 |
196 | # end of epoch
197 | log.info('Ending epoch {}'.format(epoch+1))
198 | # logs
199 | if sum_en_q[0] > 0:
200 | log.info(f'Average English Q output: {sum_en_q[1]/sum_en_q[0]}')
201 | log.info(f'Average Foreign Q output: {sum_ch_q[1]/sum_ch_q[0]}')
202 | # evaluate
203 | log.info('Training Accuracy: {}%'.format(100.0*correct/total))
204 | log.info('Evaluating English Validation set:')
205 | evaluate(opt, yelp_valid_loader, F, P)
206 | log.info('Evaluating Foreign validation set:')
207 | acc = evaluate(opt, chn_valid_loader, F, P)
208 | if acc > best_acc:
209 | log.info(f'New Best Foreign validation accuracy: {acc}')
210 | best_acc = acc
211 | torch.save(F.state_dict(),
212 | '{}/netF_epoch_{}.pth'.format(opt.model_save_file, epoch))
213 | torch.save(P.state_dict(),
214 | '{}/netP_epoch_{}.pth'.format(opt.model_save_file, epoch))
215 | torch.save(Q.state_dict(),
216 | '{}/netQ_epoch_{}.pth'.format(opt.model_save_file, epoch))
217 | log.info('Evaluating Foreign test set:')
218 | evaluate(opt, chn_test_loader, F, P)
219 | log.info(f'Best Foreign validation accuracy: {best_acc}')
220 |
221 |
222 | def evaluate(opt, loader, F, P):
223 | F.eval()
224 | P.eval()
225 | it = iter(loader)
226 | correct = 0
227 | total = 0
228 | confusion = ConfusionMeter(opt.num_labels)
229 | with torch.no_grad():
230 | for inputs, targets in tqdm(it):
231 | outputs = P(F(inputs))
232 | _, pred = torch.max(outputs, 1)
233 | confusion.add(pred.data, targets.data)
234 | total += targets.size(0)
235 | correct += (pred == targets).sum().item()
236 | accuracy = correct / total
237 | log.info('Accuracy on {} samples: {}%'.format(total, 100.0*accuracy))
238 | log.debug(confusion.conf)
239 | return accuracy
240 |
241 |
242 | if __name__ == '__main__':
243 | train(opt)
244 |
--------------------------------------------------------------------------------
/code/utils.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import numpy as np
3 | import torch
4 | from torch.utils.serialization import load_lua
5 | from options import opt
6 |
7 | def freeze_net(net):
8 | for p in net.parameters():
9 | p.requires_grad = False
10 |
11 |
12 | def unfreeze_net(net):
13 | for p in net.parameters():
14 | p.requires_grad = True
15 |
16 |
17 | def sorted_collate(batch):
18 | return my_collate(batch, sort=True)
19 |
20 |
21 | def unsorted_collate(batch):
22 | return my_collate(batch, sort=False)
23 |
24 |
25 | def my_collate(batch, sort):
26 | x, y = zip(*batch)
27 | x, y = pad(x, y, opt.eos_idx, sort)
28 | x = (x[0].to(opt.device), x[1].to(opt.device))
29 | y = y.to(opt.device)
30 | return (x, y)
31 |
32 |
33 | def pad(x, y, eos_idx, sort):
34 | inputs, lengths = zip(*x)
35 | max_len = max(lengths)
36 | # pad sequences
37 | padded_inputs = torch.full((len(inputs), max_len), eos_idx, dtype=torch.long)
38 | for i, row in enumerate(inputs):
39 | assert eos_idx not in row, f'EOS in sequence {row}'
40 | padded_inputs[i][:len(row)] = torch.tensor(row, dtype=torch.long)
41 | lengths = torch.tensor(lengths, dtype=torch.long)
42 | y = torch.tensor(y, dtype=torch.long).view(-1)
43 | if sort:
44 | # sort by length
45 | sort_len, sort_idx = lengths.sort(0, descending=True)
46 | padded_inputs = padded_inputs.index_select(0, sort_idx)
47 | y = y.index_select(0, sort_idx)
48 | return (padded_inputs, sort_len), y
49 | else:
50 | return (padded_inputs, lengths), y
51 |
52 |
53 | def zero_eos(emb, eos_idx):
54 | emb.weight.data[eos_idx].zero_()
55 |
--------------------------------------------------------------------------------
/code/vocab.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | from options import opt
6 |
7 |
8 | class Vocab:
9 | def __init__(self, txt_file):
10 | self.vocab_size, self.emb_size = 0, opt.emb_size
11 | self.embeddings = []
12 | self.w2vvocab = {}
13 | self.v2wvocab = []
14 | # load pretrained embedings
15 | with open(txt_file, 'r') as inf:
16 | parts = inf.readline().split()
17 | assert len(parts) == 2
18 | vs, es = int(parts[0]), int(parts[1])
19 | assert es == self.emb_size
20 | # add an UNK token
21 | self.pretrained = np.empty((vs, es), dtype=np.float)
22 | self.pt_v2wvocab = []
23 | self.pt_w2vvocab = {}
24 | cnt = 0
25 | for line in inf:
26 | parts = line.rstrip().split(' ')
27 | word = parts[0]
28 | # add to vocab
29 | self.pt_v2wvocab.append(word)
30 | self.pt_w2vvocab[word] = cnt
31 | # load vector
32 | if len(parts) == 2: # comma separated
33 | vecs = parts[-1]
34 | vector = [float(x) for x in vecs.split(',')]
35 | else:
36 | vector = [float(x) for x in parts[-self.emb_size:]]
37 | self.pretrained[cnt] = vector
38 | cnt += 1
39 | # add
40 | self.unk_tok = ''
41 | self.add_word(self.unk_tok)
42 | self.unk_idx = self.w2vvocab[self.unk_tok]
43 | # add EOS token
44 | self.eos_tok = ''
45 | self.add_word(self.eos_tok)
46 | opt.eos_idx = self.eos_idx = self.w2vvocab[self.eos_tok]
47 | self.embeddings[self.eos_idx][:] = 0
48 |
49 | def base_form(word):
50 | return word.strip().lower()
51 |
52 | def new_rand_emb(self):
53 | vec = np.random.normal(0, 1, size=self.emb_size)
54 | vec /= sum(x*x for x in vec) ** .5
55 | return vec
56 |
57 | def init_embed_layer(self):
58 | # free some memory
59 | self.clear_pretrained_vectors()
60 | emb = nn.Embedding.from_pretrained(torch.tensor(self.embeddings, dtype=torch.float),
61 | freeze=opt.fix_emb)
62 | assert len(emb.weight) == self.vocab_size
63 | return emb
64 |
65 | def add_word(self, word):
66 | word = Vocab.base_form(word)
67 | if word not in self.w2vvocab:
68 | if not opt.random_emb and hasattr(self, 'pt_w2vvocab'):
69 | if opt.fix_unk and word not in self.pt_w2vvocab:
70 | # use fixed unk token, do not update vocab
71 | return
72 | if word in self.pt_w2vvocab:
73 | vector = self.pretrained[self.pt_w2vvocab[word]].copy()
74 | else:
75 | vector = self.new_rand_emb()
76 | else:
77 | vector = self.new_rand_emb()
78 | self.v2wvocab.append(word)
79 | self.w2vvocab[word] = self.vocab_size
80 | self.embeddings.append(vector)
81 | self.vocab_size += 1
82 |
83 | def clear_pretrained_vectors(self):
84 | del self.pretrained
85 | del self.pt_w2vvocab
86 | del self.pt_v2wvocab
87 |
88 | def lookup(self, word):
89 | word = Vocab.base_form(word)
90 | if word in self.w2vvocab:
91 | return self.w2vvocab[word]
92 | return self.unk_idx
93 |
94 | def get_word(self, i):
95 | return self.v2wvocab[i]
96 |
--------------------------------------------------------------------------------