├── finetune.sh ├── viz.css ├── clean_pofo_corpus.py ├── banalify.py ├── generate_modifier.py ├── viz.js ├── README.md └── visions.py /finetune.sh: -------------------------------------------------------------------------------- 1 | export TRAIN_FILE=pofo.corpus 2 | 3 | python ../../transformers/examples/language-modeling/run_language_modeling.py \ 4 | --output_dir=pofo \ 5 | --model_type=bert \ 6 | --model_name_or_path=bert-base-uncased \ 7 | --do_train \ 8 | --train_data_file=$TRAIN_FILE \ 9 | --mlm 10 | -------------------------------------------------------------------------------- /viz.css: -------------------------------------------------------------------------------- 1 | .box { 2 | position: fixed; 3 | top: 200px; 4 | right: 100px; 5 | border: 1px solid black; 6 | padding: 5px; 7 | background-color: white; 8 | } 9 | .graph-holder { 10 | border-spacing: 10px; 11 | border-collapse: separate; 12 | } 13 | .bar { 14 | background-color: red; 15 | } 16 | -------------------------------------------------------------------------------- /clean_pofo_corpus.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | 4 | of = codecs.open('pofo-corpus-bert-noauthors.txt', 'w', 'utf-8') 5 | of.write('[CLS]') 6 | for filename in os.listdir('pofo-corpus'): 7 | f = codecs.open(os.path.join('pofo-corpus', filename), 'r', 'utf-8') 8 | poet = None 9 | title = None 10 | lines = [] 11 | for i, line in enumerate(f.readlines()): 12 | if i == 0: 13 | poet = line.strip() 14 | elif i == 2: 15 | title = line.strip() 16 | elif i > 4 and line != '~~~~!~~~\n': 17 | lines.append(line) 18 | of.write(f'Title: {title} / Text: \n') 19 | of.write(''.join(lines).strip()) 20 | of.write('[SEP]') 21 | 22 | -------------------------------------------------------------------------------- /banalify.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | from visions import * 3 | f = codecs.open('mobydick.txt', 'r', 'utf-8') 4 | text = f.read() 5 | output = banalify(text, window_size=10, context_size=20, max_iterations=500, 6 | match_meter=False, match_rhyme=False, 7 | title=None, author=None, 8 | randomize=False, cooldown=0.01, modifier=None, 9 | forbid_reversions=True, 10 | preserve_punctuation=True, 11 | allow_punctuation=False, 12 | strong_topic_bias=False, stop_score=1.0, 13 | model_type='bert-large-uncased-whole-word-masking', model_path=None, 14 | sequential=False, verbose=False) 15 | of = open('mobydick-banalified-10-100.txt', 'w') 16 | of.write(output) 17 | -------------------------------------------------------------------------------- /generate_modifier.py: -------------------------------------------------------------------------------- 1 | # This script analyzes a text file to determine how frequently it uses certain 2 | # words relative to the usual frequencies in English according to the Brown 3 | # corpus. You can use this with the modifier parameter to skew the results toward 4 | # the vocabulary of a certain author or text. 5 | 6 | import codecs 7 | import json 8 | import math 9 | import nltk 10 | import re 11 | import sys 12 | 13 | if len(sys.argv) != 3: 14 | print(f'Usage: {sys.argv[0]} ') 15 | 16 | infile = codecs.open(sys.argv[1], 'r', 'utf8') 17 | outfile = codecs.open(sys.argv[2], 'w', 'utf8') 18 | 19 | re_word = re.compile(r"^[a-zA-Z']+$") 20 | 21 | corpus = infile.read() 22 | corpus_tokens = nltk.tokenize.word_tokenize(corpus) 23 | corpus_tokens = [w.lower() for w in corpus_tokens 24 | if re_word.match(w)] 25 | corpus_freq = nltk.FreqDist(corpus_tokens) 26 | n_corpus = sum(corpus_freq.values()) 27 | 28 | brown_tokens = [w.lower() for w in nltk.corpus.brown.words() 29 | if re_word.match(w)] 30 | brown_freq = nltk.FreqDist(brown_tokens) 31 | n_brown = sum(brown_freq.values()) 32 | 33 | scores = {} 34 | for w in corpus_freq.keys() & brown_freq.keys(): 35 | scores[w] = math.log((corpus_freq[w] / n_corpus) / (brown_freq[w] / n_brown)) 36 | 37 | json.dump(scores, outfile) 38 | -------------------------------------------------------------------------------- /viz.js: -------------------------------------------------------------------------------- 1 | function hex(i) { 2 | var str = Number(i).toString(16); 3 | return str.length == 1 ? "0" + str : str; 4 | } 5 | 6 | function showEntropy() { 7 | $('.tok').each(function (index, el) { 8 | el = $(el); 9 | var entropy = el.attr('data-entropy-relative'); 10 | var val = Math.floor(255 - parseFloat(entropy) * 255); 11 | var color = "#FF" + hex(val) + hex(val); 12 | el.css('background-color', color); 13 | }); 14 | } 15 | 16 | function showScore() { 17 | $('.tok').each(function (index, el) { 18 | el = $(el); 19 | var score = el.attr('data-score-relative'); 20 | var val = Math.floor((1.0 - parseFloat(score) * 0.5) * 255); 21 | var color = "#" + hex(val) + hex(val) + "FF"; 22 | el.css('background-color', color); 23 | }); 24 | } 25 | 26 | function clearHighlighting() { 27 | $('.tok').each(function (index, el) { 28 | el = $(el); 29 | el.css('background-color', ''); 30 | }); 31 | } 32 | 33 | function indicateChanges() { 34 | $('.changed-tok').each(function (index, el) { 35 | el = $(el); 36 | el.css('font-weight', 'bold'); 37 | }); 38 | } 39 | 40 | function hideChanges() { 41 | $('.changed-tok').each(function (index, el) { 42 | el = $(el); 43 | el.css('font-weight', ''); 44 | }); 45 | } 46 | 47 | function showPopup(e) { 48 | var el = $(e.target); 49 | var score = el.attr('data-score'); 50 | var replacements = JSON.parse(el.attr('data-replacements')); 51 | $(".box").remove(); 52 | var html = "
Top prediction: " + replacements.join('/') + "
"; 53 | html += "Score: " + Number(score).toFixed(3) + "
"; 54 | for (var k = 1; k <= 3; k++) { 55 | var options = JSON.parse(el.attr("data-options" + k)); 56 | var entropy = el.attr('data-entropy' + k); 57 | if (options == null) continue; 58 | html += ""; 74 | } 75 | html += "
"; 59 | if (k == 1) html += "No topic:
"; 60 | if (k == 2) html += "Raw:
"; 61 | if (k == 3) html += "Constrained:
"; 62 | var max = 0.0; 63 | for (var i = 0; i < options.length; i++) { 64 | var p = parseFloat(options[i][1]); 65 | if (p > max) max = p; 66 | } 67 | for (var i = 0; i < options.length; i++) { 68 | html += ""; 72 | } 73 | html += "
"; 69 | html += options[i][0]; 70 | html += "
 
"; 71 | html += "
Entropy: " + Number(entropy).toFixed(3) + "
"; 76 | $("body").append(html); 77 | } 78 | 79 | $(function () { 80 | showScore(); 81 | $("#changes").change(function (e) { 82 | if ($("#changes").prop('checked')) { 83 | indicateChanges(); 84 | } else { 85 | hideChanges(); 86 | } 87 | }); 88 | $("#highlighting").change(function (e) { 89 | if ($("#highlighting").val() == 'Entropy') { 90 | showEntropy(); 91 | } else if ($("#highlighting").val() == 'Score') { 92 | showScore(); 93 | } else { 94 | clearHighlighting(); 95 | } 96 | }); 97 | $(".tok").on("dblclick", function (e) { 98 | showPopup(e); 99 | }); 100 | }) 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Hundred Visions and Revisions 2 | 3 | The recognition of their presence in a tree: 4 | Sitting on the long, thick branch. 5 | 6 | "A Hundred Visions and Revisions" is a computer program that alters poems using a neural-network language model. It works by replacing the individual words of the text, one by one, with other words that are more probable according to the BERT language model, while preserving rhyme and meter; in effect, this process banalifies the poem, replacing its linguistic distinctiveness with normativity. The program can also attempt to revise a poem to be about a different topic. 7 | 8 | [Update 2021: Building on some of the ideas I developed for this project, I am now working on a more general system for controlling text generators called [PromptArray](https://github.com/jeffbinder/promptarray).] 9 | 10 | As an example, I started with the poem "The Sick Rose" by William Blake: 11 | 12 | > O Rose thou art sick. 13 | > The invisible worm, 14 | > That flies in the night 15 | > In the howling storm: 16 | > 17 | > Has found out thy bed 18 | > Of crimson joy: 19 | > And his dark secret love 20 | > Does thy life destroy. 21 | 22 | Here is the revision: 23 | 24 | > By God thou art blessed. 25 | > The invisible man, 26 | > Who walks in the night 27 | > In a hooded cloak: 28 | > 29 | > Has found both his source 30 | > Of body heat: 31 | > And his own power that 32 | > Makes his life complete. 33 | 34 | I have also tried finetuning the neural network on a corpus of about 10,000 poems so as to improve the predictions. Here are the results: 35 | 36 | > Thank God you are safe. 37 | > The emotional wind, 38 | > That blows in the east 39 | > In a brutal gale: 40 | > 41 | > Tears leaked out like mud 42 | > From driving rain: 43 | > In your eyes only strands 44 | > Of your hair remain. 45 | 46 | As you can see, the program can produce very different output depend on how it is set up. Here is an alternative version of "The Sick Rose," based on different language model called RoBERTa: 47 | 48 | > They are his cold hands. 49 | > the aluminum shards, 50 | > his hands in the snow 51 | > and the melting ice: 52 | > 53 | > that carve out his heart 54 | > from molten clay: 55 | > and his cold fingers that 56 | > take his life away. 57 | 58 | As another example, here is Blake's poem "Jerusalem": 59 | 60 | > And did those feet in ancient time 61 | > Walk upon Englands mountains green: 62 | > And was the holy Lamb of God, 63 | > On Englands pleasant pastures seen! 64 | > 65 | > And did the Countenance Divine, 66 | > Shine forth upon our clouded hills? 67 | > And was Jerusalem builded here, 68 | > Among these dark Satanic Mills? 69 | > 70 | > Bring me my Bow of burning gold: 71 | > Bring me my arrows of desire: 72 | > Bring me my Spear: O clouds unfold! 73 | > Bring me my Chariot of fire! 74 | > 75 | > I will not cease from Mental Fight, 76 | > Nor shall my sword sleep in my hand: 77 | > Till we have built Jerusalem, 78 | > In Englands green & pleasant Land. 79 | 80 | Here is the result after fifty iterations: 81 | 82 | > And did he who in our time 83 | > Shone upon our pleasant land: 84 | > And by the holy Grace of God, 85 | > Let our pleasant country stand! 86 | > 87 | > And did the Glorious Divine, 88 | > Shine forth upon our pleasant land? 89 | > And was Jerusalem standing here, 90 | > Upon his great Almighty Hand? 91 | > 92 | > Bring me my Sword of shining light: 93 | > Bring me my weapon of desire: 94 | > Bring me my Spear: O lord above! 95 | > Bring me my Instrument of fire! 96 | > 97 | > I will not die from Our Fight, 98 | > Nor will my spear be in my hand: 99 | > For we have reached Jerusalem, 100 | > In our time & pleasant land. 101 | 102 | ## Changing the topic of a text 103 | 104 | [UPDATE June 2020: I have redone the "Tyger" examples in this section to use the new, finetuned model.] 105 | 106 | It is also possible to have the program revise a poem to be about a different topic while retaining rhyme, meter, and some other, subtler traces of the original. When I created the finetuned neural network, I included annotations indicating the title and author of each poem. This enables the AI to pick up on patterns in the relation between title and poem. You can then feed in hints about the poem's title, and the AI will alter the text accordingly. 107 | 108 | For example, here are some computer-generated variants of the first stanza of another famous Blake poem, "The Tyger," with various different titles. The following examples were generated with some degree of randomness, which generally leads to better results (in part since it helps the program avoid [local optima](https://en.wikipedia.org/wiki/Local_optimum)); I also set the program to leave the first two words of the text alone. 109 | 110 | Original text: 111 | 112 | > Tyger Tyger, burning bright, 113 | > In the forests of the night; 114 | > What immortal hand or eye, 115 | > Could frame thy fearful symmetry? 116 | 117 | title="Strolling through a lovely flower garden": 118 | 119 | > Tyger Tyger, lovely tree, 120 | > In a garden by the sea; 121 | > What admiring ear or eye, 122 | > Can see such lovely scenery? 123 | 124 | title="Charging into glorious battle": 125 | 126 | > Tyger Tyger, shining sword, 127 | > In the battle stands the lord; 128 | > Yet unable hand or foot, 129 | > Can fight this mighty warrior? 130 | 131 | title="The aesthetics of Gothic cathedrals": 132 | 133 | > Tyger Tyger, gothic spire, 134 | > In the middle of the shire; 135 | > Which cathedral bore that spire, 136 | > Which bore that gothic pinnacle? 137 | 138 | title="Urban planning and traffic management": 139 | 140 | > Tyger Tyger, traffic light, 141 | > In the middle of the night; 142 | > Does policeman check his lights, 143 | > Not check his traffic cameras? 144 | 145 | title="A little, fluffy kitty cat": 146 | 147 | > Tyger Tyger, kitty cat, 148 | > In the middle of a chat; 149 | > What electric freak was he, 150 | > To send such nasty messages? 151 | 152 | title="Leaves of Grass": 153 | 154 | > Tyger Tyger, broken glass, 155 | > In the middle of the grass; 156 | > Leaves forever blown by wind, 157 | > How much more petty agony? 158 | 159 | As the latter example shows, the model appears to be going off of the words in the title, not recalling any specific information about the titles of pre-existing poems. The titles thus seem to function mainly as topical hints. I need, however, to do more research into exactly what is happening with the titles. 160 | 161 | All of these revisions retain the rhyme, meter, and punctuation of the original (excepting the slant-rhyme of "eye" and "symmetry", which the current code cannot detect). If these formal constraints are lifted, the poem will degenerate into prose that bears little relation to the original, a fact best illustrated by the full sequence of steps by which this stanza is transformed into a text about "computational language modeling with artificial neural networks": 162 | 163 | > tyger tyger, burning bright, in the forests of the night; what immortal hand or eye, could frame such fearful symmetry? 164 | > tyger tyger, burning bright, in the forests of the night; what immortal hand or eye, could frame **create** fearful symmetry? 165 | > tyger tyger, burning bright, in the forests of the night; what immortal hand or eye, could **you** create fearful symmetry? 166 | > tyger tyger, burning bright, in the **middle** of the night; what immortal hand or eye, could you create fearful symmetry? 167 | > tyger tyger, burning bright, in the middle of the night; what immortal hand or eye, could you create **such** symmetry? 168 | > tyger tyger, burning bright, in the middle of the night; what immortal hand or eye, could you create such **things**? 169 | > tyger tyger, burning bright, in the middle of the night; what **artificial** hand or eye, could you create such things? 170 | > **john** tyger, burning bright, in the middle of the night; what artificial hand or eye, could you create such things? 171 | > john **lennon**, burning bright, in the middle of the night; what artificial hand or eye, could you create such things? 172 | > john lennon, burning **bridges**, in the middle of the night; what artificial hand or eye, could you create such things? 173 | > john lennon, burning bridges, in the middle of the night; what artificial **ear** or eye, could you create such things? 174 | > john lennon, burning bridges, in the middle of the night; **an** artificial ear or eye, could you create such things? 175 | > john lennon, burning bridges, in the middle of the night **with** an artificial ear or eye, could you create such things? 176 | > john lennon, burning bridges, in the middle of the night with an artificial ear or **something**, could you create such things? 177 | > john lennon, burning bridges, in the middle of the night with an artificial **intelligence** or something, could you create such things? 178 | > john lennon, **building** bridges, in the middle of the night with an artificial intelligence or something, could you create such things? 179 | > john lennon, building **computers**, in the middle of the night with an artificial intelligence or something, could you create such things? 180 | > john lennon, **using** computers, in the middle of the night with an artificial intelligence or something, could you create such things? 181 | > john lennon, using computers, in the middle of the night with an artificial intelligence or **computer**, could you create such things? 182 | > john lennon, using computers, in the middle of the night with an artificial intelligence **working** computer, could you create such things? 183 | > john lennon, using computers, in the middle of the night with an artificial intelligence working **nearby**, could you create such things? 184 | > john lennon, using computers, in the middle of the night with an artificial intelligence working nearby, could you **predict** such things? 185 | > john lennon, using computers, in the middle of the night with an artificial intelligence **expert** nearby, could you predict such things? 186 | > john lennon, using computers, in the middle of the night with an artificial intelligence expert **asking**, could you predict such things? 187 | > john lennon, using computers, in the middle of the night **hears** an artificial intelligence expert asking, could you predict such things? 188 | 189 | ## Changing the author 190 | 191 | The finetuned model also incorporates information about the authors of the poems it is trained on. Based on this, it is possible to give the program hints about the author of a poem as well as the title. 192 | 193 | For example, here is the first stanza of "I Wandered Lonely as a Cloud" by William Wordsworth: 194 | 195 | > I wandered lonely as a cloud 196 | > That floats on high o'er vales and hills, 197 | > When all at once I saw a crowd, 198 | > A host, of golden daffodils, 199 | > Beside the lake, beneath the trees, 200 | > Fluttering and dancing in the breeze. 201 | 202 | I told the AI that this poem was called "Rime of the Ancient Mariner" by Samuel Taylor Coleridge. Here is its revision: 203 | 204 | > He fired blindly at a whale 205 | > That shot straight out like a dart and died, 206 | > Till all at once he blew a gale, 207 | > A blast, till nothing unified, 208 | > Except the blast, engulfed the whale, 209 | > Shattering and snapping in the gale. 210 | 211 | Entering "The Raven" by Edgar Allan Poe gave very different results: 212 | 213 | > While standing staring at a bird 214 | > I watched it tick like a moth in space, 215 | > Till all at once I heard a word, 216 | > A phrase, no longer commonplace, 217 | > Without a rhyme, without a plot, 218 | > Fragmented and twisted in a knot. 219 | 220 | ## BERT-rimés 221 | 222 | [Bouts-rimés](https://en.wikipedia.org/wiki/Bouts-Rimés) is an old French game in which one person selects a series of rhyming words and another person composes a poem using them. The idea is to pick words that are tricky to use–words that don't seem to make sense together–so that it is a challenge to create a coherent, natural-sounding poem. BERT, it turns out, is able to play this game, sort of. 223 | 224 | Doing this was a little tricky, since BERT is not ideal for generating wholly new text. To start with, I wrote a function that generates words in order, following a specified meter and inserting the rhyming words at the ends of the lines. The results at this point are generally not so good; however, this stage is necessary so as to choose a syllable structure that fits the meter while roughly suiting the patterns of the English language. I then run the revision procedure to turn this initial text into something more coherent. 225 | 226 | For example, I entered the rhyming words from the first stanza of "I Wandered Lonely as a Cloud": cloud, hills, crowd, daffodils, trees, breeze. I set the verse form to iambic quadrimeter (which is the form used by the original poem) and also entered the title and author of the original. Here is the almost totally nonsensical initial output: 227 | 228 | > Opponents simulator cloud, 229 | > The education rocked a hills. 230 | > Recapture to a hui crowd, 231 | > Royale disposed the daffodils. 232 | > A vickers macy shrines and trees, 233 | > And overseeing of a breeze. 234 | 235 | Here is the output after 35 rounds of revision: 236 | 237 | > Another solitary cloud. 238 | > The disappearance of the hills. 239 | > Returning to the evening crowd, 240 | > Alone among the daffodils, 241 | > The flowers drifting through the trees, 242 | > In expectation of a breeze. 243 | 244 | As another example, I asked for an iambic pentameter couplet using the rhyming words "storm" and "form." This time I did not specify any title or author. Here is the result: 245 | 246 | > Results included project solar storm, 247 | > The first proposal for another form. 248 | 249 | ## Modifiers 250 | 251 | I also included a feature that enables you to bias the output toward an arbitrary vocabulary. I tested this out using the data from Iain Barr's [analysis of the vocabulary of heavy metal lyrics](https://github.com/ijmbarr/pythonic-metal). Suppose, for instance, "I Wandered Lonely as a Cloud" is not metal enough for your tastes. Perhaps you would prefer this machine-generated alternative: 252 | 253 | > Rage flooded slowly through her veins 254 | > That burned as cold as the sky and ground, 255 | > Then all at once she dropped her chains, 256 | > And spit, spit hatred underground, 257 | > Into her flesh, into the stone, 258 | > Vibrated and rattled in her bone. 259 | 260 | To use this feature, you can run the `generate_modifer.py` script to analyze the vocabulary 261 | of a given text, then supply the resulting JSON file by adding the parameter 262 | `modifier=json_modifier('filename')` to any of the text rewriter functions. For 263 | instance, I used the DeBERTa xxlarge model to generate a version of Alfred, Lord Tennyson's 264 | "The Kraken" using the vocabulary of the PyTorch library documentation. Here 265 | is the original: 266 | 267 | > Below the thunders of the upper deep, 268 | > Far, far beneath in the abysmal sea, 269 | > His ancient, dreamless, uninvaded sleep 270 | > The Kraken sleepeth: faintest sunlights flee 271 | > About his shadowy sides; above him swell 272 | > Huge sponges of millennial growth and height; 273 | > And far away into the sickly light, 274 | > From many a wondrous grot and secret cell 275 | > Unnumbered and enormous polypi 276 | > Winnow with giant arms the slumbering green. 277 | > There hath he lain for ages, and will lie 278 | > Battening upon huge sea worms in his sleep, 279 | > Until the latter fire shall heat the deep; 280 | > Then once by man and angels to be seen, 281 | > In roaring he shall rise and on the surface die. 282 | 283 | Here is the rewritten version, which replaces the deep sea with deep neural networks: 284 | 285 | > Within the matrix of the hidden deep, 286 | > Deep, deep below in the forgotten sea, 287 | > Through hidden, hidden, universal sleep 288 | > The Matrix passes: transient sunlights flee 289 | > Across its infinite folds; within them swell 290 | > Dense layers of infinity width and height; 291 | > And deep below beyond the transient light, 292 | > From many a hidden torch and hidden cell 293 | > Persistent and persistent arises 294 | > Outside each hidden torch the infinite green. 295 | > Thus has it passed for ever, and will lie 296 | > Heavily within its dense folds in deep sleep, 297 | > Until the hidden matrix can float the deep; 298 | > Then known to none and never to be seen, 299 | > In safety it will float and let the matrix die. 300 | 301 | ## How it works 302 | 303 | This program works with the [BERT](https://arxiv.org/pdf/1810.04805v2.pdf) language model, which is based on the [Transformer](https://arxiv.org/pdf/1706.03762.pdf) architecture. (BERT is related to the [GPT-2](https://openai.com/blog/better-language-models/) model used in [Talk to Transformer](https://talktotransformer.com/), although it uses a different network configuration and data set; whereas GPT-2 is trained on text from the internet, BERT is trained on books and Wikipedia articles.) 304 | 305 | The BERT model is capable of guessing a word that is "masked"—that is, hidden from the model. To pick an example from the [documentation](https://pytorch.org/hub/huggingface_pytorch-transformers/) for the implementation I used, one could enter "Who was Jim Henson? Jim Henson was a [MASK]"; the model predicts that the masked word is "puppeteer." The point of this is to enable the computer to perform question-answering tasks, language modeling standing as a surrogate for more general intelligence. But it is also possible to use the model's predictions to alter an existing text. 306 | 307 | To do this, my program tries masking each word in the text and guessing what word should be in that position. For instance, suppose we are looking at this text: 308 | 309 | > Tyger Tyger, burning bright, in the forests of the night 310 | 311 | We try masking each word in order; for instance, at one point we will end up with this: 312 | 313 | > Tyger Tyger, burning bright, in the [MASK] of the night 314 | 315 | The program uses the neural network to predict what word appears in the masked position, subject to various constraints such as rhyme and meter. In this case, the BERT model guesses "middle," with probability 0.6762. On the other hand, the word that is actually in that position—"forests"—gets probability 0.000076159. We divide the latter by the former to get a score for this potential change: 0.0001126. Since this score happens to be the lowest for any word in the text, the program selects the word "forests" for replacement, giving us this revision: 316 | 317 | > Tyger Tyger, burning bright, in the middle of the night 318 | 319 | The program then repeats this process until there are no more "improvements" to be made. 320 | 321 | To alter the topic of a text, the program adds additional text intended to influence the language model's predictions. 322 | 323 | > \{The following poem is titled flowers: 324 | > ****\} 325 | > Tyger Tyger, burning bright, in the forests of the night 326 | > \{**** 327 | > The preceding poem is by Charles Baudelaire.\} 328 | 329 | I also added similar text to the collection of poems on which I finetuned the model, so that the neural network would learn to recognize this type of annotation. The brackets indicate that the program is not allowed to alter that text. If the "strong topic bias" feature is turned on, the program computes the probabilities both with and without these annotations and computes the scores using the formula ```probability with annotations / probability without annotations ** n```, where n is a factor indicating the strength of the topic bias (0.5 is recommended). In this case, the topic annotations cause the program to produce a different prediction: 330 | 331 | > Tyger Tyger, burning leaves, in the middle of the night 332 | 333 | For more details about how it all works, see the code. 334 | 335 | ## Other experiments 336 | 337 | The program can also perform an alternative procedure that replaces the words in the order in which they appear in the text, rather than choosing which words to replace based on their scores. This is exponentially faster than the default procedure, but the results are generally not as compelling, especially when a topic is specified. This, for instance, is the output for "The Sick Rose": 338 | 339 | > Where else thou be thou. 340 | > The mysterious man, 341 | > Who slept in a bed 342 | > In a stormy night: 343 | > 344 | > He drew from his heart 345 | > His secret friend: 346 | > And whose own secret friend 347 | > Did his life depend. 348 | 349 | The sequential procedure does a bit better with "Jerusalem": 350 | 351 | > And let your God in our hearts 352 | > Shine upon our pastures clean: 353 | > And say the mighty Word of Christ, 354 | > On our fertile pastures green! 355 | > 356 | > In may the Glorious Divine, 357 | > Come down upon many fertile lands? 358 | > And is Jerusalem hidden here, 359 | > Beneath these green Eternal Sands? 360 | > 361 | > Bring I my Spear of shining light: 362 | > Bring I my weapon of devotion: 363 | > Bring I my Sword: Watch it again! 364 | > Bring I my Instrument of motion! 365 | > 366 | > We shall not die of Our Wounds, 367 | > Nor let my heart rest in your hands: 368 | > For I have found Jerusalem, 369 | > And many rich & fertile lands. 370 | 371 | I have also done some experiments with GPT-2, although my word-by-word revision technique does not work with GPT-style models. The problem is that, whereas BERT is able to look both forward and backward when predicting a word, GPT-2, like the Angel of History, only looks backward; accordingly, it is not good at generating words that fit into a pre-existing structure. I did, however, include a function that generates GPT-2 text constrained by the meter and rhyme scheme of a given poem. The results are so prosaic that it is difficult even to detect the rhyme and meter, although the output can, indeed, be read with the same rhythms as the original: 372 | 373 | > Tyger Tyger, also known 374 | > in the English as "the lone 375 | > wolf," created this cat, named 376 | > "T-Rex," by writing poetry 377 | -------------------------------------------------------------------------------- /visions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import pickle 4 | import random 5 | import re 6 | import time 7 | import unicodedata 8 | 9 | import numpy 10 | import scipy 11 | from scipy import sparse 12 | import torch 13 | from transformers import BertTokenizer, BertForMaskedLM 14 | from transformers import DistilBertTokenizer, DistilBertForMaskedLM 15 | from transformers import RobertaTokenizer, RobertaForMaskedLM 16 | from transformers import DebertaTokenizer, DebertaForMaskedLM 17 | from transformers import DebertaV2Tokenizer, DebertaV2ForMaskedLM 18 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 19 | 20 | #import logging 21 | #logging.basicConfig(level=logging.INFO) 22 | 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | 25 | tokenizer = None 26 | model = None 27 | loaded_model_type = None 28 | loaded_model_path = None 29 | 30 | from nltk.corpus import stopwords 31 | from nltk.corpus import cmudict 32 | dictionary = cmudict.dict() 33 | from g2p_en import G2p 34 | g2p = G2p() 35 | 36 | m = torch.nn.Softmax(dim=0) 37 | 38 | re_word = re.compile(r"^[▁a-zA-Z' ]+$") 39 | re_space = re.compile(r"^[ \n\t]+$") 40 | re_vowel = re.compile(r"[aeiouy]") 41 | re_space_and_brackets = re.compile(r"^[\s{}]+$") 42 | def get_pron(tok): 43 | if tok.startswith('madeupword'): 44 | return [] 45 | try: 46 | tok = tokenizer.convert_tokens_to_string([tok]) 47 | except KeyError: 48 | pass 49 | if tok.startswith('##'): 50 | tok = tok[2:] 51 | if tok.startswith(' ') or tok.startswith('▁'): 52 | tok = tok[1:] 53 | if not re_word.match(tok): 54 | # Punctuation 55 | return [] 56 | if tok in dictionary: 57 | pron = dictionary[tok][0] 58 | else: 59 | # Word not in CMU dict: guess using g2p_en 60 | pron = g2p(tok) 61 | return pron 62 | 63 | def get_meter(pron): 64 | if pron == []: 65 | return 'p' 66 | meter = '' 67 | for ph in pron: 68 | # We ignore stress levels in favor of poetic scansion 69 | if ph[-1].isdigit(): 70 | meter += 'u' if ph[-1] == '0' else '-' 71 | return meter 72 | 73 | def get_rhyme(pron): 74 | if pron == []: 75 | return 'p' 76 | rhyme = '' 77 | for ph in reversed(pron): 78 | rhyme = ph.replace('1', '').replace('2', '') + rhyme 79 | if ph[-1].isdigit() and int(ph[-1]) > 0: 80 | break 81 | return rhyme 82 | 83 | def is_word_piece(model, tok): 84 | if model.startswith('bert') or model.startswith('distilbert'): 85 | return tok.startswith('##') 86 | elif model.startswith('microsoft/deberta') and '-v2' in model: 87 | return re_word.match(tok) and not tok.startswith('▁') 88 | elif model.startswith('roberta') or model.startswith('gpt2') or (model.startswith('microsoft/deberta') and '-v2' not in model): 89 | try: 90 | tok = tokenizer.convert_tokens_to_string([tok]) 91 | except ValueError: 92 | pass 93 | except KeyError: 94 | pass 95 | return re_word.match(tok) and not tok.startswith(' ') 96 | def join_word_pieces(toks): 97 | word = '' 98 | for tok in toks: 99 | if tok.startswith('##'): 100 | tok = tok[2:] 101 | word += tok 102 | return word 103 | 104 | def is_full_word(model_type, tok): 105 | if model_type.startswith('bert') or model_type.startswith('distilbert'): 106 | return re_word.match(tok) and not tok.startswith('##') 107 | elif model.startswith('microsoft/deberta') and '-v2' in model: 108 | return re_word.match(tok) and tok.startswith('▁') 109 | elif model_type.startswith('roberta') or model.startswith('gpt2') or (model.startswith('microsoft/deberta') and '-v2' not in model): 110 | try: 111 | tok = tokenizer.convert_tokens_to_string([tok]) 112 | except ValueError: 113 | pass 114 | except KeyError: 115 | pass 116 | return re_word.match(tok) and (tok.startswith(' ') or tok.startswith('▁')) 117 | 118 | def is_punctuation(tok): 119 | if tok == mask_token: 120 | return False 121 | try: 122 | tok = tokenizer.convert_tokens_to_string([tok]) 123 | except ValueError: 124 | pass 125 | except KeyError: 126 | pass 127 | if tok.startswith('▁'): 128 | tok = tok[1:] 129 | return not re_word.match(tok) 130 | 131 | def is_space(tok): 132 | if tok == mask_token: 133 | return False 134 | try: 135 | tok = tokenizer.convert_tokens_to_string([tok]) 136 | except ValueError: 137 | pass 138 | except KeyError: 139 | pass 140 | if tok.startswith('▁'): 141 | tok = tok[1:] 142 | return re_space.match(tok) 143 | 144 | # Scan a text to determine spacing and capitalization so that they can be 145 | # preserved after detokenization. 146 | def scan_tokenization(model, text, toks): 147 | spacing = [] 148 | capitalization = [] 149 | char_idx = 0 150 | tok_idx = 0 151 | tok_char_idx = 0 152 | current_spacing = '' 153 | current_capitalization = None 154 | after_apostrophe = False 155 | after_double_quote = False 156 | start_of_text = True 157 | while char_idx < len(text): 158 | char = text[char_idx] 159 | if char == '{' or char == '}': 160 | char_idx += 1 161 | continue 162 | word_piece = False 163 | try: 164 | tok = toks[tok_idx] 165 | if model.startswith('microsoft/deberta') and '-v2' not in model: 166 | tok = tokenizer.convert_tokens_to_string([tok]) 167 | if (is_word_piece(model, tok) and tok_idx > 0 and not after_double_quote) or tok == "'" or \ 168 | (after_apostrophe and tok in ("s", "d", "st", "ve", "re", "nt", "ll", "t", "m")): 169 | tok = join_word_pieces([tok]) 170 | word_piece = True 171 | except IndexError: 172 | tok = '' 173 | if tok_char_idx == 0 and (tok.startswith('Ġ') or tok.startswith('Ċ') or tok.startswith(' ') or tok.startswith('▁')) and len(tok) > 1: 174 | if char != ' ': 175 | # Advance the counter when the token contains an extraneous space 176 | tok_char_idx += 1 177 | elif current_spacing.endswith('\n') or start_of_text: 178 | # We have to do this because the tokenizer always adds a space to tokens at 179 | # the start of a line, which is stripped out in detokenize(). To account 180 | # for this, we have to add an extra space in cases where a space really does 181 | # exist at the start of a line. 182 | current_spacing += ' ' 183 | try: 184 | tok_char = tok[tok_char_idx] 185 | except IndexError: 186 | tok_char = '' 187 | if tok_char in ('Ġ', 'Ċ', '▁'): 188 | tok_char = ' ' 189 | # print(f'{char_idx}: \'{char}\' \'{tok}\' \'{tok_char}\'{" word_piece" if word_piece else ""}'); time.sleep(0.001) 190 | # RoBERTa uses '▁' for both space and newline. 191 | if not char.isspace() or char == tok_char or tok == '▁': 192 | if tok_char_idx == 1 if (tok.startswith('Ġ') or tok.startswith('Ċ') or tok.startswith(' ') or tok.startswith('▁')) else tok_char_idx == 0: 193 | if char.isupper(): 194 | current_capitalization = 'upper_ambiguous' 195 | else: 196 | current_capitalization = 'lower' 197 | elif current_capitalization in ('upper_ambiguous', 'upper_all'): 198 | if char.isupper(): 199 | current_capitalization = 'upper_all' 200 | else: 201 | current_capitalization = 'upper_initial' 202 | char_idx += 1 203 | start_of_text = False 204 | tok_char_idx += 1 205 | if tok_char_idx == len(tok): 206 | tok_idx += 1 207 | tok_char_idx = 0 208 | after_apostrophe = tok == "'" 209 | after_double_quote = tok in ('"', ' "', '▁"', 'Ġ"') 210 | if not word_piece: 211 | spacing.append(current_spacing) 212 | capitalization.append(current_capitalization) 213 | current_spacing = '' 214 | current_capitalization = None 215 | elif tok_char_idx == 0 or ((tok.startswith('Ġ') or tok.startswith('Ċ') or tok.startswith(' ') or tok.startswith('▁')) and tok_char_idx == 1): 216 | current_spacing += char 217 | char_idx += 1 218 | start_of_text = False 219 | else: 220 | print("WARNING: Text scanner found an unexpected character. This probably indicates a bug in the detokenizer.") 221 | char_idx += 1 222 | print(f'Character {char_idx}: \'{char}\' / token: \'{tok}\' / token character: \'{tok_char}\'{" word_piece" if word_piece else ""}'); time.sleep(0.1) 223 | spacing.append(current_spacing) 224 | return (spacing, capitalization) 225 | 226 | def detokenize(model, toks, spacing, capitalization, html=False, start_of_line=True): 227 | text = '' 228 | i = 0 229 | j = 0 230 | current_capitalization = None 231 | after_apostrophe = False 232 | after_double_quote = False 233 | while i < len(toks): 234 | tok = toks[i] 235 | if model.startswith('microsoft/deberta') and '-v2' not in model: 236 | if tok.startswith('')+1 238 | i2 = tok[1:].index('<') 239 | try: 240 | tok = tok[:i1] + tokenizer.convert_tokens_to_string([tok[i1:i2]]) + tok[i2:] 241 | except ValueError: 242 | pass 243 | except KeyError: 244 | pass 245 | elif tok.startswith('<'): 246 | try: 247 | tok = '<' + tokenizer.convert_tokens_to_string([tok[1:-1]]) + '>' 248 | except ValueError: 249 | pass 250 | except KeyError: 251 | pass 252 | else: 253 | try: 254 | tok = tokenizer.convert_tokens_to_string([tok]) 255 | except ValueError: 256 | pass 257 | except KeyError: 258 | pass 259 | if (is_word_piece(model, tok) and i > 0 and not after_double_quote) or tok == "'" or \ 260 | (after_apostrophe and tok in ("s", "d", "st", "ve", "re", "nt", "ll", "t", "m")): 261 | tok = join_word_pieces([tok]) 262 | if current_capitalization == 'upper_all': 263 | tok = tok.upper() 264 | else: 265 | current_spacing = spacing[j] 266 | tok = tok.replace('Ġ', ' ') 267 | tok = tok.replace('Ċ', '\n') 268 | tok = tok.replace('▁', ' ') 269 | if (i == 0 and start_of_line) or '\n' in current_spacing: 270 | # Remove the extra space created by the tokenizer if we are at the start of a line. 271 | if '< ' in tok: 272 | tok = tok.replace('< ', '<') 273 | if '> ' in tok: 274 | tok = tok.replace('> ', '>') 275 | if tok.startswith(' ') and len(tok) > 1: 276 | tok = tok[1:] 277 | if html: 278 | current_spacing = current_spacing.replace(' ', ' ') 279 | text += current_spacing 280 | current_capitalization = capitalization[j] 281 | if current_capitalization in ('upper_initial', 'upper_ambiguous'): 282 | if tok.startswith('')+1 285 | if tok[i1] == ' ' and i1 < len(tok)-1: 286 | i1 += 1 287 | tok = tok[:i1] + tok[i1].upper() + tok[i1+1:] 288 | elif tok[0] == '<' and tok[-1] == '>': 289 | # Special case for tokens marked as just modified 290 | if tok[1] == ' ' and len(tok) > 2: 291 | tok = tok[0:2] + tok[2].upper() + tok[3:] 292 | else: 293 | tok = tok[0] + tok[1].upper() + tok[2:] 294 | elif tok[0] == ' ' and len(tok) > 1: 295 | tok = tok[0] + tok[1].upper() + tok[2:] 296 | else: 297 | tok = tok[0].upper() + tok[1:] 298 | elif current_capitalization == 'upper_all': 299 | tok = tok.upper() 300 | elif current_capitalization == 'lower': 301 | tok = tok.lower() 302 | j += 1 303 | text += tok 304 | i += 1 305 | after_apostrophe = tok == "'" 306 | after_double_quote = tok in ('"', ' "', '▁"', 'Ġ"') 307 | text += spacing[-1] 308 | return text 309 | 310 | def create_meter_dict(model_type): 311 | print("Generating " + model_type.replace('/', '_') + '_meter_dict.pkl') 312 | vocab = tokenizer.get_vocab() 313 | meter_dict = {} 314 | word_pieces = torch.zeros([vocab_size]) 315 | for tok in vocab: 316 | i = vocab[tok] 317 | pron = get_pron(tok) 318 | meter = get_meter(pron) 319 | if meter not in meter_dict: 320 | meter_dict[meter] = torch.zeros([vocab_size]) 321 | meter_dict[meter][i] = 1.0 322 | if is_word_piece(model_type, tok): 323 | word_pieces[i] = 1.0 324 | 325 | pickle.dump((word_pieces, meter_dict), 326 | open(model_type.replace('/', '_') + '_meter_dict.pkl', 'wb')) 327 | 328 | def create_rhyme_matrix(model_type): 329 | print("Generating " + model_type.replace('/', '_') + '_rhyme_matrix.pkl') 330 | vocab = tokenizer.get_vocab() 331 | rhyme_matrix = sparse.lil_matrix((vocab_size, vocab_size)) 332 | rhymable_words = torch.zeros([vocab_size]) 333 | rhyme_groups = {} 334 | for tok in vocab: 335 | i = vocab[tok] 336 | pron = get_pron(tok) 337 | rhyme = get_rhyme(pron) 338 | if rhyme not in rhyme_groups: 339 | rhyme_groups[rhyme] = [] 340 | rhyme_groups[rhyme].append((i, pron)) 341 | for rhyme in rhyme_groups: 342 | if len(rhyme_groups[rhyme]) < 2: 343 | continue 344 | for i, pron1 in rhyme_groups[rhyme]: 345 | rhymable = False 346 | for j, pron2 in rhyme_groups[rhyme]: 347 | # Words with identical pronunciations can't be used as rhymes 348 | if pron1 != pron2: 349 | rhyme_matrix[i,j] = 1.0 350 | rhymable = True 351 | if rhymable: 352 | rhymable_words[i] = 1.0 353 | 354 | rhyme_matrix = sparse.csc_matrix(rhyme_matrix) 355 | pickle.dump((rhymable_words, rhyme_matrix), open(model_type.replace('/', '_') + '_rhyme_matrix.pkl', 'wb')) 356 | 357 | vocab = None 358 | vocab_size = None 359 | meter_dict = {} 360 | word_pieces = None 361 | rhymable_words = None 362 | rhyme_matrix = None 363 | rhyme_tensors = {} 364 | rhyme_and_meter_loaded = None 365 | def initialize_rhyme_and_meter(model, meter=False, rhymes=False): 366 | global vocab, vocab_size, word_pieces, meter_dict, rhymable_words, rhyme_matrix, rhyme_and_meter_loaded 367 | if rhyme_and_meter_loaded == model: 368 | return 369 | else: 370 | rhyme_and_meter_loaded = model 371 | vocab = tokenizer.get_vocab() 372 | if meter: 373 | try: 374 | f = open(model.replace('/', '_') + '_meter_dict.pkl', 'rb') 375 | except FileNotFoundError: 376 | create_meter_dict(model) 377 | f = open(model.replace('/', '_') + '_meter_dict.pkl', 'rb') 378 | word_pieces, meter_dict = pickle.load(f) 379 | word_pieces = word_pieces.to(device) 380 | meter_dict = {k: v.to(device) for k, v in meter_dict.items()} 381 | else: 382 | try: 383 | f = open(model.replace('/', '_') + '_meter_dict.pkl', 'rb') 384 | except FileNotFoundError: 385 | create_meter_dict(model) 386 | f = open(model.replace('/', '_') + '_meter_dict.pkl', 'rb') 387 | word_pieces, _ = pickle.load(f) 388 | word_pieces = word_pieces.to(device) 389 | if rhymes: 390 | global rhyme_matrix 391 | try: 392 | f = open(model.replace('/', '_') + '_rhyme_matrix.pkl', 'rb') 393 | except FileNotFoundError: 394 | create_rhyme_matrix(model) 395 | f = open(model.replace('/', '_') + '_rhyme_matrix.pkl', 'rb') 396 | rhymable_words, rhyme_matrix = pickle.load(f) 397 | 398 | def initialize_model(model_type, model_path): 399 | global tokenizer, model, loaded_model_type, loaded_model_path, bos_token, eos_token, mask_token, pad_token_id, vocab_size 400 | if loaded_model_type != model_type or loaded_model_path != model_path: 401 | loaded_model_type = model_type 402 | loaded_model_path = model_path 403 | if model_type.startswith('distilbert'): 404 | tokenizer = DistilBertTokenizer.from_pretrained(model_path or model_type) 405 | model = DistilBertForMaskedLM.from_pretrained(model_path or model_type) 406 | bos_token = '[CLS]' 407 | eos_token = '[SEP]' 408 | mask_token = '[MASK]' 409 | elif model_type.startswith('bert'): 410 | tokenizer = BertTokenizer.from_pretrained(model_path or model_type) 411 | model = BertForMaskedLM.from_pretrained(model_path or model_type) 412 | bos_token = '[CLS]' 413 | eos_token = '[SEP]' 414 | mask_token = '[MASK]' 415 | elif model_type.startswith('roberta'): 416 | tokenizer = RobertaTokenizer.from_pretrained(model_path or model_type) 417 | model = RobertaForMaskedLM.from_pretrained(model_path or model_type) 418 | bos_token = tokenizer.bos_token 419 | eos_token = tokenizer.eos_token 420 | mask_token = tokenizer.mask_token 421 | elif model_type.startswith('microsoft/deberta') and '-v2' in model_type: 422 | tokenizer = DebertaV2Tokenizer.from_pretrained(model_path or model_type) 423 | model = DebertaV2ForMaskedLM.from_pretrained(model_path or model_type) 424 | bos_token = tokenizer.cls_token 425 | eos_token = tokenizer.sep_token 426 | mask_token = tokenizer.mask_token 427 | elif model_type.startswith('microsoft/deberta'): 428 | tokenizer = DebertaTokenizer.from_pretrained(model_path or model_type) 429 | model = DebertaForMaskedLM.from_pretrained(model_path or model_type) 430 | bos_token = tokenizer.cls_token 431 | eos_token = tokenizer.sep_token 432 | mask_token = tokenizer.mask_token 433 | vocab_size = model.config.vocab_size 434 | pad_token_id = model.config.pad_token_id 435 | model = torch.nn.DataParallel(model) 436 | model.to(device) 437 | model.eval() 438 | 439 | # Computes the model's predictions for a text with a given set of ranges 440 | # masked. 441 | def compute_probs_for_masked_tokens(model, tokenized_texts, masked_index_lists, batch_size, 442 | replacements_only=False): 443 | indexed_tokens = [] 444 | tensor_indices = {} 445 | wwm_tensor_indices = {} 446 | for j1, (tokenized_text, masked_indices) in enumerate(zip(tokenized_texts, masked_index_lists)): 447 | for j2, masked_index_set in enumerate(masked_indices): 448 | n = len(masked_index_set) 449 | multipart_words = False 450 | tokens = tokenized_text.copy() 451 | for i1, i2 in masked_index_set: 452 | if i2 > i1: 453 | multipart_words = True 454 | if replacements_only: 455 | break 456 | tokens[i1:i2+1] = [mask_token] * (i2 - i1 + 1) 457 | if not replacements_only or not multipart_words: 458 | tensor_indices[(j1, j2)] = len(indexed_tokens) 459 | indexed_tokens.append(tokenizer.convert_tokens_to_ids(tokens)) 460 | # If one of the ranges covers a multipart word, we need to compute probabilities 461 | # both for the text with individual tokens masked and with the whole word masked. 462 | if multipart_words: 463 | wwm_tokens = tokenized_text.copy() 464 | shift = 0 465 | for i1, i2 in masked_index_set: 466 | i1 -= shift 467 | i2 -= shift 468 | shift += (i2 - i1) 469 | wwm_tokens[i1:i2+1] = [mask_token] 470 | wwm_tensor_indices[(j1, j2)] = len(indexed_tokens) 471 | indexed_tokens.append(tokenizer.convert_tokens_to_ids(wwm_tokens)) 472 | 473 | # Add padding so all index sequences are the same length. 474 | max_len = 0 475 | for indices in indexed_tokens: 476 | n = len(indices) 477 | if n > max_len: 478 | max_len = n 479 | attention_mask = [] 480 | for i in range(len(indexed_tokens)): 481 | n = len(indexed_tokens[i]) 482 | if n < max_len: 483 | indexed_tokens[i] = indexed_tokens[i] + [pad_token_id]*(max_len-n) 484 | attention_mask.append([1]*n + [0]*(max_len-n)) 485 | tokens_tensor = torch.tensor(indexed_tokens, device='cpu') 486 | attention_mask = torch.tensor(attention_mask, device='cpu') 487 | 488 | all_predictions = [] 489 | ntexts = tokens_tensor.shape[0] 490 | nbatches = math.ceil(ntexts / batch_size) 491 | for batchnum in range(nbatches): 492 | batch_start = batchnum * batch_size 493 | batch_end = min(batch_start + batch_size, ntexts) 494 | toks_slice = tokens_tensor[batch_start:batch_end].to(device) 495 | mask_slice = attention_mask[batch_start:batch_end].to(device) 496 | with torch.no_grad(): 497 | outputs = model(toks_slice, 498 | attention_mask=mask_slice) 499 | del toks_slice 500 | del mask_slice 501 | all_predictions.append(outputs[0].to('cpu')) 502 | del outputs 503 | del tokens_tensor 504 | del attention_mask 505 | if len(all_predictions) == 0: 506 | return [None]*len(tokenized_texts), [None]*len(tokenized_texts) 507 | 508 | all_probs = [] 509 | all_replacement_probs = [] 510 | for j1, (tokenized_text, masked_indices) in enumerate(zip(tokenized_texts, masked_index_lists)): 511 | probs = [] 512 | replacement_probs = [] 513 | for j2, masked_index_set in enumerate(masked_indices): 514 | n = len(masked_index_set) 515 | multipart_words = (j1, j2) in wwm_tensor_indices 516 | if not replacements_only or not multipart_words: 517 | j = tensor_indices[(j1, j2)] 518 | index_set_probs = [None] * n 519 | for k, (i1, i2) in enumerate(masked_index_set): 520 | word_probs = [] 521 | for i in range(i1, i2+1): 522 | jbatch = j // batch_size 523 | jpreds = j % batch_size 524 | word_probs.append(all_predictions[jbatch][jpreds, i, :]) 525 | index_set_probs[k] = word_probs 526 | if not replacements_only: 527 | probs.append(index_set_probs) 528 | if multipart_words: 529 | j = wwm_tensor_indices[(j1, j2)] 530 | index_set_probs = [None] * n 531 | shift = 0 532 | for k, (i1, i2) in enumerate(masked_index_set): 533 | i1 -= shift 534 | i2 -= shift 535 | shift += (i2 - i1) 536 | jbatch = j // batch_size 537 | jpreds = j % batch_size 538 | index_set_probs[k] = [all_predictions[jbatch][jpreds, i1, :]] 539 | replacement_probs.append(index_set_probs) 540 | all_probs.append(probs) 541 | all_replacement_probs.append(replacement_probs) 542 | for predictions in all_predictions: 543 | del predictions 544 | 545 | if replacements_only: 546 | return [None]*len(tokenized_texts), all_replacement_probs 547 | else: 548 | return all_probs, all_replacement_probs 549 | 550 | # Find words that could, if chosen for the masked indices, take us back to an 551 | # arrangement that has already been tried. Because we are compiling independent 552 | # lists of forbidden words for each index, this method can overcorrect. 553 | def find_forbidden_words(tokenized_text, masked_indices, forbidden_texts): 554 | forbidden_words = [torch.ones((vocab_size,)) 555 | for i in range(len(masked_indices))] 556 | d = forbidden_texts 557 | def f(d, start): 558 | i = start 559 | for tok in tokenized_text[start:]: 560 | mask_num = None 561 | mask_len = 0 562 | for k, (i1, i2) in enumerate(masked_indices): 563 | if i == i1: 564 | mask_num = k 565 | mask_len = i2 - i1 + 1 566 | break 567 | if mask_num is not None: 568 | reached_end = False 569 | for option_tok in d.keys(): 570 | if f(d[option_tok], i+mask_len): 571 | option_idx = tokenizer \ 572 | .convert_tokens_to_ids([option_tok])[0] 573 | forbidden_words[mask_num][option_idx] = 0.0 574 | reached_end = True 575 | return reached_end 576 | else: 577 | if tok in d: 578 | d = d[tok] 579 | else: 580 | return False 581 | i += 1 582 | return True 583 | if f(d, 0): 584 | return forbidden_words 585 | else: 586 | return None 587 | 588 | # Function to adjust the output of the model based on various options. 589 | def adjust_probs(model, probs, tokenized_text, start, end, masked_indices, 590 | modifier=None, match_meter=None, forbidden_texts=None, 591 | random_factor=False, discouraged_words=None, 592 | rhyme_with=None, rhymable_only=False, rhymable_with_meters=False, 593 | allow_punctuation=None, no_word_pieces=False, 594 | strong_topic_bias=False, topicless_probs=None): 595 | 596 | if forbidden_texts is not None: 597 | forbidden_words = find_forbidden_words(tokenized_text, 598 | masked_indices, 599 | forbidden_texts) 600 | else: 601 | forbidden_words = None 602 | 603 | adj_probs = [[u.clone().to(device) for u in t] for t in probs] 604 | for k in range(len(adj_probs)): 605 | for j in range(len(adj_probs[k])): 606 | if random_factor: 607 | noise = torch.randn_like(adj_probs[k][j]) 608 | noise = noise * random_factor + 1.0 609 | adj_probs[k][j] *= noise 610 | 611 | adj_probs[k][j] = m(adj_probs[k][j]) 612 | 613 | # Do not produce word pieces. There is no way to keep the model 614 | # behaving reliably if we allow it to produce words that are not 615 | # actually in its vocabulary. 616 | if no_word_pieces: 617 | adj_probs[k][j] *= (1.0 - word_pieces) 618 | 619 | if rhymable_only: 620 | adj_probs[k][j] *= rhymable_words 621 | if rhymable_with_meters: 622 | for rhyme_meter in rhymable_with_meters: 623 | test_meter = get_meter(get_pron(rhyme_meter)) 624 | meter_tensor = meter_dict[test_meter].to('cpu') 625 | meter_matrix = sparse.dia_matrix((meter_tensor, [0]), 626 | shape=(vocab_size, vocab_size)) 627 | # Take the dot product of meter and rhyme 628 | mat = meter_matrix.dot(rhyme_matrix) 629 | vec = torch.from_numpy(mat.sum(0)).squeeze().to(dtype=bool) 630 | adj_probs[k][j] *= vec.to(device) 631 | 632 | if forbidden_words is not None: 633 | adj_probs[k][j] *= forbidden_words[k].to(device) 634 | 635 | if allow_punctuation is False: 636 | adj_probs[k][j] *= (1.0 - meter_dict['p']) 637 | 638 | if match_meter is not None: 639 | test_meter = get_meter(get_pron(match_meter[k])) 640 | meter_tensor = meter_dict[test_meter] 641 | if allow_punctuation is True: 642 | adj_probs[k][j] *= (meter_tensor + meter_dict['p']) 643 | else: 644 | adj_probs[k][j] *= meter_tensor 645 | 646 | if modifier is not None: 647 | adj_probs[k][j] *= modifier 648 | if discouraged_words is not None: 649 | adj_probs[k][j] *= discouraged_words 650 | 651 | if rhyme_with is not None: 652 | for rhyme_word in rhyme_with: 653 | rhyme_idx = tokenizer.convert_tokens_to_ids([rhyme_word])[0] 654 | rhyme_tensor = rhyme_matrix[rhyme_idx, :].todense() 655 | rhyme_tensor = torch.from_numpy(rhyme_tensor) 656 | rhyme_tensor = rhyme_tensor.squeeze() 657 | adj_probs[k][j] *= rhyme_tensor.to(device) 658 | 659 | if strong_topic_bias: 660 | adj_probs[k][j] /= m(topicless_probs[k][j].to(device)) ** strong_topic_bias 661 | # Sometimes funky scores can arise from this division; we just avoid 662 | # choosing those words. 663 | nan_mask = adj_probs[k][j].isnan() 664 | adj_probs[k][j].masked_fill_(nan_mask, 0.0) 665 | inf_mask = adj_probs[k][j].isinf() 666 | adj_probs[k][j].masked_fill_(inf_mask, 0.0) 667 | 668 | return adj_probs 669 | 670 | # Compute a score indicating how well the model's predictions improve the 671 | # probability for certain words. If multiple words are chosen, it is 672 | # assumed that they are supposed to rhyme. 673 | def compute_score_for_tokens(probs1, probs2, tokenized_text, 674 | indices, require_replacement, relative): 675 | n = len(indices) 676 | dim = [vocab_size] * n 677 | 678 | mask_token_id = tokenizer.convert_tokens_to_ids([mask_token])[0] 679 | 680 | existing_token_ids = [None] * n 681 | for k, (i1, i2) in enumerate(indices): 682 | existing_token_ids[k] = [] 683 | for i in range(i1, i2+1): 684 | token = tokenized_text[i] 685 | index = tokenizer.convert_tokens_to_ids([token])[0] 686 | existing_token_ids[k].append(index) 687 | 688 | existing_words_prob = 1.0 689 | if probs1: 690 | for k in range(n): 691 | existing_word_prob = 0.0 692 | if len(existing_token_ids[k]) > 1: 693 | for i, tok_id in enumerate(existing_token_ids[k]): 694 | prob_tensor = probs1[k][i] 695 | existing_word_prob += prob_tensor[tok_id].log() 696 | existing_word_prob /= len(existing_token_ids[k]) 697 | existing_words_prob *= existing_word_prob.exp() 698 | else: 699 | existing_words_prob = probs1[k][0][existing_token_ids[k]] 700 | 701 | if require_replacement: 702 | for k in range(n): 703 | probs2[k][0][existing_token_ids[k][0]] = 0.0 704 | 705 | if n == 1: 706 | prob_tensor = probs2[0][0] 707 | prediction_prob = torch.max(prob_tensor) 708 | idx = prob_tensor.argmax().item() 709 | predicted_token_ids = [idx] 710 | 711 | elif n == 2: 712 | # We compute scores for possible rhyme pairs using sparse matrix 713 | # arithmetic. We use scipy instead of torch because torch's sparse 714 | # tensors do not support the .max() function. 715 | left_mat = sparse.dia_matrix((probs2[0][0].to('cpu'), [0]), shape=dim) 716 | mat = left_mat.dot(rhyme_matrix) 717 | right_mat = sparse.dia_matrix((probs2[1][0].to('cpu'), [0]), shape=dim) 718 | mat = mat.dot(right_mat) 719 | prediction_prob = mat.max() 720 | idx = mat.argmax() 721 | 722 | # Hack to deal with int32 overflow when the vocab is large 723 | if idx < 0: 724 | idx += 1 << 32 725 | predicted_token_ids = list(numpy.unravel_index(idx, dim)) 726 | while mat[predicted_token_ids[0], predicted_token_ids[1]] < prediction_prob: 727 | idx += 1 << 32 728 | predicted_token_ids = list(numpy.unravel_index(idx, dim)) 729 | 730 | if probs1: 731 | if relative: 732 | score = existing_words_prob / prediction_prob 733 | else: 734 | score = existing_words_prob 735 | else: 736 | score = prediction_prob 737 | 738 | predicted_tokens = [None] * n 739 | for i in range(n): 740 | predicted_tokens[i] \ 741 | = tokenizer.convert_ids_to_tokens([predicted_token_ids[i]])[0] 742 | 743 | return predicted_tokens, float(score) 744 | 745 | # Tokenize a text and figure out (as best we can) its rhyme scheme. 746 | def process_text(model, text, start, end, match_rhyme, strip_punctuation=False): 747 | lines = text.split('\n') 748 | 749 | tok_index = start 750 | toks = [] 751 | rhyme_types = {} 752 | multipart_words = {} 753 | fixed = False 754 | fixed_toks = set() 755 | line_ends = set() 756 | first_line = True 757 | for line in lines: 758 | if (model.startswith('roberta') or model.startswith('gpt2') or (model.startswith('microsoft/deberta') and '-v2' not in model)) and not first_line and not line.startswith(' '): 759 | line = ' ' + line 760 | first_line = False 761 | 762 | # Check for the special '{}' characters that indicate fixed text. 763 | line_new = '' 764 | shift = 0 765 | fixed_chars = set() 766 | for i, ch in enumerate(line): 767 | if (model.startswith('bert') or model.startswith('distilbert') or (model.startswith('microsoft/deberta') and '-v2' in model)) and ch == ' ': 768 | # BERT tokenizer strips spaces, so we must account for that. 769 | shift += 1 770 | if ch == '{': 771 | fixed = True 772 | shift += 1 773 | elif ch == '}': 774 | fixed = False 775 | shift += 1 776 | else: 777 | line_new += ch 778 | if fixed: 779 | fixed_chars.add(i - shift) 780 | 781 | line_toks = tokenizer.tokenize(line_new) 782 | line_fixed_toks = set() 783 | i = 0 784 | for j, tok in enumerate(line_toks): 785 | if model.startswith('microsoft/deberta') and '-v2' not in model: 786 | tok = tokenizer.convert_tokens_to_string([tok]) 787 | if tok.startswith('##'): 788 | tok = tok[2:] 789 | if tok.startswith('▁'): 790 | tok = tok[1:] 791 | nchars = len(tok) 792 | for k in range(nchars): 793 | if i+k in fixed_chars: 794 | line_fixed_toks.add(j + tok_index) 795 | break 796 | i += nchars 797 | 798 | if strip_punctuation: 799 | stripped_line_toks = [] 800 | stripped_fixed_toks = set() 801 | shift = 0 802 | for j, tok in enumerate(line_toks): 803 | if is_punctuation(tok): 804 | shift += 1 805 | else: 806 | stripped_line_toks.append(tok) 807 | if j + tok_index in line_fixed_toks: 808 | stripped_fixed_toks.add(j + tok_index - shift) 809 | line_toks = stripped_line_toks 810 | line_fixed_toks = stripped_fixed_toks 811 | 812 | toks += line_toks 813 | fixed_toks.update(line_fixed_toks) 814 | 815 | # Check for multipart words. 816 | word_bounds = [] 817 | after_apostrophe = False 818 | after_double_quote = False 819 | for i, tok in enumerate(line_toks): 820 | if model.startswith('microsoft/deberta') and '-v2' not in model: 821 | tok = tokenizer.convert_tokens_to_string([tok]) 822 | if (is_word_piece(model, tok) and not after_double_quote) or tok == "'" or \ 823 | (after_apostrophe and tok in ("s", "d", "st", "ve", "re", "nt", "ll", "t", "m")): 824 | if not word_bounds: 825 | word_bounds.append([i, i]) 826 | else: 827 | word_bounds[-1][1] = i 828 | else: 829 | word_bounds.append([i, i]) 830 | after_apostrophe = tok == "'" 831 | after_double_quote = tok in ('"', ' "', '▁"', 'Ġ"') 832 | for i1, i2 in word_bounds: 833 | if i1 == i2: 834 | continue 835 | for i in range(i1, i2+1): 836 | multipart_words[i + tok_index] = (i1 + tok_index, 837 | i2 + tok_index) 838 | 839 | if match_rhyme: 840 | rhyme_type = None 841 | # Only check rhyme for the last non-punctuation word of a line. 842 | word = '' 843 | i = len(line_toks) - 1 844 | while i >= 0: 845 | if i + tok_index in multipart_words: 846 | i1, i2 = multipart_words[i + tok_index] 847 | word = join_word_pieces(line_toks[i1-tok_index:i2-tok_index+1]) 848 | i = multipart_words[i + tok_index][0] - tok_index 849 | else: 850 | word = line_toks[i] 851 | 852 | pron = get_pron(word) 853 | if pron != []: 854 | rhyme_type = get_rhyme(pron) 855 | if rhyme_type is not None: 856 | if not rhyme_type in rhyme_types: 857 | rhyme_types[rhyme_type] = [] 858 | rhyme_types[rhyme_type].append(tok_index + i) 859 | break 860 | 861 | i -= 1 862 | 863 | tok_index += len(line_toks) 864 | line_ends.add(tok_index) 865 | 866 | if match_rhyme: 867 | rhyme_groups = {} 868 | for rhyme in rhyme_types: 869 | tok_list = rhyme_types[rhyme] 870 | # Rhyme groups of more than two not currently supported, so we 871 | # split the groups up into pairs 872 | for i in range(0, len(tok_list), 2): 873 | group = tok_list[i:i+2] 874 | for index in group: 875 | rhyme_groups[index] = group 876 | 877 | return toks, fixed_toks, multipart_words, rhyme_groups, line_ends 878 | 879 | else: 880 | return toks, fixed_toks, multipart_words, {}, line_ends 881 | 882 | # Alters a text iteratively, word by word, using the model to pick 883 | # replacements. 884 | def depoeticize(text, max_iterations=100, batch_size=10, 885 | match_meter=False, match_rhyme=False, title=None, author=None, 886 | randomize=False, cooldown=0.01, modifier=None, 887 | forbid_reversions=True, preserve_punctuation=False, 888 | strong_topic_bias=False, stop_score=1.0, 889 | discourage_repetition=False, stopwords=stopwords.words('english'), 890 | model_type='bert-base-uncased', model_path=None, 891 | preserve_spacing_and_capitalization=True, 892 | allow_punctuation=None, sequential=False, verbose=True, 893 | outfile=None, top_n=10, require_new_rhymes=False, 894 | num_changes_per_iter=1): 895 | stopwords = set(stopwords) 896 | 897 | # The detokenizer doesn't properly handle cases where the input is all space. It is necessary 898 | # to implement good behavior in this case because it can arise when this function is called 899 | # by banalify. 900 | if re_space_and_brackets.match(text): 901 | return text.replace('{', '').replace('}', '') 902 | 903 | # Stripping smart quotes because some of the models don't seem to handle them properly. 904 | text = text.replace('“','"').replace('”','"').replace('‘','\'').replace('’','\'').replace('\r\n', '\n') 905 | 906 | initialize_model(model_type, model_path) 907 | initialize_rhyme_and_meter(model_type, meter=match_meter or allow_punctuation is not None, 908 | rhymes=match_rhyme) 909 | 910 | if modifier is not None: 911 | modifier = modifier().to(device) 912 | 913 | topicless_toks1 = tokenizer.tokenize(f'{bos_token}Title: {mask_token} / Author: {mask_token} {mask_token} / Text: \n\n') 914 | if title and author: 915 | toks1 = tokenizer.tokenize(f'{bos_token}Title: {title} / Author: {author} / Text: \n\n') 916 | elif title: 917 | toks1 = tokenizer.tokenize(f'{bos_token}Title: {title} / Author: {mask_token} {mask_token} / Text: \n\n') 918 | elif author: 919 | toks1 = tokenizer.tokenize(f'{bos_token}Title: {mask_token} / Author: {author} / Text: \n\n') 920 | else: 921 | toks1 = [bos_token] 922 | toks3 = [eos_token] 923 | start = len(toks1) 924 | end = len(toks3) 925 | 926 | toks2, fixed_toks, multipart_words, rhyme_groups, line_ends \ 927 | = process_text(model_type, text, start, end, match_rhyme) 928 | tokenized_text = toks1 + toks2 + toks3 929 | n = len(tokenized_text) 930 | 931 | if preserve_spacing_and_capitalization: 932 | spacing, capitalization = scan_tokenization(model_type, text, toks2) 933 | 934 | forbidden_texts = {} 935 | 936 | if outfile is not None: 937 | outfile = open(outfile, 'w') 938 | html = f''' 939 | 940 | 941 | 942 | 943 | 944 | 945 | 946 | 947 | Model: {model_type}{' "' + model_path + '"' if model_path else ''}
948 | Max iterations: {max_iterations}
949 | ''' 950 | if strong_topic_bias: 951 | html += f'Strong topic bias: {strong_topic_bias}
' 952 | if randomize: 953 | html += f'Randomizing, cooldown={cooldown}
' 954 | if stop_score != 1.0: 955 | html += f'Stop score: {stop_score}
' 956 | if match_meter: 957 | html += 'Matching meter
' 958 | if match_rhyme: 959 | html += 'Matching rhyme
' 960 | if require_new_rhymes: 961 | html += 'Requiring new rhymes
' 962 | if forbid_reversions: 963 | html += 'Forbidding reversions
' 964 | if preserve_punctuation: 965 | html += 'Preserving punctuation
' 966 | if discourage_repetition: 967 | html += 'Discouraging repetition
' 968 | if allow_punctuation is True: 969 | html += 'Always allowing punctuation
' 970 | if allow_punctuation is False: 971 | html += 'Never allowing punctuation
' 972 | if modifier is not None: 973 | html += 'Modifier provided
' 974 | if sequential: 975 | html += 'Running in sequential mode
' 976 | html += '
' 977 | if title: 978 | html += f'Title: {title}
' 979 | if author: 980 | html += f'Author: {author}
' 981 | html += '''
982 | Highlight: 987 | 988 |
Double-click on words to see predictions
''' 989 | 990 | outfile.write(html) 991 | 992 | if sequential: 993 | max_iterations = len(toks2) 994 | new_token_indices = [] 995 | if require_new_rhymes: 996 | original_rhymes = {} 997 | for i in rhyme_groups: 998 | original_rhymes[i] = tokenized_text[i] 999 | for k in range(max_iterations): 1000 | last_score = 0.0 1001 | if verbose: 1002 | iter_start_time = time.time() 1003 | 1004 | if sequential and k >= len(tokenized_text) - start - end: 1005 | break 1006 | 1007 | if require_new_rhymes: 1008 | fallback_indices = None 1009 | fallback_predicted_tokens = None 1010 | fallback_score = None 1011 | 1012 | # Discourage the selection of words already in the text, save for stopwords. 1013 | if discourage_repetition is not False: 1014 | discouraged_words = torch.ones((vocab_size,)) 1015 | for i in range(start, n-end): 1016 | tok = tokenized_text[i] 1017 | if tok in stopwords: 1018 | continue 1019 | idx = tokenizer.convert_tokens_to_ids([tok])[0] 1020 | discouraged_words[idx] = discourage_repetition 1021 | discouraged_words = discouraged_words.to(device) 1022 | else: 1023 | discouraged_words = None 1024 | 1025 | # Compute the scores used to choose which word to change 1026 | outputs = [] 1027 | if sequential: 1028 | test_range = [start + k] 1029 | else: 1030 | test_range = range(start, n-end) 1031 | 1032 | # First, figure out the indices to test for replacements. This is non-trivial because 1033 | # we want to replace multipart words as whole units and rhyme groups together. 1034 | masked_indices = [] 1035 | if strong_topic_bias: 1036 | topicless_masked_indices = [] 1037 | for i in test_range: 1038 | if preserve_punctuation: 1039 | if is_punctuation(tokenized_text[i]): 1040 | continue 1041 | if is_space(tokenized_text[i]): 1042 | continue 1043 | if i in fixed_toks: 1044 | continue 1045 | if i in multipart_words and i != multipart_words[i][0]: 1046 | # Only try the first part of a multipart word 1047 | continue 1048 | 1049 | if match_rhyme and i in rhyme_groups: 1050 | if i != rhyme_groups[i][0]: 1051 | # Only try each rhyme group once 1052 | continue 1053 | indices = rhyme_groups[i] 1054 | else: 1055 | indices = [i] 1056 | 1057 | indices = [multipart_words.get(idx, [idx, idx]) 1058 | for idx in indices] 1059 | masked_indices.append(indices) 1060 | if strong_topic_bias: 1061 | topicless_indices = [(i1-start+len(topicless_toks1), i2-start+len(topicless_toks1)) 1062 | for (i1, i2) in indices] 1063 | topicless_masked_indices.append(topicless_indices) 1064 | 1065 | # Next, run all the predictions in batches. 1066 | if strong_topic_bias: 1067 | topicless_tokenized_text = topicless_toks1 + tokenized_text[start:-end] + toks3 1068 | (all_probs1, all_topicless_probs1), (all_probs2, all_topicless_probs2) \ 1069 | = compute_probs_for_masked_tokens(model, 1070 | (tokenized_text, topicless_tokenized_text), 1071 | (masked_indices, topicless_masked_indices), 1072 | batch_size, 1073 | replacements_only=sequential) 1074 | else: 1075 | (all_probs1,), (all_probs2,) \ 1076 | = compute_probs_for_masked_tokens(model, 1077 | (tokenized_text,), 1078 | (masked_indices,), 1079 | batch_size, 1080 | replacements_only=sequential) 1081 | all_topicless_probs1 = None 1082 | all_topicless_probs2 = None 1083 | 1084 | # Finally, adjust the probabilities and compute the final scores. 1085 | for i, indices in enumerate(masked_indices): 1086 | probs1 = all_probs1 and all_probs1[i] 1087 | probs2 = all_probs2 and all_probs2[i] 1088 | topicless_probs1 = all_topicless_probs1 and all_topicless_probs1[i] 1089 | topicless_probs2 = all_topicless_probs2 and all_topicless_probs2[i] 1090 | 1091 | if match_meter: 1092 | meter = [join_word_pieces(tokenized_text[i1:i2+1]) 1093 | for (i1, i2) in indices] 1094 | else: 1095 | meter = None 1096 | 1097 | if require_new_rhymes and len(indices) > 1: 1098 | require_replacement = True 1099 | for i1, i2 in indices: 1100 | if i1 in original_rhymes and original_rhymes[i1] != tokenized_text[i1]: 1101 | require_replacement = False 1102 | break 1103 | else: 1104 | require_replacement = False 1105 | 1106 | raw_probs = m(probs2[0][0]).to('cpu') 1107 | raw_topicless_probs = topicless_probs2 and m(topicless_probs2[0][0]).to('cpu') 1108 | if not sequential: 1109 | probs1 = adjust_probs(model, probs1, tokenized_text, start, 1110 | end, indices, modifier, 1111 | random_factor=randomize, 1112 | allow_punctuation=allow_punctuation, 1113 | strong_topic_bias=strong_topic_bias, 1114 | topicless_probs=strong_topic_bias and topicless_probs1) 1115 | probs2 = adjust_probs(model, probs2, tokenized_text, start, 1116 | end, indices, modifier, 1117 | meter, forbidden_texts if num_changes_per_iter == 1 else {}, 1118 | discouraged_words=discouraged_words, 1119 | random_factor=randomize, 1120 | allow_punctuation=allow_punctuation, 1121 | no_word_pieces=True, 1122 | strong_topic_bias=strong_topic_bias, 1123 | topicless_probs=strong_topic_bias and topicless_probs2) 1124 | 1125 | adjusted_probs = probs2[0][0].to('cpu') 1126 | predicted_tokens, score \ 1127 | = compute_score_for_tokens(probs1, probs2, 1128 | tokenized_text, indices, 1129 | relative=True, 1130 | require_replacement=require_replacement) 1131 | 1132 | if require_replacement: 1133 | fallback_indices = indices 1134 | fallback_predicted_tokens = predicted_tokens 1135 | fallback_score = score 1136 | outputs.append((indices, predicted_tokens, score, raw_topicless_probs, 1137 | raw_probs, adjusted_probs)) 1138 | 1139 | del all_probs1 1140 | del all_probs2 1141 | del all_topicless_probs1 1142 | del all_topicless_probs2 1143 | 1144 | # Output an HTML visualization. 1145 | if outfile is not None: 1146 | vals = {} 1147 | min_entropy = float("inf") 1148 | max_entropy = 0 1149 | min_score = float("inf") 1150 | max_score = 0 1151 | for indices, predicted_tokens, score, probs1, probs2, probs3 in outputs: 1152 | def get_entropy(probs): 1153 | if probs is None: 1154 | return 0 1155 | else: 1156 | return scipy.stats.entropy(probs.cpu()) 1157 | entropy1 = get_entropy(probs1) 1158 | entropy2 = get_entropy(probs2) 1159 | entropy3 = get_entropy(probs3) 1160 | if title is not None: 1161 | selected_entropy = entropy1 1162 | else: 1163 | selected_entropy = entropy2 1164 | if selected_entropy < min_entropy: 1165 | min_entropy = selected_entropy 1166 | if selected_entropy > max_entropy: 1167 | max_entropy = selected_entropy 1168 | if score == float("inf"): 1169 | score_val = 0 1170 | elif score == 0: 1171 | score_val = -float("inf") 1172 | else: 1173 | score_val = -math.log(score) 1174 | if score_val < min_score: 1175 | min_score = score_val 1176 | if score_val > max_score: 1177 | max_score = score_val 1178 | for i1, i2 in indices: 1179 | for i in range(i1, i2+1): 1180 | vals[i] = (entropy1, entropy2, entropy3, probs1, probs2, probs3, score_val, predicted_tokens) 1181 | 1182 | html = "
" 1183 | viz_toks = [] 1184 | for i in range(start, len(tokenized_text)-end): 1185 | if i in vals: 1186 | entropy1, entropy2, entropy3, probs1, probs2, probs3, score_val, predicted_tokens = vals[i] 1187 | else: 1188 | entropy1, entropy2, entropy3, probs1, probs2, probs3, score_val, predicted_tokens = 0.0, 0.0, 0.0, None, None, None, 0.0, None 1189 | if i in multipart_words: 1190 | i1, i2 = multipart_words[i] 1191 | if i > i1: 1192 | continue 1193 | else: 1194 | i1 = i 1195 | i2 = i 1196 | s = tokenizer.convert_tokens_to_string(tokenized_text[i1:i2+1]).replace(" ' ", "'") 1197 | if tokenized_text[i1][0] in ('Ġ', 'Ċ', '▁', ' '): 1198 | s = ' ' + s 1199 | if max_entropy == min_entropy: 1200 | entropy_relative = 0.0 1201 | else: 1202 | if title is not None: 1203 | selected_entropy = entropy1 1204 | else: 1205 | selected_entropy = entropy2 1206 | entropy_relative = (selected_entropy - min_entropy) / (max_entropy - min_entropy) 1207 | if max_score == min_score: 1208 | score_relative = 0.0 1209 | else: 1210 | score_relative = (score_val - min_score) / (max_score - min_score) 1211 | changed = i in new_token_indices 1212 | changed = " changed-tok" if changed else "" 1213 | raw_probs = probs1 1214 | raw_topicless_probs = probs2 1215 | adjusted_probs = probs3 1216 | def get_top(probs): 1217 | if probs is None: 1218 | return 'null' 1219 | out = torch.topk(probs, top_n) 1220 | top_options = zip(out.indices, out.values) 1221 | top_options = [(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([j])).replace(' ', ''), 1222 | float(p)) 1223 | for j, p in top_options] 1224 | top_options = json.dumps(top_options) 1225 | return top_options 1226 | options1 = get_top(raw_probs) 1227 | options2 = get_top(raw_topicless_probs) 1228 | options3 = get_top(adjusted_probs) 1229 | if predicted_tokens is not None: 1230 | replacement_tokens = json.dumps([tokenizer.convert_tokens_to_string([s]) for s in predicted_tokens]) 1231 | else: 1232 | replacement_tokens = 'null' 1233 | viz_toks.append(f"{s}") 1234 | if preserve_spacing_and_capitalization: 1235 | html += detokenize(model_type, viz_toks, spacing, capitalization, html=True) 1236 | else: 1237 | html += tokenizer.clean_up_tokenization(tokenizer.convert_tokens_to_string(viz_toks)) 1238 | html = html.replace('\n', '
') 1239 | html += "

\n" 1240 | 1241 | outfile.write(html) 1242 | outfile.flush() 1243 | 1244 | # Choose words to change 1245 | outputs.sort(key=lambda t: t[2]) 1246 | chosen_index_lists = [] 1247 | chosen_token_lists = [] 1248 | i = 0 1249 | for indices, predicted_tokens, score, _, _, _ in outputs: 1250 | if score >= stop_score or i >= num_changes_per_iter: 1251 | break 1252 | if predicted_tokens is None: 1253 | continue 1254 | i += 1 1255 | chosen_index_lists.append(indices) 1256 | chosen_token_lists.append(predicted_tokens) 1257 | if i == 1: 1258 | lowest_score = score 1259 | 1260 | if not chosen_index_lists: 1261 | if sequential: 1262 | continue 1263 | elif require_new_rhymes and fallback_indices is not None: 1264 | chosen_index_lists = [fallback_indices] 1265 | chosen_token_lists = [fallback_predicted_tokens] 1266 | last_score = fallback_score 1267 | else: 1268 | break 1269 | 1270 | # To prevent loops, we forbid the model from reverting to texts that it 1271 | # has already tried. The texts are stored in a trie (prefix tree) for 1272 | # efficient searchability. 1273 | if forbid_reversions: 1274 | d = forbidden_texts 1275 | for tok in tokenized_text: 1276 | if not tok in d: 1277 | d[tok] = {} 1278 | d = d[tok] 1279 | 1280 | # Make the actual revision and make note of what we've done. 1281 | change_made = False 1282 | new_token_indices = [] 1283 | new_tokenized_text = tokenized_text.copy() 1284 | for change_num in range(len(chosen_index_lists)): 1285 | chosen_indices = chosen_index_lists[change_num] 1286 | chosen_tokens = chosen_token_lists[change_num] 1287 | shift = 0 1288 | for j, (i1, i2) in enumerate(chosen_indices): 1289 | i1 -= shift 1290 | i2 -= shift 1291 | shift += (i2 - i1) 1292 | n -= (i2 - i1) 1293 | token = chosen_tokens[j] 1294 | if i2 > i1: 1295 | change_made = True 1296 | new_tokenized_text[i1:i2+1] = [token] 1297 | new_token_indices.append(i1) 1298 | elif tokenized_text[i1] != token: 1299 | change_made = True 1300 | new_tokenized_text[i1] = token 1301 | new_token_indices.append(i1) 1302 | 1303 | fixed_toks_new = set() 1304 | for fixed_tok in fixed_toks: 1305 | if fixed_tok > i1 and fixed_tok <= i2: 1306 | pass 1307 | elif fixed_tok > i2: 1308 | fixed_toks_new.add(fixed_tok - (i2 - i1)) 1309 | else: 1310 | fixed_toks_new.add(fixed_tok) 1311 | fixed_toks = fixed_toks_new 1312 | 1313 | for change_num2 in range(len(chosen_index_lists)): 1314 | replacement = [] 1315 | for i, (j1, j2) in enumerate(chosen_index_lists[change_num2]): 1316 | if j1 > i2: 1317 | replacement.append((j1 - (i2 - i1), 1318 | j2 - (i2 - i1))) 1319 | else: 1320 | replacement.append((j1, j2)) 1321 | chosen_index_lists[change_num2] = replacement 1322 | 1323 | for i in range(i1, i2+1): 1324 | if i in multipart_words: 1325 | del multipart_words[i] 1326 | replacements = {} 1327 | for i in list(multipart_words.keys()): 1328 | if i > i2: 1329 | j1, j2 = multipart_words[i] 1330 | del multipart_words[i] 1331 | replacements[i - (i2 - i1)] = (j1 - (i2 - i1), 1332 | j2 - (i2 - i1)) 1333 | for i in replacements: 1334 | multipart_words[i] = replacements[i] 1335 | 1336 | replacements = {} 1337 | for i_old in list(rhyme_groups.keys()): 1338 | group = rhyme_groups[i_old].copy() 1339 | if i_old > i1: 1340 | i_new = i_old - (i2 - i1) 1341 | else: 1342 | i_new = i_old 1343 | group = [(idx - (i2 - i1) if idx > i1 else idx) 1344 | for idx in group] 1345 | replacements[i_new] = group 1346 | rhyme_groups = replacements 1347 | 1348 | if require_new_rhymes: 1349 | replacements = {} 1350 | for i_old in list(original_rhymes.keys()): 1351 | rhyme_tok = original_rhymes[i_old] 1352 | if i_old > i1: 1353 | i_new = i_old - (i2 - i1) 1354 | else: 1355 | i_new = i_old 1356 | replacements[i_new] = rhyme_tok 1357 | original_rhymes = replacements 1358 | 1359 | if forbid_reversions and num_changes_per_iter > 1: 1360 | # There's no clear way to implement this when choosing tokens if we 1361 | # are going to change multiple tokens at once, so we just stop 1362 | # when we reach a reversion. 1363 | d = forbidden_texts 1364 | for tok in new_tokenized_text: 1365 | if tok in d: 1366 | d = d[tok] 1367 | else: 1368 | break 1369 | else: 1370 | change_made = False 1371 | 1372 | if change_made: 1373 | tokenized_text = new_tokenized_text 1374 | else: 1375 | if sequential: 1376 | continue 1377 | else: 1378 | break 1379 | 1380 | if verbose: 1381 | iter_end_time = time.time() 1382 | sample = tokenized_text[start:-end].copy() 1383 | for i in new_token_indices: 1384 | sample[i-start] = '<' + sample[i-start] + '>' 1385 | if preserve_spacing_and_capitalization: 1386 | text = detokenize(model_type, sample, spacing, capitalization) 1387 | else: 1388 | text = tokenizer.clean_up_tokenization(tokenizer.convert_tokens_to_string(sample)) 1389 | print('-----------------------') 1390 | print('Iteration {0}, lowest score = {1}, running time = {2}s'.format(k+1, lowest_score, 1391 | iter_end_time - iter_start_time)) 1392 | print(text) 1393 | 1394 | if randomize and cooldown: 1395 | randomize *= (1.0 - cooldown) 1396 | 1397 | if outfile is not None: 1398 | outfile.write("\n") 1399 | outfile.flush() 1400 | 1401 | if preserve_spacing_and_capitalization: 1402 | text = detokenize(model_type, tokenized_text[start:-end], spacing, capitalization) 1403 | else: 1404 | text = tokenizer.clean_up_tokenization(tokenizer.convert_tokens_to_string(tokenized_text[start:-end])) 1405 | 1406 | return text 1407 | 1408 | # Generates a wholly new text by running a decoder model forward with the specified 1409 | # constraints. This doesn't work very well. 1410 | def parody(text, match_meter=False, match_rhyme=False, topic=None, 1411 | randomize=False, modifier=None, verbose=True, 1412 | topic_prefix="", model='gpt2'): 1413 | model_type = model 1414 | 1415 | if modifier is not None: 1416 | modifier = modifier() 1417 | 1418 | global tokenizer 1419 | tokenizer = GPT2Tokenizer.from_pretrained(model_type) 1420 | model = GPT2LMHeadModel.from_pretrained(model_type) 1421 | model.to(device) 1422 | model.eval() 1423 | eos_token = tokenizer.eos_token 1424 | 1425 | initialize_rhyme_and_meter(model_type, meter=True, rhymes=match_rhyme) 1426 | eol_token = tokenizer.convert_tokens_to_ids(['Ġ'])[0] 1427 | 1428 | if topic: 1429 | toks1 = tokenizer.tokenize("{0} {1} {2}. " 1430 | .format(eos_token, topic_prefix, topic)) 1431 | else: 1432 | toks1 = [eos_token] 1433 | start = len(toks1) 1434 | 1435 | # We strip punctuation because, not being able to look ahead, the GPT-2 1436 | # model cannot reliably produce text that matches the punctuation of the 1437 | # original; the only way to get coherent output is to let the model decide 1438 | # on the punctuation. 1439 | toks2, fixed_toks, multipart_words, rhyme_groups, line_ends \ 1440 | = process_text(model_type, text, start, 0, match_rhyme, 1441 | strip_punctuation=True) 1442 | 1443 | tokenized_text = toks1 + toks2 1444 | n = len(tokenized_text) 1445 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 1446 | 1447 | # As in Beckett's "The Unnamable," we force the model to keep writing 1448 | # even when it wants to stop. 1449 | discouraged_words = torch.ones((vocab_size,)).to(device) 1450 | eos_token_id = tokenizer.convert_tokens_to_ids([eos_token])[0] 1451 | discouraged_words[eos_token_id] = 0.0 1452 | newline_token_id = tokenizer.convert_tokens_to_ids(['\n'])[0] 1453 | discouraged_words[newline_token_id] = 0.0 1454 | 1455 | out_toks = indexed_tokens[:start] 1456 | i = start 1457 | just_added_punctuation = False 1458 | just_rhymed = False 1459 | while i < n: 1460 | if i in fixed_toks: 1461 | tok = indexed_tokens[i] 1462 | out_toks.append(tok) 1463 | tok = tokenizer.convert_ids_to_tokens([tok])[0] 1464 | just_added_punctuation = is_punctuation(tok) 1465 | i += 1 1466 | continue 1467 | 1468 | if indexed_tokens[i] == eol_token: 1469 | out_toks.append(eol_token) 1470 | i += 1 1471 | continue 1472 | 1473 | if match_rhyme and i in rhyme_groups: 1474 | rhyming = True 1475 | # We can't look ahead with this model, so the rhyming constraint 1476 | # only looks at words already chosen. 1477 | rhyme_words = [tokenizer.convert_ids_to_tokens([indexed_tokens[idx]])[0] 1478 | for idx in rhyme_groups[i] if idx < i] 1479 | # ...but we do need to make sure to choose a word that can be rhymed 1480 | # with at least one word of the meter of later rhyming words. 1481 | rhyme_meters = [tokenizer.convert_ids_to_tokens([indexed_tokens[idx]])[0] 1482 | for idx in rhyme_groups[i] if idx > i] 1483 | else: 1484 | rhyming = False 1485 | rhyme_words = None 1486 | rhyme_meters = None 1487 | 1488 | i1, i2 = multipart_words.get(i, [i, i]) 1489 | if match_meter: 1490 | meter = [join_word_pieces(tokenized_text[i1:i2+1])] 1491 | else: 1492 | meter = None 1493 | 1494 | with torch.no_grad(): 1495 | tokens_tensor = torch.tensor([out_toks]).to(device) 1496 | outputs = model(tokens_tensor) 1497 | predictions = outputs[0] 1498 | 1499 | no_word_pieces = (i == start) or rhyme_words or just_added_punctuation or just_rhymed 1500 | 1501 | probs = [[predictions[0, -1, :]]] 1502 | probs = adjust_probs(model, probs, None, 0, 0, None, 1503 | modifier, meter, 1504 | random_factor=randomize, 1505 | discouraged_words=discouraged_words, 1506 | allow_punctuation=not just_added_punctuation and not rhyme_words, 1507 | no_word_pieces=no_word_pieces, 1508 | rhyme_with=rhyme_words, 1509 | rhymable_only=not match_meter and rhyming, 1510 | rhymable_with_meters=match_meter and rhyme_meters) 1511 | 1512 | idx = probs[0][0].argmax().item() 1513 | if idx == eos_token_id: 1514 | break 1515 | 1516 | tok = tokenizer.convert_ids_to_tokens([idx])[0] 1517 | out_toks.append(idx) 1518 | 1519 | just_rhymed = not not rhyme_words 1520 | 1521 | # Only proceed to the next input token if the output is a 1522 | # non-punctuation token. 1523 | if meter_dict['p'][idx] == 0.0: 1524 | if verbose and i in line_ends: 1525 | print('') 1526 | # Record the chosen token for rhyming purposes. 1527 | indexed_tokens[i] = idx 1528 | i += i2 - i1 + 1 1529 | just_added_punctuation = False 1530 | else: 1531 | # We don't allow multiple punctuation tokens in a row. This is 1532 | # because the model can potentially get stuck in a loop where it 1533 | # generates nothing but punctuation, in which case the process 1534 | # would never end. 1535 | just_added_punctuation = True 1536 | 1537 | if verbose: 1538 | string = tokenizer.convert_tokens_to_string([tok]) 1539 | print(string, end='') 1540 | 1541 | if verbose: 1542 | print('') 1543 | out = tokenizer.convert_ids_to_tokens(out_toks[start:]) 1544 | text = tokenizer.convert_tokens_to_string(out) 1545 | return tokenizer.clean_up_tokenization(text) 1546 | 1547 | # Add modifier=json_modifier('') to bias the results in favor of certain 1548 | # words, as read from a JSON file. The file should contain an object mapping words to 1549 | # numbers. You can generate file like this using generate_modifier.py. Lower the 1550 | # factor parameter to decrease the effect. 1551 | def json_modifier(filename, factor=1.0, default_score=-10.0): 1552 | def modifier(): 1553 | f = open(filename, 'r') 1554 | scores = json.load(f) 1555 | f.close() 1556 | vocab = tokenizer.get_vocab() 1557 | score_vector = [default_score] * vocab_size 1558 | for tok in vocab: 1559 | i = vocab[tok] 1560 | if tok.startswith('Ġ') or tok.startswith('Ċ') or tok.startswith(' ') or tok.startswith('▁'): 1561 | tok = tok[1:] 1562 | tok = tok.lower() 1563 | if tok in scores: 1564 | score_vector[i] = scores[tok] 1565 | score_tensor = m(torch.tensor(score_vector)) 1566 | mean_score = 1.0 / score_tensor.shape[0] 1567 | return (1.0 - factor) * mean_score + factor * score_tensor 1568 | return modifier 1569 | 1570 | # Add modifier=metalness_modifier to bias the results toward words that occur 1571 | # frequently in heavy metal lyrics. First you will need to download the data set 1572 | # available at https://github.com/ijmbarr/pythonic-metal. 1573 | metalness_modifier = json_modifier('metalness.json') 1574 | 1575 | # Depoeticizes a text piece by piece examining only an n-word window at a time, with 1576 | # a certain amount of context to the left and right. This procedure can handle longer 1577 | # texts than depoeticize(). 1578 | def banalify(text, window_size=10, context_size=10, batch_size=10, 1579 | max_iterations=100, match_meter=False, match_rhyme=False, 1580 | title=None, author=None, randomize=False, cooldown=0.01, modifier=None, 1581 | forbid_reversions=True, preserve_punctuation=False, 1582 | strong_topic_bias=False, stop_score=1.0, 1583 | discourage_repetition=False, stopwords=stopwords.words('english'), 1584 | model_type='bert-base-uncased', model_path=None, 1585 | allow_punctuation=None, sequential=False, verbose=True): 1586 | initialize_model(model_type, model_path) 1587 | initialize_rhyme_and_meter(model_type, meter=match_meter or allow_punctuation is not None, 1588 | rhymes=match_rhyme) 1589 | 1590 | # Stripping characters that some of the models don't seem to handle properly. 1591 | text = text.replace('“','"').replace('”','"').replace('‘','\'').replace('’','\'').replace('\r', '').replace('—', '').replace('…', '...').replace(' ', ' ') 1592 | text = "".join([c for c in unicodedata.normalize('NFKD', text) if not unicodedata.combining(c)]) 1593 | 1594 | toks, fixed_toks, _, _, _ = process_text(model_type, text, 0, 0, False, False) 1595 | spacing, capitalization = scan_tokenization(model_type, text, toks) 1596 | 1597 | if model_type.startswith('microsoft/deberta') and '-v2' not in model_type: 1598 | open_bracket = tokenizer.tokenize('{')[0] 1599 | close_bracket = tokenizer.tokenize('}')[0] 1600 | else: 1601 | open_bracket = '{' 1602 | close_bracket = '}' 1603 | 1604 | out_text = '' 1605 | left_context_toks = [] 1606 | left_context_text = '' 1607 | left_context_size = 0 1608 | 1609 | i = 0 1610 | spacing_idx = 0 1611 | bracket_left_open = False 1612 | while i < len(toks): 1613 | # Count the current_window 1614 | window_end = i 1615 | num_non_word_pieces = 0 1616 | bracket_indices = set() 1617 | spaced_bracket_indices = set() 1618 | j = 0 1619 | after_apostrophe = False 1620 | after_double_quote = False 1621 | while j < window_size and window_end < len(toks): 1622 | tok = toks[window_end] 1623 | if model_type.startswith('microsoft/deberta') and '-v2' not in model_type: 1624 | tok = tokenizer.convert_tokens_to_string([tok]) 1625 | if not ((is_word_piece(model_type, tok) and not after_double_quote) or tok == "'" or \ 1626 | (after_apostrophe and tok in ("s", "d", "st", "ve", "re", "nt", "ll", "t", "m"))): 1627 | j += 1 1628 | num_non_word_pieces += 1 1629 | window_end += 1 1630 | after_apostrophe = tok == "'" 1631 | after_double_quote = tok in ('"', ' "', '▁"', 'Ġ"') 1632 | 1633 | # Extend the window if it ends in the middle of a word 1634 | after_apostrophe = False 1635 | after_double_quote = False 1636 | while window_end < len(toks): 1637 | tok = toks[window_end] 1638 | if model_type.startswith('microsoft/deberta') and '-v2' not in model_type: 1639 | tok = tokenizer.convert_tokens_to_string([tok]) 1640 | if (is_word_piece(model_type, tok) and not after_double_quote) or tok == "'" or \ 1641 | (after_apostrophe and tok in ("s", "d", "st", "ve", "re", "nt", "ll", "t", "m")): 1642 | window_end += 1 1643 | else: 1644 | break 1645 | after_apostrophe = tok == "'" 1646 | after_double_quote = tok in ('"', ' "', '▁"', 'Ġ"') 1647 | 1648 | window_toks = toks[i:window_end] 1649 | window_spacing = spacing[spacing_idx:spacing_idx+num_non_word_pieces] + [''] 1650 | window_capitalization = capitalization[spacing_idx:spacing_idx+num_non_word_pieces] 1651 | 1652 | # Add { and } around fixed text within the window 1653 | bracketed_window_toks = [] 1654 | fixed = False 1655 | for k, tok in enumerate(window_toks): 1656 | tok_idx = i + k 1657 | if tok_idx in fixed_toks and not fixed: 1658 | fixed = True 1659 | bracketed_window_toks.append(open_bracket) 1660 | if tok_idx not in fixed_toks and fixed: 1661 | fixed = False 1662 | bracketed_window_toks.append(close_bracket) 1663 | bracketed_window_toks.append(tok) 1664 | if fixed: 1665 | bracketed_window_toks.append(close_bracket) 1666 | window_text = tokenizer.convert_tokens_to_string(bracketed_window_toks) 1667 | 1668 | # Count the right context, as above 1669 | right_context_end = window_end 1670 | j = 0 1671 | after_apostrophe = False 1672 | after_double_quote = False 1673 | while j < context_size and right_context_end < len(toks): 1674 | tok = toks[right_context_end] 1675 | if model_type.startswith('microsoft/deberta') and '-v2' not in model_type: 1676 | tok = tokenizer.convert_tokens_to_string([tok]) 1677 | if not ((is_word_piece(model_type, tok) and not after_double_quote) or tok == "'" or \ 1678 | (after_apostrophe and tok in ("s", "d", "st", "ve", "re", "nt", "ll", "t", "m"))): 1679 | j += 1 1680 | right_context_end += 1 1681 | after_apostrophe = tok == "'" 1682 | after_double_quote = tok in ('"', ' "', '▁"', 'Ġ"') 1683 | after_apostrophe = False 1684 | after_double_quote = False 1685 | while right_context_end < len(toks): 1686 | tok = toks[right_context_end] 1687 | if model_type.startswith('microsoft/deberta') and '-v2' not in model_type: 1688 | tok = tokenizer.convert_tokens_to_string([tok]) 1689 | if (is_word_piece(model_type, tok) and not after_double_quote) or tok == "'" or \ 1690 | (after_apostrophe and tok in ("s", "d", "st", "ve", "re", "nt", "ll", "t", "m")): 1691 | right_context_end += 1 1692 | else: 1693 | break 1694 | after_apostrophe = tok == "'" 1695 | after_double_quote = tok in ('"', ' "', '▁"', 'Ġ"') 1696 | 1697 | right_context_toks = toks[window_end:right_context_end] 1698 | right_context_text = tokenizer.convert_tokens_to_string(right_context_toks) 1699 | if model_type.startswith('bert') or (model_type.startswith('microsoft/deberta') and '-v2' in model_type): 1700 | maybe_space = ' ' 1701 | else: 1702 | maybe_space = '' 1703 | if left_context_text: 1704 | contextualized_text = f'{{{left_context_text}}}{maybe_space}{window_text}{maybe_space}{{{right_context_text}}}' 1705 | else: 1706 | contextualized_text = f'{window_text}{maybe_space}{{{right_context_text}}}' 1707 | 1708 | contextualized_text = depoeticize(contextualized_text, max_iterations, batch_size, 1709 | match_meter, match_rhyme, title, author, 1710 | randomize, cooldown, modifier, 1711 | forbid_reversions, preserve_punctuation, 1712 | strong_topic_bias, stop_score, 1713 | discourage_repetition, stopwords, 1714 | model_type, model_path, False, 1715 | allow_punctuation, sequential, verbose) 1716 | 1717 | # Trim off the previous and next windows from the output 1718 | window_toks = tokenizer.tokenize(contextualized_text + ' x')[:-1] 1719 | if model_type.startswith('microsoft/deberta') and '-v2' in model_type and tokenizer.convert_tokens_to_string(window_toks[:2]) in ('."', ".'"): 1720 | # Special case involving how the DeBERTa v2 tokenizer handles '."' 1721 | window_toks = window_toks[1:] 1722 | window_toks[0] = tokenizer.tokenize('."')[0] 1723 | window_toks = window_toks[len(left_context_toks) if left_context_toks else 0: 1724 | -len(right_context_toks) if right_context_toks else None] 1725 | 1726 | start_of_line = out_text == '' or out_text.endswith('\n') 1727 | window_text = detokenize(model_type, window_toks, window_spacing, window_capitalization, start_of_line=start_of_line) 1728 | out_text += window_text 1729 | print(window_text, end='') 1730 | 1731 | # Advance the window and the end of the left context 1732 | left_context_toks += window_toks 1733 | #print(left_context_toks) 1734 | #print(left_context_size) 1735 | #print(window_size) 1736 | left_context_size = len(left_context_toks) 1737 | spacing_idx += num_non_word_pieces - len(bracket_indices) 1738 | i = window_end 1739 | 1740 | # Advance the beginning of the left context 1741 | after_apostrophe = False 1742 | after_double_quote = False 1743 | while left_context_size > context_size: 1744 | left_context_toks = left_context_toks[1:] 1745 | tok = left_context_toks[0] 1746 | if model_type.startswith('microsoft/deberta') and '-v2' not in model_type: 1747 | tok = tokenizer.convert_tokens_to_string([tok]) 1748 | if not ((is_word_piece(model_type, tok) and not after_double_quote) or tok == "'" or \ 1749 | (after_apostrophe and tok in ("s", "d", "st", "ve", "re", "nt", "ll", "t", "m"))): 1750 | left_context_size -= 1 1751 | after_apostrophe = tok == "'" 1752 | after_double_quote = tok in ('"', ' "', '▁"', 'Ġ"') 1753 | after_apostrophe = False 1754 | after_double_quote = False 1755 | while left_context_toks and ( 1756 | (is_word_piece(model_type, left_context_toks[0]) and not after_double_quote) 1757 | or left_context_toks[0] == "'" or 1758 | (after_apostrophe and left_context_toks[0] in ("s", "d", "st", "ve", "re", "nt", "ll", "t", "m")) 1759 | ): 1760 | after_apostrophe = left_context_toks[0] == "'" 1761 | after_double_quote = left_context_toks[0] in ('"', ' "', '▁"', 'Ġ"') 1762 | left_context_toks = left_context_toks[1:] 1763 | left_context_text = tokenizer.convert_tokens_to_string(left_context_toks) 1764 | return out_text 1765 | 1766 | # Bouts-rimés (rhymed ends) is an old French pastime in which one person selects 1767 | # series of rhyming words and another person composes a poem using them. This 1768 | # function gets the depoeticizer to play this game by generating metered verse. 1769 | def bouts_rimés(rhymes, meter='u-u-u-u-u-', 1770 | max_iterations=100, title=None, author=None, 1771 | randomize=0.5, cooldown=0.3, modifier=None, 1772 | forbid_reversions=True, 1773 | strong_topic_bias=False, stop_score=1.0, 1774 | discourage_repetition=False, stopwords=stopwords.words('english'), 1775 | model_type='bert-base-uncased', model_path=None, 1776 | sequential=False, verbose=True): 1777 | 1778 | initialize_model(model_type, model_path) 1779 | initialize_rhyme_and_meter(model_type, meter=True, rhymes=True) 1780 | 1781 | # Start by generating words in order that match the meter. 1782 | comma = True 1783 | first = True 1784 | toks = [] 1785 | lines = [] 1786 | for rhyme_word in rhymes: 1787 | rhyme_word_meter = get_meter(get_pron(rhyme_word)) 1788 | rhyme_word_nsyls = len(rhyme_word_meter) 1789 | required_meter = meter[:-rhyme_word_nsyls] 1790 | required_nsyls = len(required_meter) 1791 | 1792 | line_toks = [] 1793 | while required_nsyls > 0: 1794 | permitted_words = torch.zeros([vocab_size]).to(device) 1795 | for i in range(1, len(required_meter)+1): 1796 | test_meter = required_meter[:i] 1797 | if test_meter in meter_dict: 1798 | permitted_words += meter_dict[test_meter] 1799 | permitted_words *= 1.0 - word_pieces 1800 | permitted_words = permitted_words.cpu() 1801 | 1802 | if first: 1803 | probs = permitted_words / sum(permitted_words) 1804 | tok_id = torch.multinomial(probs, 1) 1805 | else: 1806 | _, probs = compute_probs_for_masked_tokens(model, [toks + [{mask_token}]], [[[[len(toks), len(toks)]]]], 1, replacements_only=True) 1807 | probs = probs[0][0][0][0] 1808 | noise = torch.randn_like(probs) 1809 | noise = noise * 0.75 + 1.0 1810 | probs *= noise 1811 | probs = m(probs) 1812 | probs *= permitted_words 1813 | tok_id = probs.argmax() 1814 | 1815 | tok = tokenizer.convert_ids_to_tokens([tok_id])[0] 1816 | line_toks.append(tok) 1817 | toks.append(tok) 1818 | nsyls = len(get_meter(get_pron(tok)).replace('p', '')) 1819 | required_meter = required_meter[nsyls:] 1820 | required_nsyls -= nsyls 1821 | first = False 1822 | 1823 | if comma: 1824 | punct = ',' 1825 | else: 1826 | punct = '.' 1827 | comma = not comma 1828 | toks += [rhyme_word, punct] 1829 | lines.append(' '.join(line_toks + ['{' + rhyme_word + '}']) + punct) 1830 | 1831 | text = '\n'.join(lines) 1832 | print(text) 1833 | 1834 | return depoeticize(text, max_iterations, match_meter=True, 1835 | match_rhyme=False, title=title, author=author, 1836 | randomize=randomize, cooldown=cooldown, 1837 | modifier=modifier, forbid_reversions=forbid_reversions, 1838 | preserve_punctuation=False, 1839 | strong_topic_bias=strong_topic_bias, 1840 | stop_score=stop_score, 1841 | discourage_repetition=discourage_repetition, 1842 | stopwords=stopwords, 1843 | model_type=model_type, model_path=model_path, 1844 | allow_punctuation=None, 1845 | sequential=sequential, verbose=verbose) 1846 | 1847 | --------------------------------------------------------------------------------