":
69 | # todo: smart sampling per dim distribution
70 | # sample = numpy.random.uniform(low=-0.5, high=0.5,
71 | # size=embeddings.shape[1])
72 | sample = numpy.random.normal(mu, sigma / 4)
73 | filtered_embeddings[token_id] = sample
74 |
75 | mask[token_id] = 1
76 | missing.append(token_id)
77 | else:
78 | filtered_embeddings[token_id] = embeddings[word2idx[token]]
79 |
80 | print(f"Missing tokens from the pretrained embeddings: {len(missing)}")
81 |
82 | return filtered_embeddings, mask, missing
83 |
84 | def read_fasttext(self, file):
85 | """
86 | Create an Embeddings Matrix, in which each row corresponds to
87 | the word vector from the pretrained word embeddings.
88 | If a word is missing then obtain a representation on-the-fly
89 | using fasttext.
90 |
91 | Args:
92 | file:
93 | dim:
94 |
95 | Returns:
96 |
97 | """
98 | model = FastText.load_fasttext_format(file)
99 |
100 | embeddings = numpy.zeros((len(self), model.vector_size))
101 |
102 | missing = []
103 |
104 | for token_id, token in tqdm(self.id2tok.items(),
105 | desc="Reading embeddings...",
106 | total=len(self.id2tok.items())):
107 | if token not in model.wv.vocab:
108 | missing.append(token)
109 | embeddings[token_id] = model[token]
110 |
111 | print(f"Missing tokens from the pretrained embeddings: {len(missing)}")
112 |
113 | return embeddings, missing
114 |
115 | def add_token(self, token):
116 | index = len(self.tok2id)
117 |
118 | if token not in self.tok2id:
119 | self.tok2id[token] = index
120 | self.id2tok[index] = token
121 | self.size = len(self)
122 |
123 | def __add_special_tokens(self):
124 | self.add_token(self.PAD)
125 | self.add_token(self.SOS)
126 | self.add_token(self.EOS)
127 | self.add_token(self.UNK)
128 |
129 | def from_file(self, file, skip=0):
130 | self.__add_special_tokens()
131 |
132 | lines = open(file).readlines()[skip:]
133 | for line in lines:
134 | token = line.split()[0]
135 | self.add_token(token)
136 |
137 | def to_file(self, file):
138 | with open(file, "w") as f:
139 | f.write("\n".join(self.tok2id.keys()))
140 |
141 | def is_corrupt(self):
142 | return len([tok for tok, index in self.tok2id.items()
143 | if self.id2tok[index] != tok]) > 0
144 |
145 | def get_tokens(self):
146 | return [self.id2tok[key] for key in sorted(self.id2tok.keys())]
147 |
148 | def build(self, size=None):
149 | self.__add_special_tokens()
150 |
151 | for w, k in self.vocab.most_common(size):
152 | self.add_token(w)
153 |
154 | def __len__(self):
155 | return len(self.tok2id)
156 |
--------------------------------------------------------------------------------
/modules/helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from numpy import mean
3 | from torch.nn import functional as F
4 | from torch.nn.functional import _gumbel_softmax_sample
5 |
6 |
7 | def sequence_mask(lengths, max_len=None):
8 | """
9 | Creates a boolean mask from sequence lengths.
10 | """
11 | batch_size = lengths.numel()
12 | max_len = max_len or lengths.max()
13 | return (torch.arange(0, max_len, device=lengths.device)
14 | .type_as(lengths)
15 | .unsqueeze(0).expand(batch_size, max_len)
16 | .lt(lengths.unsqueeze(1)))
17 |
18 |
19 | def masked_normalization(logits, mask):
20 | scores = F.softmax(logits, dim=-1)
21 |
22 | # apply the mask - zero out masked timesteps
23 | masked_scores = scores * mask.float()
24 |
25 | # re-normalize the masked scores
26 | normed_scores = masked_scores.div(masked_scores.sum(-1, keepdim=True))
27 |
28 | return normed_scores
29 |
30 |
31 | def masked_mean(vecs, mask):
32 | masked_vecs = vecs * mask.float()
33 |
34 | mean = masked_vecs.sum(1) / mask.sum(1)
35 |
36 | return mean
37 |
38 |
39 | def masked_normalization_inf(logits, mask):
40 | logits.masked_fill_(1 - mask, float('-inf'))
41 | # energies.masked_fill_(1 - mask, -1e18)
42 |
43 | scores = F.softmax(logits, dim=-1)
44 |
45 | return scores
46 |
47 |
48 | def expected_vecs(dists, vecs):
49 | flat_probs = dists.contiguous().view(dists.size(0) * dists.size(1),
50 | dists.size(2))
51 | flat_embs = flat_probs.mm(vecs)
52 | embs = flat_embs.view(dists.size(0), dists.size(1), flat_embs.size(1))
53 | return embs
54 |
55 |
56 | def straight_softmax(logits, tau=1, hard=False, target_mask=None):
57 | y_soft = F.softmax(logits.squeeze() / tau, dim=1)
58 |
59 | if target_mask is not None:
60 | y_soft = y_soft * target_mask.float()
61 | y_soft.div(y_soft.sum(-1, keepdim=True))
62 |
63 | if hard:
64 | shape = logits.size()
65 | _, k = y_soft.max(-1)
66 | y_hard = logits.new_zeros(*shape).scatter_(-1, k.view(-1, 1), 1.0)
67 | y = y_hard - y_soft.detach() + y_soft
68 | return y
69 | else:
70 | return y_soft
71 |
72 |
73 | def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, target_mask=None):
74 | r"""
75 | Sample from the Gumbel-Softmax distribution and optionally discretize.
76 |
77 | Args:
78 | logits: `[batch_size, num_features]` unnormalized log probabilities
79 | tau: non-negative scalar temperature
80 | hard: if ``True``, the returned samples will be discretized as one-hot vectors,
81 | but will be differentiated as if it is the soft sample in autograd
82 |
83 | Returns:
84 | Sampled tensor of shape ``batch_size x num_features`` from the Gumbel-Softmax distribution.
85 | If ``hard=True``, the returned samples will be one-hot, otherwise they will
86 | be probability distributions that sum to 1 across features
87 |
88 | Constraints:
89 |
90 | - Currently only work on 2D input :attr:`logits` tensor of shape ``batch_size x num_features``
91 |
92 | Based on
93 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
94 | (MIT license)
95 | """
96 | shape = logits.size()
97 | assert len(shape) == 2
98 | y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps)
99 |
100 | if target_mask is not None:
101 | y_soft = y_soft * target_mask.float()
102 | y_soft.div(y_soft.sum(-1, keepdim=True))
103 |
104 | if hard:
105 | _, k = y_soft.max(-1)
106 | # this bit is based on
107 | # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
108 | y_hard = logits.new_zeros(*shape).scatter_(-1, k.view(-1, 1), 1.0)
109 | # this cool bit of code achieves two things:
110 | # - makes the output value exactly one-hot (since we add then
111 | # subtract y_soft value)
112 | # - makes the gradient equal to y_soft gradient (since we strip
113 | # all other gradients)
114 | y = y_hard - y_soft.detach() + y_soft
115 | else:
116 | y = y_soft
117 | return y
118 |
119 |
120 | def avg_vectors(vectors, mask, energies=None):
121 | if energies is None:
122 | centroid = masked_mean(vectors, mask)
123 | return centroid, None
124 |
125 | else:
126 | masked_scores = energies * mask.float()
127 | normed_scores = masked_scores.div(masked_scores.sum(1, keepdim=True))
128 | centroid = (vectors * normed_scores).sum(1)
129 | return centroid, normed_scores
130 |
131 |
132 | def aeq(*args):
133 | """
134 | Assert all arguments have the same value
135 | """
136 | arguments = (arg for arg in args)
137 | first = next(arguments)
138 | assert all(arg == first for arg in arguments), \
139 | "Not all arguments have the same value: " + str(args)
140 |
141 |
142 | def module_grad_wrt_loss(optimizers, module, loss, prefix=None):
143 | loss.backward(retain_graph=True)
144 |
145 | grad_norms = [(n, p.grad.norm().item())
146 | for n, p in module.named_parameters()]
147 |
148 | if prefix is not None:
149 | grad_norms = [g for g in grad_norms if g[0].startswith(prefix)]
150 |
151 | mean_norm = mean([gn for n, gn in grad_norms])
152 |
153 | for optimizer in optimizers:
154 | optimizer.zero_grad()
155 |
156 | return mean_norm
157 |
158 |
159 | def index_mask(mask_row, mask_col, index):
160 | A = torch.zeros((mask_row, mask_col))
161 | B = index.float()
162 | AA = 1 - A
163 | seq_len = torch.sum(AA, dim=-1)
164 | word_offset = torch.cumsum(seq_len, dim=0).cuda()
165 | BB = B + word_offset.unsqueeze(dim=-1)
166 |
167 | flag_A = A.view(-1)
168 | flag_BB = BB.view(-1).long()
169 | flag_BB = torch.sub(flag_BB, mask_col)
170 | flag_A[flag_BB] = 1
171 | A = flag_A.view(mask_row, -1)
172 |
173 | return A
174 |
175 |
176 | def kl_categorical(p_logit, q_logit):
177 | p = F.softmax(p_logit, dim=-1)
178 | _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1) - F.log_softmax(q_logit, dim=-1)), 1)
179 | return torch.mean(_kl)
180 |
--------------------------------------------------------------------------------
/modules/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Variable
4 |
5 | from modules.helpers import sequence_mask, masked_normalization_inf
6 |
7 |
8 | class GaussianNoise(nn.Module):
9 | def __init__(self, stddev, mean=.0):
10 | """
11 | Additive Gaussian Noise layer
12 | Args:
13 | stddev (float): the standard deviation of the distribution
14 | mean (float): the mean of the distribution
15 | """
16 | super().__init__()
17 | self.stddev = stddev
18 | self.mean = mean
19 |
20 | def forward(self, x):
21 | if self.training:
22 | # todo data_bug
23 | noise = Variable(x.data.new(x.size()).normal_(self.mean,
24 | self.stddev))
25 | return x + noise
26 | return x
27 |
28 | def __repr__(self):
29 | return '{} (mean={}, stddev={})'.format(self.__class__.__name__,
30 | str(self.mean),
31 | str(self.stddev))
32 |
33 |
34 | class Embed(nn.Module):
35 | def __init__(self,
36 | num_embeddings,
37 | embedding_dim,
38 | embeddings=None,
39 | noise=.0,
40 | dropout=.0,
41 | trainable=True, grad_mask=None, norm=False):
42 | """
43 | Define the layer of the model and perform the initializations
44 | of the layers (wherever it is necessary)
45 | Args:
46 | embeddings (numpy.ndarray): the 2D ndarray with the word vectors
47 | noise (float):
48 | dropout (float):
49 | trainable (bool):
50 | """
51 | super(Embed, self).__init__()
52 |
53 | self.norm = norm
54 |
55 | # define the embedding layer, with the corresponding dimensions
56 | self.embedding = nn.Embedding(num_embeddings=num_embeddings,
57 | embedding_dim=embedding_dim)
58 |
59 | # initialize the weights of the Embedding layer,
60 | # with the given pre-trained word vectors
61 | if embeddings is not None:
62 | print("Initializing Embedding layer with pre-trained weights!")
63 | self.init_embeddings(embeddings, trainable)
64 |
65 | # the dropout "layer" for the word embeddings
66 | self.dropout = nn.Dropout(dropout)
67 |
68 | # the gaussian noise "layer" for the word embeddings
69 | self.noise = GaussianNoise(noise)
70 |
71 | self.grad_mask = grad_mask
72 |
73 | if self.norm:
74 | self.layer_norm = nn.LayerNorm(embedding_dim)
75 |
76 | if self.grad_mask is not None:
77 | self.set_grad_mask(self.grad_mask)
78 |
79 | def _emb_hook(self, grad):
80 | return grad * Variable(self.grad_mask.unsqueeze(1)).type_as(grad)
81 |
82 | def set_grad_mask(self, mask):
83 | self.grad_mask = torch.from_numpy(mask)
84 | self.embedding.weight.register_hook(self._emb_hook)
85 |
86 | def init_embeddings(self, weights, trainable):
87 | self.embedding.weight = nn.Parameter(torch.from_numpy(weights),
88 | requires_grad=trainable)
89 |
90 | def regularize(self, embeddings):
91 | if self.noise.stddev > 0:
92 | embeddings = self.noise(embeddings)
93 |
94 | if self.dropout.p > 0:
95 | embeddings = self.dropout(embeddings)
96 |
97 | return embeddings
98 |
99 | def expectation(self, dists):
100 | """
101 | Obtain a weighted sum (expectation) of all the embeddings, from a
102 | given probability distribution.
103 |
104 | """
105 | flat_probs = dists.contiguous().view(dists.size(0) * dists.size(1), dists.size(2))
106 | flat_embs = flat_probs.mm(self.embedding.weight)
107 | embs = flat_embs.view(dists.size(0), dists.size(1), flat_embs.size(1))
108 |
109 | # apply layer normalization on the expectation
110 | if self.norm:
111 | embs = self.layer_norm(embs)
112 |
113 | # apply all embedding layer's regularizations
114 | embs = self.regularize(embs)
115 |
116 | return embs
117 |
118 | def forward(self, x):
119 | """
120 | This is the heart of the model. This function, defines how the data
121 | passes through the network.
122 | Args:
123 | x (): the input data (the sentences)
124 |
125 | Returns: the logits for each class
126 |
127 | """
128 | embeddings = self.embedding(x)
129 |
130 | if self.norm:
131 | embeddings = self.layer_norm(embeddings)
132 |
133 | embeddings = self.regularize(embeddings)
134 |
135 | return embeddings
136 |
137 |
138 | class Attention(nn.Module):
139 | def __init__(self,
140 | input_size,
141 | context_size,
142 | batch_first=True,
143 | non_linearity="tanh",
144 | method="general",
145 | coverage=False):
146 | super(Attention, self).__init__()
147 |
148 | self.batch_first = batch_first
149 | self.method = method
150 | self.coverage = coverage
151 |
152 | if self.method not in ["dot", "general", "concat", "additive"]:
153 | raise ValueError("Please select a valid attention type.")
154 |
155 | if self.coverage:
156 | self.W_c = nn.Linear(1, context_size, bias=False)
157 | self.method = "additive"
158 |
159 | if non_linearity == "relu":
160 | self.activation = nn.ReLU()
161 | else:
162 | self.activation = nn.Tanh()
163 |
164 | if self.method == "general":
165 | self.W_h = nn.Linear(input_size, context_size)
166 |
167 | elif self.method == "additive":
168 | self.W_h = nn.Linear(input_size, context_size)
169 | self.W_s = nn.Linear(context_size, context_size)
170 | self.W_v = nn.Linear(context_size, 1)
171 |
172 | elif self.method == "concat":
173 | self.W_h = nn.Linear(input_size + context_size, context_size)
174 | self.W_v = nn.Linear(context_size, 1)
175 |
176 | def score(self, sequence, query, coverage=None):
177 | batch_size, max_length, feat_size = sequence.size()
178 |
179 | if self.method == "dot":
180 | energies = torch.matmul(sequence, query.unsqueeze(2)).squeeze(2)
181 |
182 | elif self.method == "additive":
183 | enc = self.W_h(sequence)
184 | dec = self.W_s(query)
185 | sums = enc + dec.unsqueeze(1)
186 |
187 | if self.coverage:
188 | cov = self.W_c(coverage.unsqueeze(-1))
189 | sums = sums + cov
190 |
191 | energies = self.W_v(self.activation(sums)).squeeze(2)
192 |
193 | elif self.method == "general":
194 | h = self.W_h(sequence)
195 | energies = torch.matmul(h, query.unsqueeze(2)).squeeze(2)
196 |
197 | elif self.method == "concat":
198 | c = query.unsqueeze(1).expand(-1, max_length, -1)
199 | u = self.W_h(torch.cat([sequence, c], -1))
200 | energies = self.W_v(self.activation(u)).squeeze(2)
201 |
202 | else:
203 | raise ValueError
204 |
205 | return energies
206 |
207 | def forward(self, sequence, query, lengths, coverage=None):
208 |
209 | energies = self.score(sequence, query, coverage)
210 |
211 | # construct a mask, based on sentence lengths
212 | mask = sequence_mask(lengths, energies.size(1))
213 |
214 | scores = masked_normalization_inf(energies, mask)
215 | # scores = self.masked_normalization(energies, mask)
216 |
217 | contexts = (sequence * scores.unsqueeze(-1)).sum(1)
218 |
219 | return contexts, scores
220 |
--------------------------------------------------------------------------------
/modules/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from modules.modules import RecurrentHelper, AttSeqDecoder, SeqReader
6 | from modules.helpers import sequence_mask, avg_vectors, index_mask
7 |
8 |
9 | class Seq2Seq2Seq(nn.Module, RecurrentHelper):
10 |
11 | def __init__(self, n_tokens, **kwargs):
12 | super(Seq2Seq2Seq, self).__init__()
13 |
14 | ############################################
15 | # Attributes
16 | ############################################
17 | self.n_tokens = n_tokens
18 | self.bridge_hidden = kwargs.get("bridge_hidden", False)
19 | self.bridge_non_linearity = kwargs.get("bridge_non_linearity", None)
20 | self.detach_hidden = kwargs.get("detach_hidden", False)
21 | self.input_feeding = kwargs.get("input_feeding", False)
22 | self.length_control = kwargs.get("length_control", False)
23 | self.bi_encoder = kwargs.get("rnn_bidirectional", False)
24 | self.rnn_type = kwargs.get("rnn_type", "LSTM")
25 | self.layer_norm = kwargs.get("layer_norm", False)
26 | self.sos = kwargs.get("sos", 1)
27 | self.sample_embed_noise = kwargs.get("sample_embed_noise", 0)
28 | self.topic_idf = kwargs.get("topic_idf", False)
29 | self.dec_token_dropout = kwargs.get("dec_token_dropout", .0)
30 | self.enc_token_dropout = kwargs.get("enc_token_dropout", .0)
31 |
32 | self.batch_size = kwargs.get("batch_size")
33 | self.sent_num = kwargs.get("sent_num")
34 | self.sent_len = kwargs.get("sent_len")
35 |
36 | # tie embedding layers to output layers (vocabulary projections)
37 | kwargs["tie_weights"] = kwargs.get("tie_embedding_outputs", False)
38 |
39 | ############################################
40 | # Layers
41 | ############################################
42 |
43 | # backward-compatibility for older version of the project
44 | kwargs["rnn_size"] = kwargs.get("enc_rnn_size", kwargs.get("rnn_size"))
45 | self.inp_encoder = SeqReader(self.n_tokens, **kwargs)
46 | enc_size = self.inp_encoder.rnn_size
47 | self.sent_classification = torch.nn.Linear(enc_size, 1)
48 | self.sent_similar = torch.nn.Linear(enc_size*2, 1)
49 |
50 | # backward-compatibility for older version of the project
51 | kwargs["rnn_size"] = kwargs.get("dec_rnn_size", kwargs.get("rnn_size"))
52 | self.dia_nsent = AttSeqDecoder(self.n_tokens, enc_size, **kwargs)
53 | self.sum_nsent = AttSeqDecoder(self.n_tokens, enc_size, **kwargs)
54 |
55 | # create a dummy embedding layer, which will retrieve the idf values
56 | # of each word, given the word ids
57 | if self.topic_idf:
58 | self.idf = nn.Embedding(num_embeddings=n_tokens, embedding_dim=1)
59 | self.idf.weight.requires_grad = False
60 |
61 | if self.bridge_hidden:
62 | self._initialize_bridge(enc_size,
63 | kwargs["dec_rnn_size"],
64 | kwargs["rnn_layers"])
65 |
66 | def _initialize_bridge(self, enc_hidden_size, dec_hidden_size, num_layers):
67 | """
68 | adapted from
69 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/encoders/rnn_encoder.py#L85
70 | """
71 |
72 | # LSTM has hidden and cell state, other only one
73 | number_of_states = 2 if self.rnn_type == "LSTM" else 1
74 |
75 | if self.length_control:
76 | # add a parameter, for scaling the absolute target length
77 | self.Wl = nn.Parameter(torch.rand(1))
78 | # the length information will contain 2 additional dimensions,
79 | # - the target length
80 | # - the expansion / compression ratio given the source length
81 | enc_hidden_size += 2
82 |
83 | # Build a linear layer for each
84 | self.src_bridge = nn.ModuleList([nn.Linear(enc_hidden_size,
85 | dec_hidden_size)
86 | for _ in range(number_of_states)])
87 | self.trg_bridge = nn.ModuleList([nn.Linear(enc_hidden_size,
88 | dec_hidden_size)
89 | for _ in range(number_of_states)])
90 |
91 | def _bridge(self, bridge, hidden, src_lengths=None, trg_lengths=None):
92 | """Forward hidden state through bridge."""
93 |
94 | def _fix_hidden(_hidden):
95 | # The encoder hidden is (layers*directions) x batch x dim.
96 | # We need to convert it to layers x batch x (directions*dim).
97 | fwd_final = _hidden[0:_hidden.size(0):2]
98 | bwd_final = _hidden[1:_hidden.size(0):2]
99 | final = torch.cat([fwd_final, bwd_final], dim=2)
100 | return final
101 |
102 | def bottle_hidden(linear, states, length_feats=None):
103 | if length_feats is not None:
104 | lf = length_feats.unsqueeze(0).repeat(states.size(0), 1, 1)
105 | _states = torch.cat([states, lf], -1)
106 | result = linear(_states)
107 | else:
108 | result = linear(states)
109 |
110 | if self.bridge_non_linearity == "tanh":
111 | result = torch.tanh(result)
112 | elif self.bridge_non_linearity == "relu":
113 | result = F.relu(result)
114 |
115 | return result
116 |
117 | if self.length_control:
118 | ratio = trg_lengths.float() / src_lengths.float()
119 | lengths = trg_lengths.float() * self.Wl
120 | L = torch.stack([ratio, lengths], -1)
121 | else:
122 | L = None
123 |
124 | if isinstance(hidden, tuple): # LSTM
125 | # concat directions
126 | hidden = tuple(_fix_hidden(h) for h in hidden)
127 | outs = tuple([bottle_hidden(state, hidden[ix], L)
128 | for ix, state in enumerate(bridge)])
129 | else:
130 | outs = bottle_hidden(bridge[0], hidden)
131 |
132 | return outs
133 |
134 | def initialize_embeddings(self, embs, trainable=False):
135 |
136 | freeze = not trainable
137 |
138 | embeddings = torch.from_numpy(embs).float()
139 | embedding_layer = nn.Embedding.from_pretrained(embeddings, freeze)
140 |
141 | self.inp_encoder.embed.embedding = embedding_layer
142 | self.cmp_encoder.embed.embedding = embedding_layer
143 | self.compressor.embed.embedding = embedding_layer
144 | self.decompressor.embed.embedding = embedding_layer
145 | self.original_task.embed.embedding = embedding_layer
146 |
147 | def initialize_embeddings_idf(self, idf):
148 | idf_embs = torch.from_numpy(idf).float().unsqueeze(-1)
149 | self.idf = nn.Embedding.from_pretrained(idf_embs, freeze=True)
150 |
151 | def set_embedding_gradient_mask(self, mask):
152 | self.inp_encoder.embed.set_grad_mask(mask)
153 | self.cmp_encoder.embed.set_grad_mask(mask)
154 | self.compressor.embed.set_grad_mask(mask)
155 | self.decompressor.embed.set_grad_mask(mask)
156 | self.original_task.embed.set_grad_mask(mask)
157 |
158 | def _fake_inputs(self, inputs, latent_lengths, pad=1):
159 | batch_size, seq_len = inputs.size()
160 |
161 | if latent_lengths is not None:
162 | max_length = max(latent_lengths)
163 | else:
164 | max_length = seq_len + pad
165 |
166 | fakes = torch.zeros(batch_size, max_length, device=inputs.device)
167 | fakes = fakes.type_as(inputs)
168 | fakes[:, 0] = self.sos
169 | return fakes
170 |
171 | #def generate(self, inputs, nsent, src_lengths, trg_seq_len, nsent_len, sampling):
172 | def generate(self, inputs, src_lengths, trg_seq_len):
173 | # ENCODER1
174 | enc1_results = self.inp_encoder(inputs, None, src_lengths)
175 | outs_enc1, hn_enc1 = enc1_results[-2:]
176 |
177 | # DECODER1
178 | dec_init = self._bridge(self.src_bridge, hn_enc1, src_lengths, trg_seq_len)
179 | inp_fake = self._fake_inputs(inputs, trg_seq_len)
180 | dec1_results = self.compressor(inp_fake, outs_enc1, dec_init,
181 | argmax=True,
182 | enc_lengths=src_lengths,
183 | sampling_prob=1.,
184 | desired_lengths=trg_seq_len)
185 |
186 | return enc1_results, dec1_results
187 |
188 | def summary(self, inp_src, sent_len, sent_num):
189 | imp_src_org = inp_src.view(self.batch_size, self.sent_num, self.sent_len)
190 | inp_src = imp_src_org.view(imp_src_org.size(0) * imp_src_org.size(1), imp_src_org.size(2))
191 | inp_length = imp_src_org.size(2)
192 | sent_len = sent_len.view(self.batch_size * self.sent_num)
193 | enc1_results = self.inp_encoder(inp_src, None, sent_len, word_dropout=self.enc_token_dropout)
194 | outs_enc, hn_enc = enc1_results[-2:]
195 |
196 | sent_len_mask = torch.unsqueeze(sequence_mask(sent_len, max_len=self.sent_len), -1).float()
197 | outs_enc = torch.mul(outs_enc, sent_len_mask)
198 | outs_enc = outs_enc.view(self.batch_size, self.sent_num, self.sent_len, -1)
199 | outs_enc = torch.sum(outs_enc, dim=2)
200 |
201 | sent_num_mask = torch.unsqueeze(sequence_mask(sent_num, max_len=self.sent_num), -1).float()
202 | sent_sum_prb = self.sent_classification(outs_enc)
203 | sent_sum_prb = nn.functional.softmax(sent_sum_prb, dim=1)
204 | sent_sum_prb = torch.mul(sent_sum_prb, sent_num_mask)
205 | """_, top_k_index = torch.topk(sent_sum_prb, k=k, dim=1)
206 | top_k_index = torch.squeeze(top_k_index)
207 | top_k_mask = index_mask(self.batch_size, self.sent_num, top_k_index)
208 | top_k_mask = torch.unsqueeze(top_k_mask, dim=-1).cuda()
209 | outs_enc_filter = outs_enc.mul(top_k_mask)"""
210 |
211 | return sent_sum_prb
212 |
213 | def forward(self, k, inp_src, inp_sim, inp_trg, sim_len, sent_len, sent_num, trg_lengths):
214 | """
215 | enc1------------------>dec1
216 | | |
217 | | |
218 | summary------>enc2---->dec2
219 |
220 | (extrative-based summarization)
221 |
222 | inp_src: input source (batch x sent_num x sent_len)
223 | inp_sim: k similar sentences to nth sentence (batch x k x sim_len)
224 | inp_trg: input nsent (batch x nsent_len)
225 | sim_len: length of each sentence of similar sentences
226 | sent_len: length of each sentence in a dialogue
227 | sent_num: sentence number in a dialogue
228 | trg_lenghts: nth sentence length
229 | """
230 |
231 | # --------------------------------------------
232 | # ENCODER (encode each sentence)
233 | # --------------------------------------------
234 | # encode dialogue
235 | imp_src_org = inp_src.view(self.batch_size, self.sent_num, self.sent_len)
236 | inp_src = imp_src_org.view(imp_src_org.size(0) * imp_src_org.size(1), imp_src_org.size(2))
237 | inp_length = imp_src_org.size(2)
238 | sent_len = sent_len.view(self.batch_size * self.sent_num)
239 | enc1_results = self.inp_encoder(inp_src, None, sent_len, word_dropout=self.enc_token_dropout)
240 | outs_enc, hn_enc = enc1_results[-2:]
241 |
242 | sent_len_mask = torch.unsqueeze(sequence_mask(sent_len, max_len=self.sent_len), -1).float()
243 | outs_enc = torch.mul(outs_enc, sent_len_mask)
244 | outs_enc = outs_enc.view(self.batch_size, self.sent_num, self.sent_len, -1)
245 | outs_enc = torch.sum(outs_enc, dim=2)
246 |
247 | sent_num_mask = torch.unsqueeze(sequence_mask(sent_num, max_len=self.sent_num), -1).float()
248 | sent_sum_prb = self.sent_classification(outs_enc)
249 | sent_sum_prb = nn.functional.softmax(sent_sum_prb, dim=1)
250 | sent_sum_prb = torch.mul(sent_sum_prb, sent_num_mask)
251 | _, top_k_index = torch.topk(sent_sum_prb, k=k, dim=1)
252 | top_k_index = torch.squeeze(top_k_index)
253 | top_k_mask = index_mask(self.batch_size, self.sent_num, top_k_index)
254 | top_k_mask = torch.unsqueeze(top_k_mask, dim=-1).cuda()
255 | outs_enc_filter = outs_enc.mul(top_k_mask)
256 |
257 | # encode k similar sentences to nth sentence
258 | k_num = sim_len.size(1)
259 | inp_sim = inp_sim.view(self.batch_size, k_num, -1)
260 | inp_sim = inp_sim.view(self.batch_size * k_num, -1)
261 | sim_len = sim_len.view(self.batch_size * k_num)
262 | enc2_results = self.inp_encoder(inp_sim, None, sim_len, word_dropout=self.enc_token_dropout)
263 | outs_enc_sim, hn_enc_sim = enc2_results[-2:]
264 | outs_enc_sim = torch.sum(outs_enc_sim, dim=1)
265 | outs_enc_sim = outs_enc_sim.view(self.batch_size, k_num, -1)
266 |
267 | ## initiate decoder
268 | hn_enc_rst = []
269 | for index, hn_emc_tmp in enumerate(hn_enc):
270 | hn_emc_tmp = hn_emc_tmp.chunk(self.batch_size, dim=1)
271 | rst = []
272 | for _, sample in enumerate(hn_emc_tmp):
273 | sample = torch.sum(sample, dim=1)
274 | rst.append(sample)
275 | hn_emc_tmp = torch.stack(rst, dim=1)
276 | hn_enc_rst.append(hn_emc_tmp)
277 | hn_enc = tuple(hn_enc_rst)
278 |
279 | sent_len= sent_len.view(self.batch_size, self.sent_num)
280 | _dec_init = self._bridge(self.src_bridge, hn_enc, sent_num, trg_lengths)
281 |
282 | # -------------------------------------------------------------
283 | # DECODER-1 (generate nth sentence based on original dialogue)
284 | # -------------------------------------------------------------
285 | dec1_results = self.dia_nsent(inp_trg, outs_enc, _dec_init,
286 | enc_lengths=sent_num,
287 | sampling_prob=1.,
288 | desired_lengths=trg_lengths)
289 |
290 | # --------------------------------------------------
291 | # DECODER-2 (generate nth sentence based on summary)
292 | # --------------------------------------------------
293 | dec2_results = self.sum_nsent(inp_trg, outs_enc_filter, _dec_init,
294 | enc_lengths=sent_num,
295 | sampling_prob=1.,
296 | desired_lengths=trg_lengths)
297 |
298 | # --------------------------------------------------
299 | # Predict similar sentences
300 | # --------------------------------------------------
301 | outs_enc_pre = torch.unsqueeze(torch.sum(outs_enc, dim=1), dim=1)
302 | outs_enc_filter_pre = torch.unsqueeze(torch.sum(outs_enc_filter, dim=1), dim=1)
303 |
304 | outs_enc_pre = outs_enc_pre.expand(outs_enc_sim.size(0), outs_enc_sim.size(1), outs_enc_sim.size(2))
305 | outs_enc_filter_pre = outs_enc_filter_pre.expand(outs_enc_sim.size(0), outs_enc_sim.size(1), outs_enc_sim.size(2))
306 | outs_enc_pre = torch.cat((outs_enc_pre, outs_enc_sim), dim=-1)
307 | outs_enc_filter_pre = torch.cat((outs_enc_filter_pre, outs_enc_sim), dim=-1)
308 |
309 | outs_enc_pre = self.sent_similar(outs_enc_pre)
310 | outs_enc_filter_pre = self.sent_similar(outs_enc_filter_pre)
311 |
312 | dialog_pre = torch.squeeze(outs_enc_pre, dim=-1)
313 | summary_pre = torch.squeeze(outs_enc_filter_pre, dim=-1)
314 |
315 | return sent_sum_prb, outs_enc, outs_enc_filter, dec1_results, dec2_results, sent_len, dialog_pre, summary_pre
316 |
--------------------------------------------------------------------------------
/modules/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiyan524/RepSum/cddc47188e445ccdc23b30e9d8d5f2daa16b7c1d/modules/training/__init__.py
--------------------------------------------------------------------------------
/modules/training/base_trainer.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy
4 | import torch
5 |
6 | class BaseTrainer:
7 | def __init__(self, train_loader, valid_loader,
8 | config, device,
9 | batch_end_callbacks=None, loss_weights=None,
10 | parallel=False,
11 | **kwargs):
12 |
13 | self.train_loader = train_loader
14 | self.valid_loader = valid_loader
15 | self.device = device
16 | self.loss_weights = loss_weights
17 |
18 | self.config = config
19 | self.log_interval = self.config["log_interval"]
20 | self.batch_size = self.config["batch_size"]
21 | self.checkpoint_interval = self.config["checkpoint_interval"]
22 | self.clip = self.config["model"]["clip"]
23 |
24 | if batch_end_callbacks is None:
25 | self.batch_end_callbacks = []
26 | else:
27 | self.batch_end_callbacks = [c for c in batch_end_callbacks if callable(c)]
28 |
29 | self.epoch = 0
30 | self.step = 0
31 | self.progress_log = None
32 |
33 | # init dataset
34 | self.train_set_size = self._get_dataset_size(self.train_loader)
35 | self.val_set_size = self._get_dataset_size(self.valid_loader)
36 |
37 | self.n_batches = math.ceil(
38 | float(self.train_set_size) / self.batch_size)
39 | self.total_steps = self.n_batches * self.config["epochs"]
40 |
41 | if self.loss_weights is not None:
42 | self.loss_weights = [self.anneal_init(w) for w in
43 | self.loss_weights]
44 |
45 | @staticmethod
46 | def _roll_seq(x, dim=1, shift=1):
47 | length = x.size(dim) - shift
48 |
49 | seq = torch.cat([x.narrow(dim, shift, length),
50 | torch.zeros_like(x[:, :1])], dim)
51 |
52 | return seq
53 |
54 | @staticmethod
55 | def _get_dataset_size(loader):
56 | """
57 | If the trainer holds multiple datasets, then the size
58 | is estimated based on the largest one.
59 | """
60 | if isinstance(loader, (tuple, list)):
61 | return len(loader[0].dataset)
62 | else:
63 | return len(loader.dataset)
64 |
65 | def anneal_init(self, param, steps=None):
66 | if isinstance(param, list):
67 | if steps is None:
68 | steps = self.total_steps
69 | return numpy.geomspace(param[0], param[1], num=steps).tolist()
70 | else:
71 | return param
72 |
73 | def anneal_step(self, param):
74 | if isinstance(param, list):
75 | try:
76 | _val = param[self.step]
77 | except:
78 | _val = param[-1]
79 | else:
80 | _val = param
81 |
82 | return _val
83 |
84 | def _tensors_to_device(self, batch):
85 | """batch_trans = []
86 | for sample in batch:
87 | batch_trans.append([i for item in sample for i in item])
88 | res_tmp = list(map(lambda x: x.to(self.device), batch_trans))"""
89 | return list(map(lambda x: x.to(self.device), batch))
90 |
91 | def _batch_to_device(self, batch):
92 |
93 | if torch.is_tensor(batch[0]):
94 | batch = self._tensors_to_device(batch)
95 | else:
96 | batch = list(map(lambda x: self._tensors_to_device(x), batch))
97 |
98 | return batch
99 |
100 | @staticmethod
101 | def _multi_dataset_iter(loader, strategy, step=1):
102 | # todo: generalize to N datasets. For now works only with 2.
103 | sizes = [len(x) for x in loader]
104 |
105 | iter_a = iter(loader[0])
106 | iter_b = iter(loader[1])
107 |
108 | if strategy == "spread":
109 | step = math.floor((sizes[0] - sizes[1]) / (sizes[1] - 1))
110 |
111 | for i in range(max(sizes)):
112 | if i % (step + 1) == 0:
113 | batch_a = next(iter_a)
114 | batch_b = next(iter_b, None)
115 |
116 | if batch_b is not None:
117 | yield batch_a, batch_b
118 | else:
119 | yield batch_a
120 | else:
121 | yield next(iter_a)
122 |
123 | if strategy == "modulo":
124 | for i in range(max(sizes)):
125 | if i % step == 0:
126 | batch_a = next(iter_a)
127 | batch_b = next(iter_b, None)
128 |
129 | if batch_b is None: # reset iterator b
130 | iter_b = iter(loader[1])
131 | batch_b = next(iter_b, None)
132 |
133 | yield batch_a, batch_b
134 | else:
135 | yield next(iter_a)
136 |
137 | elif strategy == "cycle":
138 | for i in range(max(sizes)):
139 | batch_a = next(iter_a)
140 | batch_b = next(iter_b, None)
141 |
142 | if batch_b is None: # reset iterator b
143 | iter_b = iter(loader[1])
144 | batch_b = next(iter_b, None)
145 |
146 | yield batch_a, batch_b
147 |
148 | elif strategy == "beginning":
149 | for i in range(max(sizes)):
150 | batch_a = next(iter_a)
151 | batch_b = next(iter_b, None)
152 |
153 | if batch_b is not None:
154 | yield batch_a, batch_b
155 | else:
156 | yield batch_a
157 | else:
158 | raise ValueError("Invalid iteration strategy!")
159 |
160 | def _dataset_iterator(self, loader, strategy=None, step=1):
161 | # if all datasets have the same size
162 | if isinstance(loader, (tuple, list)):
163 | if len(set(len(x) for x in loader)) == 1:
164 | return zip(*loader)
165 | else:
166 | return self._multi_dataset_iter(loader, strategy, step)
167 | else:
168 | return loader
169 |
170 | def _aggregate_losses(self, batch_losses, loss_weights=None):
171 | """
172 | This function computes a weighted sum of the models losses
173 | Args:
174 | batch_losses(torch.Tensor, tuple):
175 |
176 | Returns:
177 | loss_sum (int): the aggregation of the constituent losses
178 | loss_list (list, int): the constituent losses
179 |
180 | """
181 | if isinstance(batch_losses, (tuple, list)):
182 |
183 | if loss_weights is None:
184 | loss_weights = self.loss_weights
185 | loss_weights = [self.anneal_step(w) for w in loss_weights]
186 |
187 | if loss_weights is None:
188 | loss_sum = sum(batch_losses)
189 | loss_list = [x.item() for x in batch_losses]
190 | else:
191 | loss_sum = sum(w * x for x, w in
192 | zip(batch_losses, loss_weights))
193 |
194 | loss_list = [w * x.item() for x, w in
195 | zip(batch_losses, loss_weights)]
196 | else:
197 | loss_sum = batch_losses
198 | loss_list = batch_losses.item()
199 | return loss_sum, loss_list
200 |
--------------------------------------------------------------------------------
/modules/training/trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy
4 | import torch
5 | from torch.nn.utils import clip_grad_norm_
6 |
7 | from modules.training.base_trainer import BaseTrainer
8 | from utils._logging import epoch_progress
9 | from utils.training import save_checkpoint
10 |
11 |
12 | class Trainer(BaseTrainer):
13 | """
14 | An abstract class representing a Trainer.
15 | A Trainer object, is responsible for handling the training process and
16 | provides various helper methods.
17 |
18 | All other trainers should subclass it.
19 | All subclasses should override process_batch, which handles the way
20 | you feed the input data to the model and performs a forward pass.
21 | """
22 |
23 | def __init__(self, model, train_loader, valid_loader, criterion,
24 | optimizers, config, device,
25 | batch_end_callbacks=None, loss_weights=None, **kwargs):
26 |
27 | super().__init__(train_loader, valid_loader, config, device,
28 | batch_end_callbacks, loss_weights, **kwargs)
29 |
30 | self.model = model
31 | self.criterion = criterion
32 | self.optimizers = optimizers
33 |
34 | if not isinstance(self.optimizers, (tuple, list)):
35 | self.optimizers = [self.optimizers]
36 |
37 | def _process_batch(self, *args):
38 | raise NotImplementedError
39 |
40 | def _seq_loss(self, logits, labels):
41 |
42 | """
43 | Compute a sequence loss (i.e. per timestep).
44 | Used for tasks such as Translation, Language Modeling and
45 | Sequence Labelling.
46 | """
47 | _logits = logits.contiguous().view(-1, logits.size(-1))
48 | _labels = labels.contiguous().view(-1)
49 | loss = self.criterion(_logits, _labels)
50 |
51 | return loss
52 |
53 | def grads(self):
54 | """
55 | Get the list of the norms of the gradients for each parameter
56 | """
57 | return [(name, parameter.grad.norm().item())
58 | for name, parameter in self.model.named_parameters()
59 | if parameter.requires_grad and parameter.grad is not None]
60 |
61 | def train_epoch(self, pre_train_epoch, batch_num, writer):
62 | """
63 | Train the network for one epoch and return the average loss.
64 | * This will be a pessimistic approximation of the true loss
65 | of the network, as the loss of the first batches will be higher
66 | than the true.
67 |
68 | Returns:
69 | loss (float, list(float)): list of mean losses
70 |
71 | """
72 | self.model.train()
73 | losses = []
74 |
75 | self.epoch += 1
76 | epoch_start = time.time()
77 |
78 | iterator = self._dataset_iterator(self.train_loader)
79 | for i_batch, batch in enumerate(iterator, 1):
80 |
81 | self.step += 1
82 |
83 | # zero gradients
84 | for optimizer in self.optimizers:
85 | optimizer.zero_grad()
86 |
87 | batch = self._batch_to_device(batch)
88 | if batch[0].size(0) != self.batch_size:
89 | continue
90 |
91 | # return here only the first batch losses, in order to avoid
92 | # breaking the existing framework
93 | # pre-train enc1-dec3 by self-supervised training
94 | batch_losses, batch_outputs = self._process_batch(*batch)
95 |
96 | # aggregate the losses into a single loss value
97 | loss_sum, loss_list = self._aggregate_losses(batch_losses)
98 | losses.append(loss_list)
99 | writer.add_scalar('Train/loss', loss_sum, self.step)
100 | loss_count = 0
101 | writer.add_scalar('Train/nsent1_loss', loss_list[loss_count], self.step)
102 | if self.config["model"]["n_sent_sum_loss"]:
103 | loss_count += 1
104 | writer.add_scalar('Train/nsent2_loss', loss_list[loss_count], self.step)
105 | if self.config["model"]["prior_loss"]:
106 | loss_count += 1
107 | writer.add_scalar('Train/lm_loss', loss_list[loss_count], self.step)
108 | if self.config["model"]["topic_loss"]:
109 | loss_count += 1
110 | writer.add_scalar('Train/topic_loss', loss_list[loss_count], self.step)
111 | if self.config["model"]["length_loss"]:
112 | loss_count += 1
113 | writer.add_scalar('Train/length_loss', loss_list[loss_count], self.step)
114 | if self.config["model"]["doc_sum_kl_loss"]:
115 | loss_count += 1
116 | writer.add_scalar('Train/kl_loss', loss_list[loss_count], self.step)
117 | if self.config["model"]["doc_sum_sim_loss"]:
118 | loss_count += 1
119 | writer.add_scalar('Train/doc_sim_loss', loss_list[loss_count], self.step)
120 | if self.config["model"]["sum_loss"]:
121 | loss_count += 1
122 | writer.add_scalar('Train/sum_loss', loss_list[loss_count], self.step)
123 | if self.config["model"]["nsent_classification"]:
124 | loss_count += 1
125 | writer.add_scalar('Train/cls_loss', loss_list[loss_count], self.step)
126 | if self.config["model"]["nsent_classification_sum"]:
127 | loss_count += 1
128 | writer.add_scalar('Train/cls_sum_loss', loss_list[loss_count], self.step)
129 | if self.config["model"]["nsent_classification_kl"]:
130 | loss_count += 1
131 | writer.add_scalar('Train/cla_kl_loss', loss_list[loss_count], self.step)
132 | writer.flush()
133 |
134 | # back-propagate
135 | loss_sum.backward()
136 |
137 | if self.clip is not None:
138 | # clip_grad_norm_(self.model.parameters(), self.clip)
139 | for optimizer in self.optimizers:
140 | clip_grad_norm_((p for group in optimizer.param_groups
141 | for p in group['params']), self.clip)
142 |
143 | # update weights
144 | for optimizer in self.optimizers:
145 | optimizer.step()
146 |
147 | if self.step % self.log_interval == 0:
148 | self.progress_log = epoch_progress(self.epoch, i_batch,
149 | self.batch_size,
150 | self.train_set_size,
151 | epoch_start)
152 |
153 | for c in self.batch_end_callbacks:
154 | if callable(c):
155 | c(batch, losses, loss_list, batch_outputs, self.epoch)
156 | try:
157 | return numpy.array(losses).mean(axis=0)
158 | except: # parallel losses
159 | return numpy.array([x[:len(self.loss_weights) - 1]
160 | for x in losses]).mean(axis=0)
161 |
162 | def eval_epoch(self):
163 | """
164 | Evaluate the network for one epoch and return the average loss.
165 |
166 | Returns:
167 | loss (float, list(float)): list of mean losses
168 |
169 | """
170 | self.model.eval()
171 | losses = []
172 |
173 | iterator = self._dataset_iterator(self.valid_loader)
174 | with torch.no_grad():
175 | for i_batch, batch in enumerate(iterator, 1):
176 | batch = self._batch_to_device(batch)
177 |
178 | batch_losses, batch_outputs = self._process_batch(*batch)
179 |
180 | # aggregate the losses into a single loss value
181 | loss, _losses = self._aggregate_losses(batch_losses)
182 | losses.append(_losses)
183 |
184 | return numpy.array(losses).mean(axis=0)
185 |
186 | def get_state(self):
187 | """
188 | Return a dictionary with the current state of the model.
189 | The state should contain all the important properties which will
190 | be save when taking a model checkpoint.
191 | Returns:
192 | state (dict)
193 |
194 | """
195 | state = {
196 | "config": self.config,
197 | "epoch": self.epoch,
198 | "step": self.step,
199 | "model": self.model.state_dict(),
200 | "model_class": self.model.__class__.__name__,
201 | "optimizers": [x.state_dict() for x in self.optimizers],
202 | }
203 |
204 | return state
205 |
206 | def checkpoint(self, name=None, timestamp=False, tags=None, verbose=False):
207 |
208 | if name is None:
209 | name = self.config["name"]
210 |
211 | return save_checkpoint(self.get_state(),
212 | name=name, tag=tags, timestamp=timestamp,
213 | verbose=verbose)
214 |
--------------------------------------------------------------------------------
/mylogger/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiyan524/RepSum/cddc47188e445ccdc23b30e9d8d5f2daa16b7c1d/mylogger/__init__.py
--------------------------------------------------------------------------------
/mylogger/attention.py:
--------------------------------------------------------------------------------
1 | import html
2 | import os
3 |
4 | import numpy
5 | import numpy as np
6 |
7 |
8 | def viz_sequence(words, scores=None, color="255, 0, 0"):
9 | text = []
10 |
11 | if scores is None:
12 | scores = [0] * len(words)
13 | else:
14 | # mean = numpy.mean(scores)
15 | # std = numpy.std(scores)
16 | # scores = [(x - mean) / (6 * std) for x in scores]
17 |
18 | length = len([x for x in scores if x != 0])
19 | scores = [x / sum(scores) for x in scores]
20 | mean = numpy.mean(scores[:length])
21 | std = numpy.std(scores[:length])
22 |
23 | scores = [max(0, (x - mean) / (6 * std)) for x in scores]
24 |
25 | # score = (score - this.att_mean) / (4 * this.att_std);
26 | for word, score in zip(words, scores):
27 | text.append(f"{html.escape(word)}")
29 | return "".join(text)
30 |
31 |
32 | def viz_summary(seqs):
33 | txt = ""
34 | for name, data, color in seqs:
35 | if isinstance(data, tuple):
36 | _text = viz_sequence(data[0], data[1], color=color)
37 | length = len(data[0])
38 | else:
39 | _text = viz_sequence(data)
40 | length = len(data)
41 |
42 | txt += f"{name}({length}): {_text}
"
43 |
44 | return f"{txt}
"
45 |
46 |
47 | def sample(words):
48 | return np.random.dirichlet(np.ones(len(words)))
49 |
50 |
51 | def samples2dom(samples):
52 | dom = """
53 |
54 |
55 |
56 |
57 |
79 |
80 | """
81 | for s in samples:
82 | dom += viz_summary(s)
83 |
84 | dom += """
85 |
86 |
87 | """
88 | return dom
89 |
90 |
91 | def samples2html(samples):
92 | dom = """
93 |
123 |
124 | """
125 |
126 | for s in samples:
127 | dom += viz_summary(s)
128 |
129 | dom += """
130 |
131 | """
132 | return dom
133 |
134 |
135 | def viz_seq3(dom):
136 | # or simply save in an html file and open in browser
137 | file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
138 | 'attention.html')
139 |
140 | with open(file, 'w') as f:
141 | f.write(dom)
142 |
143 | # samples = []
144 | # for i in range(10):
145 | # source = lorem.sentence().split()
146 | # scores = sample(source)
147 | # summary = lorem.sentence().split()
148 | # reconstruction = lorem.sentence().split()
149 | # samples.append(((source, scores), summary, reconstruction))
150 | # viz_seq3(samples2html(samples))
151 |
--------------------------------------------------------------------------------
/mylogger/db.json:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiyan524/RepSum/cddc47188e445ccdc23b30e9d8d5f2daa16b7c1d/mylogger/db.json
--------------------------------------------------------------------------------
/mylogger/experiment.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pickle
4 | import sys
5 | import time
6 | from collections import defaultdict
7 | from datetime import datetime
8 |
9 | from pymongo import MongoClient
10 | from tabulate import tabulate
11 |
12 | from mylogger.helpers import dict_to_html, files_to_dict
13 | from mylogger.plotting import Visualizer
14 | from sys_config import VIS, BASE_DIR
15 |
16 |
17 | class Experiment(object):
18 | """
19 | Experiment class
20 | """
21 |
22 | def __init__(self, name, config, desc=None,
23 | output_dir=None,
24 | src_dirs=None,
25 | use_db=True,
26 | db_host="localhost",
27 | db_port=27017,
28 | db_uri=None,
29 | db_name="experiments"):
30 | """
31 |
32 | Metrics = history of values
33 | Values = state of values
34 | Args:
35 | name:
36 | config:
37 | desc:
38 | output_dir:
39 | src_dirs:
40 | use_db:
41 | db_host:
42 | db_port:
43 | db_uri: mongodb://[username:password@]host1[:port1]
44 | db_name:
45 | """
46 | self.name = name
47 | self.desc = desc
48 | self.config = config
49 | self.metrics = defaultdict(Metric)
50 | self.values = defaultdict(Value)
51 |
52 | self.use_db = use_db
53 | self.db_host = db_host
54 | self.db_port = db_port
55 | self.db_uri = db_uri
56 | self.db_name = db_name
57 |
58 | # the src files (dirs) to backup
59 | if src_dirs is not None:
60 | self.src = files_to_dict(src_dirs)
61 | else:
62 | self.src = None
63 |
64 | # the currently running script
65 | self.src_main = sys.argv[0]
66 |
67 | self.timestamp_start = datetime.now()
68 | self.timestamp_update = datetime.now()
69 | self.last_update = time.time()
70 |
71 | if output_dir is not None:
72 | self.output_dir = output_dir
73 | else:
74 | self.output_dir = BASE_DIR
75 |
76 | server = VIS["server"]
77 | port = VIS["port"]
78 | base_url = VIS["base_url"]
79 | http_proxy_host = VIS["http_proxy_host"]
80 | http_proxy_port = VIS["http_proxy_port"]
81 | self.enabled = VIS["enabled"]
82 | vis_log_file = os.path.join(self.output_dir, f"{self.name}.vis")
83 |
84 | if self.enabled:
85 | self.viz = Visualizer(env=name,
86 | server=server,
87 | port=port,
88 | base_url=base_url,
89 | http_proxy_host=http_proxy_host,
90 | http_proxy_port=http_proxy_port,
91 | log_to_filename=vis_log_file)
92 |
93 | self.add_value("config", "text")
94 | self.update_value("config", dict_to_html(self.config))
95 |
96 | # connect to MongoDB
97 | if self.use_db and self.enabled:
98 | if self.db_uri:
99 | self.db_client = MongoClient(self.db_uri)
100 | else:
101 | self.db_client = MongoClient(self.db_host, self.db_port)
102 |
103 | self.db = self.db_client[self.db_name]
104 | self.db_collection = self.db.experiments
105 | self.db_record = None
106 |
107 | #############################################################
108 | # Metric
109 | #############################################################
110 | def add_metric(self, key, vis_type, title=None, tags=None):
111 | """
112 | Add a new metric to the experiment.
113 | Metrics hold a history of all the inserted values.
114 | The last value(s) will be used for presentation (plotting and console)
115 | Args:
116 | key (str): the name of the value. This will be used for getting
117 | a handle of the metric
118 | vis_type (str): the visualization type
119 | tags (list): list of tags e.g. ["train_set", "val_set"]
120 | title (str): used for presentation purposes (figure, console...)
121 |
122 | Returns:
123 |
124 | """
125 | self.metrics[key] = Metric(key, vis_type, tags, title)
126 |
127 | def get_metric(self, key):
128 | """
129 | Returns a handle to the metric with the given key
130 | Args:
131 | key:
132 |
133 | Returns:
134 |
135 | """
136 | return self.metrics[key]
137 |
138 | def update_metric(self, key, value, tag=None):
139 | """
140 | Add new value to the given metric
141 | Args:
142 | key:
143 | value:
144 | tag:
145 |
146 | Returns:
147 |
148 | """
149 | self.get_metric(key).add(value, tag)
150 |
151 | try:
152 | if self.enabled:
153 | self.__plot_metric(key)
154 |
155 | except IndexError as e:
156 | pass
157 |
158 | except Exception as e:
159 | print(f"An error occurred while trying to plot metric:{key}")
160 |
161 | def __plot_metric(self, key):
162 |
163 | metric = self.get_metric(key)
164 |
165 | if metric.vis_type == "line":
166 |
167 | if metric.tags is not None:
168 | x = [[len(metric.values[tag])] for tag in metric.tags]
169 | y = [[metric.values[tag][-1]] for tag in metric.tags]
170 | else:
171 | x = [len(metric.values)]
172 | y = [metric.values[-1]]
173 | self.viz.plot_line(y, x, metric.title, metric.tags)
174 |
175 | elif metric.vis_type == "scatter":
176 | raise NotImplementedError
177 | elif metric.vis_type == "bar":
178 | raise NotImplementedError
179 | else:
180 | raise NotImplementedError
181 |
182 | #############################################################
183 | # Value
184 | #############################################################
185 | def add_value(self, key, vis_type, title=None, tags=None, init=None):
186 | self.values[key] = Value(key, vis_type, tags, title)
187 |
188 | def get_value(self, key):
189 | return self.values[key]
190 |
191 | def update_value(self, key, value, tag=None):
192 | """
193 | Update the state of the given value
194 | Args:
195 | key:
196 | value:
197 | tag:
198 |
199 | Returns:
200 |
201 | """
202 | self.get_value(key).update(value, tag)
203 |
204 | try:
205 | if self.enabled:
206 | self.__plot_value(key)
207 |
208 | except IndexError as e:
209 | pass
210 |
211 | except Exception as e:
212 | print(f"An error occurred while trying to plot value:{key}")
213 |
214 | def __plot_value(self, key):
215 | value = self.get_value(key)
216 |
217 | if value.vis_type == "text":
218 | self.viz.plot_text(value.value, value.title, pre=value.pre)
219 | elif value.vis_type == "scatter":
220 | if value.tags is not None:
221 | raise NotImplementedError
222 | else:
223 | data = value.value
224 |
225 | self.viz.plot_scatter(data[0], data[1], value.title)
226 | elif value.vis_type == "heatmap":
227 | if value.tags is not None:
228 | raise NotImplementedError
229 | else:
230 | data = value.value
231 |
232 | self.viz.plot_heatmap(data[0], data[1], value.title)
233 | elif value.vis_type == "bar":
234 | if value.tags is not None:
235 | raise NotImplementedError
236 | else:
237 | data = value.value
238 |
239 | self.viz.plot_bar(data[0], data[1], value.title)
240 | else:
241 | raise NotImplementedError
242 |
243 | #############################################################
244 | # Persistence
245 | #############################################################
246 | def _state_dict(self):
247 | omit = ["db", "db_client", "db_collection"]
248 | state = {k: v for k, v in self.__dict__.items() if k not in omit}
249 |
250 | return state
251 |
252 | def to_db(self):
253 | self.timestamp_update = datetime.now()
254 | # record = self._state_dict()
255 |
256 | # todo: avoid this workaround
257 | record = json.loads(self._serialize())
258 |
259 | if self.db_record is None:
260 | self.db_record = self.db_collection.insert(record)
261 | else:
262 | self.db_collection.replace_one({"_id": self.db_record}, record)
263 |
264 | def _serialize(self):
265 |
266 | data = json.dumps(self._state_dict(),
267 | default=lambda o: getattr(o, '__dict__', str(o)))
268 | return data
269 |
270 | def to_json(self):
271 | self.timestamp_update = datetime.now()
272 | name = self.name + "_{}.json".format(self.get_timestamp())
273 | filename = os.path.join(self.output_dir, name)
274 | with open(filename, 'w', encoding='utf-8') as f:
275 | f.write(self._serialize())
276 |
277 | def get_timestamp(self):
278 | return self.timestamp_start.strftime("%y-%m-%d_%H:%M:%S")
279 |
280 | def to_pickle(self):
281 | self.timestamp_update = datetime.now()
282 | name = self.name + "_{}.pickle".format(self.get_timestamp())
283 | filename = os.path.join(self.output_dir, name)
284 | with open(filename, 'wb') as f:
285 | pickle.dump(self._state_dict(), f)
286 |
287 | def save(self):
288 | try:
289 | self.to_pickle()
290 | except:
291 | print("Failed to save to pickle...")
292 |
293 | # try:
294 | # self.to_json()
295 | # except:
296 | # print("Failed to save to json...")
297 |
298 | # try:
299 | # self.to_db()
300 | # except:
301 | # print("Failed to save to db...")
302 |
303 | def log_metrics(self, keys, epoch):
304 |
305 | _metrics = [self.metrics[key] for key in keys]
306 | _tags = _metrics[0].tags
307 | if _tags is not None:
308 | values = [[tag] + [metric.values[tag][-1] for metric in _metrics] for tag in _tags]
309 | headers = ["TAG"] + [metric.title.upper() for metric in _metrics]
310 | else:
311 | values = [[metric.values[-1] for metric in _metrics]]
312 | headers = [metric.title.upper() for metric in _metrics]
313 |
314 | log_output = tabulate(values, headers, floatfmt=".4f")
315 |
316 | return log_output
317 |
318 |
319 | class Metric(object):
320 | """
321 | Metric hold the data of a value of the model that is being monitored
322 |
323 | A Metric object has to have a name,
324 | a vis_type which defines how it will be visualized
325 | and a dataset on which it will be attached to.
326 | """
327 |
328 | def __init__(self, key, vis_type, tags=None, title=None):
329 | """
330 |
331 | Args:
332 | key (str): the name of the metric
333 | vis_type (str): the visualization type
334 | tags (list): list of tags
335 | title (str): used for presentation purposes (figure, console...)
336 | """
337 | self.key = key
338 | self.title = title
339 | self.vis_type = vis_type
340 | self.tags = tags
341 |
342 | assert vis_type in ["line"]
343 |
344 | if tags is not None:
345 | self.values = {tag: [] for tag in tags}
346 | else:
347 | self.values = []
348 |
349 | if title is None:
350 | self.title = key
351 |
352 | def add(self, value, tag=None):
353 | """
354 | Add a value to the list of values of this metric
355 | Args:
356 | value (int, float):
357 | tag (str):
358 |
359 | Returns:
360 |
361 | """
362 | if self.tags is not None:
363 | self.values[tag].append(value)
364 | else:
365 | self.values.append(value)
366 |
367 |
368 | class Value(object):
369 | """
370 |
371 | """
372 |
373 | def __init__(self, key, vis_type, tags=None, title=None, pre=True):
374 | """
375 |
376 | Args:
377 | key (str): the name of the value
378 | vis_type (str): the visualization type
379 | tags (list): list of tags
380 | title (str): used for presentation purposes (figure, console...)
381 | """
382 | self.key = key
383 | self.title = title
384 | self.vis_type = vis_type
385 | self.tags = tags
386 | self.pre = pre
387 |
388 | assert vis_type in ["text", "scatter", "bar", "heatmap"]
389 |
390 | if tags is not None:
391 | self.value = {tag: [] for tag in tags}
392 | else:
393 | self.value = []
394 |
395 | if title is None:
396 | self.title = key
397 |
398 | def update(self, value, tag=None):
399 | """
400 | Update the value
401 | Args:
402 | value (int, float):
403 | tag (str):
404 |
405 | Returns:
406 |
407 | """
408 | if self.tags is not None:
409 | self.value[tag] = value
410 | else:
411 | self.value = value
412 |
--------------------------------------------------------------------------------
/mylogger/helpers.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import defaultdict
4 | from pathlib import Path
5 |
6 | from glob2 import glob
7 |
8 |
9 | def files_to_dict(dirs, safe=True):
10 | file_data = defaultdict(dict)
11 |
12 | for dir in dirs:
13 | for file in glob(os.path.join(dir + "/*.py")):
14 | _dir = os.path.split(dir)[1]
15 | filename = os.path.basename(file)
16 | if safe:
17 | filename = filename.replace('.', '[dot]')
18 | file_data[_dir][filename] = Path(file).read_text()
19 |
20 | return file_data
21 |
22 |
23 | def dict_to_html(config):
24 | indent = 2
25 | msg = json.dumps(config, indent=indent)
26 | msg = "\n".join([line[2:].rstrip() for line in msg.split("\n")
27 | if len(line.strip()) > 3])
28 | # format with html
29 | msg = msg.replace('{', '')
30 | msg = msg.replace('}', '')
31 | # msg = msg.replace('\n', '
')
32 | return msg
33 |
--------------------------------------------------------------------------------
/mylogger/inspection.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | from visdom import Visdom
3 |
4 | from mylogger.plotting import plot_line
5 |
6 |
7 | class Inspector(object):
8 | """
9 | Class for inspecting the internals of neural networks
10 | """
11 |
12 | def __init__(self, model, stats):
13 | """
14 |
15 | Args:
16 | model (torch.nn.Module): the PyTorch model
17 | stats (list): list of stats names. e.g. ["std", "mean"]
18 | """
19 |
20 | # watch only trainable layers
21 | self.watched_layers = {}
22 | for name, module in self.get_watched_modules(model):
23 | self.watched_layers[name] = {stat: [] for stat in stats}
24 |
25 | self.viz = Visdom()
26 | self.update_state(model)
27 |
28 | def get_watched_modules(self, model):
29 | all_modules = []
30 | for name, module in model.named_modules():
31 | if len(list(module.parameters())) > 0 and all(
32 | param.requires_grad for param in module.parameters()):
33 | all_modules.append((name, module))
34 |
35 | # filter parent nodes
36 | fitered_modules = []
37 | for name, module in all_modules:
38 | if not any(
39 | [(name in n and name is not n) for n, m in all_modules]):
40 | fitered_modules.append((name, module))
41 |
42 | return fitered_modules
43 |
44 | def plot_layer(self, name, weights):
45 | self.viz.histogram(X=weights,
46 | win=name,
47 | opts=dict(title="{} weights dist".format(name),
48 | numbins=40))
49 | for stat_name, stat_val in self.watched_layers[name].items():
50 | stat_val.append(getattr(numpy, stat_name)(weights))
51 |
52 | plot_name = "{}-{}".format(name, stat_name)
53 | plot_line(self.viz, numpy.array(stat_val), plot_name, [plot_name])
54 |
55 | def update_state(self, model):
56 | gen = (child for child in model.named_modules()
57 | if child[0] in self.watched_layers)
58 | for name, layer in gen:
59 | weights = [param.data.cpu().numpy() for param in
60 | layer.parameters()]
61 | if len(weights) > 0:
62 | weights = numpy.concatenate([w.ravel() for w in weights])
63 | self.plot_layer(name, weights)
64 |
--------------------------------------------------------------------------------
/mylogger/plotting.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | from visdom import Visdom
3 | import matplotlib.pyplot as plt
4 | import seaborn as sns
5 |
6 |
7 | class Visualizer:
8 |
9 | def __init__(self, env="main",
10 | server="http://localhost",
11 | port=8097,
12 | base_url="/",
13 | http_proxy_host=None,
14 | http_proxy_port=None,
15 | log_to_filename=None):
16 | self._viz = Visdom(env=env,
17 | server=server,
18 | port=port,
19 | http_proxy_host=http_proxy_host,
20 | http_proxy_port=http_proxy_port,
21 | log_to_filename=log_to_filename,
22 | use_incoming_socket=False)
23 | self._viz.close(env=env)
24 |
25 | def plot_line(self, values, steps, name, legend=None):
26 | if legend is None:
27 | opts = dict(title=name)
28 | else:
29 | opts = dict(title=name, legend=legend)
30 |
31 | self._viz.line(
32 | X=numpy.column_stack(steps),
33 | Y=numpy.column_stack(values),
34 | win=name,
35 | update='append',
36 | opts=opts
37 | )
38 |
39 | def plot_text(self, text, title, pre=True):
40 | _width = max([len(x) for x in text.split("\n")]) * 10
41 | _heigth = len(text.split("\n")) * 20
42 | _heigth = max(_heigth, 120)
43 | if pre:
44 | text = "{}
".format(text)
45 |
46 | self._viz.text(text, win=title, opts=dict(title=title,
47 | width=min(_width, 400),
48 | height=min(_heigth, 400)))
49 |
50 | def plot_bar(self, data, labels, title):
51 | self._viz.bar(win=title, X=data,
52 | opts=dict(legend=labels, stacked=False, title=title))
53 |
54 | def plot_scatter(self, data, labels, title):
55 | X = numpy.concatenate(data, axis=0)
56 | Y = numpy.concatenate([numpy.full(len(d), i)
57 | for i, d in enumerate(data, 1)], axis=0)
58 | self._viz.scatter(win=title, X=X, Y=Y,
59 | opts=dict(legend=labels, title=title,
60 | markersize=5,
61 | webgl=True,
62 | width=400,
63 | height=400,
64 | markeropacity=0.5))
65 |
66 | def plot_heatmap(self, data, labels, title):
67 | self._viz.heatmap(win=title,
68 | X=data,
69 | opts=dict(
70 | title=title,
71 | columnnames=labels[1],
72 | rownames=labels[0],
73 | width=700,
74 | height=700,
75 | layoutopts={'plotly': {
76 | 'xaxis': {
77 | 'side': 'top',
78 | 'tickangle': -60,
79 | # 'autorange': "reversed"
80 | },
81 | 'yaxis': {
82 | 'autorange': "reversed"
83 | },
84 | }
85 | }
86 | ))
87 |
--------------------------------------------------------------------------------
/rouge-test.py:
--------------------------------------------------------------------------------
1 | import files2rouge
2 | import chardet
3 | import codecs
4 | import os
5 |
6 | dec_path = ""
7 | ref_path = ""
8 | result_id_path = ""
9 |
10 |
11 | def tokens_to_ids(token_list1, token_list2):
12 | ids = {}
13 | out1 = []
14 | out2 = []
15 | for token in token_list1:
16 | out1.append(ids.setdefault(token, len(ids)))
17 | for token in token_list2:
18 | out2.append(ids.setdefault(token, len(ids)))
19 |
20 | return out1, out2
21 |
22 | def write_id(id_lst, file):
23 | for id in id_lst:
24 | file.write(str(id)+" ")
25 | file.write("\n")
26 |
27 | def trans_id(dec_path, ref_path):
28 | dec_files = codecs.open(dec_path, encoding="utf-8").read().split("\n")
29 | ref_files = codecs.open(ref_path, encoding="utf-8").read().split("\n")
30 |
31 | dec_files_id = codecs.open(os.path.join(result_id_path, "decode_id_tmp.txt"), 'a')
32 | ref_files_id = codecs.open(os.path.join(result_id_path, "reference_id_tmp.txt"), 'a')
33 |
34 | sample_num = len(dec_files)
35 | for index in range(sample_num):
36 | dec_file = dec_files[index].split(" ")
37 | ref_file = ref_files[index].split(" ")
38 | dec_id, ref_id = tokens_to_ids(dec_file, ref_file)
39 | write_id(dec_id, dec_files_id)
40 | write_id(ref_id, ref_files_id)
41 |
42 |
43 | #trans_id(dec_path, ref_path)
44 | #files2rouge.run(os.path.join(result_id_path, "decode_id_tmp.txt"),
45 | #os.path.join(result_id_path, "reference_id_tmp.txt"),
46 | #os.path.join(result_id_path, "results.txt"))
47 |
48 | files2rouge.run(ref_path, dec_path)
--------------------------------------------------------------------------------
/sys_config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | print("torch:", torch.__version__)
6 | print("Cuda:", torch.backends.cudnn.cuda)
7 | print("CuDNN:", torch.backends.cudnn.version())
8 |
9 | CPU_CORES = 4
10 | RANDOM_SEED = 1618
11 |
12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
13 |
14 | MODEL_CNF_DIR = os.path.join(BASE_DIR, "model_configs")
15 |
16 | TRAINED_PATH = os.path.join(BASE_DIR, "checkpoints")
17 |
18 | EMBS_PATH = os.path.join(BASE_DIR, "embeddings")
19 |
20 | DATA_DIR = os.path.join(BASE_DIR, 'datasets')
21 |
22 | EXP_DIR = os.path.join(BASE_DIR, 'experiments')
23 |
24 | MODEL_DIRS = ["models", "modules", "utils"]
25 |
26 | VIS = {
27 | "server": "http://localhost",
28 | "enabled": False,
29 | "port": 8097,
30 | "base_url": "/",
31 | "http_proxy_host": None,
32 | "http_proxy_port": None,
33 | "log_to_filename": os.path.join(BASE_DIR, "vis_logger.json")
34 | }
35 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.insert(0, '.')
4 | sys.path.insert(0, '..')
5 | sys.path.insert(0, '../../')
6 | sys.path.insert(0, '../../../')
7 |
--------------------------------------------------------------------------------
/utils/_logging.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | import time
4 |
5 | from tabulate import tabulate
6 |
7 |
8 | def erase_line():
9 | sys.stdout.write("\033[K")
10 |
11 |
12 | def asMinutes(s):
13 | m = math.floor(s / 60)
14 | s -= m * 60
15 | return '%dm %ds' % (m, s)
16 |
17 |
18 | def timeSince(since, percent):
19 | now = time.time()
20 | s = now - since
21 | es = s / percent
22 | rs = es - s
23 | return asMinutes(s), asMinutes(rs)
24 |
25 |
26 | def log_seq3_losses(L1_LM, L1_AE, L2_LM, L2_AE,
27 | L1_LMD, L1_TRANSD, L2_LMD, L2_TRANSD,
28 | L1_TRANSG, L2_TRANSG):
29 | losses = []
30 | losses.append(["L1", L1_LM, L1_AE, math.exp(L1_LM), math.exp(L1_AE),
31 | L1_LMD, L1_TRANSD, L1_TRANSG])
32 | losses.append(["L2", L2_LM, L2_AE, math.exp(L2_LM), math.exp(L2_AE),
33 | L2_LMD, L2_TRANSD, L2_TRANSG])
34 | return tabulate(losses,
35 | headers=['Lang', 'LM Loss', 'AE Loss', 'LM PPL', 'AE PPL',
36 | 'LM-D Loss', 'TRANS-D Loss', 'TRANS-G Loss'],
37 | floatfmt=".4f")
38 |
39 |
40 | def progress_bar(percentage, bar_len=20):
41 | filled_len = int(round(bar_len * percentage))
42 | bar = '=' * filled_len + '-' * (bar_len - filled_len)
43 | return "[{}]".format(bar)
44 |
45 |
46 | def epoch_progress(epoch, batch, batch_size, dataset_size, start):
47 | n_batches = math.ceil(float(dataset_size) / batch_size)
48 | percentage = batch / n_batches
49 |
50 | # stats = 'Epoch:{}, Batch:{}/{} ({0:.2f}%)'.format(epoch, batch, n_batches,
51 | # percentage)
52 | stats = f'Epoch:{epoch}, Batch:{batch}/{n_batches} ' \
53 | f'({100* percentage:.0f}%)'
54 | # stats = f'Epoch:{epoch}, Batch:{batch} ({100* percentage:.0f}%)'
55 |
56 | elapsed, eta = timeSince(start, batch / n_batches)
57 | time_info = 'Time: {} (-{})'.format(elapsed, eta)
58 |
59 | # clean every line and then add the text output
60 | # log_output = stats + " " + progress_bar + ", " + time_info
61 |
62 | # log_output = " ".join([stats, time_info])
63 | log_output = " ".join([stats, progress_bar(percentage), time_info])
64 |
65 | sys.stdout.write("\r \r\033[K" + log_output)
66 | sys.stdout.flush()
67 | return log_output
68 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import yaml
4 |
5 | from sys_config import DATA_DIR
6 |
7 |
8 | def get_parser():
9 | """Get parser object."""
10 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
11 | parser = ArgumentParser(description=__doc__,
12 | formatter_class=ArgumentDefaultsHelpFormatter)
13 | parser.add_argument("-cfg", "--config",
14 | dest="cfg",
15 | help="experiment definition file",
16 | metavar="FILE",
17 | required=True)
18 | return parser
19 |
20 |
21 | def make_paths(cfg):
22 | """
23 | Make all values for keys ending with `_path` absolute to dir_.
24 | """
25 | for key in cfg.keys():
26 | if key.endswith("_path"):
27 | if cfg[key] is not None:
28 | cfg[key] = os.path.join(DATA_DIR, cfg[key])
29 | cfg[key] = os.path.abspath(cfg[key])
30 | if type(cfg[key]) is dict:
31 | cfg[key] = make_paths(cfg[key])
32 | return cfg
33 |
34 |
35 | def load_config(file):
36 | with open(file, 'r') as stream:
37 | cfg = yaml.load(stream)
38 | cfg = make_paths(cfg)
39 |
40 | return cfg
41 |
--------------------------------------------------------------------------------
/utils/data_parsing.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import os
4 | import random
5 | from collections import Counter, defaultdict
6 | from matplotlib import pyplot as plt
7 |
8 | import numpy
9 | from glob2 import glob
10 | from sklearn.model_selection import train_test_split
11 | from sklearn.preprocessing import LabelBinarizer
12 |
13 | from sys_config import DATA_DIR
14 |
15 |
16 | def read_amazon(file):
17 | reviews = []
18 | summaries = []
19 | labels = []
20 |
21 | with open(file) as f:
22 | for line in f:
23 | entry = json.loads(line)
24 | reviews.append(entry["reviewText"])
25 | summaries.append(entry["summary"])
26 | labels.append(entry["overall"])
27 |
28 | return reviews, summaries, labels
29 |
30 |
31 | def read_semeval():
32 | def read_dataset(d):
33 | with open(os.path.join(DATA_DIR, "semeval", "E-c",
34 | "E-c-En-{}.txt".format(d))) as f:
35 | reader = csv.reader(f, delimiter='\t')
36 | labels = next(reader)[2:]
37 |
38 | _X = []
39 | _y = []
40 | for row in reader:
41 | _X.append(row[1])
42 | _y.append([int(x) for x in row[2:]])
43 | return _X, _y
44 |
45 | X_train, y_train = read_dataset("train")
46 | X_dev, y_dev = read_dataset("dev")
47 | X_test, y_test = read_dataset("test")
48 |
49 | X_train = X_train + X_test
50 | y_train = y_train + y_test
51 |
52 | return X_train, numpy.array(y_train), X_dev, numpy.array(y_dev)
53 |
54 |
55 | def imdb_get_index():
56 | index = defaultdict(list)
57 |
58 | dirs = ["pos", "neg", "unsup"]
59 | sets = ["train", "test"]
60 |
61 | for s in sets:
62 | for d in dirs:
63 | for file in glob(os.path.join(DATA_DIR, "imdb", s, d) + "/*.txt"):
64 | index["_".join([s, d])].append(file)
65 | return index
66 |
67 |
68 | def get_imdb():
69 | index = imdb_get_index()
70 |
71 | data = []
72 |
73 | for ki, vi in index.items():
74 | for f in vi:
75 | data.append(" ".join(open(f).readlines()).replace('
', ''))
76 |
77 | return data
78 |
79 |
80 | def read_emoji(split=0.1, min_freq=100, max_ex=1000000, top_n=None):
81 | X = []
82 | y = []
83 | with open(os.path.join(DATA_DIR, "emoji", "emoji_1m.txt")) as f:
84 | for i, line in enumerate(f):
85 | if i > max_ex:
86 | break
87 | emoji, text = line.rstrip().split("\t")
88 | X.append(text)
89 | y.append(emoji)
90 |
91 | counter = Counter(y)
92 | top = set(l for l, f in counter.most_common(top_n) if f > min_freq)
93 |
94 | data = [(_x, _y) for _x, _y in zip(X, y) if _y in top]
95 |
96 | total = len(data)
97 |
98 | data = [(_x, _y) for _x, _y in data if
99 | random.random() > counter[_y] / total]
100 |
101 | X = [x[0] for x in data]
102 | y = [x[1] for x in data]
103 |
104 | X_train, X_test, y_train, y_test = train_test_split(X, y,
105 | test_size=split,
106 | stratify=y,
107 | random_state=0)
108 |
109 | lb = LabelBinarizer()
110 | lb.fit(y_train)
111 | y_train = lb.transform(y_train)
112 | y_test = lb.transform(y_test)
113 |
114 | return X_train, y_train, X_test, y_test
115 |
--------------------------------------------------------------------------------
/utils/eval.py:
--------------------------------------------------------------------------------
1 | import pandas
2 |
3 | import rouge
4 | import codecs
5 | import os
6 | import files2rouge
7 | from tabulate import tabulate
8 |
9 | def rouge_lists(refs, hyps):
10 | evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'],
11 | max_n=2,
12 | limit_length=True,
13 | length_limit=100,
14 | length_limit_type='words',
15 | apply_avg=True,
16 | apply_best=False,
17 | alpha=0.5, # Default F1_score
18 | weight_factor=1.2,
19 | stemming=True)
20 | scores = evaluator.get_scores(hyps, refs)
21 |
22 | return scores
23 |
24 |
25 | def tokens_to_ids(token_list1, token_list2):
26 | ids = {}
27 | out1 = []
28 | out2 = []
29 | for token in token_list1:
30 | out1.append(ids.setdefault(token, len(ids)))
31 | for token in token_list2:
32 | out2.append(ids.setdefault(token, len(ids)))
33 | return out1, out2
34 |
35 |
36 | def write_id(id_lst, file):
37 | for id in id_lst:
38 | file.write(str(id)+" ")
39 | file.write("\n")
40 |
41 |
42 | def filter_file(file_lst):
43 | if "" == file_lst[-1]:
44 | del(file_lst[-1])
45 | return file_lst
46 |
47 |
48 | # trans word to id to escape from chinese character
49 | def trans_id(path, hyps, ref_path):
50 | #dec_files = codecs.open(dec_path, encoding="utf-8").read().split("\n")
51 | dec_files = hyps
52 | ref_files = codecs.open(ref_path, encoding="utf-8").read().split("\n")
53 | dec_files = filter_file(dec_files)
54 | ref_files = filter_file(ref_files)
55 |
56 | dec_files_id = codecs.open(os.path.join(path, "decode_id_tmp.txt"), 'w')
57 | ref_files_id = codecs.open(os.path.join(path, "reference_id_tmp.txt"), 'w')
58 |
59 | results_path = os.path.join(path, "results.txt")
60 |
61 | sample_num = len(dec_files)
62 | for index in range(sample_num):
63 | dec_file = dec_files[index].split(" ")
64 | ref_file = ref_files[index].split(" ")
65 | dec_id, ref_id = tokens_to_ids(dec_file, ref_file)
66 | write_id(dec_id, dec_files_id)
67 | write_id(ref_id, ref_files_id)
68 |
69 | scores_str = files2rouge.run(os.path.join(path, "decode_id_tmp.txt"), os.path.join(path, "reference_id_tmp.txt"))
70 | return scores_str
71 |
72 |
73 | def rouge_files(path, refs_file, hyps):
74 | #refs = open(refs_file).readlines()
75 | #hyps = open(hyps_file).readlines()
76 | #scores = rouge_lists(refs, hyps)
77 | scores_str = trans_id(path, hyps, refs_file)
78 | result_file = codecs.open(os.path.join(path, "result.txt"), 'a')
79 | result_file.write(scores_str)
80 |
81 | r1_r = scores_str[scores_str.find("ROUGE-1 Average_R:")+19:scores_str.find("ROUGE-1 Average_R:")+26]
82 | r2_r = scores_str[scores_str.find("ROUGE-2 Average_R:")+19:scores_str.find("ROUGE-2 Average_R:")+26]
83 | rl_r = scores_str[scores_str.find("ROUGE-L Average_R:")+19:scores_str.find("ROUGE-L Average_R:")+26]
84 |
85 | r1_p = scores_str[scores_str.find("ROUGE-1 Average_P:")+19:scores_str.find("ROUGE-1 Average_P:")+26]
86 | r2_p = scores_str[scores_str.find("ROUGE-2 Average_P:")+19:scores_str.find("ROUGE-2 Average_P:")+26]
87 | rl_p = scores_str[scores_str.find("ROUGE-L Average_P:")+19:scores_str.find("ROUGE-L Average_P:")+26]
88 |
89 | r1_f = scores_str[scores_str.find("ROUGE-1 Average_F:")+19:scores_str.find("ROUGE-1 Average_F:")+26]
90 | r2_f = scores_str[scores_str.find("ROUGE-2 Average_F:")+19:scores_str.find("ROUGE-2 Average_F:")+26]
91 | rl_f = scores_str[scores_str.find("ROUGE-L Average_F:")+19:scores_str.find("ROUGE-L Average_F:")+26]
92 |
93 | scores = {}
94 | scores['rouge-1'] = {}
95 | scores['rouge-2'] = {}
96 | scores['rouge-l'] = {}
97 |
98 | scores['rouge-1']['r'] = float(r1_r)
99 | scores['rouge-1']['p'] = float(r1_p)
100 | scores['rouge-1']['f'] = float(r1_f)
101 |
102 | scores['rouge-2']['r'] = float(r2_r)
103 | scores['rouge-2']['p'] = float(r2_p)
104 | scores['rouge-2']['f'] = float(r2_f)
105 |
106 | scores['rouge-l']['r'] = float(rl_r)
107 | scores['rouge-l']['p'] = float(rl_p)
108 | scores['rouge-l']['f'] = float(rl_f)
109 | return scores
110 |
111 |
112 | def rouge_files_simple(path, refs_file, hyps):
113 | scores_str = trans_id(path, hyps, refs_file)
114 | result_file = codecs.open(os.path.join(path, "result_nsent.txt"), 'a')
115 | result_file.write(scores_str)
116 |
117 | r1_r = scores_str[scores_str.find("ROUGE-1 Average_R:")+19:scores_str.find("ROUGE-1 Average_R:")+26]
118 | r2_r = scores_str[scores_str.find("ROUGE-2 Average_R:")+19:scores_str.find("ROUGE-2 Average_R:")+26]
119 | rl_r = scores_str[scores_str.find("ROUGE-L Average_R:")+19:scores_str.find("ROUGE-L Average_R:")+26]
120 |
121 | r1_p = scores_str[scores_str.find("ROUGE-1 Average_P:")+19:scores_str.find("ROUGE-1 Average_P:")+26]
122 | r2_p = scores_str[scores_str.find("ROUGE-2 Average_P:")+19:scores_str.find("ROUGE-2 Average_P:")+26]
123 | rl_p = scores_str[scores_str.find("ROUGE-L Average_P:")+19:scores_str.find("ROUGE-L Average_P:")+26]
124 |
125 | r1_f = scores_str[scores_str.find("ROUGE-1 Average_F:")+19:scores_str.find("ROUGE-1 Average_F:")+26]
126 | r2_f = scores_str[scores_str.find("ROUGE-2 Average_F:")+19:scores_str.find("ROUGE-2 Average_F:")+26]
127 | rl_f = scores_str[scores_str.find("ROUGE-L Average_F:")+19:scores_str.find("ROUGE-L Average_F:")+26]
128 |
129 | return r1_f, r2_f, rl_f
130 |
131 |
132 | def rouge_file_list(refs_file, hyps_list):
133 | refs = open(refs_file).readlines()
134 | scores = rouge_lists(refs, hyps_list)
135 |
136 | return scores
137 |
138 |
139 | def pprint_rouge_scores(scores, pivot=False):
140 | pdt = pandas.DataFrame(scores)
141 |
142 | if pivot:
143 | pdt = pdt.T
144 |
145 | table = tabulate(pdt,
146 | headers='keys',
147 | floatfmt=".4f", tablefmt="psql")
148 |
149 | return table
150 |
--------------------------------------------------------------------------------
/utils/generic.py:
--------------------------------------------------------------------------------
1 | from itertools import zip_longest
2 |
3 | import numpy
4 | import umap
5 | from sklearn.decomposition import PCA
6 |
7 |
8 | def merge_dicts(a, b):
9 | a.update({k: v for k, v in b.items() if k in a})
10 | return a
11 |
12 |
13 | def number_h(num):
14 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
15 | if abs(num) < 1000.0:
16 | return "%3.1f%s" % (num, unit)
17 | num /= 1000.0
18 | return "%.1f%s" % (num, 'Yi')
19 |
20 |
21 | def group(lst, n):
22 | """group([0,3,4,10,2,3], 2) => [(0,3), (4,10), (2,3)]
23 |
24 | Group a list into consecutive n-tuples. Incomplete tuples are
25 | discarded e.g.
26 |
27 | >>> group(range(10), 3)
28 | [(0, 1, 2), (3, 4, 5), (6, 7, 8)]
29 | """
30 | return zip(*[lst[i::n] for i in range(n)])
31 |
32 |
33 | def pairwise(iterable):
34 | it = iter(iterable)
35 | a = next(it, None)
36 |
37 | for b in it:
38 | yield (a, b)
39 | a = b
40 |
41 |
42 | def concat_multiline_strings(a, b):
43 | str = []
44 | for line1, line2 in zip_longest(a.split("\n"), b.split("\n"),
45 | fillvalue=''):
46 | str.append("\t".join([line1, line2]))
47 |
48 | return "\n".join(str)
49 |
50 |
51 | def dim_reduce(data_sets, n_components=2, method="PCA"):
52 | data = numpy.vstack(data_sets)
53 | splits = numpy.cumsum([0] + [len(x) for x in data_sets])
54 | if method == "PCA":
55 | reducer = PCA(random_state=20, n_components=n_components)
56 | embedding = reducer.fit_transform(data)
57 | elif method == "UMAP":
58 | reducer = umap.UMAP(random_state=20,
59 | n_components=n_components,
60 | min_dist=0.5)
61 | embedding = reducer.fit_transform(data)
62 | else:
63 | reducer_linear = PCA(random_state=20, n_components=50)
64 | linear_embedding = reducer_linear.fit_transform(data)
65 | reducer_nonlinear = umap.UMAP(random_state=20,
66 | n_components=n_components,
67 | min_dist=0.5)
68 | embedding = reducer_nonlinear.fit_transform(linear_embedding)
69 |
70 | return [embedding[start:stop] for start, stop in pairwise(splits)]
71 |
--------------------------------------------------------------------------------
/utils/load_embeddings.py:
--------------------------------------------------------------------------------
1 | import errno
2 | import os
3 | import pickle
4 |
5 | import numpy
6 |
7 |
8 | def file_cache_name(file):
9 | head, tail = os.path.split(file)
10 | filename, ext = os.path.splitext(tail)
11 | return os.path.join(head, filename + ".p")
12 |
13 |
14 | def write_cache_word_vectors(file, data):
15 | with open(file_cache_name(file), 'wb') as pickle_file:
16 | pickle.dump(data, pickle_file)
17 |
18 |
19 | def load_cache_word_vectors(file):
20 | with open(file_cache_name(file), 'rb') as f:
21 | return pickle.load(f)
22 |
23 |
24 | def load_word_vectors(file, dim):
25 | """
26 | Read the word vectors from a text file
27 | Args:
28 | file (): the filename
29 | dim (): the dimensions of the word vectors
30 |
31 | Returns:
32 | word2idx (dict): dictionary of words to ids
33 | idx2word (dict): dictionary of ids to words
34 | embeddings (numpy.ndarray): the word embeddings matrix
35 |
36 | """
37 | # in order to avoid this time consuming operation, cache the results
38 | try:
39 | cache = load_cache_word_vectors(file)
40 | print("Loaded word embeddings from cache.")
41 | return cache
42 | except OSError:
43 | print("Didn't find embeddings cache file {}".format(file))
44 |
45 | # create the necessary dictionaries and the word embeddings matrix
46 | if os.path.exists(file):
47 | print('Indexing file {} ...'.format(file))
48 |
49 | word2idx = {} # dictionary of words to ids
50 | idx2word = {} # dictionary of ids to words
51 | embeddings = [] # the word embeddings matrix
52 |
53 | # create the 2D array, which will be used for initializing
54 | # the Embedding layer of a NN.
55 | # We reserve the first row (idx=0), as the word embedding,
56 | # which will be used for zero padding (word with id = 0).
57 | embeddings.append(numpy.zeros(dim))
58 |
59 | # flag indicating whether the first row of the embeddings file
60 | # has a header
61 | header = False
62 |
63 | # read file, line by line
64 | with open(file, "r", encoding="utf-8") as f:
65 | for i, line in enumerate(f, 1):
66 |
67 | # skip the first row if it is a header
68 | if i == 1:
69 | if len(line.split()) < dim:
70 | header = True
71 | continue
72 |
73 | values = line.split(" ")
74 | word = values[0]
75 | vector = numpy.asarray(values[1:], dtype='float32')
76 |
77 | index = i - 1 if header else i
78 |
79 | idx2word[index] = word
80 | word2idx[word] = index
81 | embeddings.append(vector)
82 |
83 | # add an unk token, for OOV words
84 | if "" not in word2idx:
85 | idx2word[len(idx2word) + 1] = ""
86 | word2idx[""] = len(word2idx) + 1
87 | embeddings.append(
88 | numpy.random.uniform(low=-0.05, high=0.05, size=dim))
89 |
90 | print(set([len(x) for x in embeddings]))
91 |
92 | print('Found %s word vectors.' % len(embeddings))
93 | embeddings = numpy.array(embeddings, dtype='float32')
94 |
95 | # write the data to a cache file
96 | write_cache_word_vectors(file, (word2idx, idx2word, embeddings))
97 |
98 | return word2idx, idx2word, embeddings
99 |
100 | else:
101 | print("{} not found!".format(file))
102 | raise OSError(errno.ENOENT, os.strerror(errno.ENOENT), file)
103 |
--------------------------------------------------------------------------------
/utils/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import signal
4 | import subprocess
5 | import sys
6 |
7 | import torch
8 |
9 | from sys_config import BASE_DIR
10 | from utils.config import load_config
11 |
12 |
13 | def spawn_visdom():
14 | try:
15 | subprocess.run(["visdom > visdom.txt 2>&1 &"], shell=True)
16 | except:
17 | print("Visdom is already running...")
18 |
19 | def signal_handler(signal, frame):
20 | subprocess.run(["pkill visdom"], shell=True)
21 | print("Killing Visdom server...")
22 | sys.exit(0)
23 |
24 | signal.signal(signal.SIGINT, signal_handler)
25 |
26 |
27 | def train_options():
28 | print(os.getcwd())
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--config', default="model_configs/lm_prior.yaml")
31 | parser.add_argument('--name', default="lm_3w")
32 | parser.add_argument('--desc')
33 | parser.add_argument('--resume')
34 | parser.add_argument('--transfer')
35 | parser.add_argument('--visdom', action='store_true')
36 | parser.add_argument('--vocab')
37 | parser.add_argument('--cp-vocab')
38 | parser.add_argument('--device', default="auto")
39 | parser.add_argument('--cores', type=int, default=1)
40 | parser.add_argument('--source', nargs='*',
41 | default=["models", "modules", "utils"])
42 |
43 | args = parser.parse_args()
44 | config = load_config(args.config)
45 |
46 | if args.name is None:
47 | config_filename = os.path.basename(args.config)
48 | args.name = os.path.splitext(config_filename)[0]
49 |
50 | config["name"] = args.name
51 | config["desc"] = args.desc
52 |
53 | if args.device == "auto":
54 | args.device = torch.device("cuda" if torch.cuda.is_available()
55 | else "cpu")
56 |
57 | if args.source is None:
58 | args.source = []
59 |
60 | args.source = [os.path.join(BASE_DIR, dir) for dir in args.source]
61 |
62 | if args.visdom:
63 | spawn_visdom()
64 |
65 | for arg in vars(args):
66 | print("{}:{}".format(arg, getattr(args, arg)))
67 | print()
68 |
69 | return args, config
70 |
71 |
72 | def seq2seq2seq_options():
73 | parser = argparse.ArgumentParser()
74 | parser.add_argument('--config', default="model_configs/ds.full.yaml")
75 | parser.add_argument('--name', default="test-justice-out")
76 | parser.add_argument('--desc')
77 | parser.add_argument('--resume')
78 | parser.add_argument('--visdom', action='store_true')
79 | parser.add_argument('--transfer-lm')
80 | parser.add_argument('--device', default="auto")
81 | parser.add_argument('--cores', type=int, default=4)
82 | parser.add_argument('--source', nargs='*', default=["models", "modules", "utils"])
83 |
84 | args = parser.parse_args()
85 | config = load_config(args.config)
86 |
87 | if args.name is None:
88 | config_filename = os.path.basename(args.config)
89 | args.name = os.path.splitext(config_filename)[0]
90 |
91 | config["name"] = args.name
92 | config["desc"] = args.desc
93 |
94 | if args.device == "auto":
95 | args.device = torch.device("cuda" if torch.cuda.is_available()
96 | else "cpu")
97 |
98 | if args.source is None:
99 | args.source = []
100 |
101 | args.source = [os.path.join(BASE_DIR, dir) for dir in args.source]
102 |
103 | if args.visdom:
104 | spawn_visdom()
105 |
106 | for arg in vars(args):
107 | print("{}:{}".format(arg, getattr(args, arg)))
108 | print()
109 |
110 | return args, config
111 |
--------------------------------------------------------------------------------
/utils/training.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 |
4 | import torch
5 |
6 | from sys_config import BASE_DIR
7 |
8 |
9 | def save_checkpoint(state, name, path=None, timestamp=False, tag=None, verbose=False):
10 | """
11 | Save a trained model, along with its optimizer, in order to be able to
12 | resume training
13 | Args:
14 | path (str): the directory, in which to save the checkpoints
15 | timestamp (bool): whether to keep only one model (latest), or keep every
16 | checkpoint
17 |
18 | Returns:
19 |
20 | """
21 | now = datetime.datetime.now().strftime("%y-%m-%d_%H:%M:%S")
22 |
23 | if tag is not None:
24 | if isinstance(tag, str):
25 | name += "_{}".format(tag)
26 | elif isinstance(tag, list):
27 | for t in tag:
28 | name += "_{}".format(t)
29 | else:
30 | raise ValueError("invalid tag type!")
31 |
32 | if timestamp:
33 | name += "_{}".format(now)
34 |
35 | name += ".pt"
36 |
37 | if path is None:
38 | path = os.path.join(BASE_DIR, "checkpoints")
39 |
40 | file = os.path.join(path, name)
41 |
42 | if verbose:
43 | print("saving checkpoint:{} ...".format(name))
44 |
45 | torch.save(state, file)
46 |
47 | return name
48 |
49 |
50 | def load_checkpoint(name, path=None, device=None):
51 | """
52 | Load a trained model, along with its optimizer
53 | Args:
54 | name (str): the name of the model
55 | path (str): the directory, in which the model is saved
56 |
57 | Returns:
58 | model, optimizer
59 |
60 | """
61 | if path is None:
62 | path = os.path.join(BASE_DIR, "checkpoints")
63 |
64 | model_fname = os.path.join(path, "{}.pt".format(name))
65 |
66 | print("Loading checkpoint `{}` ...".format(model_fname), end=" ")
67 |
68 | with open(model_fname, 'rb') as f:
69 | state = torch.load(f, map_location="cpu")
70 |
71 | print("done!")
72 |
73 | return state
74 |
--------------------------------------------------------------------------------
/utils/transfer.py:
--------------------------------------------------------------------------------
1 | def freeze_module(layer, depth=None):
2 | if depth is None:
3 | for param in layer.parameters():
4 | param.requires_grad = False
5 | else:
6 | for weight in layer.all_weights[depth]:
7 | weight.requires_grad = False
8 |
9 |
10 | def train_module(layer, depth=None):
11 | if depth is None:
12 | for param in layer.parameters():
13 | param.requires_grad = True
14 | else:
15 | for weight in layer.all_weights[depth]:
16 | weight.requires_grad = True
17 |
18 |
19 | def dict_rename_by_pattern(from_dict, patterns):
20 | for k in list(from_dict.keys()):
21 | v = from_dict.pop(k)
22 | p = list(filter(lambda x: x in k, patterns.keys()))
23 | if len(p) > 0:
24 | new_key = k.replace(p[0], patterns[p[0]])
25 | from_dict[new_key] = v
26 | else:
27 | from_dict[k] = v
28 |
29 |
30 | def load_state_dict_subset(model, pretrained_dict):
31 | model_dict = model.state_dict()
32 |
33 | # 1. filter out unnecessary keys
34 | pretrained_dict = {k: v for k, v in pretrained_dict.items()
35 | if k in model_dict}
36 | # 2. overwrite entries in the existing state dict
37 | model_dict.update(pretrained_dict)
38 |
39 | # 3. load the new state dict
40 | model.load_state_dict(model_dict)
41 |
--------------------------------------------------------------------------------
/utils/viz.py:
--------------------------------------------------------------------------------
1 | from matplotlib.backends.backend_pdf import PdfPages
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 | from graphviz import Digraph
5 | from torch.autograd import Variable
6 |
7 |
8 | def make_dot_2(var):
9 | node_attr = dict(style='filled',
10 | shape='box',
11 | align='left',
12 | fontsize='12',
13 | ranksep='0.1',
14 | height='0.2')
15 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
16 | seen = set()
17 |
18 | def add_nodes(var):
19 | if var not in seen:
20 | if isinstance(var, Variable):
21 | value = '(' + (', ').join(['%d' % v for v in var.size()]) + ')'
22 | dot.node(str(id(var)), str(value), fillcolor='lightblue')
23 | else:
24 | dot.node(str(id(var)), str(type(var).__name__))
25 | seen.add(var)
26 | if hasattr(var, 'previous_functions'):
27 | for u in var.previous_functions:
28 | dot.edge(str(id(u[0])), str(id(var)))
29 | add_nodes(u[0])
30 |
31 | add_nodes(var.creator)
32 | return dot
33 |
34 |
35 | def attention_heatmap_subplot(src, trg, attention, ax=None):
36 | g = sns.heatmap(attention,
37 | # cmap="Greys_r",
38 | cmap="viridis",
39 | cbar=False,
40 | # annot=True,
41 | vmin=0, vmax=1,
42 | robust=False,
43 | fmt=".2f",
44 | annot_kws={'size': 12},
45 | xticklabels=trg,
46 | yticklabels=src,
47 | # square=True,
48 | ax=ax)
49 | g.set_yticklabels(g.get_yticklabels(), rotation=0, fontsize=12)
50 | g.set_xticklabels(g.get_xticklabels(), rotation=60, fontsize=12)
51 |
52 | # g.set_xticks(numpy.arange(len(src)), src, rotation=0)
53 | # g.set_yticks(numpy.arange(len(trg)), trg, rotation=60)
54 |
55 |
56 | def visualize_translations(lang, prefix_trg2src=False):
57 | for s1, s2, a12, s3, a23 in lang:
58 | # attention_heatmap(i, o, a[:len(o), :len(i)].t().cpu().numpy())
59 | if prefix_trg2src:
60 | s2_enc = [""] + s2[:-1]
61 | else:
62 | s2_enc = s2
63 | attention_heatmap_pair(s1, s2, s2_enc, s3,
64 | a12.t()[:len(s1), :len(s2)].cpu().numpy(),
65 | a23.t()[:len(s2_enc), :len(s3)].cpu().numpy())
66 |
67 |
68 | def visualize_compression(lang, prefix_trg2src=False):
69 | for s1, s2, a12, s3, a23 in lang:
70 | # attention_heatmap(i, o, a[:len(o), :len(i)].t().cpu().numpy())
71 | if prefix_trg2src:
72 | s2_enc = [""] + s2[:-1]
73 | else:
74 | s2_enc = s2
75 | attention_heatmap_pair(s1, s2, s2_enc, s3,
76 | a12.t()[:len(s1), :len(s2)].cpu().numpy(),
77 | a23.t()[:len(s2_enc), :len(s3)].cpu().numpy())
78 |
79 |
80 | def seq3_attentions(sent, file='foo.pdf'):
81 | from matplotlib import rc
82 | rc('font', **{'family': 'serif', 'serif': ['CMU Serif']})
83 | # rc('text', usetex=True)
84 | # rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
85 | # rc('text', usetex=True)
86 |
87 | with PdfPages(file) as pdf:
88 | for s1, s2, a12, s3, a23 in sent:
89 | s1 = s1[:s1.index(".") + 1]
90 | s12 = s2[:s2.index("") + 1]
91 | s23 = s2[:s2.index("")]
92 | s3 = s3[:len(s1)]
93 |
94 | att12 = a12.t()[:len(s1), :len(s12)].cpu().numpy()
95 | att23 = a23.t()[:len(s23), :len(s3)].cpu().numpy()
96 |
97 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
98 | attention_heatmap_subplot(s1, s12, att12, ax=ax1)
99 | attention_heatmap_subplot(s23, s3, att23, ax=ax2)
100 | ax1.set_title("Source to Compression")
101 | ax2.set_title("Compression to Reconstruction")
102 | fig.tight_layout()
103 |
104 | pdf.savefig(fig)
105 |
106 |
107 | def attention_heatmap(src, trg, attention):
108 | fig, ax = plt.subplots(figsize=(11, 5))
109 | attention_heatmap_subplot(src, trg, attention)
110 | fig.tight_layout()
111 | plt.show()
112 |
113 |
114 | def attention_heatmap_pair(s1, s2, s3, s4, att12, att23):
115 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
116 | attention_heatmap_subplot(s1, s2, att12, ax=ax1)
117 | attention_heatmap_subplot(s3, s4, att23, ax=ax2)
118 | ax1.set_title("RNN1 -> RNN2")
119 | ax2.set_title("RNN2 -> RNN3")
120 | fig.tight_layout()
121 | plt.show()
122 |
--------------------------------------------------------------------------------