├── code ├── data ├── constants.py ├── ReadMe.md ├── utils_lm.py ├── utils.py ├── rhyming_eval.py ├── main.py ├── models.py ├── solvers_pretrain_disc_encoder.py ├── models_lm.py └── solvers_merged.py ├── README.md └── generated quatrain samples.md /code/data: -------------------------------------------------------------------------------- 1 | ../data/ -------------------------------------------------------------------------------- /code/constants.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | DISC_TYPE_MATRIX = "matrix" 4 | DISC_TYPE_NON_STRUCTURED = "non-structured" 5 | 6 | SONNET_DATASET_IDENTIFIER='sonnet' 7 | SONNET_ENDINGS_DATASET_IDENTIFIER = 'sonnet_endings' 8 | LIMERICK_DATASET_IDENTIFIER='limerick' 9 | NUM_LINES_LIMERICK = 5 10 | NUM_LINES_QUATRAIN = 4 11 | NUM_LINES_SONNET = 14 12 | 13 | UNKNOWN = '@@UNKNOWN@@' 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structured-Adversary 2 | 3 | Code for our EMNLP 2019 paper titled 'Learning Rhyming Constraints using Structured Adversaries' 4 | 5 | [Link to EMNLP 2019 paper](https://arxiv.org/abs/1909.06743) 6 | 7 | 8 | ### Data 9 | Data can be downloaded from the following link: [Link to Data](https://drive.google.com/drive/folders/1Cn8biL-K2kSPQ4RbxMfTCI377sDTxgID?usp=sharing)
10 | DISCLAIMER: NOTE THAT THIS DATASET MAY CONTAIN WORDS THAT ARE ABUSIVE, OFFENSIVE, HURTFUL OR BIASED, OR MAY APPEAR OR BE CONSIDERED ABUSIVE, OFFENSIVE, HURTFUL WITH RESPECT TO AN INDIVIDUAL OR A COMMUNITY. THE AUTHORS DO NOT ENDORSE OR PROMOTE THE USE OF SUCH LANGUAGE OR WORDS, AND THESE HAVE PURELY BEEN INCLUDED AS A MATTER OF SCIENTIFIC ANALYSIS/INVESTIGATION. 11 | 12 | 13 | ### Requirements 14 | 15 | - python 3.6.7 16 | - pytorch 0.4.1 17 | 18 | 19 | ### Usage 20 | 21 | - Download data from the above link to the main project directory 22 | - Refer to code/Readme.md for instructions to run code 23 | 24 | 25 | ### Reference 26 | 27 | ``` 28 | @article{jhamtani2019rhymgan, 29 | title={Learning Rhyming Constraints using Structured Adversaries}, 30 | author={Harsh Jhamtani and Sanket Vaibhav Mehta and Jaime Carbonell and Taylor Berg-Kirkpatrick}, 31 | booktitle = {Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 32 | month = {November}, 33 | year = {2019} 34 | } 35 | ``` 36 | 37 | -------------------------------------------------------------------------------- /generated quatrain samples.md: -------------------------------------------------------------------------------- 1 | # Samples 2 | 3 | Disclaimer: Some of the poems can be offensive even though we have tried to mask out certain words to mitigate this. 4 | 5 | ### Sample
6 | with balls of silver a chaste knight fired
7 | but now whatever was was much admired
8 | with nice suit him did meet nice
9 | the MASKED inclined to follow what advice
10 |
11 | ### Sample
12 | who is the god in reverence
13 | be large itself is the goal of XXance
14 | and looks forever with a lot of large
15 | his work in straight beheld the word of charge
16 |
17 | ### Sample
18 | and thro the vapours shot the rays on waves
19 | MASKED radiant canvas and MASKED thunders laves
20 | was oer MASKED columns and ascends the main
21 | MASKED near and face and with MASKED captive joys train
22 |
23 | ### Sample
24 | bestow d nature pallas couch d in sorrow s right
25 | downcast and think i fear and see me slight
26 | though in the books i are by griefs confined
27 | i but such being ravished for MASKED kind
28 |
29 | ### Sample
30 | beside the can chant unweeting of the other
31 | only to feel and thrill the word to me
32 | holds from some grave his speed and give his brother
33 | and had invoked from which his soul agree
34 |
35 | ### Sample
36 | spoke in the skies and lift their shining lights
37 | her social laws and human souls that grew
38 | all sacred heaven to earth s delightful dell
39 | the creature journeys through their camels of the sea
40 |
41 | ### Sample
42 | had vanished gone and all must day and day
43 | we name this flower which i did in the story
44 | poets and whom new spirits strove away
45 | i is their lov in wander of the way
46 |
47 | ### Sample
48 | to shine in slumber till she breathes at rest
49 | a mystery to whom the angels bound
50 | takes her her spirit bridegroom for to right
51 | she woos the blushing queen when seems to thee
52 |
53 | ### Sample
54 | if once for thee then word may have leaped on
55 | with ye on cliffs i dwell upon a floor
56 | which meet are mingling in the sounding strand
57 | lest the own arrows beat from the sounding hand
58 |
59 | ### Sample
60 | and i knew wander here its sunny scenes
61 | have all their hopes too lately made its guest
62 | virtue he polish d its all in these
63 | the humble mansion for his mighty fault
64 |
65 | ### Sample
66 | “ death and love of that truth
67 | for paradise on earth thou art the youth
68 | ah darkly waves do me a willing tide
69 | should visit his void doom him turns aside
70 |
71 | ### Sample
72 | to grace and languish to our ravish d wish
73 | this truth shalt fall in any one shall wound
74 | with those white tresses with the moon and wind
75 | i in my senses here doth strain heaven s mind
76 |
77 | 78 | -------------------------------------------------------------------------------- /code/ReadMe.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | #### Pretraining Discriminator's Encoder 4 | 5 | ``` 6 | python main.py --solver_type Endings --model_name rhymgan_ae1 --use_alignment false --model_type encdec --data_type ae | tee logs/rhymgan_ae1_train.log 7 | ``` 8 | 9 | Set 'use_cuda' as false to run without gpu 10 | 11 | #### LM training 12 | 13 | 14 | ``` 15 | python main.py --solver_type Main --mode train_lm --model_name rhymgan_lm1 --load_gutenberg true --emsize 100 --data_type sonnet_endings --vanilla_lm_path tmp/rhymgan_lm1/ | tee logs/rhymgan_lm1_train.log 16 | ``` 17 | 18 | ``` 19 | python main.py --solver_type Main --mode train_lm --model_name rhymgan_limerick_lm1 --load_gutenberg true --emsize 100 --data_type limerick --vanilla_lm_path tmp/rhymgan_limerick_lm1/ > logs/rhymgan_limerick_lm1_train.log 20 | ``` 21 | 22 | #### RHYMEGAN 23 | 24 | ###### RhymeGAN training 25 | ``` 26 | #---- Sonnet 27 | python main.py --solver_type Main --model_name rhymgan1 --vanilla_lm_path tmp/rhymgan_lm1/ --load_vanilla_lm true --g2p_model_name rhymgan_ae1 --trainable_g2p true --reinforce_weight=0.1 --use_all_sonnet_data true --num_samples_at_epoch_end 150 --epochs 80 > logs/rhymgan1_train.log 28 | ``` 29 | 30 | 31 | ``` 32 | #---- Limerick 33 | MODEL_NAME=rhymgan_limerick1 34 | CUDA_VISIBLE_DEVICES=1 python main.py --solver_type Main --model_name $MODEL_NAME --data_type limerick --vanilla_lm_path tmp/rhymgan_limerick_lm1/ --load_vanilla_lm true --g2p_model_name rhymgan_ae1 --trainable_g2p true --reinforce_weight=0.1 --use_all_sonnet_data true --num_samples_at_epoch_end 100 --epochs 100 > logs/"$MODEL_NAME"_train.log 35 | 36 | ``` 37 | 38 | ###### Eval 39 | 40 | ``` 41 | #---- Sonnet 42 | EPOCH_TO_TEST=69 43 | #eval 44 | MODEL_NAME=rhymgan1 45 | python main.py --solver_type Main --model_name $MODEL_NAME --vanilla_lm_path tmp/rhymgan_lm1/ --load_vanilla_lm true --g2p_model_name rhymgan_ae1 --trainable_g2p true --reinforce_weight=0.1 --use_all_sonnet_data true --num_samples_at_epoch_end 100 --epochs 80 --mode eval --epoch_to_test $EPOCH_TO_TEST | tee logs/"$MODELNAME"_eval.log 46 | 47 | #eval with lower temperature 48 | python main.py --solver_type Main --model_name $MODEL_NAME --vanilla_lm_path tmp/rhymgan_lm1/ --load_vanilla_lm true --g2p_model_name rhymgan_ae1 --trainable_g2p true --use_all_sonnet_data true --mode eval --epoch_to_test $EPOCH_TO_TEST --temperature 0.7 | tee logs/"$MODEL_NAME"_temp7_eval.log 49 | ``` 50 | 51 | 52 | 53 | ``` 54 | #---- Limerick 55 | EPOCH_TO_TEST=74 56 | #eval 57 | MODEL_NAME=rhymgan_limerick1 58 | python main.py --solver_type Main --model_name $MODEL_NAME --data_type limerick --vanilla_lm_path tmp/rhymgan_limerick_lm1/ --load_vanilla_lm true --g2p_model_name rhymgan_ae1 --trainable_g2p true --mode eval --epoch_to_test $EPOCH_TO_TEST | tee logs/"$MODELNAME"_eval.log 59 | 60 | #eval with lower temperature 61 | MODEL_NAME=rhymgan_limerick1 62 | python main.py --solver_type Main --model_name $MODEL_NAME --data_type limerick --vanilla_lm_path tmp/rhymgan_limerick_lm1/ --load_vanilla_lm true --g2p_model_name rhymgan_ae1 --trainable_g2p true --mode eval --epoch_to_test $EPOCH_TO_TEST --temperature 0.7 | tee logs/"$MODELNAME"_temp7_eval.log 63 | 64 | ``` 65 | -------------------------------------------------------------------------------- /code/utils_lm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import codecs 5 | 6 | from typing import Dict 7 | 8 | 9 | def get_text_field_mask(text_field_tensors: torch.Tensor, 10 | num_wrapping_dims: int = 0) -> torch.LongTensor: 11 | # if "mask" in text_field_tensors: 12 | return text_field_tensors["mask"] 13 | 14 | # tensor_dims = [(tensor.dim(), tensor) for tensor in text_field_tensors.values()] 15 | # tensor_dims.sort(key=lambda x: x[0]) 16 | # 17 | # smallest_dim = tensor_dims[0][0] - num_wrapping_dims 18 | # if smallest_dim == 2: 19 | # token_tensor = tensor_dims[0][1] 20 | # return (token_tensor != 0).long() 21 | # elif smallest_dim == 3: 22 | # character_tensor = tensor_dims[0][1] 23 | # return ((character_tensor > 0).long().sum(dim=-1) > 0).long() 24 | # else: 25 | # raise ValueError("Expected a tensor with dimension 2 or 3, found {}".format(smallest_dim)) 26 | 27 | 28 | def get_lengths_from_binary_sequence_mask(mask: torch.Tensor): 29 | """ 30 | Compute sequence lengths for each batch element in a tensor using a 31 | binary mask. 32 | 33 | Parameters 34 | ---------- 35 | mask : torch.Tensor, required. 36 | A 2D binary mask of shape (batch_size, sequence_length) to 37 | calculate the per-batch sequence lengths from. 38 | 39 | Returns 40 | ------- 41 | A torch.LongTensor of shape (batch_size,) representing the lengths 42 | of the sequences in the batch. 43 | """ 44 | return mask.long().sum(-1) 45 | 46 | 47 | def sequence_cross_entropy_with_logits(logits: torch.FloatTensor, 48 | targets: torch.LongTensor, 49 | weights: torch.FloatTensor, 50 | average: str = "batch", 51 | label_smoothing: float = None) -> torch.FloatTensor: 52 | """ 53 | Computes the cross entropy loss of a sequence, weighted with respect to 54 | some user provided weights. Note that the weighting here is not the same as 55 | in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting 56 | classes; here we are weighting the loss contribution from particular elements 57 | in the sequence. This allows loss computations for models which use padding. 58 | 59 | Parameters 60 | ---------- 61 | logits : ``torch.FloatTensor``, required. 62 | A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes) 63 | which contains the unnormalized probability for each class. 64 | targets : ``torch.LongTensor``, required. 65 | A ``torch.LongTensor`` of size (batch, sequence_length) which contains the 66 | index of the true class for each corresponding step. 67 | weights : ``torch.FloatTensor``, required. 68 | A ``torch.FloatTensor`` of size (batch, sequence_length) 69 | average: str, optional (default = "batch") 70 | If "batch", average the loss across the batches. If "token", average 71 | the loss across each item in the input. If ``None``, return a vector 72 | of losses per batch element. 73 | label_smoothing : ``float``, optional (default = None) 74 | Whether or not to apply label smoothing to the cross-entropy loss. 75 | For example, with a label smoothing value of 0.2, a 4 class classification 76 | target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was 77 | the correct label. 78 | 79 | Returns 80 | ------- 81 | A torch.FloatTensor representing the cross entropy loss. 82 | If ``average=="batch"`` or ``average=="token"``, the returned loss is a scalar. 83 | If ``average is None``, the returned loss is a vector of shape (batch_size,). 84 | 85 | """ 86 | if average not in {None, "token", "batch"}: 87 | raise ValueError("Got average f{average}, expected one of " 88 | "None, 'token', or 'batch'") 89 | 90 | # shape : (batch * sequence_length, num_classes) 91 | logits_flat = logits.view(-1, logits.size(-1)) 92 | # shape : (batch * sequence_length, num_classes) 93 | log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1) 94 | # shape : (batch * max_len, 1) 95 | targets_flat = targets.view(-1, 1).long() 96 | 97 | if label_smoothing is not None and label_smoothing > 0.0: 98 | num_classes = logits.size(-1) 99 | smoothing_value = label_smoothing / num_classes 100 | # Fill all the correct indices with 1 - smoothing value. 101 | one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing) 102 | smoothed_targets = one_hot_targets + smoothing_value 103 | negative_log_likelihood_flat = - log_probs_flat * smoothed_targets 104 | negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True) 105 | else: 106 | # Contribution to the negative log likelihood only comes from the exact indices 107 | # of the targets, as the target distributions are one-hot. Here we use torch.gather 108 | # to extract the indices of the num_classes dimension which contribute to the loss. 109 | # shape : (batch * sequence_length, 1) 110 | negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat) 111 | # shape : (batch, sequence_length) 112 | negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size()) 113 | # shape : (batch, sequence_length) 114 | negative_log_likelihood = negative_log_likelihood * weights.float() 115 | 116 | if average == "batch": 117 | # shape : (batch_size,) 118 | per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13) 119 | num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13) 120 | return per_batch_loss.sum() / num_non_empty_sequences 121 | elif average == "token": 122 | return negative_log_likelihood.sum() / (weights.sum().float() + 1e-13) 123 | else: 124 | # shape : (batch_size,) 125 | per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13) 126 | return per_batch_loss 127 | 128 | 129 | def get_metrics(total_loss: float, num_batches: int, key: str = "loss", metrics: Dict = None) -> Dict[str, float]: 130 | """ 131 | Gets the metrics but sets ``"loss"`` to 132 | the total loss divided by the ``num_batches`` so that 133 | the ``"loss"`` metric is "average loss per batch". 134 | """ 135 | if metrics is None: 136 | metrics = {} 137 | metrics[key] = float(total_loss / num_batches) if num_batches > 0 else 0.0 138 | return metrics 139 | 140 | 141 | # We want to warn people that tqdm ignores metrics that start with underscores 142 | # exactly once. This variable keeps track of whether we have. 143 | class HasBeenWarned: 144 | tqdm_ignores_underscores = False 145 | 146 | 147 | def description_from_metrics(metrics: Dict[str, float]) -> str: 148 | if (not HasBeenWarned.tqdm_ignores_underscores and 149 | any(metric_name.startswith("_") for metric_name in metrics)): 150 | logger.warning("Metrics with names beginning with \"_\" will " 151 | "not be logged to the tqdm progress bar.") 152 | HasBeenWarned.tqdm_ignores_underscores = True 153 | return ', '.join(["%s: %.4f" % (name, value) 154 | for name, value in 155 | metrics.items() if not name.startswith("_")]) + " ||" -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import codecs 4 | import torch 5 | import pickle 6 | 7 | 8 | #### 9 | UNKNOWN='@@UNKNOWN@@' 10 | START='@start@' 11 | END='@end@' 12 | PAD='@pad@' 13 | EOW='@eow@' 14 | 15 | ################### 16 | def load_gutenberg_word2vec(pth): 17 | data = [r.strip().split() for r in open(pth,"r").readlines()] 18 | word_2_emb = {r[0]:np.array([float(v) for v in r[1:]]) for r in data} 19 | return word_2_emb 20 | 21 | def load_sonnet_vocab(pth): 22 | data = [r.strip() for r in open(pth,"r").readlines()] 23 | return data 24 | 25 | 26 | ######################### 27 | # Modified indexer 28 | class Indexer: 29 | 30 | def __init__(self, args): 31 | # self.w2idx = {'start': 1, 'pad': 0, 'end': 2} 32 | self.w2idx = {START: 1, PAD: 0, END: 2, UNKNOWN:3} 33 | # self.w2idx = {'pad': 0} 34 | if args.data_type == "sonnet_endings": 35 | self.w2idx[EOW] = 4 36 | self.w_cnt = len(self.w2idx) 37 | self.idx2w = None 38 | 39 | def process(self, lst_of_lst): 40 | for item in lst_of_lst: 41 | for v in item: 42 | if v not in self.w2idx: 43 | self.w2idx[v] = self.w_cnt 44 | self.w_cnt += 1 45 | self.idx2w = {i: w for w, i in self.w2idx.items()} 46 | 47 | def w_to_idx(self, w, is_char_level=False): 48 | # print("w=",w) 49 | if UNKNOWN not in self.w2idx: 50 | self.w2idx[UNKNOWN] = self.w2idx[PAD] #TODO: Need a permanent fix for this 51 | if is_char_level: 52 | w = [str(ch) for ch in w if ord(ch) < 128] 53 | return [self.w2idx[ch] if ch in self.w2idx else self.w2idx[UNKNOWN] for ch in w] 54 | 55 | def idx_to_2(self, w): 56 | return [self.idx2w[ch] for ch in w] 57 | 58 | def load(self, prefix): 59 | self.w2idx = pickle.load( open(prefix+'.w2idx.pkl','rb')) 60 | self.idx2w = pickle.load( open(prefix+'.idx2w.pkl','rb')) 61 | self.w_cnt = pickle.load( open(prefix+'.w_cnt.pkl','rb')) 62 | 63 | def save(self, prefix): 64 | pickle.dump(self.w2idx, open(prefix+'.w2idx.pkl','wb')) 65 | pickle.dump(self.idx2w, open(prefix+'.idx2w.pkl','wb')) 66 | pickle.dump(self.w_cnt, open(prefix+'.w_cnt.pkl','wb')) 67 | 68 | 69 | def get_char_seq_from_word_seq(self, lst_of_words, w2idx, use_eow_marker, eow_marker=None): 70 | #lst of words 71 | #each word is list of chars 72 | #consider end of word_sep 73 | # retuerns a lit of units for indexing 74 | ret = [] 75 | for j,w in enumerate(lst_of_words): 76 | ret.extend(w) 77 | if use_eow_marker and j<(len(lst_of_words)-1): 78 | ret.append(eow_marker) 79 | # ret_idx = [self.g_indexer.w2idx[ch] for ch in ret] 80 | ret_idx = [w2idx[ch] for ch in ret] 81 | return {'x':ret, 'indexed_x':ret_idx} 82 | 83 | 84 | ################ 85 | 86 | 87 | def preproLine(line, lower=True, ascii_only=True, remove_punc=False): 88 | if lower: 89 | line = line.lower() 90 | if ascii_only: 91 | line = ''.join([str(ch) for ch in line if ord(ch)<=127]) 92 | if remove_punc: 93 | line = ''.join([str(ch) for ch in line if (ch>='A' and ch<='Z') or (ch>='a' and ch<='z') or (ch==' ')]) 94 | return line 95 | 96 | 97 | def loadCMUDict(fname="../data/cmudict-0.7b.txt"): 98 | #d = open(fname,"r").readlines() 99 | d = [] 100 | with codecs.open(fname, "r",encoding='utf-8', errors='ignore') as fdata: 101 | for line in fdata: 102 | d.append(line) 103 | d = [line for line in d if line[0].isalpha()] 104 | d = [line.strip().split(' ') for line in d] 105 | #print d[0] 106 | ret = { val[0].strip().lower():val[1].strip().split(' ') for val in d } 107 | return ret 108 | 109 | 110 | 111 | #following function is due to https://github.com/aparrish/rwet-examples/blob/master/pronouncing/cmudict.py 112 | def get_rhyming_part(phones_list): 113 | """Returns the "rhyming part" of a string with phones. "Rhyming part" here 114 | means everything from the vowel in the stressed syllable nearest the end 115 | of the word up to the end of the word.""" 116 | # return get_rhyming_part_deepspeare(phones_list) ### TAKE A NOTE OF THIS 117 | idx = 0 118 | for i in reversed(range(0, len(phones_list))): 119 | if phones_list[i][-1] in ('1', '2'): 120 | idx = i 121 | break 122 | return ' '.join(phones_list[idx:]) 123 | 124 | def get_rhyming_part_deepspeare(phones_list): 125 | """Returns the "rhyming part" of a string with phones. "Rhyming part" here 126 | means everything from the vowel in the syllable nearest the end 127 | of the word up to the end of the word.""" 128 | idx = 0 129 | for i in reversed(range(0, len(phones_list))): 130 | if phones_list[i][-1] in ('1', '2', '0'): 131 | idx = i 132 | break 133 | return ' '.join(phones_list[idx:]) 134 | 135 | 136 | 137 | def compute_rhyming_pattern(lst_of_words, cmu_dict): 138 | rhyming_part_to_idx = {} 139 | word_to_rhyming_part = [] 140 | not_found = False 141 | for w in lst_of_words: 142 | if w not in cmu_dict: 143 | not_found = True 144 | return None, not_found 145 | rhyming_part = get_rhyming_part(cmu_dict[w]) 146 | word_to_rhyming_part.append(rhyming_part) 147 | if rhyming_part not in rhyming_part_to_idx: 148 | rhyming_part_to_idx[rhyming_part] = len(rhyming_part_to_idx) 149 | return [rhyming_part_to_idx[rm] for rm in word_to_rhyming_part], not_found 150 | 151 | ######### 152 | #Stress 153 | def _extract_stress_pattern_word(word, cmu_dict): 154 | phones_list = cmu_dict[word] 155 | ret = [] 156 | for i in range(0, len(phones_list)): 157 | if phones_list[i][-1] in ('1', '2'): 158 | ret.append(1) 159 | elif phones_list[i][-1] in ('0'): 160 | ret.append(0) 161 | return ret,phones_list 162 | 163 | def _extract_stress_pattern_lst_of_words(lst_of_words, cmu_dict): 164 | ret = [] 165 | for word in lst_of_words: 166 | s,p = _extract_stress_pattern_word(word, cmu_dict) 167 | #print(word,s,p) 168 | ret.extend(s) 169 | return ret 170 | 171 | def _count_violations(seq): 172 | ret = 0 173 | for j in range(len(seq)-1): 174 | if seq[j]==seq[j+1]: 175 | ret+=1 176 | #print("count: seq, ret : ", seq, ret) 177 | return ret 178 | 179 | def count_stress_pattern_violations(lst_of_words, cmu_dict): 180 | violations = 0 181 | total_valid = 0 182 | for i in range(len(lst_of_words)-1): 183 | w1 = lst_of_words[i] 184 | w2 = lst_of_words[i+1] 185 | if w1 in cmu_dict and w2 in cmu_dict: 186 | s1,p1 = _extract_stress_pattern_word(w1, cmu_dict) 187 | #print(w1,s1,p1) 188 | s2,p2 = _extract_stress_pattern_word(w2, cmu_dict) 189 | #print(w2,s2,p2) 190 | violations_i = _count_violations(s1+s2) 191 | #print(violations_i) 192 | violations+=violations_i 193 | total_valid+=(len(s1+s2)-1) 194 | #print() 195 | return violations, total_valid 196 | 197 | ########## 198 | 199 | def test_rhyming_pattern(): 200 | cmu_data = loadCMUDict() 201 | print(compute_rhyming_pattern(['read','head','bed'],cmu_data)) 202 | arr = [ ['rose', 'foes', 'descends', 'lends'], 203 | ['pill', 'will', 'west', 'best'], 204 | ['round', 'hand', 'sound', 'mused'], 205 | ['clene', 'hide', 'fears', 'bent'], 206 | ['poles', 'ground', 'found', 'ground'], 207 | ['keep', 'hill', 'even', 'knows'], 208 | ['made', 'knew', 'stage', 'past'], 209 | ['spend', 'death', 'sea', 'calling'], 210 | ['fen', 'den', 'artichoke', 'rock'], 211 | ['wear', 'waste', 'away', 'slay'], 212 | ['strife', 'life', 'best', 'best'] ] 213 | for arri in arr: 214 | print("arr=",arri) 215 | print("pattern=", compute_rhyming_pattern(arri,cmu_data)) 216 | 217 | def test_stress_funcs(): 218 | global cmu_dict 219 | cmu_dict = loadCMUDict() 220 | print(_extract_stress_pattern_word('bat')) 221 | print() 222 | print(_extract_stress_pattern_lst_of_words(['bat','is','near','the','window'])) 223 | print() 224 | i,o = count_stress_pattern_violations(['bat','is','near','the','window']) 225 | print("invalid=",i," out of o=",o, " [remaining valid. ignore hwne not foiund in cmu dic]") 226 | print() 227 | 228 | # test_rhyming_pattern() 229 | # test_stress_funcs() 230 | -------------------------------------------------------------------------------- /code/rhyming_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import utils 4 | import pickle 5 | 6 | tmp_dir = 'tmp/' 7 | model_name = 'lstm_sonnet_reinforce0_1_allsonnet_ae' 8 | epoch = '40' 9 | 10 | from constants import * 11 | 12 | cmu_dict = utils.loadCMUDict('data/resources/cmudict-0.7b.txt') 13 | 14 | class RhymingEval: 15 | 16 | def __init__(self, dataset_identifier:str = SONNET_ENDINGS_DATASET_IDENTIFIER): 17 | self.dataset_identifier = dataset_identifier 18 | assert dataset_identifier in [SONNET_ENDINGS_DATASET_IDENTIFIER, LIMERICK_DATASET_IDENTIFIER] 19 | if self.dataset_identifier == SONNET_ENDINGS_DATASET_IDENTIFIER: 20 | self.interesting_patterns = ['0011','0110','0101'] 21 | else: 22 | self.interesting_patterns = ['00110'] 23 | 24 | def __str__(self): 25 | return 'RhymingEval[dataset_identifier={},interesting_patterns={}]'\ 26 | .format(self.dataset_identifier,self.interesting_patterns) 27 | 28 | def _group_by_rhyming(self): 29 | ret = {} 30 | for w,p in cmu_dict.items(): 31 | pr = utils.get_rhyming_part(p) 32 | #print(p,pr) 33 | if pr not in ret: 34 | ret[pr] = [] 35 | ret[pr].append(w) 36 | #break 37 | return ret 38 | 39 | def setup(self, line_endings_file_location, line_endings_file_location_test): 40 | # rhy_to_words = self._group_by_rhyming() 41 | # print(len(rhy_to_words), list(rhy_to_words.items())[10]) 42 | # rhy_to_words_items = sorted(list(rhy_to_words.items()), key=lambda x: -len(x[1])) 43 | all_sonnet_line_endings = json.load(open(line_endings_file_location,'r')) 44 | all_sonnet_line_endings_test = json.load(open(line_endings_file_location_test,'r')) 45 | # if self.dataset_identifier == SONNET_ENDINGS_DATASET_IDENTIFIER: 46 | # all_sonnet_line_endings = json.load(open('all_line_endings_sonnet.json','r')) 47 | # else: 48 | # all_sonnet_line_endings = json.load(open('all_line_endings_limerick.json','r')) 49 | #6*len(all_sonnet_line_endings) 50 | #TODO: do split-wise. load different splits 51 | 52 | def process_endings(data): 53 | rhyming_pairs = [] 54 | non_rhyming = [] 55 | pattern_cnt = {} 56 | for line in data: 57 | pattern, failure = utils.compute_rhyming_pattern(line, cmu_dict) 58 | if not failure: 59 | pattern = ''.join([str(p) for p in pattern]) 60 | if pattern not in pattern_cnt: 61 | pattern_cnt[pattern] = 0 62 | pattern_cnt[pattern] += 1 63 | for i in range(len(line)): 64 | for j in range(i+1,len(line)): 65 | w1 = line[i] 66 | w2=line[j] 67 | if w1 in cmu_dict and w2 in cmu_dict: 68 | p1 = utils.get_rhyming_part(cmu_dict[w1]) 69 | p2 = utils.get_rhyming_part(cmu_dict[w2]) 70 | if p1==p2: 71 | rhyming_pairs.append([w1,w2]) 72 | else: 73 | non_rhyming.append([w1,w2]) 74 | return rhyming_pairs, non_rhyming, pattern_cnt 75 | 76 | rhyming_pairs, non_rhyming, pattern_cnt = process_endings(all_sonnet_line_endings) 77 | self.pattern_cnt_sum = sum(pattern_cnt.values()) 78 | self.rhyming_pairs, self.non_rhyming, self.pattern_cnt = rhyming_pairs, non_rhyming, pattern_cnt 79 | print("------>>> rhyming_eval: len(self.rhyming_pairs) = ", len(self.rhyming_pairs), 80 | "\n -- len(self.non_rhyming) = ", len(self.non_rhyming) ) 81 | 82 | rhyming_pairs_test, non_rhyming_test, pattern_cnt_test = process_endings(all_sonnet_line_endings_test) 83 | self.pattern_cnt_sum_test = sum(pattern_cnt_test.values()) 84 | self.rhyming_pairs_test, self.non_rhyming_test, self.pattern_cnt_test = rhyming_pairs_test, non_rhyming_test, pattern_cnt_test 85 | print("------>>> rhyming_eval: len(self.rhyming_pairs_test) = ", len(self.rhyming_pairs_test), 86 | "\n -- len(self.non_rhyming_test) = ", len(self.non_rhyming_test)) 87 | 88 | 89 | def _get_spelling_baseline(self, w1, w2, spelling_type='last3'): 90 | if spelling_type=='last3': 91 | return w1[-3:] == w2[-3:] 92 | elif spelling_type=='last2': 93 | return w1[-2:] == w2[-2:] 94 | elif spelling_type=='last4': 95 | return w1[-4:] == w2[-4:] 96 | elif spelling_type=='last1': 97 | return w1[-1:] == w2[-1:] 98 | else: 99 | assert False 100 | 101 | 102 | ###### emb 103 | def _load_cosines(self, emb, get_spelling_baseline_for_rhyming, spelling_type, split='dev'): 104 | if split == 'test': 105 | rhyming_pairs = self.rhyming_pairs_test 106 | non_rhyming = self.non_rhyming_test 107 | else: 108 | rhyming_pairs = self.rhyming_pairs 109 | non_rhyming = self.non_rhyming 110 | all_cosines_r = [] 111 | all_cosines_nr = [] 112 | for r in rhyming_pairs: 113 | w1,w2 = r 114 | if w1 in emb and w2 in emb: 115 | if get_spelling_baseline_for_rhyming: 116 | if self._get_spelling_baseline(w1,w2, spelling_type): 117 | all_cosines_r.append(1) 118 | else: 119 | all_cosines_r.append(0) 120 | else: 121 | e1_numpy, e2_numpy = emb[w1], emb[w2] 122 | all_cosines_r.append(np.sum(e1_numpy * e2_numpy)/ np.sqrt( (np.sum(e1_numpy * e1_numpy) * np.sum(e2_numpy * e2_numpy)))) 123 | for r in non_rhyming: 124 | w1,w2 = r 125 | if w1 in emb and w2 in emb: 126 | if get_spelling_baseline_for_rhyming: 127 | if self._get_spelling_baseline(w1,w2, spelling_type): 128 | all_cosines_nr.append(1) 129 | else: 130 | all_cosines_nr.append(0) 131 | else: 132 | e1_numpy, e2_numpy = emb[w1], emb[w2] 133 | all_cosines_nr.append(np.sum(e1_numpy * e2_numpy)/ np.sqrt( (np.sum(e1_numpy * e1_numpy) * np.sum(e2_numpy * e2_numpy)))) 134 | return np.array(all_cosines_r), np.array(all_cosines_nr) 135 | 136 | 137 | def analyze_embeddings_for_rhyming(self, emb_loc, get_spelling_baseline_for_rhyming=False, spelling_type=None): 138 | emb = pickle.load(open(emb_loc,'rb')) 139 | return self.analyze_embeddings_for_rhyming_from_dict(emb) 140 | 141 | 142 | def analyze_embeddings_for_rhyming_from_dict(self, emb, get_spelling_baseline_for_rhyming=False, spelling_type=None): 143 | 144 | all_cosines_r,all_cosines_nr = self._load_cosines(emb, get_spelling_baseline_for_rhyming, spelling_type) 145 | maxf1 = 0 146 | thresh_f1 = -1 147 | thresh=0.5 148 | # details = [] 149 | while thresh<0.95: #[0.8,0.75, 0.77, 0.73, 0.70, 0.65, 0.55, 0.60, 0.63, 0.85, 0.86]: 150 | vals = [sum(all_cosines_r>thresh), sum(all_cosines_nr>thresh), sum(all_cosines_r<=thresh), sum(all_cosines_nr<=thresh)] 151 | prec = vals[0]/(vals[0]+vals[1]) 152 | rec = vals[0]/(vals[0]+vals[2]) 153 | f1 = 2*prec*rec/(prec+rec) 154 | if f1>maxf1: 155 | maxf1 = f1 156 | thresh_f1 = thresh 157 | print("thresh, prec, rec, f1 = ", thresh, prec, rec, f1) 158 | thresh+=0.01 #0.001 159 | print("========= len(all_cosines_r), len(all_cosines_nr), thresh_f1, maxf1 =", len(all_cosines_r), len(all_cosines_nr), thresh_f1, maxf1) 160 | 161 | all_cosines_r, all_cosines_nr = self._load_cosines(emb, get_spelling_baseline_for_rhyming, spelling_type, split='test') 162 | thresh = thresh_f1 163 | vals = [sum(all_cosines_r > thresh), sum(all_cosines_nr > thresh), sum(all_cosines_r <= thresh), 164 | sum(all_cosines_nr <= thresh)] 165 | prec = vals[0] / (vals[0] + vals[1]) 166 | rec = vals[0] / (vals[0] + vals[2]) 167 | maxf1_test = 2 * prec * rec / (prec + rec) 168 | print(" -test- len(all_cosines_r), len(all_cosines_nr), thresh, prec, rec, f1 = ", len(all_cosines_r), len(all_cosines_nr), thresh, prec, rec, maxf1_test) 169 | #print(" -test- thresh, prec, rec, f1 = ", thresh, prec, rec, maxf1_test) 170 | print("========= maxf1_test =", maxf1_test) 171 | return {'thresh_f1':thresh_f1, 'maxf1':maxf1,'test_f1':maxf1_test} 172 | 173 | 174 | ## pattern eval 175 | def analyze_samples_from_endings(self, samples): 176 | from collections import defaultdict 177 | epoch_pattern_cnt = defaultdict(lambda: 0) 178 | total_count = 0 179 | for sample in samples: 180 | pattern, failure = utils.compute_rhyming_pattern(sample, cmu_dict) 181 | if not failure: 182 | pattern = ''.join([str(p) for p in pattern]) 183 | if pattern not in epoch_pattern_cnt: 184 | epoch_pattern_cnt[pattern] = 0 185 | epoch_pattern_cnt[pattern] += 1 186 | total_count += 1 187 | print("[ANALYZING SAMPLES] [model patterns] epoch_pattern_count of samples = ", \ 188 | json.dumps(epoch_pattern_cnt, indent=4)) 189 | pattern_dist = {k:v*1.0/total_count for k,v in sorted(epoch_pattern_cnt.items(), key=lambda x:-x[1])} 190 | print("[ANALYZING SAMPLES] [modeel_patterns] epoch_pattern_count of samples (in %) = ", \ 191 | json.dumps(pattern_dist,indent=4)) 192 | data_pattern_dist = {k:v*1.0/self.pattern_cnt_sum for k,v in sorted(self.pattern_cnt.items(), key=lambda x:-x[1])} 193 | print("[ANALYZING SAMPLES] [dataset patterns] epoch_pattern_count of samples (in %) = ", \ 194 | json.dumps(data_pattern_dist,indent=4)) 195 | print() 196 | f = sum([epoch_pattern_cnt[pattern] for pattern in self.interesting_patterns])*1.0/total_count 197 | summary = {} 198 | summary['pattern_dist_model'] = pattern_dist 199 | summary['pattern_dist_data'] = data_pattern_dist 200 | #summary['kl'] = kl 201 | summary['pattern_success_ratio'] = f 202 | summary['pattern_sampling_rate'] = 1.0/f if f!=0 else float("inf") 203 | summary['pattern_cnt_of_samples_tested'] = total_count 204 | summary['patterns_tested_for_success'] = self.interesting_patterns 205 | return summary 206 | 207 | ''' 208 | Assumes a specific format. 209 | Samples is a list of strings, each string represents one poem 210 | ''' 211 | def analyze_samples_from_poetry_samples(self, samples): 212 | endings = [] 213 | for sample in data: 214 | sample = sample.strip().split(' ')[:-1] 215 | sample = [line.strip().split()[-1] for line in sample] 216 | endings.append(sample) 217 | return analyze_samples_from_endings(endings) 218 | 219 | 220 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | import torch 3 | import random 4 | import numpy as np 5 | import argparse 6 | def str2bool(val): 7 | if val.lower() in ['1','true','t']: 8 | return True 9 | if val.lower() in ['0','false','f']: 10 | return False 11 | print("val = ", val) 12 | return 0/0 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--mode', type=str, default="train",help='train/train_lm') 16 | parser.add_argument('--use_cuda', type=str2bool, default='true') 17 | parser.add_argument('--model_name', type=str, default="default",help='') 18 | parser.add_argument('--use_alignment', type=str2bool, default="false",help='') 19 | parser.add_argument('--model_type', type=str, default="encdec",help='') 20 | parser.add_argument('--data_type', type=str, default="sonnet_endings",help='limerick or sonnet_endings') 21 | parser.add_argument('--use_all_sonnet_data', type=str2bool, default="false",help='') 22 | parser.add_argument('--debug', type=str2bool, default="false",help='') 23 | parser.add_argument('--epochs', type=int, default=41,help='') 24 | parser.add_argument('--batch_size', type=int, default=64,help='') 25 | parser.add_argument('--gutenberg_emb_path', default="data/pretrained_embeddings/gutenberg_word2vec.txt",help='') 26 | parser.add_argument('--sonnet_vocab_path', default="data/splits/sonnet/vocabulary_v2/tokens.txt",help='') 27 | parser.add_argument('--limerick_vocab_path', default="data/vocab/limerick_vocabulary_0_threshold_2_tokens.txt",help='') 28 | parser.add_argument('--cmu_data_path', default="data/resources/cmudict-0.7b.txt",help='') 29 | parser.add_argument('--sonnet_data_path', default="data/splits/sonnet/sonnet_",help='') 30 | parser.add_argument('--all_line_endings_sonnet', default="data/splits/sonnet/all_line_endings_sonnet_quatrains_valid.json",help='') 31 | parser.add_argument('--all_line_endings_sonnet_test', default="data/splits/sonnet/all_line_endings_sonnet_quatrains_test.json",help='') 32 | parser.add_argument('--all_line_endings_limerick', default="data/splits/limerick_only_subset/all_line_endings_limerick_val.json",help='') 33 | parser.add_argument('--all_line_endings_limerick_test', default="data/splits/limerick_only_subset/all_line_endings_limerick_test.json",help='') 34 | parser.add_argument('--limerick_data_path', default="data/splits/limerick_only_subset/",help='') 35 | parser.add_argument('--tie_emb', type=str2bool, default="false",help='') 36 | parser.add_argument('--use_eow_in_enc', type=str2bool, default="false",help='MAKE SURE this is consistent with option used in g2p training. this bascally attaches a end marker at of seq of charatcters in a word. This should impoact discriminators ugin charatccer level encoders') 37 | 38 | 39 | parser.add_argument('--pretraining_data_type', default='cmu_dict_words', help='cmu_dict_words OR sonnet OR limerick') 40 | parser.add_argument('--g2p_model_name', default=None,help='') 41 | parser.add_argument('--use_reinforce_loss', default='true', type=str2bool, help='') 42 | parser.add_argument('--trainable_g2p', default='true', type=str2bool, help='') 43 | parser.add_argument('--reinforce_weight', default=1.0, type=float, help='') 44 | parser.add_argument('--load_gutenberg', default='false', type=str2bool, help='') 45 | parser.add_argument('--load_gutenberg_path', default="data/pretrained_embeddings/gutenberg_word2vec.txt", help='') 46 | parser.add_argument('--freeze_emb', default='false', type=str2bool, help='') 47 | parser.add_argument('--emsize', default=128, type=int, help='') 48 | parser.add_argument('--H', default=128, type=int, help='') 49 | parser.add_argument('--pretrain_lm', default='false', type=str2bool, help='') 50 | parser.add_argument('--train_lm_supervised', default='false', type=str2bool, help='train lm using superivsed objective also') 51 | parser.add_argument('--add_entropy_regularizer', default='false', type=str2bool, help='') #TODO - not yet implemented 52 | parser.add_argument('--use_score_as_reward', default='false', type=str2bool, help='') 53 | parser.add_argument('--solver_type', default='Endings', type=str, help='Main, Endings') 54 | parser.add_argument('--seed', default=123, type=int, help='') 55 | parser.add_argument('--save_vanilla_lm', default='true', type=str2bool, help='Whether to save vanilla LM') 56 | parser.add_argument('--load_vanilla_lm', default='false', type=str2bool, help='Whether to load pre-trained vanilla LM') 57 | parser.add_argument('--vanilla_lm_path', default='tmp/tmp_best_vanilla_lm', help='Path to folder which has the stored vanilla lm') 58 | 59 | #analysis mode params only 60 | parser.add_argument('--epoch_to_test', default='40', help='') 61 | parser.add_argument('--tmp_dir', default='tmp/', help='') 62 | parser.add_argument('--dump_matrices', default='true', type=str2bool, help='dump_matrices') 63 | parser.add_argument('--learn_g2p_encoder_from_scratch', default='false',type=str2bool,help='') ## when true, does NOT load g2pmodel from g2p_model_name, though still loads indexers from g2p_model_name 64 | parser.add_argument('--temperature', default=1.0, type=float, help='currently being used only in eval mode while generation') ## 65 | parser.add_argument('--disc_type', default=DISC_TYPE_MATRIX, type=str, help=''+DISC_TYPE_MATRIX+' OR '+DISC_TYPE_NON_STRUCTURED) ## 66 | parser.add_argument('--num_samples_at_epoch_end', default=40, type=int, help='') ## 67 | 68 | args = parser.parse_args() 69 | print(" ======== args ====== ") 70 | for arg in vars(args): 71 | print( arg, ":", getattr(args,arg) ) 72 | print("============== \n ") 73 | torch.manual_seed(args.seed) 74 | np.random.seed(args.seed) 75 | random.seed(args.seed) 76 | use_cuda = cuda = args.use_cuda 77 | if args.load_gutenberg: 78 | assert args.emsize==100, "Gutenberg embeddings are of size 100 - use emsize=100" 79 | print() 80 | 81 | 82 | from solvers_merged import * 83 | from solvers_pretrain_disc_encoder import * 84 | # from utils import Indexer 85 | 86 | assert args.disc_type in [DISC_TYPE_MATRIX, DISC_TYPE_NON_STRUCTURED] 87 | 88 | ################ Data 89 | 90 | vocab = None 91 | if args.data_type == 'limerick': 92 | limerick_vocab = load_sonnet_vocab(args.limerick_vocab_path) 93 | print(len(limerick_vocab), limerick_vocab[0:5]) 94 | vocab = limerick_vocab 95 | else: 96 | print("[INFO] Loading Sonnet vocab ... ") 97 | sonnet_vocab = load_sonnet_vocab(args.sonnet_vocab_path) 98 | print("len(sonnet_vocab), sonnet_vocab[0:5] = ", len(sonnet_vocab), sonnet_vocab[0:5]) 99 | vocab = sonnet_vocab 100 | print() 101 | 102 | 103 | print("[INFO] Loading CMU dictionary...") 104 | cmu_data = loadCMUDict(args.cmu_data_path) 105 | print("cmu dictionary: list(cmu_data.items())[0:10] = ", list(cmu_data.items())[0:10]) 106 | print() 107 | 108 | 109 | ################ Model dumps save dir 110 | model_dir = 'tmp/tmp_'+args.model_name+'/' 111 | if not os.path.exists(model_dir): 112 | os.makedirs(model_dir) 113 | args.model_dir = model_dir 114 | 115 | ################ Solver choice 116 | if args.solver_type == "Main": # endings model training, LM training 117 | 118 | solver = MainSolver(typ=args.data_type, cmu_dict=cmu_data, args=args, mode=args.mode) 119 | 120 | if args.mode == "train_lm": 121 | 122 | if not os.path.exists(args.vanilla_lm_path): 123 | os.makedirs(args.vanilla_lm_path) 124 | 125 | solver.train_lm(epochs=30, debug=False, args=args) 126 | # solver.train_lm(epochs=1, debug=True, args=args) 127 | 128 | else: 129 | 130 | #### Loading pretrained LM 131 | if args.load_vanilla_lm: 132 | lm_model_state_dict = torch.load(os.path.join(args.vanilla_lm_path, 'model_best')) 133 | solver.lm_model.load_state_dict(lm_model_state_dict) 134 | print("LM model initialized with pre-trained model!") 135 | 136 | ################ Solver sanity check and Save solver indexers, etc 137 | # - only if train mode (to not overwrite in analysis/eval mode) 138 | solver.get_stats_ending_words(split='train', batch_size=32) 139 | print(solver.g_indexer.w_cnt) 140 | 141 | if args.mode=="train": 142 | solver.save(model_dir + 'solver_') 143 | else: 144 | solver.load(model_dir + 'solver_') 145 | 146 | # Some sanity checks 147 | # list(solver.g_indexer.w2idx.items())[0:5] 148 | x_sample,_,x_start = solver.get_batch(i=0, batch_size=3, split='train') #, typ=args.data_type) 149 | print("x_sample = ", x_sample) 150 | 151 | 152 | ################ Pretraining LM Model 153 | if args.pretrain_lm: 154 | args.use_reinforce_loss = False 155 | model_dir = args.model_dir 156 | model_name = args.model_name 157 | args.model_dir = 'tmp/tmp_pretrain_'+args.model_name+'/' 158 | if not os.path.exists(args.model_dir): 159 | os.makedirs(args.model_dir) 160 | args.model_name = 'pretrain_'+args.model_name 161 | solver.train(args.epochs, debug=args.debug) 162 | state_dict_best = torch.load(args.model_dir+'model_best') 163 | solver.model.load_state_dict(state_dict_best) 164 | #reset 165 | args.use_reinforce_loss = True 166 | args.model_dir = model_dir 167 | args.model_name = model_name 168 | 169 | 170 | ################ mode=train:Training Full Model 171 | print(args.mode) 172 | if args.mode=="train": 173 | print("-x"*55) 174 | solver.train(args.epochs, debug=args.debug, train_lm_supervised=args.train_lm_supervised) 175 | elif args.mode=="eval": # || mode=eval eval_type=rhyming 176 | # load the model states 177 | # solver.load_models(args.model_name, args.model_state) 178 | args.model_dir = args.tmp_dir+'tmp_'+args.model_name+'/' 179 | epoch=args.epoch_to_test #'40' #'best' #'40' 180 | solver.load_models(model_dir=args.model_dir, model_epoch=epoch, load_lm=False) 181 | #run the analysis 182 | solver.analysis(epoch=epoch, args=args) 183 | elif args.mode == "lm_eval": 184 | # load the model states 185 | # solver.load_models(args.model_name, args.model_state) 186 | print("lm_eval mode ON!!!!!") 187 | args.model_dir = args.tmp_dir + 'tmp_' + args.model_name + '/' 188 | epoch = args.epoch_to_test # '40' #'best' #'40' 189 | solver.load_models(model_dir=args.model_dir, model_epoch=epoch, load_lm=False) 190 | # run the analysis 191 | solver.lm_analysis(epoch=epoch, args=args) 192 | 193 | 194 | elif args.solver_type == "Endings": #pretraining 195 | 196 | assert args.data_type in ["ae","g2p", "sonnet_endings", "g2plast"] 197 | if args.tie_emb: 198 | assert args.data_type=="ae" 199 | use_cuda = cuda = args.use_cuda 200 | 201 | ################ 202 | model_dir = 'tmp/tmp_'+args.model_name+'/' 203 | if not os.path.exists(model_dir): 204 | print("Creating ", model_dir, " ... ") 205 | os.makedirs(model_dir) 206 | print() 207 | 208 | print("[INFO] Creating solver ... ") 209 | solver = EndingsSolver(typ=args.data_type, cmu_dict=cmu_data, args=args) 210 | print() 211 | 212 | print("[INFO] Getting splits ... ") 213 | # solver.get_splits() 214 | solver.get_splits(data_type=args.pretraining_data_type) # 'sonnet', 'cmu_dict_words' 215 | print() 216 | 217 | print("Indexing data ... ") 218 | solver.index() 219 | print("[INFO] solver.g_indexer.w_cnt = ", solver.g_indexer.w_cnt) 220 | print() 221 | 222 | if args.mode=="train": 223 | print("[INFO] Saving indexer at " , model_dir + 'solver_', " ... ") 224 | solver.save(model_dir + 'solver_') 225 | print() 226 | 227 | print("[INFO] Some data samples ... ") 228 | x_sample,y_sample,y_start = solver.get_batch(i=0, batch_size=3, split='train') #, typ=args.data_type) 229 | print("x_sample = ", x_sample) 230 | print("y_sample = ", y_sample) 231 | print("x_sample[0]: idx_to_2: = ", solver.g_indexer.idx_to_2(x_sample[0])) 232 | 233 | print("[INFO] Create model ... ") 234 | solver.init_model(y_start=y_start) 235 | print() 236 | 237 | if args.mode=="train": 238 | print("[INFO] Beginning Training ... ") 239 | solver.train(args.epochs, debug=args.debug, use_alignment=args.use_alignment, model_dir=model_dir) 240 | 241 | elif args.mode=="eval": 242 | if args.data_type=="ae" or args.data_type=="g2plast" or args.data_type=="g2p" or args.data_type=="aelast": 243 | solver.analyze(vocab=vocab) 244 | else: 245 | assert False 246 | 247 | else: 248 | 249 | assert False 250 | 251 | 252 | -------------------------------------------------------------------------------- /code/models.py: -------------------------------------------------------------------------------- 1 | from torch.distributions import Categorical 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from utils import Indexer 6 | import utils 7 | from constants import * 8 | 9 | 10 | ################ 11 | class Model(nn.Module): 12 | 13 | def __init__(self, H, i_size=None, o_size=None, emsize=128, start_idx=1, end_idx=2, typ='encdec', args=None, 14 | tie_inp_out_emb=False): 15 | super(Model, self).__init__() 16 | self.H = H 17 | self.emsize = emsize 18 | self.emb_i = nn.Embedding(i_size, emsize) 19 | if tie_inp_out_emb: 20 | self.emb_o = self.emb_i 21 | else: 22 | self.emb_o = nn.Embedding(o_size, emsize) 23 | self.encoder = nn.LSTMCell(emsize, H) 24 | self.decoder = nn.LSTMCell(emsize, H) 25 | self.softmax = nn.Softmax() 26 | self.decoder_layer = nn.Linear(H, o_size) 27 | self.sig = nn.Sigmoid() 28 | self.start_idx = start_idx 29 | self.end_idx = end_idx 30 | self.typ = typ 31 | self.use_cuda = args.use_cuda 32 | 33 | def encode(self, e1): 34 | h = torch.zeros(1, self.H), torch.zeros(1, self.H) 35 | if self.use_cuda: 36 | h = h[0].cuda(), h[1].cuda() 37 | for ch in e1: 38 | ch_idx = torch.tensor(ch) 39 | if self.use_cuda: 40 | ch_idx = ch_idx.cuda() 41 | ch_emb = self.emb_i(ch_idx).view(1, -1) 42 | h = self.encoder(ch_emb, h) 43 | out, c = h 44 | return c 45 | 46 | def decode(self, enc, gt, use_gt=True, max_steps=11): 47 | info = {} 48 | h = enc, enc 49 | sz = len(gt) 50 | if not use_gt: 51 | sz = max_steps 52 | ch = self.start_idx 53 | ch_idx = ch 54 | out_all = [] 55 | prediction = [] 56 | for i in range(sz): 57 | ch_idx = torch.tensor(ch_idx) 58 | if self.use_cuda: 59 | ch_idx = ch_idx.cuda() 60 | ch_emb = self.emb_o(ch_idx).view(1, -1) 61 | h = self.decoder(ch_emb, h) 62 | out, _ = h 63 | out = self.decoder_layer(out) 64 | out_all.append(out) 65 | if not use_gt: 66 | preds = self.softmax(out) 67 | pred = torch.argmax(preds) 68 | if pred.cpu().numpy() == self.end_idx: 69 | break 70 | ch_idx = pred 71 | prediction.append(ch_idx.data.cpu().item()) 72 | else: 73 | ch_idx = gt[i] 74 | if use_gt: 75 | return out_all, info 76 | else: 77 | return prediction, info 78 | 79 | 80 | ################ 81 | class GeneratorModel(nn.Module): 82 | 83 | def __init__(self, H, o_size=None, emsize=128, start_idx=1, end_idx=2, typ='dec', args=None, use_gan=False, 84 | unk_idx=None): 85 | super(GeneratorModel, self).__init__() 86 | self.H = H 87 | self.emsize = emsize 88 | self.emb_o = nn.Embedding(o_size, emsize) 89 | self.phonetic_emb = nn.Embedding(o_size, emsize) 90 | self.decoder = nn.LSTMCell(emsize, H) 91 | self.softmax = nn.Softmax() 92 | self.decoder_layer = nn.Linear(H, o_size) 93 | self.sig = nn.Sigmoid() 94 | self.start_idx = start_idx 95 | self.end_idx = end_idx 96 | self.typ = typ 97 | self.use_cuda = args.use_cuda 98 | self.use_gan = use_gan 99 | assert not use_gan 100 | self.emsize = emsize 101 | self.args = args 102 | self.unk_idx = unk_idx 103 | 104 | def reparametrize(self, mu, logvar): 105 | std = logvar.mul(0.5).exp_() 106 | eps = torch.FloatTensor(std.size()).normal_() 107 | if self.use_cuda: 108 | eps = eps.cuda() 109 | return eps.mul(std).add_(mu) 110 | 111 | def freeze_emb(self): 112 | self.emb_o.weight.requires_grad = False 113 | print("[GeneratorModel] -- FREEZING word embedding emb_o") 114 | 115 | def display_params(self): 116 | print("[GeneratorModel]: model parametrs") 117 | for name, param in self.named_parameters(): 118 | print("name=", name, " || grad:", param.requires_grad, "| size = ", param.size()) 119 | 120 | def load_gutenberg_we(self, dict_pickle_pth, indexer_idx2w): 121 | sz = len(indexer_idx2w) 122 | emb = np.random.rand(sz, self.emsize) 123 | print(emb.shape) 124 | pretrained = {} 125 | lines = open(dict_pickle_pth, 'r').readlines() 126 | for line in lines: 127 | line = line.strip().split() 128 | w = line[0] 129 | e = np.array([float(val) for val in line[1:]]) 130 | pretrained[w] = e 131 | found = 0 132 | for i in range(sz): 133 | word = indexer_idx2w[i] 134 | if word in pretrained: 135 | emb[i, :] = pretrained[word] 136 | found += 1 137 | print("load_gutenberg_we: found", found, " out of ", sz) 138 | self.emb_o.weight.data.copy_(torch.from_numpy(emb)) 139 | 140 | def forward(self, gt=None, use_gt=True, max_steps=4, temperature=1.0): 141 | if self.use_gan: 142 | enc = torch.cuda.FloatTensor(np.random.normal(0, 1, (1, self.H))) 143 | else: 144 | enc = torch.zeros(1, self.H) 145 | if self.use_cuda: 146 | enc = enc.cuda() 147 | h = enc, enc # torch.zeros(1,self.H), torch.zeros(1,self.H) 148 | if not use_gt: 149 | if self.args.data_type == "sonnet_endings": 150 | sz = max_steps = 4 151 | elif self.args.data_type == "limerick": 152 | sz = max_steps = 5 153 | else: 154 | assert False 155 | else: 156 | sz = max_steps = len(gt) 157 | word = self.start_idx 158 | word_idx = word 159 | out_all = [] 160 | prediction = [] 161 | batch_size = 1 162 | probs, states, actions = [], [], [] 163 | for i in range(sz): 164 | word_idx = torch.tensor(word_idx) 165 | if self.use_cuda: 166 | word_idx = word_idx.cuda() 167 | # print(i,ch_idx) 168 | word_emb = self.emb_o(word_idx).view(1, -1) 169 | h = self.decoder(word_emb, h) 170 | out, _ = h 171 | out = self.decoder_layer(out) 172 | out = out / temperature 173 | out_all.append(out) 174 | action_probs = self.softmax(out) 175 | if not use_gt: 176 | torch_distribution = Categorical(action_probs.view(batch_size, -1)) 177 | not_done = True 178 | attempts = 21 179 | while not_done: 180 | attempts -= 1 181 | action = torch_distribution.sample() 182 | log_prob_action = torch_distribution.log_prob(action) 183 | action_idx = action.data[0] 184 | if action_idx.cpu().item() == self.end_idx: 185 | continue 186 | if action_idx.cpu().item() == self.unk_idx: 187 | continue 188 | word_idx = action_idx 189 | probs.append(log_prob_action) 190 | states.append(h) 191 | actions.append(action_idx) 192 | not_done = False 193 | else: 194 | word_idx = gt[i] 195 | return {'out_all': out_all, 'logprobs': probs, 'states': states, 'actions': actions} 196 | 197 | 198 | 199 | #################### 200 | 201 | 202 | class CNN(nn.Module): 203 | def __init__(self, args, reduced_size=None, info={}): 204 | super(CNN, self).__init__() 205 | # disc_type=DISC_TYPE_MATRIX 206 | self.disc_type = disc_type = args.disc_type 207 | self.layer1 = nn.Sequential( 208 | nn.Conv2d(1, 4, kernel_size=2, padding=0), 209 | nn.ReLU()) 210 | # 1,4,3,3 211 | self.layer2 = nn.Sequential( 212 | nn.Conv2d(4, 8, kernel_size=2), 213 | nn.ReLU()) 214 | # 1,8,2,2 215 | ## but for 5 lines, it is 1,8,3,3 216 | if args.data_type == "sonnet_endings": 217 | self.scorer = nn.Linear(2 * 2 * 8, 1) 218 | elif args.data_type == "limerick": 219 | self.scorer = nn.Linear(3 * 3 * 8, 1) 220 | self.predictor = nn.Sigmoid() 221 | self.args = args 222 | self.use_cuda = args.use_cuda 223 | 224 | ## 225 | self.g_indexer = Indexer(args) 226 | self.g_indexer.load('tmp/tmp_' + args.g2p_model_name + '/solver_g_indexer') 227 | self.g2pmodel = Model(H=info['H'], args=args, i_size=self.g_indexer.w_cnt, o_size=self.g_indexer.w_cnt, 228 | start_idx=self.g_indexer.w2idx[utils.START]) 229 | if not args.learn_g2p_encoder_from_scratch: 230 | print("=====" * 7, "LOADING g2p ENCODER PRETRAINED") 231 | model_dir = 'tmp/tmp_' + args.g2p_model_name + '/' 232 | state_dict_best = torch.load(model_dir + 'model_best') 233 | self.g2pmodel.load_state_dict(state_dict_best) 234 | if not args.trainable_g2p: 235 | assert not args.learn_g2p_encoder_from_scratch 236 | for param in self.g2pmodel.parameters(): 237 | param.requires_grad = False 238 | 239 | def display_params(self): 240 | print("=" * 44) 241 | print("[CNN]: model parametrs") 242 | for name, param in self.named_parameters(): 243 | print("name=", name, " || grad:", param.requires_grad, "| size = ", param.size()) 244 | print("=" * 44) 245 | 246 | def _compute_word_reps(self, words_str, deb=False): 247 | if deb: 248 | print("words_str = ", words_str) 249 | use_eow_marker = self.args.use_eow_in_enc 250 | assert not use_eow_marker, "Not yet tested" 251 | word_reps = [self.g_indexer.w_to_idx(s1) for s1 in words_str] 252 | if self.args.use_eow_in_enc: 253 | x_end = self.g_indexer.w2idx[utils.END] 254 | word_reps = [x_i + [x_end] for x_i in word_reps] 255 | word_reps = [self.g2pmodel.encode(w) for w in word_reps] 256 | return word_reps 257 | 258 | def _compute_pairwise_dot(self, measure_encodings_b): 259 | ret = [] 260 | sz = len(measure_encodings_b) 261 | for measure_encodings_b_t in measure_encodings_b: 262 | for measure_encodings_b_t2 in measure_encodings_b: 263 | t1 = torch.sum(measure_encodings_b_t * measure_encodings_b_t2) 264 | t2 = torch.sqrt(torch.sum(measure_encodings_b_t * measure_encodings_b_t)) 265 | t3 = torch.sqrt(torch.sum(measure_encodings_b_t2 * measure_encodings_b_t2)) 266 | assert t2 > 0 267 | assert t3 > 0, "t3=" + str(t3) 268 | ret.append(t1 / (t2 * t3)) 269 | ret = torch.stack(ret) 270 | ret = ret.view(sz, sz) 271 | return ret 272 | 273 | def _score_matrix(self, x, deb=False): 274 | x = x[0].unsqueeze(0).unsqueeze(0) # -> 1,1,ms,ms 275 | if deb: 276 | print("---x.shape = ", x.size()) 277 | out = self.layer1(x) 278 | if deb: 279 | print("---out = ", out.size(), out) 280 | out = self.layer2(out) 281 | if deb: 282 | print("---out = ", out.size(), out) 283 | out = out.view(out.size(0), -1) # arrange by bsz 284 | score = self.scorer(out) 285 | if deb: 286 | print("---out sum = ", torch.sum(out)) 287 | print("---score = ", score) 288 | prob = self.predictor(score) 289 | return {'prob': prob, 'out': out, 'score': score} 290 | 291 | def _compute_rhyming_matrix(self, words_str, deb=False): 292 | word_reps = self._compute_word_reps(words_str) 293 | rhyming_matrix = self._compute_pairwise_dot(word_reps) 294 | return rhyming_matrix, words_str 295 | 296 | def _compute_rnn_on_word_reps(self, word_reps): 297 | h = torch.zeros(1, self.linear_rep_H), torch.zeros(1, self.linear_rep_H) 298 | if self.use_cuda: 299 | h = h[0].cuda(), h[1].cuda() 300 | for w in word_reps: 301 | h = self.linear_rep_encoder(w, h) 302 | out, c = h 303 | return c 304 | 305 | def _run_discriminator(self, words_str, deb): 306 | rhyming_matrix, words_str = self._compute_rhyming_matrix(words_str, deb) 307 | vals = self._score_matrix([rhyming_matrix]) 308 | vals.update({'rhyming_matrix': rhyming_matrix, 'linear_rep': None, 'words_str': words_str}) 309 | return vals 310 | 311 | def update_discriminator(self, line_endings_gen, line_endings_train, deb=False, word_idx_to_str_dict=None): 312 | eps = 0.0000000001 313 | ret = {} 314 | dump_info = {} 315 | words_str_train = [word_idx_to_str_dict[word_idx.data.cpu().item()] for word_idx in line_endings_train] 316 | words_str_gen = [word_idx_to_str_dict[word_idx.data.cpu().item()] for word_idx in line_endings_gen] 317 | disc_real = self._run_discriminator(words_str_train, deb) 318 | if deb: 319 | print("rhyming_matrix_trai = ", disc_real['rhyming_matrix'], "|| prob = ", disc_real['prob']) 320 | if self.args.disc_type == DISC_TYPE_MATRIX: 321 | dump_info['rhyming_matrix_trai'] = disc_real['rhyming_matrix'].data.cpu().numpy() 322 | dump_info['real_prob'] = disc_real['prob'].data.cpu().item() 323 | dump_info['real_words_str'] = disc_real['words_str'] 324 | disc_gen = self._run_discriminator(words_str_gen, deb) 325 | if deb: 326 | print("rhyming_matrix_gen = ", disc_gen['rhyming_matrix'], "|| prob = ", disc_gen['prob']) 327 | if self.args.disc_type == DISC_TYPE_MATRIX: 328 | dump_info['rhyming_matrix_gen'] = disc_gen['rhyming_matrix'].data.cpu().numpy() 329 | dump_info['gen_prob'] = disc_gen['prob'].data.cpu().item() 330 | dump_info['gen_words_str'] = disc_gen['words_str'] 331 | prob_real = disc_real['prob'] 332 | prob_gen = disc_gen['prob'] 333 | loss = -torch.log(prob_real + eps) - torch.log(1.0 - prob_gen + eps) 334 | reward = prob_gen 335 | if self.args.use_score_as_reward: 336 | reward = disc_gen['score'] 337 | ret.update({'loss': loss, 'reward': reward, 'dump_info': dump_info}) 338 | return ret 339 | 340 | 341 | 342 | 343 | -------------------------------------------------------------------------------- /code/solvers_pretrain_disc_encoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import json 4 | import pickle 5 | import numpy as np 6 | import os 7 | import codecs 8 | import random 9 | from torch import nn 10 | from torch.distributions import Categorical 11 | from models import * 12 | import utils 13 | from utils import Indexer 14 | 15 | 16 | 17 | ################ 18 | # This Solver is used for pretraining word encoder 19 | 20 | class EndingsSolver: 21 | 22 | def __init__(self, typ="ae", cmu_dict=None, args=None): 23 | self.last2_to_words = None 24 | self.typ = typ 25 | self.cmu_dict = cmu_dict 26 | self.args = args 27 | 28 | def init_model(self, y_start): 29 | 30 | #o_size=self.p_indexer.w_cnt, 31 | model= Model(H=128, 32 | i_size=self.g_indexer.w_cnt, 33 | o_size=self.g_indexer.w_cnt, 34 | start_idx=y_start, 35 | typ=self.args.model_type, 36 | args=self.args) 37 | if self.args.use_cuda: 38 | model = model.cuda() 39 | # batch_size = args.batch_size # 64 40 | self.model = model 41 | self.criterion = nn.CrossEntropyLoss(ignore_index=0) 42 | 43 | def get_splits(self, data_type='cmu_dict_words'): 44 | # self.all_cmu_items = all_items = list(self.cmu_data.items()) 45 | if data_type == 'cmu_dict_words': 46 | all_items = list(self.cmu_dict.items()) 47 | for i in range(11): 48 | random.shuffle(all_items) 49 | sz = len(all_items) 50 | train = all_items[:int(0.8*sz)] 51 | val = all_items[int(0.8*sz):int(0.9*sz)] 52 | test = all_items[int(0.9*sz):] 53 | self.splits = {'train':train,'val':val, 'test':test} 54 | elif data_type == 'sonnet': 55 | self.get_sonnet_splits(data_path=self.args.sonnet_data_path) 56 | self.splits = {} 57 | for k,vals in self.sonnet_splits.items(): 58 | self.splits[k] = [] 59 | for val in vals: 60 | for w in val['ending_words']: 61 | self.splits[k].append([w,None]) 62 | elif data_type == 'limerick': 63 | self.get_limerick_splits(data_path=self.args.limerick_data_path) 64 | self.splits = {} 65 | for k, vals in self.limerick_splits.items(): 66 | self.splits[k] = [] 67 | for val in vals: 68 | for w in val['ending_words']: 69 | self.splits[k].append([w, None]) 70 | else: 71 | assert False 72 | for k,v in self.splits.items(): 73 | print("split ",k," has ", len(v), " items") 74 | print("split ",k," items[0:5] ", v[:5]) 75 | 76 | 77 | def _load_sonnet_data(self, sonnet_data_path, split, skip_not_in_cmu=False): 78 | sonnet_data_file = sonnet_data_path + split + '.txt' 79 | data = open(sonnet_data_file,'r').readlines() 80 | ret = [] 81 | skipped = 0 82 | for sonnet in data: 83 | lines = sonnet.strip().split(' ')[:-1] 84 | last_words = [line.strip().split()[-1] for line in lines] 85 | if not skip_not_in_cmu: 86 | ret.append({'lines':lines, 'ending_words':last_words}) 87 | else: 88 | skip=False 89 | for j,w in enumerate(last_words[:4]): 90 | if w not in self.cmu_dict: 91 | skipped+=1 92 | skip=True 93 | break 94 | if not skip: 95 | ret.append({'lines':lines, 'ending_words':last_words}) 96 | print("[_load_sonnet_data] : split=",split, " skipped ",skipped," out of ", len(data)) 97 | return ret 98 | 99 | def _load_limerick_data(self, limerick_data_path, split): 100 | data_file = os.path.join(limerick_data_path, split + '_0.json') 101 | print("Loading from ", data_file) 102 | data = json.load(open(data_file, 'r')) 103 | ret = [] 104 | skipped = 0 105 | n_tokens = 0 106 | for limerick in data: 107 | # if translator_limerick is not None: 108 | # limerick = limerick['txt'].strip().translate(translator_limerick) 109 | # else: 110 | limerick = limerick['txt'].strip() 111 | lines = limerick.split('|') 112 | if len(lines) != 5: 113 | skipped += 1 114 | continue 115 | last_words = [line.strip().split()[-1] for line in lines] 116 | instance = {} 117 | instance['ending_words'] = last_words 118 | ret.append(instance) 119 | print("[_load_sonnet_data] : split=", split, " skipped ", skipped, " out of ", len(data)) 120 | return ret 121 | 122 | def get_sonnet_splits(self, data_path="../data/sonnet_"): 123 | self.sonnet_splits = {k:self._load_sonnet_data(data_path,k) for k in ['train','valid','test'] } 124 | self.sonnet_splits['val'] = self.sonnet_splits['valid'] 125 | # self.sonnet_splits = {k:self._process_sonnet_data(val) for k,val in self.sonnet_splits.items()} 126 | 127 | def get_limerick_splits(self, data_path="/data/limerick_"): 128 | self.limerick_splits = {k: self._load_limerick_data(data_path, k) for k in ['train', 'val', 'test']} 129 | # self.limerick_splits['val'] = self.limerick_splits['valid'] 130 | # self.sonnet_splits = {k:self._process_sonnet_data(val) for k,val in self.sonnet_splits.items()} 131 | 132 | def index(self): 133 | # why not use the vocab for indexing ? 134 | items = list(self.cmu_dict.items()) #self.all_cmu_items 135 | self.g_indexer = Indexer(self.args) 136 | self.g_indexer.process([i[0] for i in items]) 137 | if self.typ!="ae": 138 | self.p_indexer = Indexer(self.args) 139 | self.p_indexer.process([i[1] for i in items]) 140 | 141 | def save(self, dump_pre): 142 | pickle.dump(self.splits, open(dump_pre+'splits.pkl','wb')) 143 | pickle.dump(self.g_indexer, open(dump_pre+'g_indexer.pkl','wb')) 144 | self.g_indexer.save(dump_pre+'g_indexer') 145 | if self.typ!="ae": 146 | pickle.dump(self.p_indexer, open(dump_pre+'p_indexer.pkl','wb')) 147 | self.p_indexer.save(dump_pre+'p_indexer') 148 | 149 | def load(self, dump_pre): 150 | self.splits = pickle.load( open(dump_pre+'splits.pkl','rb')) 151 | if self.typ!="ae": 152 | self.p_indexer = Indexer(self.args) #pickle.load(open('tmp/tmp_'+args.g2p_model_name+'/solver_p_indexer.pkl','rb')) 153 | self.p_indexer.load(dump_pre+'p_indexer') 154 | self.g_indexer = Indexer(self.args) #pickle.load(open('tmp/tmp_'+args.g2p_model_name+'/solver_g_indexer.pkl','rb')) 155 | self.g_indexer.load(dump_pre+'g_indexer') 156 | 157 | def get_batch(self, i, batch_size, split, last_two=False, add_end_to_y=True): #typ='g2p', 158 | # typ: g2p, ae 159 | # assert typ==self.typ 160 | typ = self.typ 161 | if typ=="g2p": 162 | data = self.splits[split][i*batch_size:(i+1)*batch_size] 163 | x = [self.g_indexer.w_to_idx(g) for g,p in data] 164 | y = [self.p_indexer.w_to_idx(p) for g,p in data] 165 | y_start = self.p_indexer.w2idx[utils.START] 166 | y_end = self.p_indexer.w2idx[utils.END] 167 | if add_end_to_y: 168 | y = [y_i+[y_end] for y_i in y] 169 | x_end = self.g_indexer.w2idx[utils.END] 170 | if self.args.use_eow_in_enc: 171 | x = [x_i+[x_end] for x_i in x] 172 | return x,y,y_start 173 | elif typ=="g2plast": 174 | data = self.splits[split][i*batch_size:(i+1)*batch_size] 175 | x = [self.g_indexer.w_to_idx(g) for g,p in data] 176 | #for g,p in data: 177 | # print(p) 178 | y = [self.p_indexer.w_to_idx(p)[-2:] for g,p in data] 179 | y_start = self.p_indexer.w2idx[utils.START] 180 | y_end = self.p_indexer.w2idx[utils.END] 181 | if add_end_to_y: 182 | y = [y_i+[y_end] for y_i in y] 183 | if self.args.use_eow_in_enc: 184 | x_end = self.g_indexer.w2idx[utils.END] 185 | x = [x_i+[x_end] for x_i in x] 186 | return x,y,y_start 187 | elif typ=="ae": 188 | data = self.splits[split][i*batch_size:(i+1)*batch_size] 189 | x = [self.g_indexer.w_to_idx(g) for g,p in data] 190 | y = [self.g_indexer.w_to_idx(g) for g,p in data] 191 | y_start = self.g_indexer.w2idx[utils.START] 192 | y_end = self.g_indexer.w2idx[utils.END] 193 | if add_end_to_y: 194 | y = [y_i+[y_end] for y_i in y] 195 | if self.args.use_eow_in_enc: 196 | x_end = self.g_indexer.w2idx[utils.END] 197 | x = [x_i+[x_end] for x_i in x] 198 | return x,y,y_start 199 | 200 | def get_num_batches(self, split, batch_size): 201 | return int( ( len(self.splits[split]) + batch_size - 1.0)/batch_size ) 202 | 203 | 204 | ################ 205 | 206 | def get_loss(self, split, batch, batch_size, mode='train', use_alignment=False): 207 | 208 | model = self.model 209 | use_cuda = self.args.use_cuda 210 | batch_size = self.args.batch_size 211 | if mode=="train": 212 | model.train() 213 | else: 214 | model.eval() 215 | len_output = 0 216 | x,y,y_start = self.get_batch(i=batch, batch_size=batch_size, split=split) #, typ=args.data_type) 217 | 218 | batch_loss = torch.tensor(0.0) 219 | batch_align_loss = torch.tensor(0.0) 220 | batch_kl_loss = torch.tensor(0.0) 221 | if use_cuda: 222 | batch_loss = batch_loss.cuda() 223 | batch_align_loss = batch_align_loss.cuda() 224 | batch_kl_loss = batch_kl_loss.cuda() 225 | 226 | i=0 227 | all_e = [] 228 | #print(" -- batch=",batch, " || x: ",len(x)) 229 | for x_i,y_i in zip(x,y): 230 | 231 | len_output += len(y_i) 232 | e_i = model.encode(x_i) 233 | #print(i, e_i.size()) 234 | all_e.append(e_i) 235 | out_all_i, info = model.decode(e_i,y_i) 236 | i+=1 237 | out_all_i = torch.stack(out_all_i) 238 | #dist = out_all_i.view(-1, self.p_indexer.w_cnt) 239 | dist = out_all_i.view(-1, self.g_indexer.w_cnt) 240 | targets = np.array(y_i, dtype=np.long) 241 | targets = torch.from_numpy(targets) 242 | if self.args.use_cuda: 243 | targets = targets.cuda() 244 | cur_loss = self.criterion(dist, targets) 245 | 246 | if use_alignment: 247 | y_i_all = self.p_indexer.idx_to_2(y_i) 248 | y_i_last2 = y_i_all[-3:-1] #solver.p_indexer.idx_to_2(y_i[-3:-1]) # remove end token ans use last2 249 | #x_j_word, success = solver.find_word_with_same_phoneme(y_i_last2) 250 | x_j_word, success = self.find_word_with_same_phoneme(y_i_all) 251 | #print(x_j_word, y_i_last2) 252 | if success: 253 | x_j = self.g_indexer.w_to_idx(x_j_word) 254 | e_j = model.encode(x_j) 255 | cur_align_loss = torch.mean((e_i-e_j)*(e_i-e_j)) 256 | batch_align_loss += cur_align_loss 257 | 258 | if model.typ=="vaed": 259 | batch_kl_loss+= info['kl_loss'] 260 | 261 | #print(cur_loss) 262 | batch_loss += cur_loss 263 | 264 | total_loss = batch_align_loss + batch_loss + batch_kl_loss 265 | return total_loss, len_output, {'batch_align_loss':batch_align_loss.data.cpu().item(), \ 266 | 'total_batch_loss':total_loss.data.cpu().item(),\ 267 | 'batch_recon_loss':batch_loss.data.cpu().item(),\ 268 | 'batch_kl_loss': batch_kl_loss.cpu().item(),\ 269 | 'elbo_loss':(batch_loss.data.cpu().item()+batch_kl_loss.cpu().item()) 270 | } 271 | 272 | 273 | def train(self, epochs=11, debug=False, use_alignment=False, args=None, model_dir=None): 274 | learning_rate = 1e-4 275 | model = self.model 276 | batch_size = self.args.batch_size 277 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 278 | best_loss = 999999999999.0 279 | 280 | for epoch in range(epochs): 281 | 282 | print("="*20, "[Training] beginning of epoch = ", epoch) 283 | 284 | all_loss_tracker = {} 285 | all_loss_tracker_val = {} 286 | 287 | num_batches = self.get_num_batches('train', batch_size) 288 | epoch_loss = 0.0 289 | ctr = 0 290 | for batch in range(num_batches): 291 | total_batch_loss, len_output, info = self.get_loss('train', batch, batch_size, mode='train', 292 | use_alignment=use_alignment) 293 | if debug: 294 | print("TRAIN info = ", info) 295 | print() 296 | for k,v in info.items(): 297 | if k.count('loss')>0: 298 | if k not in all_loss_tracker: 299 | all_loss_tracker[k] = 0.0 300 | all_loss_tracker[k] += v 301 | epoch_loss = epoch_loss + total_batch_loss.data.cpu().item() 302 | model.zero_grad() 303 | optimizer.zero_grad() 304 | total_batch_loss.backward() 305 | optimizer.step() 306 | ctr = ctr + len_output 307 | if batch%1000==0: 308 | print("[Training] batch = ", batch, "epoch_loss = ", epoch_loss/ctr) 309 | if debug: 310 | break 311 | print() 312 | print("[Training] epoch perplexity = ", np.exp(all_loss_tracker['elbo_loss']/ctr) ) 313 | print("[Training] epoch all_loss_tracker (norm. by num_batches) = ", 314 | {k:v/num_batches for k,v in all_loss_tracker.items()} ) 315 | print() 316 | 317 | num_batches = self.get_num_batches('val', batch_size) 318 | epoch_val_loss = 0.0 319 | ctr = 0 320 | for batch in range(num_batches): 321 | total_batch_loss, len_output, info = self.get_loss('val', batch, batch_size, mode='eval', use_alignment=use_alignment) 322 | if debug: 323 | print("VAL info = ", info) 324 | epoch_val_loss = epoch_val_loss + total_batch_loss.data.cpu().item() 325 | for k,v in info.items(): 326 | if k.count('loss')>0: 327 | if k not in all_loss_tracker_val: 328 | all_loss_tracker_val[k] = 0.0 329 | all_loss_tracker_val[k] += v 330 | ctr = ctr + len_output 331 | if debug: 332 | break 333 | print("[Training] epoch VAL epoch_val_loss = ", epoch_val_loss) 334 | if 'elbo_loss' in all_loss_tracker_val: 335 | print("[Training] epoch VAL perplexity = ", np.exp(all_loss_tracker_val['elbo_loss']/ctr) ) 336 | print("\n[Training] epoch all_loss_tracker_val (norm. by num_batches) = ", {k:v/num_batches for k,v in all_loss_tracker_val.items()} ) 337 | if True: #not debug: 338 | print("\n[Training] Saving model at ", model_dir + 'model_' + str(epoch%5) ) 339 | torch.save(model.state_dict(), model_dir + 'model_' + str(epoch%5)) 340 | if (epoch_val_loss/ctr) < best_loss: 341 | best_loss = epoch_val_loss/ctr 342 | torch.save(model.state_dict(), model_dir + 'model_best') 343 | print("\n[Training] Saving best model till now at ", model_dir + 'model_best') 344 | if debug: 345 | break 346 | print() 347 | 348 | ################ 349 | 350 | def analyze(self, vocab): 351 | 352 | ##load 353 | model = self.model 354 | args = self.args 355 | batch_size = args.batch_size 356 | model_dir = 'tmp/tmp_'+args.model_name+'/' 357 | assert os.path.exists(model_dir) 358 | self.load(model_dir + 'solver_') 359 | state_dict_best = torch.load(model_dir+'model_best') 360 | model.load_state_dict(state_dict_best) 361 | self.load(model_dir+'solver_') 362 | 363 | ##utils 364 | def fnc(s1,s2): 365 | #x,y,y_start = [s1,s2],None,y_start 366 | s1 = self.g_indexer.w_to_idx(s1) 367 | s2 = self.g_indexer.w_to_idx(s2) 368 | model.eval() 369 | e1 = model.encode(s1) 370 | e2 = model.encode(s2) 371 | #y = [self.p_indexer.w_to_idx(p) for g,p in data] 372 | e1_numpy = e1.data.cpu().numpy().reshape(-1) 373 | e2_numpy = e2.data.cpu().numpy().reshape(-1) 374 | return np.sum(e1_numpy * e2_numpy)/ np.sqrt( (np.sum(e1_numpy * e1_numpy) * np.sum(e2_numpy * e2_numpy)) ) 375 | 376 | def pred(s1): 377 | model.eval() 378 | e_1 = model.encode( self.g_indexer.w_to_idx(s1) ) 379 | out_all_i,_ = model.decode(e_1,'',use_gt=False) 380 | print(out_all_i) 381 | return self.p_indexer.idx_to_2(out_all_i) 382 | 383 | ## simple analysis 384 | for word_pair in [['red','head'], ['glue','blue'],['red','red'],['apple','blue'],['table','able'],['tram','cram']]: 385 | if args.data_type!="ae": 386 | print("pred: word_pair[0] ", word_pair[0], " => ", pred(word_pair[0]), " || ", \ 387 | "pred: word_pair[1] ", word_pair[1], " => ", pred(word_pair[1]) ) 388 | print("word_pair = ", word_pair, " --fnc(): ", fnc(word_pair[0],word_pair[1])) 389 | print() 390 | 391 | ## compute loss vals 392 | num_batches = self.get_num_batches('test', batch_size) 393 | epoch_val_loss = 0.0 394 | ctr = 0 395 | all_loss_tracker_val = {} 396 | for batch in range(num_batches): 397 | total_batch_loss, len_output, info = self.get_loss('test', batch, batch_size, mode='eval', use_alignment=args.use_alignment) 398 | epoch_val_loss = epoch_val_loss + total_batch_loss.data.cpu().item() 399 | for k,v in info.items(): 400 | if k.count('loss')>0: 401 | if k not in all_loss_tracker_val: 402 | all_loss_tracker_val[k] = 0.0 403 | all_loss_tracker_val[k] += v 404 | ctr = ctr + len_output 405 | print("TEST epoch_val_loss = ", epoch_val_loss) 406 | if 'elbo_loss' in all_loss_tracker_val: 407 | print("TEST perplexity = ", np.exp(all_loss_tracker_val['elbo_loss']/ctr) ) 408 | print("TEST all_loss_tracker_val (norm. by num_batches) = ", {k:v/num_batches for k,v in all_loss_tracker_val.items()} ) 409 | 410 | #dump embs 411 | all_embs = {} 412 | dump_file = model_dir + "all_embs.pkl" 413 | print("dumping to ",dump_file) 414 | vocab[0:5], len(vocab) 415 | model.eval() 416 | for w in vocab: 417 | try: 418 | s1 = self.g_indexer.w_to_idx(w.lower()) 419 | except: 420 | print("error for ",w) 421 | continue 422 | e1 = model.encode(s1) 423 | e1_numpy = e1.data.cpu().numpy().reshape(-1) 424 | all_embs[w] = e1_numpy 425 | #break 426 | pickle.dump(all_embs, open(dump_file,'wb')) 427 | 428 | ################ 429 | 430 | 431 | 432 | -------------------------------------------------------------------------------- /code/models_lm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from overrides import overrides 3 | 4 | import numpy as np 5 | import ipdb as pdb 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | from torch.nn.modules.linear import Linear 12 | from torch.nn.modules.rnn import LSTMCell, LSTM 13 | 14 | from typing import Any, Dict, List, Optional, Tuple 15 | from utils import * 16 | from utils_lm import * 17 | 18 | logger = logging.getLogger(__name__) 19 | #UNKNOWN = '@@UNKNOWN@@' 20 | 21 | class LockedDropout(nn.Module): 22 | 23 | def __init__(self): 24 | super(LockedDropout, self).__init__() 25 | 26 | def forward(self, x, dropout=0.5): 27 | 28 | if not self.training or not dropout: 29 | return x 30 | 31 | m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - dropout) 32 | mask = Variable(m, requires_grad=False) / (1 - dropout) 33 | mask = mask.expand_as(x) 34 | 35 | return mask * x 36 | 37 | class VanillaLM(nn.Module): 38 | 39 | def __init__(self, 40 | args, 41 | vocab_indexer, 42 | vocab, 43 | decoder_hidden_size = 600, 44 | emb_size=128, 45 | num_classes=None, 46 | start_idx=1, 47 | end_idx=2, 48 | padding_idx=0, 49 | typ='lstm', 50 | max_decoding_steps=120, 51 | sampling_scheme: str = "first_word", 52 | line_separator_symbol: str = "", 53 | reverse_each_line: str = False, 54 | n_lines_per_sample: int = 14, 55 | tie_weights: bool = True, 56 | dropout_ratio: float = 0.3, 57 | phoneme_embeddings_dim:int =128, 58 | encoder_type: str = None, 59 | encoder_input_size: int = 100, 60 | encoder_hidden_size: int = 100, 61 | encoder_n_layers: int = 1, 62 | n_lines_to_gen:int = 4): 63 | 64 | super(VanillaLM, self).__init__() 65 | 66 | self.args = args 67 | self.vocab_indexer = vocab_indexer 68 | self.vocab = vocab 69 | 70 | self._scheduled_sampling_ratio = 0.0 71 | 72 | self._max_decoding_steps = max_decoding_steps 73 | decoder_input_size = emb_size 74 | 75 | self._decoder_input_dim = decoder_input_size 76 | self._decoder_output_dim = decoder_hidden_size 77 | 78 | self._target_embedder = nn.Embedding(num_classes, emb_size) 79 | 80 | self._context_embedder = nn.Embedding(num_classes, phoneme_embeddings_dim) ## TODO: Not clear why this is phoneme_embeddings_dim 81 | 82 | self.padding_idx = padding_idx 83 | self.start_idx = start_idx 84 | self.end_idx = end_idx 85 | 86 | self.type = typ 87 | self.use_cuda = args.use_cuda #True 88 | 89 | decoder_embedding_dim = emb_size 90 | self._target_embedding_dim = decoder_embedding_dim 91 | 92 | assert self.type == "lstm", "Incorrect decoder type" 93 | self._lm_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) 94 | 95 | self._intermediate_projection_layer = Linear(self._decoder_output_dim, 96 | self._target_embedding_dim) # , bias=False) 97 | self._activation = torch.tanh 98 | self._num_classes = num_classes 99 | self._output_projection_layer = Linear(self._target_embedding_dim, self._num_classes) 100 | 101 | self._dropout_ratio = dropout_ratio 102 | self._dropout = nn.Dropout(p=dropout_ratio, inplace=False) 103 | self._lockdropout = LockedDropout() 104 | 105 | self._encoder_type = encoder_type 106 | 107 | if self._encoder_type is not None: 108 | self._encoder_input_size = encoder_input_size 109 | self._encoder_hidden_size = encoder_hidden_size 110 | self._encoder_namespace = encoder_namespace 111 | self._encoder = nn.LSTM(input_size=self._encoder_input_size, hidden_size=self._encoder_hidden_size, 112 | batch_first=True, bias=False, num_layers=encoder_n_layers, bidirectional=False) 113 | 114 | if tie_weights: 115 | # assert self._target_embedding_dim == self._target_embedder.token_embedder_tokens.get_output_dim(), "Dimension mis-match!" 116 | self._output_projection_layer.weight = self._target_embedder.weight 117 | 118 | # in the config, make these options consistent with those in the reader 119 | self._sampling_scheme = sampling_scheme # "first_sentence" # "first_word" 120 | self.line_separator = line_separator_symbol 121 | self.reverse_each_line = reverse_each_line 122 | self.n_lines_per_sample = n_lines_per_sample 123 | 124 | self._n_lines_to_gen = n_lines_to_gen 125 | 126 | self._attention = False 127 | self.END_SYMBOL = line_separator_symbol 128 | 129 | # self._sonnet_eval = SonnetMeasures() 130 | 131 | def freeze_emb(self, type="_target_embedder"): 132 | 133 | if type == "_target_embedder": 134 | self._target_embedder.weight.requires_grad = False 135 | print("[VanillaLM-Model] -- FREEZING target embedder") 136 | 137 | def display_params(self): 138 | 139 | print("[VanillaLM-Model]: model parameters") 140 | for name, param in self.named_parameters(): 141 | print("name=", name, " || grad:", param.requires_grad) 142 | 143 | def load_gutenberg_we(self, dict_pickle_pth, indexer_idx2w): 144 | sz = len(indexer_idx2w) 145 | emb = torch.FloatTensor(sz, self._target_embedding_dim) 146 | # self.weight = torch.nn.Parameter(weight, requires_grad=trainable) 147 | torch.nn.init.xavier_uniform_(emb) 148 | # emb = np.random.rand(sz, self._target_embedding_dim) 149 | print(emb.shape) 150 | 151 | pretrained = {} #pickle.load(dict_pickle_pth,'rb') 152 | lines = open(dict_pickle_pth,'r').readlines() 153 | for line in lines: 154 | line = line.strip().split() 155 | w = line[0] 156 | e = np.array([float(val) for val in line[1:]]) 157 | pretrained[w] = e 158 | found = 0 159 | 160 | for i in range(sz): 161 | word = indexer_idx2w[i] 162 | if word == PAD: 163 | emb[0].fill_(0) 164 | elif word in pretrained: 165 | emb[i,:] = torch.from_numpy(pretrained[word]) 166 | found+=1 167 | 168 | print("load_gutenberg_we: found",found, " out of ", sz) 169 | # self._target_embedder.weight.data.copy_(torch.from_numpy(emb)) 170 | self._target_embedder.weight.data.copy_(emb) 171 | 172 | def initialize_hidden(self, batch_size, hidden_dim=None, n_dim=2): 173 | 174 | if hidden_dim is None: 175 | hidden_dim = self._decoder_output_dim 176 | 177 | if n_dim == 2: 178 | hidden_state = torch.zeros(batch_size, hidden_dim) 179 | cell_state = torch.zeros(batch_size, hidden_dim) 180 | if torch.cuda.is_available(): 181 | hidden_state = hidden_state.cuda() 182 | cell_state = cell_state.cuda() 183 | else: 184 | hidden_state = torch.zeros(1, batch_size, hidden_dim) 185 | cell_state = torch.zeros(1, batch_size, hidden_dim) 186 | if torch.cuda.is_available(): 187 | hidden_state = hidden_state.cuda() 188 | cell_state = cell_state.cuda() 189 | 190 | return (hidden_state, cell_state) 191 | 192 | def _sample(self, start_token_idx=None, max_decoding_steps=None, 193 | ending_words = None, decoder_hidden=None, 194 | decoder_context=None, conditional=True, debug=False): 195 | 196 | if max_decoding_steps is None: 197 | max_decoding_steps = self._max_decoding_steps 198 | 199 | if decoder_hidden is None: 200 | decoder_hidden, decoder_context = self.initialize_hidden(batch_size=1) 201 | 202 | assert not (conditional and ending_words is None), "--exception-- | conditional is set to True and ending_words are not provided!" 203 | 204 | 205 | # if conditional: 206 | # if not isinstance(ending_words, torch.Tensor): 207 | # inp = torch.LongTensor(ending_words)[0] 208 | # else: 209 | # inp = ending_words[0] 210 | # else: 211 | if start_token_idx is None: 212 | start_token_idx = self.vocab_indexer.w2idx[self.line_separator] 213 | 214 | if not isinstance(start_token_idx, torch.Tensor): 215 | inp = torch.LongTensor([start_token_idx]) # Token(START_SYMBOL) 216 | else: 217 | inp = start_token_idx 218 | 219 | if ending_words is not None: 220 | n_lines_to_gen = len(ending_words) 221 | if isinstance(ending_words, list): 222 | ending_words = torch.LongTensor(ending_words) 223 | if self.args.use_cuda: 224 | ending_words = ending_words.cuda() 225 | else: 226 | n_lines_to_gen = self._n_lines_to_gen 227 | 228 | if self.args.use_cuda: 229 | inp = inp.cuda() 230 | 231 | # if self._augment_phoneme_embeddings: 232 | # emb_t = torch.cat((self._target_embedder(inp).view(-1), self._context_embedder(inp).view(-1)), dim=-1) ### this should be target embedded and context embedded 233 | # else: 234 | emb_t = self._target_embedder(inp).view(-1) ### this should be target embedded 235 | 236 | logprobs = [] 237 | actions = [] 238 | actions_idx = [] 239 | 240 | logprobs_line = [] 241 | actions_line = [] 242 | actions_idx_line = [] 243 | 244 | i2v = self.vocab 245 | 246 | count_lines_gen = 0 247 | 248 | prev_action = self.line_separator 249 | 250 | for t in range(max_decoding_steps): 251 | 252 | decoder_hidden, decoder_context = self._lm_cell(emb_t.unsqueeze(0), (decoder_hidden, decoder_context)) 253 | 254 | # output = self._lockdropout(x=decoder_hidden.unsqueeze(1), dropout=self._dropout_ratio) 255 | output = decoder_hidden.unsqueeze(1) 256 | 257 | pre_decoded_output = self._intermediate_projection_layer(output.view(1, -1)) 258 | decoded_output = self._output_projection_layer(pre_decoded_output) 259 | logits = decoded_output.view(1, decoded_output.size(1)) 260 | 261 | logprobs_line.append(F.log_softmax(logits, dim=-1)) 262 | class_probabilities = F.softmax(logits, dim=-1) 263 | 264 | predicted_action = UNKNOWN 265 | while predicted_action == UNKNOWN: 266 | predicted_action_idx = torch.multinomial(class_probabilities, 1) 267 | predicted_action = i2v[predicted_action_idx.data.item()] 268 | 269 | if prev_action == self.line_separator and conditional: 270 | predicted_action_idx = ending_words[count_lines_gen] 271 | predicted_action = i2v[ending_words[count_lines_gen].data.item()] 272 | prev_action = predicted_action 273 | else: 274 | predicted_action_idx = predicted_action_idx[0] 275 | prev_action = predicted_action 276 | 277 | actions_line.append(predicted_action) 278 | actions_idx_line.append(predicted_action_idx) 279 | 280 | inp = predicted_action_idx 281 | 282 | # if self._augment_phoneme_embeddings: 283 | # emb_t = torch.cat((self._target_embedder(inp).view(-1), self._context_embedder(inp).view(-1)), dim=-1) 284 | # else: 285 | emb_t = self._target_embedder(inp).view(-1) 286 | 287 | # all_predictions_indices.append(last_predictions) 288 | # last_predictions_str = i2v[last_predictions.data.item()] 289 | 290 | if predicted_action == self.line_separator: 291 | actions.append(actions_line[:-1]) 292 | logprobs.append(logprobs_line[:-1]) 293 | actions_idx.append(actions_idx_line[:-1]) 294 | 295 | actions_line = [] 296 | logprobs_line = [] 297 | actions_idx_line = [] 298 | 299 | count_lines_gen += 1 300 | 301 | if predicted_action == self.line_separator and count_lines_gen == n_lines_to_gen: 302 | break 303 | # all_predictions.append(last_predictions_str) 304 | 305 | if self.args.data_type == "limerick": 306 | sampled_str = ' | '.join([' '.join(actions_line) for actions_line in actions]) 307 | else: 308 | sampled_str = ' '.join([' '.join(actions_line) for actions_line in actions]) 309 | 310 | if debug: 311 | print("[Sample]: sampled_str= ", sampled_str) 312 | 313 | assert len(logprobs) == len(actions_idx), "Length mis-match for logprobs and actions_idx!" 314 | return {'logprobs': logprobs, 'actions': actions_idx} 315 | 316 | @overrides 317 | def forward(self, # type: ignore 318 | source_tokens: [str, torch.LongTensor], 319 | target_tokens: [str, torch.LongTensor] = None, 320 | ending_words: [str, torch.LongTensor] = None, 321 | batch_idx: int = None, 322 | decoder_hidden: torch.FloatTensor = None, 323 | decoder_context: torch.FloatTensor = None, 324 | ending_words_mask: [str, torch.Tensor] = None, 325 | hier_mode: bool = False) -> Dict[str, torch.Tensor]: 326 | 327 | # pylint: disable=arguments-differ 328 | """ 329 | Decoder logic for producing the entire target sequence. 330 | 331 | Parameters 332 | ---------- 333 | source_tokens : Dict[str, torch.LongTensor] 334 | The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be 335 | passed through a ``TextFieldEmbedder`` and then through an encoder. 336 | target_tokens : Dict[str, torch.LongTensor], optional (default = None) 337 | Output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the 338 | target tokens are also represented as a ``TextField``. 339 | """ 340 | # pdb.set_trace() 341 | # source_mask = utils.get_text_field_mask(source_tokens) 342 | 343 | source_mask = get_text_field_mask(source_tokens) 344 | 345 | embedded_input = self._target_embedder(source_tokens["tokens"]) 346 | 347 | targets = target_tokens["tokens"] 348 | # target_mask = util.get_text_field_mask(target_tokens) 349 | target_mask = get_text_field_mask(target_tokens) 350 | 351 | batch_size, time_steps, _ = embedded_input.size() 352 | 353 | embedded_ending_words = self._context_embedder(target_tokens["tokens"]) 354 | 355 | # pdb.set_trace() 356 | # Apply dropout to embeddings 357 | # embedded_input = self._dropout(embedded_input) 358 | embedded_input = self._lockdropout(x=embedded_input, dropout=self._dropout_ratio) 359 | 360 | if self._sampling_scheme == "first_word": 361 | if ((not self.training) and np.random.rand() < 0.05): 362 | 363 | i2v = self.vocab 364 | 365 | ending_words_idx = ending_words["tokens"][0].data.cpu().numpy() 366 | ending_words_str = '|'.join([i2v[int(word)] for word in ending_words_idx]) 367 | print("Ending word sequence : ", ending_words_str) 368 | 369 | start_token_idx = source_tokens["tokens"][0][0] 370 | self._sample(start_token_idx=start_token_idx, ending_words=ending_words["tokens"][0], debug=True) 371 | 372 | if decoder_hidden is None: 373 | (decoder_hidden, decoder_context) = self.initialize_hidden(batch_size=batch_size) 374 | 375 | hiddens = [] 376 | contexts = [] 377 | p_gens = [] 378 | 379 | for t, emb_t in enumerate(embedded_input.chunk(time_steps, dim=1)): 380 | 381 | decoder_hidden, decoder_context = self._lm_cell(emb_t.squeeze(1), (decoder_hidden, decoder_context)) 382 | 383 | hiddens.append(decoder_hidden.unsqueeze(1)) 384 | contexts.append(decoder_context.unsqueeze(1)) 385 | 386 | hidden = torch.cat(hiddens, 1) 387 | # context = torch.cat(contexts, 1) 388 | output = self._lockdropout(x=hidden, dropout=self._dropout_ratio) 389 | 390 | batch_size = output.size(0) 391 | seq_len = output.size(1) 392 | hidden_dim = output.size(2) 393 | 394 | pre_decoded_output = self._intermediate_projection_layer(output.view(batch_size * seq_len, hidden_dim)) 395 | decoded_output = self._output_projection_layer(pre_decoded_output) 396 | logits = decoded_output.view(batch_size, seq_len, decoded_output.size(1)) 397 | 398 | 399 | class_probabilities = F.softmax(logits, dim=-1) 400 | _, predicted_classes = torch.max(class_probabilities, dim=-1) 401 | 402 | output_dict = {"logits": logits, 403 | "class_probabilities": class_probabilities, 404 | "predictions": predicted_classes} 405 | 406 | # This code block masks all line endings (ending words) 407 | if not self.training and hier_mode == True: 408 | # pdb.set_trace() 409 | tmp_mask = (1 - ending_words_mask["tokens"]) 410 | target_mask = target_mask.long() * tmp_mask 411 | 412 | loss = self._get_loss_custom(logits, targets, target_mask, training=self.training) 413 | 414 | output_dict["loss"] = loss 415 | 416 | target_mask = get_text_field_mask(target_tokens) 417 | source_sentence_lengths = get_lengths_from_binary_sequence_mask(mask=source_mask) 418 | target_sentence_lengths = get_lengths_from_binary_sequence_mask(mask=target_mask) 419 | 420 | output_dict["source_sentence_lengths"] = source_sentence_lengths 421 | output_dict["target_sentence_lengths"] = target_sentence_lengths 422 | 423 | # if self.training: 424 | decoder_hidden = [] 425 | decoder_context = [] 426 | 427 | for idx, length in enumerate(source_sentence_lengths): 428 | assert source_sentence_lengths[idx] == target_sentence_lengths[idx], "Mis-match!" 429 | decoder_hidden.append(hiddens[length - 1][idx].squeeze(0)) 430 | decoder_context.append(contexts[length - 1][idx].squeeze(0)) 431 | 432 | output_dict["decoder_hidden"] = decoder_hidden 433 | output_dict["decoder_context"] = decoder_context 434 | 435 | return output_dict 436 | 437 | def get_context(self, batch_size, ending_words): 438 | assert False, "No longer used - marked for removal" 439 | # if self._encoder_type is not None: 440 | # ending_words_embedded = self._context_embedder(ending_words) 441 | # encoder_hidden, encoder_context = self.initialize_hidden(batch_size=batch_size, 442 | # hidden_dim=self._encoder_hidden_size, n_dim=3) 443 | # _, (embedded_ending_words, _) = self._encoder(ending_words_embedded, (encoder_hidden, encoder_context)) 444 | # embedded_ending_words = embedded_ending_words.squeeze(0) 445 | # elif self._context_embedder is not None: 446 | # embedded_ending_words = torch.sum(self._context_embedder(ending_words), dim=1) 447 | # else: 448 | # embedded_ending_words = torch.sum(self._target_embedder(ending_words), dim=1) 449 | 450 | # return embedded_ending_words 451 | 452 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 453 | metrics = {} 454 | if not self.training: 455 | metrics.update(self._sonnet_eval.get_metric(reset)) 456 | 457 | return metrics 458 | 459 | @staticmethod 460 | def _get_loss_custom(logits: torch.LongTensor, 461 | targets: torch.LongTensor, 462 | target_mask: torch.LongTensor, 463 | training: bool = True) -> torch.LongTensor: 464 | """ 465 | As opposed to get_loss, logits and targets are of same size 466 | """ 467 | relevant_targets = targets.contiguous() # (batch_size, num_decoding_steps) 468 | relevant_mask = target_mask.contiguous() # (batch_size, num_decoding_steps) 469 | # loss = util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) 470 | 471 | if training: 472 | loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) 473 | else: 474 | loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask, 475 | average=None) 476 | 477 | return loss 478 | 479 | 480 | -------------------------------------------------------------------------------- /code/solvers_merged.py: -------------------------------------------------------------------------------- 1 | import time 2 | import string 3 | import torch 4 | import json 5 | import pickle 6 | import numpy as np 7 | import ipdb as pdb 8 | import os 9 | import codecs 10 | import random 11 | from torch import nn 12 | from torch.distributions import Categorical 13 | from models_lm import * 14 | from models import * 15 | import utils 16 | from utils import * 17 | from utils_lm import * 18 | from allennlp.common.tqdm import Tqdm 19 | import signal 20 | from contextlib import contextmanager 21 | from typing import Tuple 22 | from constants import * 23 | from rhyming_eval import RhymingEval 24 | 25 | 26 | translator = None 27 | translator_limerick = None 28 | # translator = str.maketrans('', '', string.punctuation.replace('<','').replace('>','')) 29 | # translator_limerick = str.maketrans('', '', string.punctuation.replace('|','')) 30 | _couplet_mode = True # Setting it to True uses vanilla-lm for couplets 31 | 32 | def load_sonnet_vocab(pth): 33 | data = [r.strip() for r in open(pth,"r").readlines()] 34 | return data 35 | 36 | class TimeoutException(Exception): pass 37 | 38 | @contextmanager 39 | def time_limit(seconds): 40 | def signal_handler(signum, frame): 41 | raise TimeoutException("Timed out!") 42 | signal.signal(signal.SIGALRM, signal_handler) 43 | signal.alarm(seconds) 44 | try: 45 | yield 46 | finally: 47 | signal.alarm(0) 48 | 49 | 50 | 51 | ################ 52 | 53 | class MainSolver: 54 | 55 | def __init__(self, 56 | typ="g2p", 57 | cmu_dict=None, 58 | args=None, 59 | vocab_type="only_ending_words", 60 | mode='train'): 61 | 62 | self.last2_to_words = None 63 | self.typ = typ 64 | self.cmu_dict = cmu_dict 65 | self.args = args 66 | self.sonnet_splits = None 67 | self.limerick_splits = None 68 | self.vocab_type = vocab_type 69 | 70 | self.rhyming_eval = RhymingEval(dataset_identifier = args.data_type) #[SONNET_DATASET_IDENTIFIER, LIMERICK_DATASET_IDENTIFIER] 71 | if args.data_type == SONNET_ENDINGS_DATASET_IDENTIFIER: 72 | line_endings_file_location = args.all_line_endings_sonnet 73 | line_endings_file_location_test = args.all_line_endings_sonnet_test 74 | elif args.data_type == LIMERICK_DATASET_IDENTIFIER: 75 | line_endings_file_location = args.all_line_endings_limerick 76 | line_endings_file_location_test = args.all_line_endings_limerick_test 77 | self.rhyming_eval.setup(line_endings_file_location, line_endings_file_location_test) 78 | print(" ----->> self.rhyming_eval: ", self.rhyming_eval) 79 | 80 | if args.data_type == LIMERICK_DATASET_IDENTIFIER: 81 | if os.path.exists(args.limerick_vocab_path): 82 | self.limerick_vocab = load_sonnet_vocab(args.limerick_vocab_path) 83 | print("No. of tokens loaded from existing vocab : ", len(self.limerick_vocab)) 84 | else: 85 | self.limerick_vocab = None 86 | else: # sonnet_endings 87 | if os.path.exists(args.sonnet_vocab_path): 88 | self.sonnet_vocab = load_sonnet_vocab(args.sonnet_vocab_path) 89 | print("No. of tokens loaded from existing vocab : ", len(self.sonnet_vocab)) 90 | else: 91 | self.sonnet_vocab = None 92 | 93 | self.get_splits() 94 | self.index() 95 | 96 | self.num_lines = NUM_LINES_QUATRAIN 97 | if args.data_type == LIMERICK_DATASET_IDENTIFIER: 98 | self.num_lines = NUM_LINES_LIMERICK 99 | 100 | x_start = self.g_indexer.w2idx[utils.START] 101 | x_end = self.g_indexer.w2idx[utils.END] 102 | x_unk=-999 103 | if utils.UNKNOWN in self.g_indexer.w2idx: 104 | x_unk = self.g_indexer.w2idx[utils.UNKNOWN] 105 | 106 | #### Vanilla LM model initialized 107 | num_classes = len(self.g_indexer.w2idx) 108 | 109 | if args.data_type == LIMERICK_DATASET_IDENTIFIER: 110 | self._trun_bptt_mode = True 111 | self._validation_bptt_mode = False 112 | line_separator_symbol = "|" 113 | n_lines_per_sample = NUM_LINES_LIMERICK 114 | n_lines_to_gen = NUM_LINES_LIMERICK 115 | else: 116 | self._trun_bptt_mode = True 117 | self._validation_bptt_mode = False 118 | # self._couplet_mode = True # Setting it to True uses vanilla-lm for couplets 119 | line_separator_symbol = "" 120 | n_lines_per_sample = NUM_LINES_SONNET 121 | n_lines_to_gen = NUM_LINES_QUATRAIN 122 | 123 | self.lm_model = VanillaLM(args, 124 | decoder_hidden_size=600, 125 | vocab_indexer=self.g_indexer, 126 | vocab=self.g_indexer.idx2w, 127 | emb_size=100, 128 | num_classes=num_classes, 129 | start_idx=1, 130 | end_idx=2, 131 | padding_idx=0, 132 | typ='lstm', 133 | max_decoding_steps=120, 134 | sampling_scheme="first_word", 135 | line_separator_symbol=line_separator_symbol, 136 | reverse_each_line=True, 137 | n_lines_per_sample=n_lines_per_sample, 138 | tie_weights=True, 139 | dropout_ratio=0.3, 140 | phoneme_embeddings_dim=128, 141 | encoder_type=None, 142 | encoder_input_size=100, 143 | encoder_hidden_size=100, 144 | encoder_n_layers=1, 145 | n_lines_to_gen=n_lines_to_gen) 146 | 147 | self.lm_model.display_params() 148 | 149 | if args.load_gutenberg: 150 | self.lm_model.load_gutenberg_we(args.load_gutenberg_path, self.g_indexer.idx2w) 151 | 152 | if args.use_cuda: 153 | print("---") 154 | self.lm_model = self.lm_model.cuda() 155 | 156 | if mode not in ["train_lm"]: 157 | 158 | #### Generator Model initialized 159 | H = args.H #128 160 | emsize = args.emsize #128 161 | self.model = GeneratorModel(H=H, emsize=args.emsize, o_size=self.g_indexer.w_cnt, start_idx=x_start, end_idx=x_end, unk_idx=x_unk, typ=args.model_type, args=args) 162 | 163 | #### Rhyming discriminator initialized 164 | disc_info = {'i_size': None, 'emsize': emsize, 'H': H, 'solver_model_name': None} 165 | self.disc = CNN(args, info=disc_info) 166 | 167 | if args.use_cuda: 168 | print("---") 169 | # self.lm_model = self.lm_model.cuda() 170 | self.model = self.model.cuda() 171 | self.disc = self.disc.cuda() 172 | # self.disc_syllable = self.disc_syllable.cuda() 173 | # self.disc_syllable_line = self.disc_syllable_line.cuda() 174 | 175 | if args.load_gutenberg: 176 | self.model.load_gutenberg_we(args.load_gutenberg_path, self.g_indexer.idx2w) 177 | 178 | if args.freeze_emb: 179 | self.model.freeze_emb() 180 | 181 | self.model.display_params() 182 | self.disc.display_params() 183 | 184 | 185 | def get_splits(self): 186 | # self.all_cmu_items = all_items = list(self.cmu_data.items()) 187 | all_items = list(self.cmu_dict.items()) 188 | if self.typ == SONNET_ENDINGS_DATASET_IDENTIFIER: 189 | self.get_sonnet_splits(self.args.sonnet_data_path) 190 | self.splits = self.sonnet_splits 191 | #@NEW 192 | elif self.typ == LIMERICK_DATASET_IDENTIFIER: 193 | self.get_limerick_splits(self.args.limerick_data_path) 194 | self.splits = self.limerick_splits 195 | else: 196 | assert False 197 | for i in range(11): 198 | random.shuffle(all_items) 199 | sz = len(all_items) 200 | train = all_items[:int(0.8 * sz)] 201 | val = all_items[int(0.8 * sz):int(0.9 * sz)] 202 | test = all_items[int(0.9 * sz):] 203 | self.splits = {'train': train, 'val': val, 'test': test} 204 | 205 | 206 | def _split_into_lines(self, txt, split_sym): 207 | return [line.strip() for line in txt.split(split_sym) if len(line.strip()) > 0] 208 | 209 | 210 | def text_to_instance(self, instance_string, 211 | line_separator_symbol="", 212 | trun_bptt_mode=True, 213 | reverse_each_line=True, 214 | n_lines_per_sample=14): 215 | # pylint: disable=arguments-differ 216 | fields = dict() 217 | metadata = {'instance_str': instance_string} 218 | reverse_target_string = ' '.join(reversed(instance_string.split())) 219 | tokenized_target = reverse_target_string.split() 220 | tokenized_target.append(line_separator_symbol) 221 | source_tokens_lm = tokenized_target[:-1] 222 | target_tokens_lm = tokenized_target[1:] 223 | fields['source_tokens_lm'] = source_tokens_lm 224 | fields['target_tokens_lm'] = target_tokens_lm 225 | ending_words_mask = [1 if token == line_separator_symbol else 0 for token in tokenized_target[:-1]] 226 | fields['ending_words_mask'] = np.array(ending_words_mask) 227 | lines = self._split_into_lines(txt=instance_string, split_sym=line_separator_symbol) 228 | ending_words = [line.split()[-1] for line in lines] 229 | if trun_bptt_mode: 230 | lines = self._split_into_lines(txt=instance_string, split_sym=line_separator_symbol) 231 | if reverse_each_line: 232 | lines.reverse() 233 | ending_words.reverse() 234 | 235 | for i in range(int(n_lines_per_sample/2)): 236 | if reverse_each_line: 237 | line1 = ' '.join(reversed(lines[i * 2].split())) + " " + line_separator_symbol 238 | line2 = ' '.join(reversed(lines[i * 2 + 1].split())) + " " + line_separator_symbol 239 | else: 240 | line1 = lines[i*2] + " " + line_separator_symbol 241 | line2 = lines[i * 2 + 1] + " " + line_separator_symbol 242 | line = line1 + " " + line2 243 | tokenized_target = line.split() 244 | if reverse_each_line: 245 | tokenized_target.insert(0, line_separator_symbol) 246 | source_tokens_i = tokenized_target[:-1] 247 | target_tokens_i = tokenized_target[1:] 248 | if i == 0: 249 | ending_words_i = ending_words[0: 2] 250 | elif i % 2 == 1: 251 | ending_words_i = ending_words[i * 2: (i + 2) * 2] 252 | else: 253 | ending_words_i = ending_words[(i - 1) * 2: (i + 1) * 2] 254 | ending_words_i_textfield = ending_words_i 255 | fields['source_tokens_'+str(i)] = source_tokens_i 256 | fields['target_tokens_'+str(i)] = target_tokens_i 257 | fields['ending_words_'+str(i)] = ending_words_i_textfield 258 | 259 | ending_words_mask_i = [1 if token == line_separator_symbol else 0 for token in tokenized_target[:-1]] 260 | 261 | # In case of couplet mode, unmask first 2 ending words (remember we have reversed lines) 262 | if _couplet_mode: 263 | n_words_to_unmask = 2 264 | for z in range(len(ending_words_mask_i)): 265 | 266 | if ending_words_mask_i[z] == 1: 267 | ending_words_mask_i[z] = 0 268 | n_words_to_unmask -= 1 269 | 270 | if n_words_to_unmask == 0: 271 | break 272 | 273 | fields['ending_words_mask_'+str(i)] = np.array(ending_words_mask_i) 274 | 275 | reversed_lines = [] 276 | for line in lines: 277 | reversed_line = line.split() 278 | reversed_line.reverse() 279 | reversed_lines.append(reversed_line) 280 | reversed_quatrains = [] 281 | reversed_quatrains.append(reversed_lines[2:6]) 282 | reversed_quatrains.append(reversed_lines[6:10]) 283 | reversed_quatrains.append(reversed_lines[10:14]) 284 | fields['reversed_quatrains'] = reversed_quatrains 285 | fields['ending_words_reversed'] = ending_words 286 | fields['metadata'] = metadata 287 | return fields 288 | 289 | 290 | def _load_sonnet_data(self, sonnet_data_path, split): 291 | sonnet_data_file = sonnet_data_path + split + '.txt' 292 | data = open(sonnet_data_file, 'r').readlines() 293 | ret = [] 294 | skipped = 0 295 | n_tokens = 0 296 | n_endings = 0 297 | for sonnet in data: 298 | if translator is not None: 299 | instance_string = sonnet.strip().translate(translator).replace('‘','') 300 | else: 301 | instance_string = sonnet.strip().replace('‘', '') 302 | instance = self.text_to_instance(instance_string=instance_string) 303 | lines = instance_string.split(' ')[:-1] 304 | last_words = [line.strip().split()[-1] for line in lines] 305 | instance['lines'] = lines 306 | instance['ending_words'] = last_words 307 | instance['skip'] = False 308 | n_tokens += len(instance['target_tokens_lm']) 309 | n_endings += len(last_words) 310 | for j, w in enumerate(last_words[:4]): 311 | if w not in self.cmu_dict: 312 | skipped += 1 313 | instance['skip'] = True 314 | break 315 | # ret.append({'lines': lines, 'ending_words': last_words}) 316 | ret.append(instance) 317 | print("[_load_sonnet_data] : split=", split, " skipped ", skipped, " out of ", len(data)) 318 | print("No. of tokens in sonnet split :", split, " : ", n_tokens) 319 | print("No. of endings in sonnet split :", split, " : ", n_endings) 320 | 321 | return ret 322 | 323 | def get_sonnet_splits(self, data_path="../data/sonnet_"): 324 | if self.sonnet_splits is None: 325 | self.sonnet_splits = {k: self._load_sonnet_data(data_path, k) for k in ['train', 'valid', 'test']} 326 | self.sonnet_splits['val'] = self.sonnet_splits['valid'] 327 | # self.sonnet_splits = {k:self._process_sonnet_data(val) for k,val in self.sonnet_splits.items()} 328 | 329 | 330 | def text_to_instance_limerick(self, instance_string, 331 | line_separator_symbol="|", 332 | trun_bptt_mode=True, 333 | reverse_each_line=True, 334 | n_lines_per_sample=5): 335 | # pylint: disable=arguments-differ 336 | fields = dict() 337 | 338 | metadata = {'instance_str': instance_string} 339 | 340 | lines = self._split_into_lines(txt=instance_string, split_sym=line_separator_symbol) 341 | ending_words = [line.split()[-1] for line in lines] 342 | 343 | if trun_bptt_mode: 344 | # lines = self._split_into_lines(txt=target_string, split_sym=self.line_separator_symbol) 345 | if reverse_each_line: 346 | lines.reverse() 347 | ending_words.reverse() 348 | line = '' 349 | for i in range(n_lines_per_sample): 350 | if reverse_each_line: 351 | line1 = ' '.join(reversed(lines[i].split())) + " " + line_separator_symbol 352 | else: 353 | line1 = lines[i] + " " + line_separator_symbol 354 | if i != (n_lines_per_sample - 1): 355 | line = line + line1 + " " 356 | else: 357 | line = line + line1 358 | tokenized_target = line.split() 359 | if reverse_each_line: 360 | tokenized_target.insert(0, line_separator_symbol) 361 | source_tokens_i = tokenized_target[:-1] 362 | target_tokens_i = tokenized_target[1:] 363 | fields['source_tokens_lm'] = source_tokens_i 364 | fields['target_tokens_lm'] = target_tokens_i 365 | ending_words_mask = [1 if token == line_separator_symbol else 0 for token in tokenized_target[:-1]] 366 | fields['ending_words_mask'] = np.array(ending_words_mask) 367 | reversed_lines = [] 368 | for line in lines: 369 | reversed_line = line.split() 370 | reversed_line.reverse() 371 | reversed_lines.append(reversed_line) 372 | fields['reversed_quatrains'] = reversed_lines 373 | fields['ending_words_reversed'] = ending_words 374 | fields['metadata'] = metadata 375 | return fields 376 | 377 | 378 | def _load_limerick_data(self, limerick_data_path, split): 379 | data_file = os.path.join(limerick_data_path, split + '_0.json') 380 | print("Loading from ", data_file) 381 | data = json.load(open(data_file, 'r')) 382 | ret = [] 383 | skipped = 0 384 | n_tokens = 0 385 | for limerick in data: 386 | if translator_limerick is not None: 387 | limerick = limerick['txt'].strip().translate(translator_limerick) 388 | else: 389 | limerick = limerick['txt'].strip() 390 | lines = limerick.split('|') 391 | if len(lines) != 5: 392 | skipped += 1 393 | continue 394 | instance = self.text_to_instance_limerick(instance_string=limerick) 395 | # pdb.set_trace() 396 | last_words = [line.strip().split()[-1] for line in lines] 397 | instance['lines'] = lines 398 | instance['ending_words'] = last_words 399 | instance['skip'] = False 400 | n_tokens += len(instance['target_tokens_lm']) 401 | ret.append(instance) 402 | print("[_load_sonnet_data] : split=", split, " skipped ", skipped, " out of ", len(data)) 403 | print("No. of tokens in limerick split ", split, " : ", n_tokens) 404 | return ret 405 | 406 | 407 | def get_limerick_splits(self, data_path="../data/limerick_only_subset/"): 408 | if self.limerick_splits is None: 409 | self.limerick_splits = {k:self._load_limerick_data(data_path, k) for k in ['train','val','test'] } 410 | 411 | 412 | def index(self): # , typ='g2p'): 413 | if False: # typ=="g2p": 414 | self.g_indexer = Indexer(self.args) 415 | items = list(self.cmu_dict.items()) # self.all_cmu_items 416 | self.g_indexer.process([i[0] for i in items]) 417 | elif self.typ=="limerick": 418 | self.get_limerick_splits() 419 | self.g_indexer = Indexer(self.args) 420 | if self.limerick_vocab is not None: 421 | self.g_indexer.process([self.limerick_vocab]) 422 | elif self.vocab_type == "only_ending_words": 423 | for split, items in self.limerick_splits.items(): 424 | self.g_indexer.process([i['ending_words'] for i in items]) 425 | else: 426 | for split, items in self.sonnet_splits.items(): 427 | self.g_indexer.process([i['source_tokens_lm'] for i in items]) 428 | else: 429 | self.get_sonnet_splits() 430 | self.g_indexer = Indexer(self.args) 431 | if self.sonnet_vocab is not None: 432 | self.g_indexer.process([self.sonnet_vocab]) 433 | elif self.vocab_type == "only_ending_words": 434 | for split, items in self.sonnet_splits.items(): 435 | self.g_indexer.process([i['ending_words'] for i in items]) 436 | else: 437 | for split, items in self.sonnet_splits.items(): 438 | self.g_indexer.process([i['source_tokens_lm'] for i in items]) 439 | 440 | 441 | def save(self, dump_pre): 442 | pickle.dump(self.splits, open(dump_pre + 'splits.pkl', 'wb')) 443 | pickle.dump(self.g_indexer, open(dump_pre + 'g_indexer.pkl', 'wb')) 444 | 445 | 446 | def load(self, dump_pre): 447 | self.splits = pickle.load(open(dump_pre + 'splits.pkl', 'rb')) 448 | self.g_indexer = pickle.load(open(dump_pre + 'g_indexer.pkl', 'rb')) 449 | 450 | 451 | def get_dict_by_phoneme(self): 452 | data = self.splits['train'] 453 | self.last2_to_words = {} 454 | for g, p in data: 455 | last2 = '_'.join(p[-2:]) 456 | if last2 not in self.last2_to_words: 457 | self.last2_to_words[last2] = [] 458 | self.last2_to_words[last2].append(g) 459 | 460 | 461 | def batchify_field(self, data, field, use_indexer=True) -> torch.LongTensor: 462 | field_data = [] 463 | max_len = 0 464 | for instance in data: 465 | 466 | field_instance = instance[field] 467 | if len(field_instance) > max_len: 468 | max_len = len(field_instance) 469 | if use_indexer: 470 | field_data.append(self.g_indexer.w_to_idx(w=field_instance)) 471 | else: 472 | if isinstance(field_instance, np.ndarray): 473 | field_data.append(field_instance.tolist()) 474 | else: 475 | field_data.append(field_instance) 476 | 477 | padded_field_data = [] 478 | field_mask = torch.zeros([len(data), max_len], dtype=torch.int32) 479 | for idx, encoded_field_instance in enumerate(field_data): 480 | current_len = len(encoded_field_instance) 481 | field_mask[idx, :current_len] = 1 482 | # field_mask[:current_len] = encoded_field_instance 483 | if current_len < max_len: 484 | padded_field_data.append(np.array(encoded_field_instance+[0]*(max_len - current_len))) 485 | else: 486 | padded_field_data.append(np.array(encoded_field_instance)) 487 | # return padded_field_data, field_mask 488 | # pdb.set_trace() 489 | return torch.LongTensor(np.array(padded_field_data)), field_mask 490 | 491 | 492 | def get_batch_lm(self, i, batch_size, split): 493 | 494 | if self.args.data_type == "limerick": 495 | data = self.limerick_splits[split][i * batch_size: (i + 1) * batch_size] 496 | else: 497 | data = self.sonnet_splits[split][i * batch_size: (i + 1) * batch_size] 498 | batch = {} 499 | all_fields = data[0].keys() 500 | for field in all_fields: 501 | # pdb.set_trace() 502 | if field != "metadata" and field != "lines" and field != "skip" and field != "ending_words" and field != "reversed_quatrains": 503 | if 'mask' in field: 504 | field_data, field_mask = self.batchify_field(data=data, field=field, 505 | use_indexer=False) 506 | else: 507 | field_data, field_mask = self.batchify_field(data=data, field=field, 508 | use_indexer=True) 509 | batch[field] = {} 510 | batch[field]["tokens"] = field_data 511 | batch[field]["mask"] = field_mask 512 | else: 513 | batch[field] = [instance[field] for instance in data] 514 | return batch 515 | 516 | 517 | def get_stats_ending_words_batch(self, batch): 518 | if self.args.data_type == "limerick": 519 | total_limericks = 0 520 | total_limericks_ending_not_in_vocab = 0 521 | for ending_words in batch['ending_words']: 522 | total_limericks += 1 523 | for word in ending_words: 524 | if word not in self.g_indexer.w2idx: 525 | total_limericks_ending_not_in_vocab += 1 526 | break 527 | return total_limericks, total_limericks_ending_not_in_vocab 528 | else: 529 | total_quatrains = 0 530 | total_quatrains_ending_not_in_vocab = 0 531 | for ending_words in batch['ending_words']: 532 | total_quatrains += 3 533 | for word in ending_words[0:4]: 534 | if word not in self.g_indexer.w2idx: 535 | total_quatrains_ending_not_in_vocab += 1 536 | break 537 | for word in ending_words[4:8]: 538 | if word not in self.g_indexer.w2idx: 539 | total_quatrains_ending_not_in_vocab += 1 540 | break 541 | for word in ending_words[8:12]: 542 | if word not in self.g_indexer.w2idx: 543 | total_quatrains_ending_not_in_vocab += 1 544 | break 545 | return total_quatrains, total_quatrains_ending_not_in_vocab 546 | 547 | 548 | def get_stats_ending_words(self, split, batch_size): 549 | if self.args.data_type == "limerick": 550 | num_batches = self.get_num_batches(split=split, batch_size=batch_size, data_type="limerick") 551 | total_limericks = 0 552 | total_limericks_ending_not_in_vocab = 0 553 | for batch_idx in range(num_batches): 554 | batch = self.get_batch_lm(i=batch_idx, batch_size=32, split='train') 555 | limericks_batch, limericks_ending_not_in_vocab_batch = self.get_stats_ending_words_batch(batch=batch) 556 | total_limericks += limericks_batch 557 | total_limericks_ending_not_in_vocab += limericks_ending_not_in_vocab_batch 558 | print("No. of quatrains : ", total_limericks) 559 | print("No. of quatrains with endings in vocab : ", total_limericks - total_limericks_ending_not_in_vocab) 560 | else: 561 | num_batches = self.get_num_batches(split=split, batch_size=batch_size, data_type="sonnet") 562 | total_quatrains = 0 563 | total_quatrains_ending_not_in_vocab = 0 564 | for batch_idx in range(num_batches): 565 | batch = self.get_batch_lm(i=batch_idx, batch_size=32, split='train') 566 | quatrains_batch, quatrains_ending_not_in_vocab_batch = self.get_stats_ending_words_batch(batch=batch) 567 | total_quatrains += quatrains_batch 568 | total_quatrains_ending_not_in_vocab += quatrains_ending_not_in_vocab_batch 569 | print("No. of quatrains : ", total_quatrains) 570 | print("No. of quatrains with endings in vocab : ", total_quatrains - total_quatrains_ending_not_in_vocab) 571 | 572 | 573 | def endings_in_vocab(self, ending_words): 574 | for word in ending_words: 575 | if word not in self.g_indexer.w2idx: 576 | return False 577 | return True 578 | 579 | 580 | def get_batch(self, i, batch_size, split, last_two=False, add_end_to_y=True, skip_unk=True): # typ='g2p', 581 | # typ: g2p, ae 582 | # assert typ==self.typ 583 | typ = self.typ 584 | if typ == "sonnet_endings": 585 | data = self.sonnet_splits[split][i * batch_size:(i + 1) * batch_size] 586 | x = [] 587 | for val in data: 588 | if self.args.use_all_sonnet_data: 589 | for st in [0,4,8]: 590 | if (not skip_unk) or self.endings_in_vocab(ending_words=val['ending_words'][st:st+4]): 591 | x_val = self.prepare_sonnet_x(val['ending_words'][st:st+4]) 592 | x.append(x_val['indexed_x']) 593 | else: 594 | if (not skip_unk) or self.endings_in_vocab(ending_words=val['ending_words'][0:4]): 595 | x_val = self.prepare_sonnet_x(val['ending_words'][0:4]) 596 | x.append(x_val['indexed_x']) 597 | elif (not skip_unk) or self.endings_in_vocab(ending_words=val['ending_words'][4:8]): 598 | x_val = self.prepare_sonnet_x(val['ending_words'][4:8]) 599 | x.append(x_val['indexed_x']) 600 | elif (not skip_unk) or self.endings_in_vocab(ending_words=val['ending_words'][8:12]): 601 | x_val = self.prepare_sonnet_x(val['ending_words'][8:12]) 602 | x.append(x_val['indexed_x']) 603 | else: 604 | continue 605 | x_end = self.g_indexer.w2idx[utils.END] 606 | x = [x_i + [x_end] for x_i in x] 607 | x_start = self.g_indexer.w2idx[utils.START] 608 | return x, None, x_start 609 | elif typ=="limerick": 610 | data = self.limerick_splits[split][i*batch_size:(i+1)*batch_size] 611 | x = [] 612 | for val in data: 613 | if (not skip_unk) or self.endings_in_vocab(ending_words=val['ending_words']): 614 | x_val = self.prepare_sonnet_x(val['ending_words']) 615 | x.append(x_val['indexed_x']) 616 | else: 617 | continue 618 | x_end = self.g_indexer.w2idx[utils.END] 619 | x = [x_i+[x_end] for x_i in x] 620 | x_start = self.g_indexer.w2idx[utils.START] 621 | return x,None,x_start 622 | else: 623 | assert False 624 | 625 | 626 | def get_num_batches(self, split, batch_size, data_type=None): 627 | if data_type == "limerick": 628 | return int((len(self.limerick_splits[split]) + batch_size - 1.0) / batch_size) 629 | elif data_type == "sonnet": 630 | return int((len(self.sonnet_splits[split]) + batch_size - 1.0) / batch_size) 631 | else: 632 | return int((len(self.splits[split]) + batch_size - 1.0) / batch_size) 633 | 634 | 635 | def prepare_sonnet_x(self, lst_of_words): 636 | ret = lst_of_words 637 | ret_idx = [] 638 | for w in ret: 639 | if w not in self.g_indexer.w2idx: 640 | w = UNKNOWN 641 | ret_idx.append( self.g_indexer.w2idx[w] ) 642 | return {'x': ret, 'indexed_x': ret_idx} 643 | 644 | ################ 645 | 646 | def batch_loss_lm_2lines(self, batch, for_training, batch_idx=None, decoder_hidden=None, decoder_context=None, hier_mode=False) -> (torch.Tensor, Dict[str, torch.Tensor]): 647 | 648 | if batch_idx is None: 649 | output_dict = self.lm_model(source_tokens=batch["source_tokens_lm"], 650 | target_tokens=batch["target_tokens_lm"], 651 | ending_words=batch["ending_words_reversed"], 652 | # metadata=batch["metadata"], 653 | batch_idx=batch_idx, 654 | decoder_hidden=decoder_hidden, 655 | decoder_context=decoder_context, 656 | ending_words_mask=batch["ending_words_mask"], 657 | hier_mode = hier_mode) 658 | else: 659 | output_dict = self.lm_model(source_tokens=batch["source_tokens_" + str(batch_idx)], 660 | target_tokens=batch["target_tokens_" + str(batch_idx)], 661 | ending_words=batch["ending_words_" + str(batch_idx)], 662 | # metadata=batch["metadata"], 663 | batch_idx=batch_idx, 664 | decoder_hidden=decoder_hidden, 665 | decoder_context=decoder_context, 666 | ending_words_mask=batch["ending_words_mask_" + str(batch_idx)], hier_mode=hier_mode) 667 | 668 | try: 669 | loss = output_dict["loss"] 670 | # if for_training: 671 | # loss += self.model.get_regularization_penalty() 672 | except KeyError: 673 | if for_training: 674 | raise RuntimeError("The model you are trying to optimize does not contain a" 675 | " 'loss' key in the output of model.forward(inputs).") 676 | loss = None 677 | return loss, output_dict 678 | 679 | 680 | def get_batch_loss_lm(self, split, batch, batch_size, optimizer, mode='train'): 681 | 682 | n_batches = 0 683 | train_loss = 0 684 | batch = self.get_batch_lm(i=batch, batch_size=batch_size, split=split) 685 | if self.args.use_cuda: 686 | batch = {k: ( 687 | {k1: (v1.cuda() if isinstance(v1, torch.Tensor) and torch.cuda.is_available() else v1) for k1, v1 in v.items()} if isinstance(v, 688 | dict) else v) 689 | for k, v in batch.items()} 690 | decoder_hidden = None 691 | decoder_context = None 692 | 693 | if self.args.data_type=="limerick": 694 | n_batches += 1 695 | optimizer.zero_grad() 696 | loss, output_dict = self.batch_loss_lm_2lines(batch=batch, for_training=True, batch_idx=None, 697 | decoder_hidden=decoder_hidden, 698 | decoder_context=decoder_context) 699 | if torch.isnan(loss): 700 | raise ValueError("nan loss encountered") 701 | loss.backward() 702 | train_loss += loss.item() 703 | optimizer.step() 704 | 705 | else: 706 | for i in range(7): 707 | n_batches += 1 708 | optimizer.zero_grad() 709 | loss, output_dict = self.batch_loss_lm_2lines(batch=batch, for_training=True, batch_idx=i, 710 | decoder_hidden=decoder_hidden, 711 | decoder_context=decoder_context) 712 | # loss = self.batch_loss(batch_group, for_training=True) 713 | # pdb.set_trace() 714 | if torch.isnan(loss): 715 | raise ValueError("nan loss encountered") 716 | loss.backward() 717 | train_loss += loss.item() 718 | # batch_grad_norm = self.rescale_gradients() 719 | optimizer.step() 720 | if self._trun_bptt_mode: 721 | decoder_hidden = torch.stack(output_dict["decoder_hidden"]).detach() 722 | decoder_context = torch.stack(output_dict["decoder_context"]).detach() 723 | 724 | return train_loss, n_batches 725 | 726 | 727 | def train_epoch_lm(self, epoch, optimizer, batch_size, split='train') -> Dict[str, float]: 728 | """ 729 | Trains one epoch and returns metrics. 730 | """ 731 | train_loss = 0.0 732 | 733 | # Shuffle training data 734 | if self.args.data_type == "limerick": 735 | random.shuffle(self.limerick_splits[split]) 736 | else: 737 | random.shuffle(self.sonnet_splits[split]) 738 | 739 | # Set the model to "train" mode. 740 | self.lm_model.train() 741 | 742 | last_save_time = time.time() 743 | if self.args.data_type == LIMERICK_DATASET_IDENTIFIER: 744 | num_batches = self.get_num_batches(split=split, batch_size=batch_size, data_type="limerick") 745 | else: 746 | num_batches = self.get_num_batches(split=split, batch_size=batch_size, data_type="sonnet") 747 | 748 | batches_this_epoch = 0 749 | 750 | print("Training epoch : ", epoch) 751 | 752 | # cumulative_batch_size = 0 753 | 754 | # Get tqdm for the training batches 755 | train_generator_tqdm = Tqdm.tqdm(range(num_batches), 756 | total=num_batches) 757 | # for batch_idx in tqdm(range(num_batches)): 758 | for batch_idx in train_generator_tqdm: 759 | train_loss_batch, n_batches = self.get_batch_loss_lm(split=split, batch=batch_idx, 760 | batch_size=batch_size, optimizer=optimizer) 761 | train_loss += train_loss_batch 762 | batches_this_epoch += n_batches 763 | # Update the description with the latest metrics 764 | metrics = get_metrics(train_loss, batches_this_epoch) 765 | description = description_from_metrics(metrics) 766 | train_generator_tqdm.set_description(description, refresh=False) 767 | metrics = get_metrics(train_loss, batches_this_epoch) 768 | return metrics 769 | 770 | 771 | def validation_loss_lm(self, split, batch_size=64, hier_mode=False) -> Tuple[float, int, int, int]: 772 | """ 773 | Computes the validation loss. Returns it and the number of batches. 774 | """ 775 | print("Evaluating : ", split) 776 | self.lm_model.eval() 777 | if self.args.data_type == LIMERICK_DATASET_IDENTIFIER: 778 | num_batches = self.get_num_batches(split=split, batch_size=batch_size, data_type="limerick") 779 | else: 780 | num_batches = self.get_num_batches(split=split, batch_size=batch_size, data_type="sonnet") 781 | # Get tqdm for the validation batches 782 | val_generator_tqdm = Tqdm.tqdm(range(num_batches), 783 | total=num_batches) 784 | batches_this_epoch = 0 785 | val_loss = 0 786 | n_words = 0 787 | n_samples = 0 788 | for batch_idx in val_generator_tqdm: 789 | batch = self.get_batch_lm(i=batch_idx, batch_size=batch_size, split=split) 790 | if self.args.use_cuda: 791 | batch = {k: ( 792 | {k1: (v1.cuda() if isinstance(v1, torch.Tensor) and torch.cuda.is_available() else v1) for k1, v1 in v.items()} if isinstance(v, 793 | dict) else v) 794 | for k, v in batch.items()} 795 | decoder_hidden = None 796 | decoder_context = None 797 | loss_this_batch = None 798 | lengths_this_batch = None 799 | 800 | if self.args.data_type == LIMERICK_DATASET_IDENTIFIER: 801 | # Computing batch-loss for all 5 reversed limerick lines 802 | loss, output_dict = self.batch_loss_lm_2lines(batch, for_training=False, hier_mode=hier_mode) 803 | 804 | lengths_this_batch = np.array([len.detach().cpu().numpy() for len in 805 | output_dict['target_sentence_lengths']]) 806 | if hier_mode: 807 | lengths_this_batch -= 5 808 | 809 | loss_this_batch = np.array([_loss.detach().cpu().numpy() * len for _loss, len in 810 | zip(output_dict['loss'], lengths_this_batch)]) 811 | 812 | else: 813 | # Compute loss for sonnet dataset 814 | if self._validation_bptt_mode: 815 | # Computing batch-loss every 2 reversed sonnet lines 816 | for i in range(7): 817 | 818 | loss, output_dict = self.batch_loss_lm_2lines(batch=batch, for_training=False, batch_idx=i, 819 | decoder_hidden=decoder_hidden, decoder_context=decoder_context, hier_mode=hier_mode) 820 | 821 | lengths_i = np.array([len.detach().cpu().numpy() for len in 822 | output_dict['target_sentence_lengths']]) 823 | 824 | # If we are computing loss/ppl for hierarchical model drop two ending words (Note that while computing loss this is taken into account by target_mask) 825 | if hier_mode: 826 | lengths_i -= 2 827 | 828 | loss_i = np.array([_loss.detach().cpu().numpy() * len for _loss, len in zip(output_dict['loss'], lengths_i)]) 829 | 830 | if loss_this_batch is None: 831 | lengths_this_batch = lengths_i 832 | loss_this_batch = loss_i 833 | else: 834 | lengths_this_batch += lengths_i 835 | loss_this_batch += loss_i 836 | 837 | decoder_hidden = torch.stack(output_dict["decoder_hidden"]).detach() 838 | decoder_context = torch.stack(output_dict["decoder_context"]).detach() 839 | 840 | else: 841 | # Computing batch-loss for all 14 reversed sonnet lines 842 | loss, output_dict = self.batch_loss_lm_2lines(batch, for_training=False, hier_mode=hier_mode) 843 | 844 | lengths_this_batch = np.array([len.detach().cpu().numpy() for len in 845 | output_dict['target_sentence_lengths']]) 846 | if hier_mode and _couplet_mode: 847 | lengths_this_batch -= 12 848 | elif hier_mode: 849 | lengths_this_batch -= 14 850 | 851 | loss_this_batch = np.array([_loss.detach().cpu().numpy() * len for _loss, len in zip(output_dict['loss'], lengths_this_batch)]) 852 | 853 | loss = np.sum(loss_this_batch) 854 | if loss is not None: 855 | batches_this_epoch += 1 856 | val_loss += loss 857 | n_words += np.sum(lengths_this_batch) 858 | n_samples += len(lengths_this_batch) 859 | 860 | # Update the description with the latest metrics 861 | val_metrics = get_metrics(val_loss, n_words) 862 | description = description_from_metrics(val_metrics) 863 | val_generator_tqdm.set_description(description, refresh=False) 864 | 865 | print("Total Loss : ", val_loss, " | no. of words : ", n_words, " | samples : ", n_samples, " | PPL : ", 866 | np.exp(val_loss / n_words)) 867 | 868 | return val_loss, batches_this_epoch, n_words, n_samples 869 | 870 | 871 | def shuffle_data(self, split='train'): 872 | 873 | print("Data shuffled based upon source_tokens_lm!") 874 | # Shuffle training data 875 | instance_len = [len(instance["source_tokens_lm"]) for instance in self.sonnet_splits[split]] 876 | sorted_indices = sorted(range(len(instance_len)), key=lambda k: instance_len[k]) 877 | 878 | self.sonnet_splits[split] = [self.sonnet_splits[split][index] for index in sorted_indices] 879 | 880 | 881 | def train_lm(self, epochs=25, debug=True, args=None): 882 | 883 | learning_rate = 0.0025 884 | optimizer = torch.optim.Adam(self.lm_model.parameters(), lr=learning_rate) 885 | best_loss = 999999999999.0 886 | best_lm_model = self.lm_model 887 | best_epoch = 0 888 | batch_size = 32 889 | 890 | training_start_time = time.time() 891 | 892 | all_loss_tracker = {} 893 | all_loss_tracker_val = {} 894 | 895 | # Shuffle training data before training 896 | # self.shuffle_data(split='train') 897 | 898 | for epoch in range(epochs): 899 | 900 | epoch_start_time = time.time() 901 | train_metrics = self.train_epoch_lm(epoch, optimizer, batch_size) 902 | 903 | print("Validation split!") 904 | with torch.no_grad(): 905 | # We have a validation set, so compute all the metrics on it. 906 | val_loss, num_batches, n_words, n_samples = self.validation_loss_lm(split='val', batch_size=64) 907 | 908 | val_metrics = get_metrics(val_loss, n_words) 909 | 910 | # Check validation metric for early stopping 911 | # this_epoch_val_metric = val_metrics[self._validation_metric] 912 | # self._metric_tracker.add_metric(this_epoch_val_metric) 913 | # 914 | # if self._metric_tracker.should_stop_early(): 915 | # logger.info("Ran out of patience. Stopping training.") 916 | # break 917 | 918 | print("Test split!") 919 | 920 | with torch.no_grad(): 921 | # We have a validation set, so compute all the metrics on it. 922 | test_loss, num_batches, n_words, n_samples = self.validation_loss_lm(split='test', batch_size=64) 923 | 924 | test_metrics = get_metrics(test_loss, n_words) 925 | 926 | # if debug: 927 | # break 928 | 929 | print("TRAIN epoch perplexity = ", np.exp(train_metrics['loss'])) 930 | print() 931 | 932 | print("VAL perplexity = ", np.exp(val_metrics['loss'])) 933 | print("TEST perplexity = ", np.exp(test_metrics['loss'])) 934 | 935 | if val_metrics['loss'] < best_loss: 936 | best_epoch = epoch 937 | if args.save_vanilla_lm: 938 | torch.save(self.lm_model.state_dict(), os.path.join(self.args.vanilla_lm_path, 'model_' + str(epoch))) 939 | print("Best LM model until epoch ", epoch, "dumped!") 940 | 941 | best_loss = val_metrics['loss'] 942 | 943 | print("Best model from epoch ", best_epoch, " saved!") 944 | best_lm_model_state_dict = torch.load(os.path.join(self.args.vanilla_lm_path, 'model_' + str(best_epoch))) 945 | torch.save(best_lm_model_state_dict, os.path.join(self.args.vanilla_lm_path, 'model_best')) 946 | 947 | 948 | def get_loss(self, split, batch, batch_size, mode='train', criterion=None): 949 | 950 | model = self.model 951 | if mode == "train": 952 | model.train() 953 | else: 954 | model.eval() 955 | len_output = 0 956 | # x, _, x_start = self.get_batch(i=batch, batch_size=batch_size, split=split) # , typ=args.data_type) 957 | x, _, x_start = self.get_batch(i=batch, batch_size=batch_size, split=split, skip_unk=False) # , typ=args.data_type) 958 | 959 | batch_loss = torch.tensor(0.0) 960 | if self.args.use_cuda: 961 | batch_loss = batch_loss.cuda() 962 | 963 | i = 0 964 | all_e = [] 965 | # print(" -- batch=",batch, " || x: ",len(x)) 966 | for x_i in x: 967 | len_output += len(x_i) 968 | info = model(x_i, use_gt=True) 969 | out_all_i = info['out_all'] 970 | i += 1 971 | # print(" *** out_all_i = ", out_all_i) 972 | out_all_i = torch.stack(out_all_i) 973 | dist = out_all_i.view(-1, self.g_indexer.w_cnt) 974 | targets = np.array(x_i, dtype=np.long) 975 | targets = torch.from_numpy(targets) 976 | if self.args.use_cuda: 977 | targets = targets.cuda() 978 | cur_loss = criterion(dist, targets) 979 | # print(cur_loss) 980 | batch_loss += cur_loss 981 | 982 | total_loss = batch_loss 983 | return total_loss, len_output, { 984 | 'total_batch_loss': total_loss.data.cpu().item(), \ 985 | 'batch_recon_loss': batch_loss.data.cpu().item(), \ 986 | 'elbo_loss': batch_loss.data.cpu().item() 987 | } 988 | 989 | 990 | def train(self, epochs=11, debug=False, train_lm_supervised=False): 991 | 992 | print("train_lm_supervised = ", train_lm_supervised) 993 | 994 | batch_size = 32 995 | 996 | # Ending words generator model 997 | model = self.model 998 | learning_rate = 1e-4 999 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 1000 | best_loss = 999999999999.0 1001 | criterion = nn.CrossEntropyLoss(ignore_index=0) 1002 | 1003 | # Vanilla-LM (Conditional quatrain generator model) 1004 | 1005 | learning_rate = 0.001 1006 | optimizer_lm = torch.optim.Adam(self.lm_model.parameters(), lr=learning_rate) 1007 | 1008 | # Rhyming Discriminator 1009 | learning_rate = 1e-4 1010 | optimizer_disc = torch.optim.Adam(self.disc.parameters(), lr=learning_rate) 1011 | 1012 | # Syllable Discriminator 1013 | learning_rate = 1e-4 1014 | # optimizer_disc_syllable = torch.optim.Adam(self.disc_syllable.parameters(), lr=learning_rate) 1015 | learning_rate = 1e-4 1016 | # optimizer_disc_syllable_line = torch.optim.Adam(self.disc_syllable_line.parameters(), lr=learning_rate) 1017 | 1018 | self.train_summary = [] 1019 | 1020 | for epoch in range(epochs): 1021 | 1022 | epoch_summary = {'epoch':epoch} 1023 | 1024 | # Enable train() mode 1025 | model.train() 1026 | self.lm_model.train() 1027 | 1028 | print("Training epoch : ", epoch) 1029 | 1030 | # Shuffle training data 1031 | # random.shuffle(self.sonnet_splits['train']) 1032 | #@NEW 1033 | random.shuffle(self.splits['train']) 1034 | 1035 | all_loss_tracker = {} 1036 | all_loss_tracker_val = {} 1037 | all_loss_tracker_disc = {} 1038 | dump_info_all = [] 1039 | 1040 | num_batches = self.get_num_batches('train', batch_size) 1041 | epoch_loss = 0.0 1042 | train_lm_loss = 0.0 1043 | ctr = 0 1044 | data_point_ctr = 0 1045 | data_point_ctr_syll_gen = 0 1046 | data_point_ctr_syll_gen_line = 0 1047 | data_point_ctr_syll_train = 0 1048 | data_point_ctr_syll_train_line = 0 1049 | 1050 | # Get tqdm for the training batches 1051 | train_generator_tqdm = Tqdm.tqdm(range(num_batches), 1052 | total=num_batches) 1053 | batches_this_epoch = 0 1054 | 1055 | for batch in train_generator_tqdm: 1056 | # for batch in range(num_batches): 1057 | 1058 | 1059 | ############ SUPERVISED ENDING WORD GENERATOR 1060 | total_batch_loss, len_output, info = self.get_loss('train', batch, batch_size, mode='train', 1061 | criterion=criterion) 1062 | # print("batch_loss = ", batch_loss/len(x_i)) 1063 | if debug: 1064 | print("TRAIN info = ", info) 1065 | for k, v in info.items(): 1066 | if k.count('loss') > 0: 1067 | if k not in all_loss_tracker: 1068 | all_loss_tracker[k] = 0.0 1069 | all_loss_tracker[k] += v 1070 | epoch_loss = epoch_loss + total_batch_loss.data.cpu().item() 1071 | 1072 | model.zero_grad() 1073 | optimizer.zero_grad() 1074 | total_batch_loss.backward() 1075 | optimizer.step() 1076 | 1077 | ctr = ctr + len_output 1078 | if batch % 1000 == 0: 1079 | print("TRAIN batch = ", batch, "epoch_loss/ctr = ", epoch_loss / ctr) 1080 | # metrics = get_metrics(epoch_loss, ctr, key="endings_supervised_loss") 1081 | 1082 | 1083 | ############ SUPERVISED LM 1084 | if train_lm_supervised: 1085 | train_lm_loss_batch, n_batches = self.get_batch_loss_lm(split='train', batch=batch, 1086 | batch_size=batch_size, optimizer=optimizer_lm) 1087 | train_lm_loss += train_lm_loss_batch 1088 | batches_this_epoch += n_batches 1089 | 1090 | print("Supervised LM train_lm_loss/batches_this_epoch = ", train_lm_loss/batches_this_epoch) 1091 | # metrics = get_metrics(train_lm_loss, batches_this_epoch, key="lm_loss", metrics=metrics) 1092 | 1093 | # print("dfdf-----exitiiiiiii") 1094 | # Update tqdm description with 'gen_loss' and 'lm_loss' 1095 | # description = description_from_metrics(metrics) 1096 | # train_generator_tqdm.set_description(description, refresh=False) 1097 | 1098 | 1099 | ############# STRUCTURE 1100 | 1101 | if self.args.use_reinforce_loss: 1102 | 1103 | model.eval() 1104 | # x, _, x_start = self.get_batch(i=batch, batch_size=batch_size, split='train') # , typ=args.data_type) 1105 | x, _, x_start = self.get_batch(i=batch, batch_size=batch_size, split='train', skip_unk=True) # , typ=args.data_type) 1106 | # lm_batch_info = self.get_batch_for_lm(i=batch, batch_size=batch_size, split='train') #, typ=args.data_type) 1107 | data_point_ctr += len(x) 1108 | all_line_endings_gen = [] 1109 | all_line_endings_train = [] 1110 | all_rewards = [] 1111 | all_rewards_syll = [] 1112 | all_rewards_syll_line = [] 1113 | all_log_probs = [] 1114 | all_log_probs_syll = [] 1115 | all_log_probs_syll_line = [] 1116 | 1117 | if train_lm_supervised: 1118 | lm_batch = self.get_batch_lm(i=batch, batch_size=batch_size, split='train') 1119 | assert False, "TODO" 1120 | # print(" --- lm_batch = ", lm_batch.keys()) 1121 | # batchsize * 3 * 4 1122 | # print(" --- lm_batch[reversed_quatrains] = ", len(lm_batch['reversed_quatrains'])) 1123 | # print(" --- lm_batch[reversed_quatrains] = ", len(lm_batch['reversed_quatrains'][0])) 1124 | # print(" --- lm_batch[reversed_quatrains] = ", len(lm_batch['reversed_quatrains'][0][0])) 1125 | # print(" --- lm_batch[reversed_quatrains] = ", lm_batch['reversed_quatrains'][0][0]) 1126 | 1127 | # print("--- REINFORCE ---") 1128 | pr = False # Random prints 1129 | if debug or np.random.rand() < 0.05: 1130 | pr = True 1131 | # pr=False 1132 | disc_loss_batch = 0.0 1133 | disc_loss_batch_line = 0.0 1134 | 1135 | for i in range(len(x)): 1136 | 1137 | #####----------------- ENDINGS 1138 | info = model(gt=None, use_gt=False) # Get Sample 1139 | # print("info['actions'] = ", info['actions']) 1140 | # print(self.g_indexer.idx_to_2( [i.data.cpu().item() for i in info['actions']] )) 1141 | line_endings_gen = info[ 1142 | 'actions'] # [i.data.cpu().item() for i in info['actions']] #info['actions'] 1143 | line_endings_train = [] 1144 | all_line_endings_gen.append(line_endings_gen) 1145 | 1146 | # for w in x[i][:4]: 1147 | #@NEW 1148 | #print("x[i] = ",x[i]) 1149 | for w in x[i][:-1]: # just removing end? 1150 | word_idx = torch.tensor(w) 1151 | if self.args.use_cuda: 1152 | word_idx = word_idx.cuda() 1153 | line_endings_train.append(word_idx) 1154 | ### TODO: do not use unks? -- Currently endingwords are returned fro a quarttain only if all ending words are in dictionary. so it is fine to ignore for now. just change in sampling part though 1155 | all_line_endings_train.append(line_endings_train) 1156 | 1157 | try: 1158 | # if True: 1159 | disc_info = self.disc.update_discriminator(line_endings_gen, line_endings_train, pr, 1160 | self.g_indexer.idx2w) 1161 | # print("disc_info=", disc_info) 1162 | self.disc.zero_grad() 1163 | optimizer_disc.zero_grad() 1164 | if pr: 1165 | dump_info_all.append(disc_info['dump_info']) 1166 | disc_loss = disc_info['loss'] 1167 | if 'disc_loss' not in all_loss_tracker_disc: 1168 | all_loss_tracker_disc['disc_loss'] = 0.0 1169 | all_loss_tracker_disc['reward'] = 0.0 1170 | all_loss_tracker_disc['disc_loss'] += disc_loss.data.cpu().item() 1171 | # disc_loss.backward(retain_graph=True) 1172 | disc_loss.backward() 1173 | optimizer_disc.step() 1174 | 1175 | reward = disc_info['reward'] 1176 | all_loss_tracker_disc['reward'] += reward.data.cpu().item() 1177 | if debug: 1178 | print("[rhyming] ----> reward = ", reward) 1179 | all_rewards.append(reward.data.cpu().item()) 1180 | # print("info['logprobs'] = ", info['logprobs']) 1181 | log_probs = torch.stack(info['logprobs']) 1182 | all_log_probs.append(log_probs) 1183 | if debug: 1184 | print("[[rhyming] ][reinforce] log_probs =", log_probs) 1185 | except: 1186 | print("--[[rhyming] ] exception--") ##TODO: Need to fix what causes these issues 1187 | # # print(" exception = ", e) 1188 | if debug: 1189 | print() 1190 | 1191 | ## Disc: Rhyming reinforce updates to Generator 1192 | if self.args.use_reinforce_loss: 1193 | all_rewards_numpy = np.array(all_rewards) 1194 | assert len(all_rewards_numpy) == len(all_log_probs) 1195 | # print("all_rewards_numpy = ", all_rewards_numpy) 1196 | all_rewards_numpy = (all_rewards_numpy - all_rewards_numpy.mean()) / (all_rewards_numpy.std()) 1197 | # print("all_rewards_numpy = ", all_rewards_numpy) 1198 | ##### -- TODO: this is inefficent. just compute for every and then backpropr 1199 | for q, log_probs in enumerate(all_log_probs): 1200 | loss = -torch.sum(log_probs * all_rewards_numpy[q]) 1201 | if pr: 1202 | print("[reinforce] loss =", loss) 1203 | loss = loss * self.args.reinforce_weight 1204 | model.zero_grad() 1205 | optimizer.zero_grad() 1206 | loss.backward() 1207 | optimizer.step() 1208 | 1209 | ########## WITHIN EPOCH PRINTS AND OPTIONS 1210 | if batch % 1000 == 0: 1211 | if 'disc_loss' in all_loss_tracker_disc: 1212 | print("TRAIN batch = ", batch, "all_loss_tracker_disc[disc_loss]/data_point_ctr = ", 1213 | all_loss_tracker_disc['disc_loss'] / data_point_ctr) 1214 | if 'disc_loss_syll' in all_loss_tracker_disc: 1215 | print("TRAIN batch = ", batch, "all_loss_tracker_disc['disc_loss_syll']/normalizert = ", 1216 | all_loss_tracker_disc['disc_loss_syll'] / (data_point_ctr_syll_gen +data_point_ctr_syll_train) ) 1217 | 1218 | if debug: 1219 | break 1220 | 1221 | ################### EPOCH END OPTIONS 1222 | # print("epoch_loss = ", epoch_loss/ctr) 1223 | print("TRAIN epoch perplexity = ", np.exp(all_loss_tracker['elbo_loss'] / ctr)) 1224 | epoch_summary.update({'train_elbo_loss':all_loss_tracker['elbo_loss'] , 1225 | 'train_ppl':np.exp(all_loss_tracker['elbo_loss'] / ctr), 1226 | 'ctr':ctr}) 1227 | print("TRAIN epoch all_loss_tracker (norm. by num_batches) = ", 1228 | {k: v / num_batches for k, v in all_loss_tracker.items()}) 1229 | all_loss_tracker.update({'ctr': ctr}) 1230 | if 'disc_loss' in all_loss_tracker_disc: 1231 | print("TRAIN all_loss_tracker_disc[disc_loss]/data_point_ctr = ", 1232 | all_loss_tracker_disc.get('disc_loss',-999) / data_point_ctr) 1233 | all_loss_tracker_disc.update({'data_point_ctr': data_point_ctr}) 1234 | print() 1235 | 1236 | ### RUN ON VALIDATION 1237 | # Will skop this when using gan? Only when using inital noise z 1238 | num_batches = self.get_num_batches('val', batch_size) 1239 | epoch_val_loss = 0.0 1240 | ctr = 0 1241 | for batch in range(num_batches): 1242 | total_batch_loss, len_output, info = self.get_loss('val', batch, batch_size, mode='eval', 1243 | criterion=criterion) 1244 | # print("batch_loss = ", batch_loss/len(x_i)) 1245 | if debug: 1246 | print("VAL info = ", info) 1247 | epoch_val_loss = epoch_val_loss + total_batch_loss.data.cpu().item() 1248 | for k, v in info.items(): 1249 | if k.count('loss') > 0: 1250 | if k not in all_loss_tracker_val: 1251 | all_loss_tracker_val[k] = 0.0 1252 | all_loss_tracker_val[k] += v 1253 | # model.zero_grad() 1254 | # optimizer.zero_grad() 1255 | # total_batch_loss.backward() 1256 | # optimizer.step() 1257 | ctr = ctr + len_output 1258 | if debug: 1259 | # print(" ---- all_loss_tracker_val = ", all_loss_tracker_val) 1260 | break 1261 | print("epoch VAL epoch_val_loss = ", epoch_val_loss, " ctr = ", ctr) 1262 | if 'elbo_loss' in all_loss_tracker_val: 1263 | print("epoch VAL perplexity = ", np.exp(all_loss_tracker_val['elbo_loss'] / ctr)) 1264 | all_loss_tracker_val.update({'ctr': ctr}) 1265 | epoch_summary.update({'val_elbo_loss':all_loss_tracker_val['elbo_loss'], 1266 | 'train_ppl':np.exp(all_loss_tracker_val['elbo_loss'] / ctr), 1267 | 'ctr':ctr}) 1268 | print("epoch all_loss_tracker_val (norm. by num_batches) = ", 1269 | {k: v / num_batches for k, v in all_loss_tracker_val.items()}) 1270 | # if not debug: 1271 | # torch.save(model.state_dict(), self.args.model_dir + 'model_' + str(epoch%5)) 1272 | torch.save(model.state_dict(), self.args.model_dir + 'model_' + str(epoch)) 1273 | torch.save(self.lm_model.state_dict(), self.args.model_dir + 'lmmodel_' + str(epoch)) 1274 | torch.save(self.disc.state_dict(), self.args.model_dir + 'disc_' + str(epoch)) 1275 | # torch.save(self.disc_syllable.state_dict(), self.args.model_dir + 'disc_syllable_' + str(epoch)) 1276 | # torch.save(self.disc_syllable_line.state_dict(), self.args.model_dir + 'disc_syllable_line_' + str(epoch)) 1277 | if (epoch_val_loss / ctr) < best_loss: ## TODO: We should probably do this by lm model loss 1278 | best_loss = epoch_val_loss / ctr 1279 | torch.save(model.state_dict(), self.args.model_dir + 'model_best') 1280 | torch.save(self.lm_model.state_dict(), self.args.model_dir + 'lmmodel_best') 1281 | torch.save(self.disc.state_dict(), self.args.model_dir + 'disc_best') 1282 | # torch.save(self.disc_syllable.state_dict(), self.args.model_dir + 'disc_syllable_best') 1283 | # torch.save(self.disc_syllable_line.state_dict(), self.args.model_dir + 'disc_syllable_line_best') 1284 | 1285 | ##### GET ENDINGS SAMPLES 1286 | model.eval() 1287 | samples = [] 1288 | self.lm_model.eval() #TODO: is it turned back to train() ? 1289 | samples_lm = [] 1290 | samples_lm_traincond = [] 1291 | print("--- ENDINGS Samples ---") 1292 | x, _, x_start = self.get_batch(i=0, batch_size=331, split='val', skip_unk=True) # , typ=args.data_type) 1293 | random.shuffle(x) 1294 | random.shuffle(x) 1295 | for i in range(self.args.num_samples_at_epoch_end): 1296 | info = model(gt=None, use_gt=False) 1297 | # print("info['actions'] = ", info['actions']) 1298 | sample_str = self.g_indexer.idx_to_2([i.data.cpu().item() for i in info['actions']]) 1299 | # print(sample_str) 1300 | samples.append(sample_str) 1301 | 1302 | ending_words_str = sample_str 1303 | ending_words = info['actions'] 1304 | if self.args.use_cuda: 1305 | ending_words = [e.cuda() for e in ending_words] 1306 | if len(ending_words)!=self.num_lines: 1307 | sample_str='--endingWordsCount !=4 -- ' 1308 | else: 1309 | ending_words_reversed = [w for w in reversed(ending_words)] 1310 | info = self.lm_model._sample(ending_words=ending_words_reversed) 1311 | # info = self.lm_model._sample(ending_words=ending_words) 1312 | quatrain_gen = info['actions'] # [i.data.cpu().item() for i in info['actions']] #info['actions'] 1313 | sample_str = ' '.join([ ' '.join([self.g_indexer.idx2w[widx.data.cpu().item()] for widx in reversed(line)]) for line in quatrain_gen ]) 1314 | samples_lm.append([ending_words_str, sample_str]) 1315 | # print(sample_str) 1316 | line_endings_train = [] 1317 | for w in x[i][:-1]: # just removing end? 1318 | word_idx = torch.tensor(w) 1319 | if self.args.use_cuda: 1320 | word_idx = word_idx.cuda() 1321 | line_endings_train.append(word_idx) 1322 | ending_words = line_endings_train 1323 | ending_words_str = [self.g_indexer.idx2w[widx.data.cpu().item()] for widx in ending_words] 1324 | if len(ending_words)!=self.num_lines: 1325 | sample_str='--endingWordsCount !=4 -- ' 1326 | else: 1327 | ending_words_reversed = [w for w in reversed(ending_words)] 1328 | info = self.lm_model._sample(ending_words=ending_words_reversed) 1329 | quatrain_gen = info['actions'] # [i.data.cpu().item() for i in info['actions']] #info['actions'] 1330 | sample_str = ' '.join([ ' '.join([self.g_indexer.idx2w[widx.data.cpu().item()] for widx in reversed(line)]) for line in quatrain_gen ]) 1331 | # print(sample_str) 1332 | samples_lm_traincond.append([ending_words_str, sample_str]) 1333 | 1334 | #### TODO: 1335 | ## compute rhyming and pattern stats 1336 | 1337 | ### TODO: reorg code along 3 groupings 1338 | ## all data loading stuff 1339 | ## all batch related stuff 1340 | ## all training stuff 1341 | 1342 | if self.args.data_type == LIMERICK_DATASET_IDENTIFIER: 1343 | sonnet_vocab = self.limerick_vocab 1344 | else: 1345 | sonnet_vocab = self.sonnet_vocab 1346 | all_embs = self._get_disc_word_representation_dictionary(sonnet_vocab) 1347 | eval_info = self.rhyming_eval.analyze_embeddings_for_rhyming_from_dict(all_embs) 1348 | thresh_f1, maxf1, test_f1 = eval_info['thresh_f1'], eval_info['maxf1'], eval_info['test_f1'] 1349 | print("[EPOCH] = ", epoch, " ---->> thresh_f1, maxf1, test_f1 = ", thresh_f1, maxf1, test_f1) 1350 | epoch_summary.update({'thresh_f1':thresh_f1, 'maxf1':maxf1, 'test_f1':test_f1}) 1351 | 1352 | pattern_eval_info = self.rhyming_eval.analyze_samples_from_endings(samples) 1353 | print("[EPOCH] = ", epoch, " ---->> pattern_eval_info = ", pattern_eval_info) #pattern_success_ratio 1354 | epoch_summary.update({'pattern_eval_info':pattern_eval_info}) 1355 | 1356 | if not os.path.exists(self.args.model_dir + 'samples/'): 1357 | os.makedirs(self.args.model_dir + 'samples/') 1358 | if not os.path.exists(self.args.model_dir + 'samples_lm/'): 1359 | os.makedirs(self.args.model_dir + 'samples_lm/') 1360 | if not os.path.exists(self.args.model_dir + 'logs/'): 1361 | os.makedirs(self.args.model_dir + 'logs/') 1362 | if not os.path.exists(self.args.model_dir + 'dump_info_all/'): 1363 | os.makedirs(self.args.model_dir + 'dump_info_all/') 1364 | json.dump(samples, open(self.args.model_dir + 'samples/' + str(epoch) + '.json', 'w')) 1365 | json.dump(samples_lm, open(self.args.model_dir + 'samples_lm/' + str(epoch) + '.json', 'w')) 1366 | json.dump(samples_lm, open(self.args.model_dir + 'samples_lm/' + str(epoch) + '_traincond.json', 'w')) 1367 | json.dump(all_loss_tracker_val, 1368 | open(self.args.model_dir + 'logs/all_loss_tracker_val' + str(epoch) + '.json', 'w')) 1369 | json.dump(all_loss_tracker, open(self.args.model_dir + 'logs/all_loss_tracker' + str(epoch) + '.json', 'w')) 1370 | json.dump(all_loss_tracker_disc, 1371 | open(self.args.model_dir + 'logs/all_loss_tracker_disc' + str(epoch) + '.json', 'w')) 1372 | # print("dump_info_all = ", dump_info_all) 1373 | pickle.dump(dump_info_all, 1374 | open(self.args.model_dir + 'dump_info_all/dump_info_all' + str(epoch) + '.pkl', 'wb')) 1375 | # if debug: 1376 | # break 1377 | print() 1378 | 1379 | self.train_summary.append(epoch_summary) 1380 | 1381 | json.dump(self.train_summary, open(self.args.model_dir + 'train_summary' + '.json', 'w')) 1382 | print("Saving train summary to ", self.args.model_dir + 'train_summary' + '.json') 1383 | 1384 | 1385 | def load_models(self, model_dir, model_epoch='best', load_lm=True): 1386 | print() 1387 | print("[load_models] : Loading from ", model_dir+'model_'+model_epoch ) 1388 | self.model.load_state_dict(torch.load(model_dir+'model_'+model_epoch)) 1389 | # self.load(dump_pre=model_dir) 1390 | if load_lm: 1391 | print("[load_models] : Loading LM model from ----------->>>>>>>>> ", model_dir+'lmmodel_'+model_epoch) 1392 | #self.lm_model.load_state_dict(torch.load('tmp/rhymgan_lm1/model_best')) #model_dir+'lmmodel_'+model_epoch)) 1393 | self.lm_model.load_state_dict(torch.load(model_dir+'lmmodel_'+model_epoch)) 1394 | self.disc.load_state_dict(torch.load(model_dir+'disc_'+model_epoch)) 1395 | self.disc.display_params() 1396 | 1397 | 1398 | def _get_disc_word_representation_dictionary(self, vocab): 1399 | all_embs = {} 1400 | for w in vocab: #['head','bread']:#sonnet_vocab: 1401 | if w==utils.UNKNOWN: 1402 | continue 1403 | s1 = self.disc.g_indexer.w_to_idx(w.lower()) 1404 | if self.args.use_eow_in_enc: 1405 | assert False 1406 | e1 = self.disc.g2pmodel.encode(s1) 1407 | e1_numpy = e1.data.cpu().numpy().reshape(-1) 1408 | all_embs[w] = e1_numpy 1409 | return all_embs 1410 | 1411 | 1412 | def analysis(self, epoch='best', args=None): 1413 | 1414 | model = self.model 1415 | model.eval() 1416 | # self.disc.eval() 1417 | model_dir = self.args.model_dir 1418 | criterion = nn.CrossEntropyLoss(ignore_index=0) 1419 | 1420 | # batch_size=32 1421 | # all_line_endings_train = [] 1422 | # for split in ['train','val','test']: 1423 | # #Compute matrices 1424 | # num_batches = self.get_num_batches(split, batch_size) 1425 | # epoch_val_loss = 0.0 1426 | # ctr = 0 1427 | # all_loss_tracker_val = {} 1428 | # dump_info_all = [] 1429 | # for batch in range(num_batches): 1430 | # len_output = 0 1431 | # x, _, x_start = self.get_batch(i=batch, batch_size=batch_size, split=split) # , typ=args.data_type) 1432 | # for x_i in x: 1433 | # line_endings_train = [] 1434 | # for w in x_i[:-1]: # just removing end? 1435 | # word_idx = self.g_indexer.idx2w[w] 1436 | # line_endings_train.append(word_idx) 1437 | # all_line_endings_train.append(line_endings_train) 1438 | # if self.args.data_type=="limerick": 1439 | # json.dump(all_line_endings_train, open('all_line_endings_limerick.json', 'w')) 1440 | # elif self.args.data_type=="sonnet_endings": 1441 | # json.dump(all_line_endings_train, open('all_line_endings_sonnet.json', 'w')) 1442 | # 0/0 1443 | 1444 | # Analyzing learnt representations 1445 | if self.args.data_type == LIMERICK_DATASET_IDENTIFIER: 1446 | sonnet_vocab = self.limerick_vocab 1447 | else: 1448 | sonnet_vocab = self.sonnet_vocab 1449 | all_embs = {} 1450 | dump_file = model_dir + "all_embs_epoch"+epoch+".pkl" 1451 | print("dumping to ",dump_file) 1452 | sonnet_vocab[0:5], len(sonnet_vocab) 1453 | self.model.eval() 1454 | # for w in sonnet_vocab: #['head','bread']:#sonnet_vocab: 1455 | # if w==utils.UNKNOWN: 1456 | # continue 1457 | # s1 = self.disc.g_indexer.w_to_idx(w.lower()) 1458 | # # print(w,s1) 1459 | # if self.args.use_eow_in_enc: 1460 | # assert False 1461 | # e1 = self.disc.g2pmodel.encode(s1) 1462 | # e1_numpy = e1.data.cpu().numpy().reshape(-1) 1463 | # all_embs[w] = e1_numpy 1464 | # #break 1465 | all_embs = self._get_disc_word_representation_dictionary(sonnet_vocab) 1466 | pickle.dump(all_embs, open(dump_file,'wb')) 1467 | print("DUMPING TO ", dump_file) 1468 | # Analyzing dot products can be done separately 1469 | 1470 | #get_spelling_baseline_for_rhyming = True #False 1471 | #eval_info = self.rhyming_eval.analyze_embeddings_for_rhyming_from_dict(all_embs, get_spelling_baseline_for_rhyming, spelling_type='last1') 1472 | eval_info = self.rhyming_eval.analyze_embeddings_for_rhyming_from_dict(all_embs) 1473 | thresh_f1, maxf1, test_f1 = eval_info['thresh_f1'], eval_info['maxf1'], eval_info['test_f1'] 1474 | print("[EPOCH] = ", epoch, " ---->> thresh_f1, maxf1, test_f1 = ", thresh_f1, maxf1, test_f1) 1475 | 1476 | 1477 | ##### GET SAMPLES 1478 | print("=======GET SAMPLES======") 1479 | model.eval() 1480 | samples = [] 1481 | x, _, x_start = self.get_batch(i=0, batch_size=331, split='val') # , typ=args.data_type) 1482 | for i in range(10000): 1483 | try: 1484 | with time_limit(4): 1485 | info = model(gt=None, use_gt=False, temperature=args.temperature) 1486 | sample_str = self.g_indexer.idx_to_2([i.data.cpu().item() for i in info['actions']]) 1487 | samples.append(sample_str) 1488 | except TimeoutException as e: 1489 | print("Timed out!") 1490 | if i%500==0: 1491 | print("Done with ", i+1, " samples") 1492 | if not os.path.exists(self.args.model_dir + 'samples_analysis/'): 1493 | os.makedirs(self.args.model_dir + 'samples_analysis/') 1494 | json.dump(samples, open(self.args.model_dir + 'samples_analysis/' + str(epoch) + '.json', 'w')) 1495 | print("="*99) 1496 | print("DONE WITH SAMPLES") 1497 | 1498 | pattern_eval_info = self.rhyming_eval.analyze_samples_from_endings(samples) 1499 | print(json.dumps(pattern_eval_info, indent=4)) 1500 | 1501 | 1502 | ##### PPL AND INFO 1503 | batch_size = 32 1504 | for split in ['val','test']: 1505 | 1506 | #Computing ppl of model 1507 | num_batches = self.get_num_batches(split, batch_size) 1508 | epoch_val_loss = 0.0 1509 | ctr = 0 1510 | all_loss_tracker_val = {} 1511 | for batch in range(num_batches): 1512 | total_batch_loss, len_output, info = self.get_loss(split, batch, batch_size, mode='eval', 1513 | criterion=criterion) 1514 | epoch_val_loss = epoch_val_loss + total_batch_loss.data.cpu().item() 1515 | for k, v in info.items(): 1516 | if k.count('loss') > 0: 1517 | if k not in all_loss_tracker_val: 1518 | all_loss_tracker_val[k] = 0.0 1519 | all_loss_tracker_val[k] += v 1520 | ctr = ctr + len_output 1521 | #break 1522 | print("[ENDINGS MODEL] epoch =" , epoch, "split = ", split, " epoch_val_loss = ", epoch_val_loss, " ctr = ", ctr, \ 1523 | "PPL = ", np.exp(epoch_val_loss/ctr)) 1524 | print("="*99) 1525 | print("DONE WITH PPL") 1526 | 1527 | #Compute matrices 1528 | if args.dump_matrices: 1529 | limit = 5 1530 | num_batches = min(limit,self.get_num_batches(split, batch_size)) 1531 | epoch_val_loss = 0.0 1532 | ctr = 0 1533 | all_loss_tracker_val = {} 1534 | dump_info_all = [] 1535 | for batch in range(num_batches): 1536 | len_output = 0 1537 | x, _, x_start = self.get_batch(i=batch, batch_size=batch_size, split=split) # , typ=args.data_type) 1538 | all_e = [] 1539 | for x_i in x: 1540 | len_output += len(x_i) 1541 | try: 1542 | with time_limit(10): 1543 | # long_function_call() 1544 | info = model(gt=None, use_gt=False) # Get Sample 1545 | line_endings_gen = info['actions'] # [i.data.cpu().item() for i in info['actions']] #info['actions'] 1546 | line_endings_train = [] 1547 | # print("info = ", info) 1548 | for w in x_i[:-1]: # just removing end? 1549 | word_idx = torch.tensor(w) 1550 | if self.args.use_cuda: 1551 | word_idx = word_idx.cuda() 1552 | line_endings_train.append(word_idx) 1553 | disc_info = self.disc.update_discriminator(line_endings_gen, line_endings_train, True, self.g_indexer.idx2w) 1554 | dump_info = disc_info['dump_info'] 1555 | dump_info.update({'line_endings_gen':[self.g_indexer.idx2w[idx.data.cpu().item()] for idx in line_endings_gen], \ 1556 | 'line_endings_train':[self.g_indexer.idx2w[idx.data.cpu().item()] for idx in line_endings_train]}) 1557 | dump_info_all.append(dump_info) 1558 | except TimeoutException as e: 1559 | print("Timed out!") 1560 | print("="*99) 1561 | print("DONE WITH MATRICES") 1562 | 1563 | if not os.path.exists(self.args.model_dir + 'dump_info_all_analysis/'): 1564 | os.makedirs(self.args.model_dir + 'dump_info_all_analysis/') 1565 | pickle.dump(dump_info_all, open(self.args.model_dir + 'dump_info_all_analysis/dump_info_all' + str(split) + '_epoch'+epoch+'.pkl', 'wb')) 1566 | 1567 | print("="*99) 1568 | print("DONE WITH DUMPING MATRICES") 1569 | 1570 | def lm_analysis(self, epoch='best', args=None): 1571 | 1572 | model = self.lm_model 1573 | model.eval() 1574 | model_dir = self.args.model_dir 1575 | criterion = nn.CrossEntropyLoss(ignore_index=0) 1576 | 1577 | ##### PPL AND INFO 1578 | batch_size = 32 1579 | for split in ['val', 'test']: 1580 | 1581 | print("Split: ", split) 1582 | # Computing ppl of model 1583 | num_batches = self.get_num_batches(split, batch_size) 1584 | epoch_val_loss = 0.0 1585 | ctr = 0 1586 | all_loss_tracker_val = {} 1587 | for batch in range(num_batches): 1588 | total_batch_loss, len_output, info = self.get_loss(split, batch, batch_size, mode='eval', 1589 | criterion=criterion) 1590 | epoch_val_loss = epoch_val_loss + total_batch_loss.data.cpu().item() 1591 | for k, v in info.items(): 1592 | if k.count('loss') > 0: 1593 | if k not in all_loss_tracker_val: 1594 | all_loss_tracker_val[k] = 0.0 1595 | all_loss_tracker_val[k] += v 1596 | ctr = ctr + len_output 1597 | # break 1598 | print("[ENDINGS MODEL] epoch =", epoch, "split = ", split, " epoch_val_loss = ", epoch_val_loss, 1599 | " ctr = ", ctr, \ 1600 | "PPL = ", np.exp(epoch_val_loss / ctr)) 1601 | 1602 | 1603 | with torch.no_grad(): 1604 | # We have a validation/test set, so compute all the metrics on it. 1605 | # Note that if we set hier_mode=False, code will compute ppl for vanilla-LM 1606 | val_loss, num_batches, n_words, n_samples = self.validation_loss_lm(split=split, batch_size=64, hier_mode=True) 1607 | # val_metrics = get_metrics(val_loss, num_batches) 1608 | 1609 | print("Total val_loss : ", val_loss, " and n_words : ", n_words) 1610 | print("Ppl: ", np.exp((epoch_val_loss + val_loss)/(ctr + n_words))) 1611 | print("DONE WITH PPL") 1612 | print("-" * 50) 1613 | 1614 | 1615 | 1616 | 1617 | 1618 | 1619 | 1620 | --------------------------------------------------------------------------------