├── README.md ├── final-output.txt └── gpt-dev.livemd /README.md: -------------------------------------------------------------------------------- 1 | # gpt-from-scratch 2 | Reimplements the GPT model in Andrej Karpathy's video: https://www.youtube.com/watch?v=kCc8FmEb1nY using Nx/Axon/Livebook 3 | -------------------------------------------------------------------------------- /final-output.txt: -------------------------------------------------------------------------------- 1 | 2 | Second Servingman: 3 | Come, for a pull, and as as to my presence; 4 | Vouchle? 5 | 6 | POMPEY: 7 | Why, lady. 8 | 9 | ROMEO: 10 | Come, sir? 11 | 12 | MERCUTIO: 13 | Have speak, Make my considerer: therefore. 14 | 15 | ROMEO: 16 | Why, the gate? 17 | 18 | MERCUTIO: 19 | And to the pretty took'd doable. 20 | 21 | MERCUTIO: 22 | Nay, be more, and after and you true, villain. 23 | 24 | ROMEO: 25 | Why, Come, that Camillo, 26 | As come valianted to my through in my power; 27 | Nor, if our to God beggar, ever, call arm; 28 | And is idle autor to loss. 29 | 30 | MERCUTIO: 31 | What, do you have no with not make I nothing? 32 | 33 | ROMEO: 34 | 'Tis that honour let honour? what do good he? 35 | 36 | ROMEO: 37 | Thou ha? 38 | 39 | PETRUCHIO: 40 | Have me lord? 41 | 42 | MERCUTIO: 43 | I pray you pretty what bear be give so? 44 | 45 | MERCUTIO: 46 | My lord? 47 | 48 | MERCUTIO: 49 | All I pray you, poor lord, I clond a be done. 50 | 51 | ROMEO: 52 | He throw with insife that I am speak and she consul. 53 | 54 | ROMEO: 55 | I'll cut see not to you have: if your death soon. 56 | If me sent my like you brother. 57 | 58 | TYRREL: 59 | Why, my son your will my continue, to-morrow, 60 | And you do my grace? and, what so do mightly be him 61 | In he rests warm jame a sacre, his with his eEven being; 62 | As blow's it a lack mine ere the court in he sill. 63 | 64 | KING RICHARD III: 65 | 66 | Keeper: 67 | Troth, will and safety, if 'gain. 68 | 69 | QUEEN ELIZABETH: 70 | Then and his comes myself, resign of the hands 71 | In honour sit, have me othing 72 | To six this at I am no beats, made and from the 73 | In may betterfly tarry this the one: set himself 74 | To look'd it of the great he dear: 75 | Where, but a gates beseech none stain'd, 76 | For shine is death very and come in honour'd; 77 | I love as a welcomed your vain a play'd. 78 | For, I claim to not stander the betwixt thousands? 79 | 80 | BUCKINGHAM: 81 | Because by those good my true old I should 82 | To with this my prince and have bite: 83 | Thou wilt noble immed, I wouldering and itself 84 | To have been in pence for young young by should be offend? 85 | To the your a title be shall and drink you, 86 | Too your being you thus next upon the gods wretch! 87 | It is tell despatch you do you to look, 88 | I pay forsake hope, my lord. Come, and for this you? 89 | 90 | ANGELO: 91 | When this is is your amorous dry edies you: 92 | I leave no banished my kind his with the purpose. 93 | 94 | ISABELLA: 95 | The should you not a quarry command expose: 96 | You not he when I need. 97 | 98 | ISABELLA: 99 | Ay, sir, that have worship, so you, that for King 100 | To tooch will perpeted passister him, 101 | And just shall be king it for your fly; 102 | No sorrow yours are of my business are weat 103 | Is not attend in the king. You must need 104 | To prove that is a garmy some a little 105 | I am in thee. 106 | 107 | This prayers to they head; but ell to be 108 | A be so pretty could d be roccursed it. 109 | 110 | TRANIO: 111 | As elder amongsy: 112 | Pardon father, sir; shall never hither's fack 113 | Thus are head as that does the war from of he 114 | in me possessioner's your love. 115 | 116 | MERCUTIO: 117 | A monger, intent for this protector a knew me 118 | too his sincely and fadies their 119 | luke earn: end to you to coldier thanks 'll. 120 | 121 | ROMEO: 122 | Come, and thy about my come violent. 123 | 124 | MERCUTIO: 125 | Why, too: thou art that my why charge, 126 | as I am date, by my destraight take with decrew. 127 | 128 | ROMEO: 129 | Such I sprepare is the is dead! I dare me devision 130 | solegiance the wish the live hence. 131 | 132 | ERCHIOLO: 133 | Thou art shalt ne'er reason bosom of a stone. 134 | 135 | POMPEY: 136 | He shalt be the gods him go. Do you villain. 137 | Did I shall had ere insmies what many, 138 | It is prince pursuits I have my namel. 139 | 140 | PETRUCHIO: 141 | The so to my lord. 142 | Your grace man arrow it now. 143 | 144 | EDWARD: 145 | Not hate. 146 | 147 | BUCKINGHAM: 148 | My lords, Lord Bohemia: he much have gone. 149 | 150 | GLOUCESTER: 151 | 152 | CLARENCE: 153 | Do now, to him be flyman Salisbury, 154 | And to comfort you this daughter lords your be not 155 | How your kindtily to be caasters of this: 156 | It warrance this like yours so bring: 157 | Grow break like unfant to your as England? 158 | 159 | KING EDWARD IV: 160 | Because to stroke a father. 161 | 162 | GLOUCESTER: 163 | Now, so much as you? what salls. As we's lady? 164 | 165 | LADY GREY: 166 | 'Tis was a trouble of thy good? 167 | 168 | KING EDWARD IV: 169 | Shall thy calls? thou have drops thy for her? 170 | 171 | GLOUCESTER: 172 | Whose come would Christ I be longs, thy lord! 173 | 174 | QUEEN MARGARET: 175 | Why, thou wolf, what forgive rich me? 176 | 177 | CLARENCE: 178 | Why this in Gaunt at being thy mother? 179 | 180 | QUEEN MARGARET: 181 | And dispite revenge thou canst thou, royalty 182 | Is thou both these plant thy hand, thy news: 183 | I thoughts of thy hearts of lovest thy tongue 184 | The from wounds purpose of and enough; 185 | For thy blood this goble and again a prove? 186 | 187 | QUEEN ELIZABETH: 188 | Therefore that I thanks, the were did o'er port, 189 | Hath colour had made mightly humble pierceed 190 | That I several this hath plagued to dispair, 191 | Of this dust thou art thou wilt be matter. 192 | But not speak what thou hast he truth rend 193 | That horses Brottom'd! 194 | Why, not I know in thou redeem, and what to thy love? 195 | 196 | CATESBY: 197 | We is't? 198 | 199 | CATESBY: 200 | He Pray most thou wast thou dost so know thy return'd? 201 | 202 | PRINCE EDWARD: 203 | And I cannot ashe bring did thy counter. 204 | Thou thy succester, not what my souls, 205 | MI'll blood without to heart thy purposes agest by 206 | To late come, thy scorn thy be giving me will 207 | A kein a better, whose man's of thy wants-pert, 208 | Who dares up thy with the and sugger things mights thee same. 209 | Kill's a deceives often that bear, to friend, 210 | To country's with slain my hand, 'Bring subject between in 211 | That mistress heir put draws. 212 | 213 | Shepherd: 214 | And thou not is the babe'st grace in yours. 215 | 216 | POMPEY: 217 | I would be no encountenance your be no shall 218 | question; sit out knows false to the greate 't. 219 | 220 | PERDITA: 221 | Sir, your hated any of your dreamd. 222 | 223 | PETUS: 224 | What, the Tush, pity 225 | What, ortly thou may shalt her; I meal, for your son? 226 | The slip wert thou this prince wish, who but jest? 227 | 228 | MENENIUS: 229 | He's stay there? 230 | 231 | COMINIUS: 232 | Sir, you speak tone spirit of this; I thoughts 233 | Where not goest the should been 234 | Be though welcome too sancture of Coriolanus, 235 | Exposition, and discover number: your she common 236 | Make done put you men counts prophecy be from father. 237 | 238 | OXFORD: 239 | I think foul this beggar-conquer where against I am to 240 | Lonf, but 'tiss to't: look'd way' humble. 241 | 242 | Second Servingman: 243 | Why, withy cannot is goodly 'tis with the he, 244 | daughter: you must entreat hath use no pale. 245 | 246 | LEONTES: 247 | How? 248 | 249 | More harse! I'll men together carved be his guilty! 250 | 251 | PRINCE EDWARD: 252 | O, why give made this it? what they are than's one? 253 | Come is is country's you will tars, hast staint hire, 254 | That have broughts for that would know dost forew. 255 | 256 | GLOUCESTER: 257 | My heart, Decliff, my father some, and in three. 258 | 259 | ANGELO: 260 | She hath all they lord; for I know my move not friend. 261 | 262 | LADY ANNE: 263 | No, It will to thinks me to not for no love. 264 | 265 | LADY ANNE: 266 | 267 | GLOUCESTER: 268 | An enever me with the king? 269 | 270 | GLOUCESTER: 271 | I did, I mother? methough ofury, we by my heart 272 | Whom the no more of much well as sendings; 273 | It cannot live thinks by my with a peace, 274 | And the right wings fair that we wilt the true.' 275 | 276 | LEONTES: 277 | O Pray you, 278 | To reclaim break, thou hast love. I was flesh all 279 | Accompanion: many three lord and only to 280 | As think. 281 | 282 | MENENIUS: 283 | Take run'd to throne with is truth, to him. 284 | 285 | MENENIUS: 286 | Methy knees: 287 | Shall have him how walkly, that chequality one 288 | And let of your his sondition, to I come go: 289 | My breath-perform your for your good soul's to us. 290 | 291 | MENENIUS: 292 | Must hold, therefore issues! you have destrance. 293 | 294 | MARCIUS: 295 | The clamations' you, you stand we wit. 296 | 297 | VIMILIA: 298 | Speak abhavours? pray none when I came gople. 299 | 300 | VOLUMNIA: 301 | He hath bed me the srunds pluck'd the cause nothing? 302 | 303 | VOLUMNIA: 304 | O my lord. 305 | 306 | PERDITA: 307 | I pray, and that say you know not, you must dieath 308 | shall forth my brother. 309 | 310 | MENENIUS: 311 | Nay, now notha second from that die? 312 | 313 | MENENIUS: 314 | The kill our say 'I' the word, I am,'tis 315 | Of your tiding rest yours, I could scalutchy's my 316 | I will stone. 317 | 318 | CORIOLANUS: 319 | Hath senator possible. 320 | 321 | First Marcius; 322 | will'd the to in't will not the and lawful give 323 | Second all reputy the marting shall thine for die. 324 | 325 | COMINIUS: 326 | Nay, not not he still unto this calarench a good. 327 | 328 | MARCIUS: 329 | Villain, brail, with sin, blow the thee; 330 | But name, let hollow, my lords; and will you have 331 | I'll resign, it is the still I saw young him 332 | some forswear of children: but for well be 333 | My tedious should not to too, sistake too a. 334 | What was not? thy tongue heart: sit but a France? 335 | 336 | CLIFFORD: 337 | Mightier to be thverence! 338 | 339 | FRIAR LAURENCE: 340 | Bid, he should that nor many book'd not in, 341 | Who confesserved chamber'd my scient his gurden me? 342 | 343 | ROMEO: 344 | A good More under of my lordship me, 345 | For what she scratch will out lord words? 346 | 347 | BENVOLIO: 348 | O, crying on, then set thanks being; not of, attrous 349 | Of you are of you weaky, wherefore you to them? 350 | 351 | ROMEO: 352 | Thou art shighness is your eye; not they soest approve 353 | Of thy me in thy done sad done's a word interch, 354 | Liewis is like a happy 'Twixt downos under, 355 | Lest the mine enemy the head thy mock make thee. 356 | 357 | ROMEO: 358 | Do there in the present daughter: 359 | Madam, good night adversity; he's it my lord. 360 | 361 | HERMIONE: 362 | Marry, when thou know me abstard, my lie? 363 | 364 | MERCUTIO: 365 | To stay, never king, sir, by me, and that 366 | Is love, my life. 367 | 368 | CAPULET: 369 | My lord. 370 | 371 | ROMEO: 372 | Auth hast is too, and served my lord. 373 | 374 | PRINCE EDWARD: 375 | But hear it me harm. 376 | 377 | GREMIO: 378 | Not very that forsaken thy mother, 379 | my name not me thy cannot thy me, 380 | Thy blood in that one sound not to theirs. 381 | 382 | DUKE OF AUMERLE: 383 | Well, or did to be man; but thou love, 384 | Tower it mistory, sdisconcil thou did his most 385 | And more piercharge a tewell, therefore wile 386 | Of so so. Both prison! what thy sir, but thou, 387 | Does that; but ere which is is need, we be 388 | At thou bears wert pergination a poor such to thy friends 389 | So infactorse, here bearts and here is should before? 390 | On plotting but thou see thy fair objected. 391 | This is own confess upon-thou sift; but thy king, 392 | To heard rocker had body I believer yet remember 393 | As if thy war stand traitors: that death blush, 394 | I would me consul! straight, let thee news. 395 | 396 | POLIXENES: 397 | Why, I am so slain that? 398 | Thy hear me, do my life ten'd to thy duty. 399 | 400 | HERMIONE: 401 | The should I reple king! 402 | No through is is a serves for think; to the world: 403 | My lordship to thy king, now heart, Do not cheek nor hand, 404 | And what thou diest thy this sound to pay thee have rank'd; 405 | And their as herefore arms are match'd bring Warwick 406 | O' the nuptain the executions with this thee: 407 | Take candh prophecy me sun schange thou speech fraughters. 408 | And in he on God's draction the kings, 409 | And who can a wword the of king? not lacks, 410 | God deed's daughter the king to dry grave: 411 | H 412 | -------------------------------------------------------------------------------- /gpt-dev.livemd: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Let's build GPT from scratch! w/ Nx and Axon 4 | 5 | ```elixir 6 | Mix.install( 7 | [ 8 | {:nx, "~> 0.5.3"}, 9 | {:req, "~> 0.3.6"}, 10 | {:kino_bumblebee, "~> 0.3.0"}, 11 | {:exla, "~> 0.5.1"}, 12 | {:table_rex, "~> 3.1.1"} 13 | ], 14 | config: [nx: [default_backend: EXLA.Backend]] 15 | ) 16 | ``` 17 | 18 | 19 | 20 | ``` 21 | :ok 22 | ``` 23 | 24 | ## Introduction 25 | 26 | This notebook covers Andrej Karpathy's video [Let's build GPT: from scratch, in code, spelled out.](https://www.youtube.com/watch?v=kCc8FmEb1nY) We'll start off building a simple bigram model, and iteratively build up to the decoder-only transformer. 27 | 28 | Note: this notebook was created to experiment with Elixir's ML libraries, so the following code is probably not idiomatic Nx/Axon code and doesn't take full advantage of their capabilities. 29 | 30 | ### References 31 | 32 | * Karpathy's companion notebook can be found [here](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing#scrollTo=wJpXpmjEYC_T) 33 | 34 | * Thanks to Lorenzo Sinisi for the initial livebook [code](https://gist.github.com/lorenzosinisi/bb928554d665bdc53aada98c3710b0d5) 35 | 36 | ## Prepare data 37 | 38 | Let's first prepare our Shakespeare data 39 | 40 | ```elixir 41 | file_path = Path.absname("./input.txt") 42 | 43 | text = 44 | if File.exists?(file_path) do 45 | IO.puts("File loaded from memory: #{file_path}") 46 | File.read!(file_path) 47 | else 48 | IO.puts( 49 | "File loaded from git: https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" 50 | ) 51 | 52 | Req.get!( 53 | "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" 54 | ).body 55 | end 56 | ``` 57 | 58 | 59 | 60 | ``` 61 | File loaded from git: https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt 62 | ``` 63 | 64 | 65 | 66 | ``` 67 | "First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger for bread, not in thirst for revenge.\n\nSecond Citizen:\nWould you proceed especially against Caius Marcius?\n\nAll:\nAgainst him first: he's a very dog to the commonalty.\n\nSecond Citizen:\nConsider you what services he has done for his country?\n\nFirst Citizen:\nVery well; and could be content to give him good\nreport fort, but that he pays himself with being proud.\n\nSecond Citizen:\nNay, but speak not maliciously.\n\nFirst Citizen:\nI say unto you, what he hath done famously, he did\nit to that end: though soft-conscienced men can be\ncontent to say it was for his country he did it to\nplease his mother and to be partly proud; which he\nis, even till the altitude of his virtue.\n\nSecond Citizen:\nWhat he cannot help in his nature, you account a\nvice in him. You must in no way say he is covetous.\n\nFirst Citizen:\nIf I must not, I need not be barren of accusations;\nhe hath faults, with surplus, to tire in repetition.\nWhat shouts are these? The other side o' the city\nis risen: why stay we prating here? to the Capitol!\n\nAll:\nCome, come.\n\nFirst Citizen:\nSoft! who comes here?\n\nSecond Citizen:\nWorthy Menenius Agrippa; one that hath always loved\nthe people.\n\nFirst Citizen:\nHe's one honest enough: would all the rest were so!\n\nMENENIUS:\nWhat work's, my countrymen, in hand? where go you\nWith bats and clubs? The matter? speak, I pray you.\n\nFirst Citizen:\nOur business is not unknown to the senate; they have\nhad inkling this fortnight what we intend to do,\nwhich now we'll show 'em in deeds. They say poor\nsuitors have strong breaths: they shall know we\nhave strong arms too.\n\nMENENIUS:\nWhy, masters, my good friends, mine honest neighbours,\nWill you undo yourselves?\n\nFirst Citizen:\nWe cannot, sir, we are undone already.\n\nMENENIUS:\nI tell you, friends, most charitable care\nHave the patricians of you. For your wants,\nYour suffering in this dearth, you may as well\nStrike at the heaven with your staves as lift them\nAgainst the Roman state, whose course will on\nThe way it takes, cracking ten thousand curbs\nOf more strong link asunder than can ever\nAppear in your impediment. For the dearth,\nThe gods, not the patricians, make it, and\nYour knees to them, not arms, must help. Alack,\nYou are transported by calamity\nThither where more attends you, and you slander\nThe helms o' the state, who care for you like fathers,\nWhen you curse them as enemies.\n\nFirst Citizen:\nCare for us! True, indeed! They ne'er cared for us\nyet: suffer us to famish, and their store-houses\ncrammed with grain; make edicts for usury, to\nsupport usurers; repeal daily any wholesome act\nestablished against the rich, and provide more\npiercing statutes daily, to chain up and restrain\nthe poor. If the wars eat us not up, they will; and\nthere's all the love they bear us.\n\nMENENIUS:\nEither you must\nConfess yourselves wondrous malicious,\nOr be accused of folly. I shall tell you\nA pretty tale: it may be you have heard it;\nBut, since it serves my purpose, I will venture\nTo stale 't a little more.\n\nFirst Citizen:\nWell, I'll hear it, sir: yet you must not think to\nfob off our disgrace with a tale: but, an 't please\nyou, deliver.\n\nMENENIUS:\nThere was a time when all " <> ... 68 | ``` 69 | 70 | ## Basic Encoder / Decoder 71 | 72 | ```elixir 73 | defmodule Minidecoder do 74 | @chars text |> String.codepoints() |> Enum.uniq() |> Enum.sort() 75 | @vocab_size Enum.count(@chars) 76 | def vocab_size, do: @vocab_size 77 | 78 | @stoi Enum.reduce(@chars, %{}, fn ch, acc -> Map.put(acc, ch, Enum.count(acc)) end) 79 | @itos Enum.reduce(@stoi, %{}, fn {ch, i}, acc -> Map.put(acc, i, ch) end) 80 | 81 | def encode_char(char), do: @stoi[char] 82 | 83 | def decode_char(encoded_char), do: @itos[encoded_char] 84 | 85 | def encode(text) do 86 | text |> String.codepoints() |> Enum.map(&encode_char(&1)) 87 | end 88 | 89 | def decode(encoded_list) do 90 | encoded_list |> Enum.map(&decode_char(&1)) |> Enum.join() 91 | end 92 | 93 | def tensor(text) do 94 | Nx.tensor(encode(text)) 95 | end 96 | end 97 | 98 | vocab_size = 99 | Minidecoder.vocab_size() 100 | |> IO.inspect(label: "vocab size is") 101 | 102 | Minidecoder.tensor(text) 103 | ``` 104 | 105 | 106 | 107 | ``` 108 | vocab size is: 65 109 | 110 | 14:50:17.553 [info] TfrtCpuClient created. 111 | 112 | ``` 113 | 114 | 115 | 116 | ``` 117 | #Nx.Tensor< 118 | s64[1115394] 119 | EXLA.Backend 120 | [18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63, 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, ...] 121 | > 122 | ``` 123 | 124 | ## Encoded Training + Validation Data 125 | 126 | ```elixir 127 | data = Minidecoder.tensor(text) 128 | n = Kernel.round(Nx.size(data) * 0.9) 129 | # take from index 0 till the end 130 | train_data = Nx.slice(data, [0], [n]) 131 | # take from index 0 for size - n (to get all until end) 132 | val_data = Nx.slice(data, [n], [Nx.size(data) - n]) 133 | {train_data, val_data} 134 | ``` 135 | 136 | 137 | 138 | ``` 139 | {#Nx.Tensor< 140 | s64[1003855] 141 | EXLA.Backend 142 | [18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63, 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, ...] 143 | >, 144 | #Nx.Tensor< 145 | s64[111539] 146 | EXLA.Backend 147 | [0, 0, 19, 30, 17, 25, 21, 27, 10, 0, 19, 53, 53, 42, 1, 51, 53, 56, 56, 53, 61, 6, 1, 52, 43, 47, 45, 46, 40, 53, 59, 56, 1, 14, 39, 54, 58, 47, 57, 58, 39, 8, 0, 0, 14, 13, 28, 32, ...] 148 | >} 149 | ``` 150 | 151 | ## Training Data 152 | 153 | To speed up training, we're going to batch our training data. 154 | It'll look like this. 155 | 156 | ``` 157 | x = [ 158 | ["h", "e", "l", "l", "o"], 159 | [" ", "w", "o", "r", "l"] 160 | ] 161 | 162 | y = [ 163 | ["e", "l", "l", "o", " "], 164 | ["w", "o", "r", "l", "d"] 165 | ] 166 | ``` 167 | 168 | We'll insert a linear layer between our x and y. After training, the model should learn these associations 169 | 170 | * "h" -> "e" 171 | * "e" -> "l" 172 | * .. 173 | * "w" -> "o" 174 | * "o" -> "r" 175 | * etc 176 | 177 | 178 | 179 | ### Training Data Generator 180 | 181 | The Axon training loop expects an Enumerable or Stream for its training data. We'll use `Stream.resource/3` to repeatedly generate random slices of our training data. Everytime we call it, it'll also keep track of a random key for the next generation. This ensures reproducible model outputs. 182 | 183 | We'll experiment with different batch sizes and block sizes, so we'll wrap this Stream in a closure. 184 | 185 | ```elixir 186 | seed = 1337 187 | 188 | get_batch_stream = fn batch_size, block_size, split -> 189 | Stream.resource( 190 | # initialization function 191 | fn -> 192 | Nx.Random.key(seed) 193 | end, 194 | # generation function 195 | fn key -> 196 | data = if(split == :train, do: train_data, else: val_data) 197 | 198 | {ix, new_key} = 199 | Nx.Random.randint(key, 0, Nx.size(data) - block_size, shape: {batch_size}, type: :u32) 200 | 201 | ix = Nx.to_list(ix) 202 | 203 | x = Enum.map(ix, fn i -> Nx.slice(data, [i], [block_size]) end) |> Nx.stack() 204 | y = Enum.map(ix, fn i -> Nx.slice(data, [i + 1], [block_size]) end) |> Nx.stack() 205 | 206 | # Reshape yb {b, t}, to be a single vector 207 | # We do this to match the shape of y_true during training 208 | # https://hexdocs.pm/axon/Axon.Losses.html#categorical_cross_entropy/3 209 | {b, t} = Nx.shape(y) 210 | 211 | # or Nx.flatten 212 | flattened_y = Nx.reshape(y, {b * t}) 213 | 214 | out_data = {x, flattened_y} 215 | 216 | {[out_data], new_key} 217 | end, 218 | # termination function 219 | fn _ -> :ok end 220 | ) 221 | end 222 | 223 | train_batch_stream = get_batch_stream.(4, 8, :train) 224 | train_batch_stream |> Enum.take(1) 225 | ``` 226 | 227 | 228 | 229 | ``` 230 | [ 231 | {#Nx.Tensor< 232 | s64[4][8] 233 | EXLA.Backend 234 | [ 235 | [46, 47, 51, 57, 43, 50, 44, 1], 236 | [26, 19, 1, 30, 21, 15, 20, 13], 237 | [41, 43, 42, 1, 39, 1, 58, 56], 238 | [1, 42, 53, 1, 46, 43, 56, 43] 239 | ] 240 | >, 241 | #Nx.Tensor< 242 | s64[32] 243 | EXLA.Backend 244 | [47, 51, 57, 43, 50, 44, 1, 40, 19, 1, 30, 21, 15, 20, 13, 30, 43, 42, 1, 39, 1, 58, 56, 39, 42, 53, 1, 46, 43, 56, 43, 6] 245 | >} 246 | ] 247 | ``` 248 | 249 | ## Simple Bigram Model 250 | 251 | Let's assume we have a well trained bigram model. 252 | 253 | Given an input tensor of size `{1, 4}` the output might look something like this. 254 | 255 | ``` 256 | # batch_size = 1, block_size = 4 257 | input = [[h, e, l, l]] 258 | 259 | # batch_size = 1, block_size = 4, vocab_size = 65 260 | output = [[[65], [65], [65], [65]]] 261 | ``` 262 | 263 | * Each index in these [65] sized tensors correspond to an encoded character from our Shakespeare vocab size 264 | * The likelihood of an encoded character appearing next in a sequence is given by its value inside the [65] sized tensor. 265 | 266 | To predict the next character in our sequence, we'll look at the last [65] sized tensor in our output. Right now the values are just some raw, non-normalized predictions for our 65 possible characters. We'll feed this tensor (called logits) into softmax to get a probability distribution that we can sample the next character from. 267 | 268 | ```elixir 269 | # Hyperparameters 270 | batch_size = 4 271 | block_size = 8 272 | 273 | bigram_model = 274 | Axon.input("sequence") 275 | |> Axon.embedding(65, 65) 276 | 277 | Axon.Display.as_graph(bigram_model, Nx.template({batch_size, block_size}, :f32), 278 | direction: :top_down 279 | ) 280 | ``` 281 | 282 | 283 | 284 | ```mermaid 285 | graph TD; 286 | 35[/"sequence (:input) {4, 8}"/]; 287 | 36["embedding_0 (:embedding) {4, 8, 65}"]; 288 | 35 --> 36; 289 | ``` 290 | 291 | ## Training the bigram model 292 | 293 | ```elixir 294 | # We'll use this for other models further along in the notebook 295 | defmodule CommonTrain do 296 | import Nx.Defn 297 | 298 | defn custom_predict_fn(model_predict_fn, params, input) do 299 | %{prediction: preds} = out = model_predict_fn.(params, input) 300 | {b, t, c} = Nx.shape(preds) 301 | reshaped = Nx.reshape(preds, {b * t, c}) 302 | %{out | prediction: reshaped} 303 | end 304 | 305 | def custom_loss_fn(y_true, y_pred) do 306 | Axon.Losses.categorical_cross_entropy(y_true, y_pred, 307 | from_logits: true, 308 | sparse: true, 309 | reduction: :mean 310 | ) 311 | end 312 | end 313 | 314 | {init_fn, predict_fn} = Axon.build(bigram_model, mode: :train) 315 | custom_predict_fn = &CommonTrain.custom_predict_fn(predict_fn, &1, &2) 316 | custom_loss_fn = &CommonTrain.custom_loss_fn(&1, &2) 317 | train_batch_stream = get_batch_stream.(4, 8, :train) 318 | 319 | params = 320 | {init_fn, custom_predict_fn} 321 | |> Axon.Loop.trainer(custom_loss_fn, Axon.Optimizers.adamw()) 322 | |> Axon.Loop.run(train_batch_stream, %{}, epochs: 1, iterations: 10000, compiler: EXLA) 323 | ``` 324 | 325 | 326 | 327 | ``` 328 | 329 | 15:25:05.635 [debug] Forwarding options: [compiler: EXLA] to JIT compiler 330 | Epoch: 0, Batch: 9950, loss: 2.8657069 331 | ``` 332 | 333 | 334 | 335 | ``` 336 | %{ 337 | "embedding_0" => %{ 338 | "kernel" => #Nx.Tensor< 339 | f32[65][65] 340 | EXLA.Backend 341 | [ 342 | [1.954552173614502, -4.385025978088379, -4.377076148986816, -4.385700702667236, -4.394716262817383, -0.9468418955802917, -4.390252590179443, -3.949601650238037, -4.3867998123168945, -2.314349412918091, -4.3820013999938965, -4.39186429977417, -4.384054660797119, 1.2772890329360962, 0.4298079311847687, 0.24700672924518585, -0.18589963018894196, -0.6697914600372314, 0.2670493721961975, -0.1326778531074524, 0.452250212430954, 0.8470256328582764, -1.1264326572418213, -0.38984325528144836, -0.031074173748493195, 0.3561098277568817, -0.02344430610537529, 0.016382480040192604, -0.12044396251440048, -1.0433743000030518, -0.5009018778800964, 0.44456756114959717, 1.4647704362869263, -1.013350009918213, -1.372435212135315, 0.9725027084350586, -4.384570598602295, -0.23563559353351593, -4.383942604064941, -1.2093874216079712, -1.6285748481750488, -1.3509862422943115, -1.6671699285507202, -2.2631280422210693, -1.6090916395187378, -1.9188610315322876, -1.212599277496338, -1.7029807567596436, ...], 343 | ... 344 | ] 345 | > 346 | } 347 | } 348 | ``` 349 | 350 | ## Generating text with the bigram model, w/ argmax 351 | 352 | Let's implement a naive way of generating text using `Nx.argmax`. Everytime we make a prediction, argmax will pick the highest probable character that our model thinks should be next. 353 | 354 | ```elixir 355 | generate_fn = fn model, params, init_seq, max_new_tokens -> 356 | Enum.reduce(1..max_new_tokens, init_seq, fn _i, acc -> 357 | {_b, t} = Nx.shape(acc) 358 | 359 | # Cap the input sequence length from [t, block size] 360 | context_length = min(t, block_size) 361 | context_range = -context_length..-1 362 | context_slice = acc[[.., context_range]] 363 | 364 | # Predict next char 365 | preds = Axon.predict(model, params, context_slice) 366 | logits = preds[[.., -1, ..]] 367 | probs = Axon.Activations.softmax(logits) 368 | # {b, 1} 369 | batch_char = Nx.argmax(probs, axis: 1, keep_axis: true) 370 | 371 | Nx.concatenate([acc, batch_char], axis: -1) 372 | end) 373 | end 374 | 375 | # init_seq = Nx.broadcast(0, {1, 1}) 376 | init_seq = Nx.iota({1, 5}) 377 | max_new_tokens = 500 378 | 379 | generate_fn.(bigram_model, params, init_seq, max_new_tokens) 380 | # Convert our Nx.tensor to Elixir list 381 | |> Nx.to_list() 382 | # Decode the results 383 | |> Enum.map(fn encoded_list -> Minidecoder.decode(encoded_list) end) 384 | # Our input just 1 batch, so grab the first one 385 | |> List.first() 386 | |> IO.puts() 387 | ``` 388 | 389 | 390 | 391 | ``` 392 | 393 | !$&cour the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the 394 | ``` 395 | 396 | 397 | 398 | ``` 399 | :ok 400 | ``` 401 | 402 | Why the repetition? 403 | 404 | Lets look inside our linear layer. We should see something like this 405 | 406 | * "t" is likely to produce "h" 407 | * "h" is likely to produce "e" 408 | * "e" is likely to produce " " 409 | * " " is likely to produce "t" 410 | 411 | ```elixir 412 | get_top_predictions = fn char, table, num_chars -> 413 | encoded_char = Minidecoder.encode_char(char) 414 | predictions = table[encoded_char] 415 | 416 | predictions 417 | |> Nx.to_list() 418 | |> Enum.with_index(fn element, index -> {index, element} end) 419 | |> Enum.map(fn {idx, logit} -> {Minidecoder.decode_char(idx), logit} end) 420 | |> Enum.sort(fn {_x_idx, x_res}, {_y_idx, y_res} -> x_res >= y_res end) 421 | |> Enum.take(num_chars) 422 | end 423 | 424 | table = params["embedding_0"]["kernel"] 425 | 426 | [ 427 | t: get_top_predictions.("t", table, 3), 428 | h: get_top_predictions.("h", table, 3), 429 | e: get_top_predictions.("e", table, 3), 430 | _: get_top_predictions.(" ", table, 3) 431 | ] 432 | ``` 433 | 434 | 435 | 436 | ``` 437 | [ 438 | t: [{"h", 2.5696206092834473}, {" ", 2.204694986343384}, {"o", 1.090488314628601}], 439 | h: [{"e", 2.453754425048828}, {"a", 1.7331098318099976}, {"i", 1.4247827529907227}], 440 | e: [{" ", 2.5171456336975098}, {"r", 1.7074227333068848}, {"n", 1.1833429336547852}], 441 | _: [{"t", 1.9793766736984253}, {"a", 1.3801567554473877}, {"h", 1.2791123390197754}] 442 | ] 443 | ``` 444 | 445 | ## Multinomial 446 | 447 | To avoid repetitive text, we want to randomly sample a character with our model prediction. We can do this with `Nx.Random.choice/4`. But, our model's output shape is `{b, t, vocab_size}`. Pytorch's `torch.multinomial` can work with batches, but (afaik) there's no equivalent function in the Nx library. We'll need to write a custom function to stack the results of `Nx.Random.choice/4` 448 | 449 | ```elixir 450 | defmodule RandomPlus do 451 | import Nx.Defn 452 | 453 | defn multinomial(init_key, input, opts \\ []) do 454 | opts = keyword!(opts, num_samples: 1) 455 | num_samples = opts[:num_samples] 456 | 457 | {b, c} = Nx.shape(input) 458 | initial_tensor = Nx.broadcast(0, {b, num_samples}) 459 | category_iota = Nx.iota({c}, type: :s32) 460 | 461 | {_i, _input, next_key, acc} = 462 | while {i = 0, input, key = init_key, acc = initial_tensor}, i < b do 463 | # Becomes {C}, represents probability distribution 464 | i_batch_prob = input[i] 465 | 466 | {i_samples, next_key} = 467 | Nx.Random.choice(key, category_iota, i_batch_prob, samples: num_samples) 468 | 469 | # Update ith row in acc to hold the new samples 470 | i_samples_reformatted = Nx.reshape(i_samples, {1, :auto}) 471 | acc = Nx.put_slice(acc, [i, i], i_samples_reformatted) 472 | {i + 1, input, next_key, acc} 473 | end 474 | 475 | {next_key, acc} 476 | end 477 | end 478 | 479 | probs = 480 | Nx.tensor([ 481 | [0.2, 0.3, 0.1, 0.15, 0.25], 482 | [0.10, 0.10, 0.10, 0.10, 0.60], 483 | [0.0, 0.0, 0.0, 0.0, 1.00] 484 | ]) 485 | 486 | # Given some batched probability distribution, sample 5 values 487 | # This is for demonstration purposes (we'll only need to sample 1 char when generating text) 488 | {_key, samples} = RandomPlus.multinomial(Nx.Random.key(1337), probs, num_samples: 5) 489 | samples 490 | ``` 491 | 492 | 493 | 494 | ``` 495 | #Nx.Tensor< 496 | s64[3][5] 497 | EXLA.Backend 498 | [ 499 | [1, 4, 0, 4, 4], 500 | [2, 1, 1, 2, 4], 501 | [4, 4, 4, 4, 4] 502 | ] 503 | > 504 | ``` 505 | 506 | ## Generating text with the bigram model, w/ multinomial 507 | 508 | Let's see what happens if we use multinomial now. 509 | 510 | ```elixir 511 | generate_fn = fn model, params, init_seq, key, max_new_tokens -> 512 | Enum.reduce(1..max_new_tokens, {key, init_seq}, fn _i, {key, acc} -> 513 | {_b, t} = Nx.shape(acc) 514 | 515 | # Cap the input sequence length from [t, block size] 516 | context_length = min(t, block_size) 517 | context_range = -context_length..-1 518 | context_slice = acc[[.., context_range]] 519 | 520 | # Predict next batch of chars (when we generate text, batch_size = 1) 521 | preds = Axon.predict(model, params, context_slice) 522 | logits = preds[[.., -1, ..]] 523 | probs = Axon.Activations.softmax(logits) 524 | {next_key, batch_char} = RandomPlus.multinomial(key, probs, num_samples: 1) 525 | 526 | {next_key, Nx.concatenate([acc, batch_char], axis: -1)} 527 | end) 528 | |> then(fn {_next_key, acc} -> acc end) 529 | end 530 | 531 | init_seq = Nx.broadcast(0, {1, 1}) 532 | key = Nx.Random.key(1337) 533 | max_new_tokens = 1000 534 | 535 | generate_fn.(bigram_model, params, init_seq, key, max_new_tokens) 536 | |> Nx.to_list() 537 | |> Enum.map(fn encoded_list -> Minidecoder.decode(encoded_list) end) 538 | |> List.first() 539 | |> IO.puts() 540 | ``` 541 | 542 | 543 | 544 | ``` 545 | 546 | S: 547 | 'S: ses t Pis ave, lef. j'ICay thiles Won.AUmoliveng'llofrrouseff he: 548 | Tho ff wous ke 549 | Ker! 550 | lthean areel. 551 | Whon ofilok Alil t thom.Xfabj! 552 | 553 | Ifo hendride mou. 554 | TrvicQu ous d mitesors; YCAns qgar matamES: 555 | RISthore qu. ue thuspipKI K: 556 | We che. nd y manthamean the fo, 557 | Ar! 558 | I oout. cowieayouroflllothalveedrgrme 559 | d patit f 3 560 | CK har 561 | 562 | UELl; 563 | If chankn ourinowoftipor hendvis?u, 564 | WAgorvan; 565 | Ho hos, 566 | EvS: 567 | TOVck fodonrQ$&-bQhons s her, 'd. 568 | 569 | tyolatoresces of$Qy; opy thTho fopum f. 570 | CHNRUCres meowea d s 571 | Thetsos on psth or 572 | PRecon limy t t: 573 | 's by indngrs, pXG bntr f hs: 574 | Thays thomea ELESSus cs; at, 575 | t sanLIt, 576 | MI he: 577 | RElide, oppratrmarorige wW: 578 | I as baiak t eind, 579 | HF&CHAtinNIUSI thenane ou I ke hou arou speou!ARE pere at t my ba'3 h brmin ntr alt; 580 | FLenurthourarait: 581 | 582 | Hor yo h ctind hadinot: 583 | I theoher. tal t is gx?ws o- 584 | CNRI: tetigQMSeay ifrorer be st ost's wn. 585 | CAifls frin tsovim athe her; ys: 586 | 'dWhteyorexBy d y it wegeangr thur y patit, ou.LI3. 587 | 588 | Austhar 589 | Gralie: 590 | Wherrvjul; fise s d arerve be'g; 591 | HZXINGinghy!q&GMBY lce 592 | ``` 593 | 594 | 595 | 596 | ``` 597 | :ok 598 | ``` 599 | 600 | This looks somewhat better with our limited training. We'll improve the text generation by focusing on single-head and multi-head attention next. 601 | 602 | Before we implement the attention models, let's create a reusable text generation function for different block sizes (sequence lengths). block_size is required to cap the sequence context for each prediction. 603 | 604 | ```elixir 605 | defmodule TextGen do 606 | def generate(model, params, init_seq, block_size, opts \\ []) do 607 | opts = Keyword.validate!(opts, key_seed: 1337, max_new_tokens: 1000) 608 | 609 | key = opts[:key_seed] |> Nx.Random.key() 610 | max_new_tokens = opts[:max_new_tokens] 611 | 612 | Enum.reduce(1..max_new_tokens, {key, init_seq}, fn _i, {key, acc} -> 613 | {_b, t} = Nx.shape(acc) 614 | 615 | # Cap the input sequence length from [t, block size] 616 | context_length = min(t, block_size) 617 | context_range = -context_length..-1 618 | context_slice = acc[[.., context_range]] 619 | 620 | # Predict next batch of chars (but for us, batch_size = 1) 621 | preds = Axon.predict(model, params, context_slice) 622 | logits = preds[[.., -1, ..]] 623 | probs = Axon.Activations.softmax(logits) 624 | {next_key, batch_char} = RandomPlus.multinomial(key, probs, num_samples: 1) 625 | 626 | {next_key, Nx.concatenate([acc, batch_char], axis: -1)} 627 | end) 628 | |> then(fn {_next_key, acc} -> acc end) 629 | # Convert our Nx.tensor to Elixir list 630 | |> Nx.to_list() 631 | # Decode the results 632 | |> Enum.map(fn encoded_list -> Minidecoder.decode(encoded_list) end) 633 | # Our input just 1 batch, so grab the first one 634 | |> List.first() 635 | end 636 | end 637 | ``` 638 | 639 | 640 | 641 | ``` 642 | {:module, TextGen, <<70, 79, 82, 49, 0, 0, 15, ...>>, {:generate, 5}} 643 | ``` 644 | 645 | ## The mathematical trick to self attention (version #4) 646 | 647 | To implement attention like how Karpathy does it, we'll create lower triangular matrices filled with ones. Nx doesn't have an equivalent `torch.tril`, but we can create these matrices using the iota function. 648 | 649 | The iota function is commonly used to create tensors with consecutive values, starting from a specified value and incrementing by one. We can leverage this to create two tensors (row_iota and column_iota) and compare them to create the attention mask. 650 | 651 | ```elixir 652 | shape = {3, 3} 653 | row_iota = Nx.iota(shape, axis: 0) 654 | ``` 655 | 656 | 657 | 658 | ``` 659 | #Nx.Tensor< 660 | s64[3][3] 661 | EXLA.Backend 662 | [ 663 | [0, 0, 0], 664 | [1, 1, 1], 665 | [2, 2, 2] 666 | ] 667 | > 668 | ``` 669 | 670 | ```elixir 671 | column_iota = Nx.iota(shape, axis: 1) 672 | ``` 673 | 674 | 675 | 676 | ``` 677 | #Nx.Tensor< 678 | s64[3][3] 679 | EXLA.Backend 680 | [ 681 | [0, 1, 2], 682 | [0, 1, 2], 683 | [0, 1, 2] 684 | ] 685 | > 686 | ``` 687 | 688 | ```elixir 689 | Nx.greater_equal(row_iota, column_iota) 690 | ``` 691 | 692 | 693 | 694 | ``` 695 | #Nx.Tensor< 696 | u8[3][3] 697 | EXLA.Backend 698 | [ 699 | [1, 0, 0], 700 | [1, 1, 0], 701 | [1, 1, 1] 702 | ] 703 | > 704 | ``` 705 | 706 | ```elixir 707 | defmodule Tril do 708 | import Nx.Defn 709 | 710 | # Creates a lower triangular matrix of 1s to use as our mask 711 | defn ones(opts \\ []) do 712 | assert_keys(opts, [:shape]) 713 | 714 | shape = opts[:shape] 715 | Nx.greater_equal(Nx.iota(shape, axis: 0), Nx.iota(shape, axis: 1)) 716 | end 717 | end 718 | 719 | Tril.ones(shape: {5, 5}) 720 | ``` 721 | 722 | 723 | 724 | ``` 725 | #Nx.Tensor< 726 | u8[5][5] 727 | EXLA.Backend 728 | [ 729 | [1, 0, 0, 0, 0], 730 | [1, 1, 0, 0, 0], 731 | [1, 1, 1, 0, 0], 732 | [1, 1, 1, 1, 0], 733 | [1, 1, 1, 1, 1] 734 | ] 735 | > 736 | ``` 737 | 738 | ## The mathematical trick to self attention (version #4) cont. 739 | 740 | Here's a rough draft of how attention is computed in a single head. We'll package this up later into a reusable layer 741 | 742 | ```elixir 743 | {b, t, c} = {4, 8, 32} 744 | {x, key} = Nx.Random.normal(Nx.Random.key(1337), shape: {b, t, c}, type: :f32) 745 | 746 | head_size = 16 747 | # Used for initializing random key, query, value kernels 748 | keys = key |> Nx.Random.split(parts: 3) 749 | 750 | # For some reason the default scale (2.0) produces really high weight values 751 | init_fn = Axon.Initializers.he_uniform(scale: 0.5) 752 | 753 | key_kernel = init_fn.({c, head_size}, {:f, 32}, keys[0]) 754 | query_kernel = init_fn.({c, head_size}, {:f, 32}, keys[1]) 755 | value_kernel = init_fn.({c, head_size}, {:f, 32}, keys[2]) 756 | k = Axon.Layers.dense(x, key_kernel) 757 | q = Axon.Layers.dense(x, query_kernel) 758 | v = Axon.Layers.dense(x, value_kernel) 759 | kT = Nx.transpose(k, axes: [0, -1, -2]) 760 | 761 | # {b, t, t} 762 | wei = Nx.dot(q, [2], [0], kT, [1], [0]) 763 | 764 | # Broadcast tril to {b, t, t} for Nx.select 765 | tril = Tril.ones(shape: {t, t}) 766 | tril = Nx.broadcast(tril, {b, t, t}) 767 | 768 | # Broadcast neg_inf to {b, t, t} for Nx.select 769 | wei_type = Nx.type(wei) 770 | neg_inf = Nx.broadcast(Nx.Constants.neg_infinity(wei_type), wei) 771 | 772 | # lower triangular part of wei has original values 773 | # upper triangular part of wei has -neg_inf for its values 774 | wei = Nx.select(tril, wei, neg_inf) 775 | 776 | # {4, 8, 8} 777 | wei = Axon.Activations.softmax(wei, axis: -1) 778 | 779 | # {4,8,8} @ {4,8,16} 780 | # out = Nx.dot(wei, [-1], [0], v, [1], [0]) 781 | # wei[0] 782 | ``` 783 | 784 | 785 | 786 | ``` 787 | #Nx.Tensor< 788 | f32[4][8][8] 789 | EXLA.Backend 790 | [ 791 | [ 792 | [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 793 | [0.5236416459083557, 0.4763583540916443, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 794 | [0.1460239738225937, 0.8355455994606018, 0.018430398777127266, 0.0, 0.0, 0.0, 0.0, 0.0], 795 | [0.2382892519235611, 0.19527707993984222, 0.5128787755966187, 0.053554967045784, 0.0, 0.0, 0.0, 0.0], 796 | [0.23525264859199524, 0.02861269935965538, 0.12083742767572403, 0.43159112334251404, 0.18370607495307922, 0.0, 0.0, 0.0], 797 | [0.4105880558490753, 0.009884358383715153, 0.19398914277553558, 9.765126160345972e-4, 0.37614840269088745, 0.008413595147430897, 0.0, 0.0], 798 | [0.4806952476501465, 0.09604756534099579, ...], 799 | ... 800 | ], 801 | ... 802 | ] 803 | > 804 | ``` 805 | 806 | ## Single-head attention layer 807 | 808 | Let's package up the computation we did earlier into an Axon layer. Since we need to do some custom calculations, we'll use `Axon.layer`, you can learn more about custom layers in the [Axon docs](https://hexdocs.pm/axon/custom_layers.html) 809 | 810 | ```elixir 811 | defmodule SingleAttention do 812 | import Nx.Defn 813 | 814 | def head_layer(%Axon{} = key, %Axon{} = query, %Axon{} = value, opts \\ []) do 815 | Axon.layer(&head_layer_impl/4, [key, query, value], 816 | name: opts[:name], 817 | op_name: :single_head 818 | ) 819 | end 820 | 821 | defn head_layer_impl(k, q, v, _opts \\ []) do 822 | {_b, t, c} = Nx.shape(k) 823 | tensor_type = Nx.type(k) 824 | 825 | kT = Nx.transpose(k, axes: [0, -1, -2]) 826 | 827 | # {4,8,_16_} @ {4,_16_,8} = {4,8,8} 828 | wei = Nx.dot(q, [2], [0], kT, [1], [0]) 829 | 830 | # Scaled attention 831 | wei = wei * Nx.rsqrt(c) 832 | 833 | # attention masking 834 | tril = Tril.ones(shape: {t, t}) 835 | tril = Nx.broadcast(tril, wei) 836 | neg_inf = Nx.broadcast(Nx.Constants.neg_infinity(tensor_type), wei) 837 | # tril, wei, and neg_inf have the shape {b, t, t} 838 | wei = Nx.select(tril, wei, neg_inf) 839 | wei = Axon.Activations.softmax(wei, axis: -1) 840 | 841 | Nx.dot(wei, [-1], [0], v, [1], [0]) 842 | end 843 | end 844 | ``` 845 | 846 | 847 | 848 | ``` 849 | {:module, SingleAttention, <<70, 79, 82, 49, 0, 0, 18, ...>>, true} 850 | ``` 851 | 852 | ## Single-head attention model 853 | 854 | The general flow of this simple single-head attention model goes like this 855 | 856 | * input sequence is some encoded text 857 | * map our input into some n_embd dimensional space 858 | * map some position information into some n_embd dimensional space 859 | * note: GPT uses sine and cosine functions for positional encoding. we won't be implementing that 860 | * our position information is just some iota tensor with values `[0..t)` where t == sequence length of our input 861 | * add the tensors to produce a tensor filled with embedding + position information, we'll call this tensor `x` 862 | * feed x into three different layers: key, value, and query 863 | * compute self attention (this is complicated, this [youtube video](https://www.youtube.com/watch?v=g2BRIuln4uc) provides a great explanation) 864 | * note: For our single-head attention model, the size of the attention head is equal to n_embd. 865 | * project self attention output down to vocab_size. tensors coming out of this layer will look like `{b,t,vocab_size}`. These are our logits. 866 | 867 | Note: The implementation of the positional_embedding_table is a bit hacky. If somebody knows a better solution, I'd be curious to hear about it. 868 | 869 | ```elixir 870 | # Hyperparameters from https://youtu.be/kCc8FmEb1nY?t=4907 871 | batch_size = 32 872 | block_size = 8 873 | n_embd = 32 874 | head_size = n_embd 875 | 876 | # Model definition 877 | input = Axon.input("sequence") 878 | 879 | token_embedding_table = 880 | input 881 | |> Axon.embedding(vocab_size, n_embd) 882 | 883 | # Generate positional encodings for the input sequence (hacky) 884 | positions = 885 | Axon.nx(input, fn input -> 886 | {_batch_size, sequence_length} = Nx.shape(input) 887 | Nx.iota({sequence_length}) 888 | end) 889 | 890 | # Positional encodings get mapped into @n_embd space 891 | position_embedding_table = 892 | Axon.embedding(positions, block_size, n_embd, name: "position_embedding") 893 | 894 | x_layer = Axon.add(token_embedding_table, position_embedding_table) 895 | 896 | he_uniform = Axon.Initializers.he_uniform(scale: 0.5) 897 | key = x_layer |> Axon.dense(head_size, kernel_initializer: he_uniform, name: "key") 898 | query = x_layer |> Axon.dense(head_size, kernel_initializer: he_uniform, name: "query") 899 | value = x_layer |> Axon.dense(head_size, kernel_initializer: he_uniform, name: "value") 900 | 901 | single_head_model = 902 | SingleAttention.head_layer(key, query, value) 903 | |> Axon.dense(vocab_size, kernel_initializer: :he_uniform, name: "language_modeling_head") 904 | ``` 905 | 906 | 907 | 908 | ``` 909 | #Axon< 910 | inputs: %{"sequence" => nil} 911 | outputs: "language_modeling_head" 912 | nodes: 11 913 | > 914 | ``` 915 | 916 | ## Single-head attention model graph 917 | 918 | ```elixir 919 | Axon.Display.as_graph(single_head_model, Nx.template({batch_size, block_size}, :f32), 920 | direction: :top_down 921 | ) 922 | ``` 923 | 924 | 925 | 926 | ```mermaid 927 | graph TD; 928 | 48[/"sequence (:input) {32, 8}"/]; 929 | 49["embedding_0 (:embedding) {32, 8, 32}"]; 930 | 50["nx_0 (:nx) {8}"]; 931 | 51["position_embedding (:embedding) {8, 32}"]; 932 | 52["container_0 (:container) {{32, 8, 32}, {8, 32}}"]; 933 | 53["add_0 (:add) {32, 8, 32}"]; 934 | 54["key (:dense) {32, 8, 32}"]; 935 | 55["query (:dense) {32, 8, 32}"]; 936 | 56["value (:dense) {32, 8, 32}"]; 937 | 57["single_head_0 (:single_head) {32, 8, 32}"]; 938 | 58["language_modeling_head (:dense) {32, 8, 65}"]; 939 | 57 --> 58; 940 | 56 --> 57; 941 | 55 --> 57; 942 | 54 --> 57; 943 | 53 --> 56; 944 | 53 --> 55; 945 | 53 --> 54; 946 | 52 --> 53; 947 | 51 --> 52; 948 | 49 --> 52; 949 | 50 --> 51; 950 | 48 --> 50; 951 | 48 --> 49; 952 | ``` 953 | 954 | ## Training the single-head attention model 955 | 956 | I lowered the iterations just to speed up training on my machine. You can increase the numbers to get better results. 957 | 958 | ```elixir 959 | {init_fn, predict_fn} = Axon.build(single_head_model, mode: :train) 960 | custom_predict_fn = &CommonTrain.custom_predict_fn(predict_fn, &1, &2) 961 | custom_loss_fn = &CommonTrain.custom_loss_fn(&1, &2) 962 | train_data_stream = get_batch_stream.(batch_size, block_size, :train) 963 | 964 | params = 965 | {init_fn, custom_predict_fn} 966 | |> Axon.Loop.trainer(custom_loss_fn, Axon.Optimizers.adamw()) 967 | |> Axon.Loop.run(train_data_stream, %{}, epochs: 1, iterations: 3000, compiler: EXLA) 968 | ``` 969 | 970 | 971 | 972 | ``` 973 | 974 | 15:28:28.074 [debug] Forwarding options: [compiler: EXLA] to JIT compiler 975 | Epoch: 0, Batch: 2950, loss: 2.6103194 976 | ``` 977 | 978 | 979 | 980 | ``` 981 | %{ 982 | "embedding_0" => %{ 983 | "kernel" => #Nx.Tensor< 984 | f32[65][32] 985 | EXLA.Backend 986 | [ 987 | [-0.026233026757836342, 0.2256242036819458, 0.3369872272014618, 0.12701235711574554, 0.4724958539009094, 0.4927522838115692, -0.11964629590511322, 0.14982973039150238, 0.1527341604232788, 0.2076384723186493, -0.4839678704738617, 0.5153075456619263, -0.17069879174232483, 0.07722043991088867, -0.024972444400191307, -0.02413080632686615, 0.06961726397275925, -0.12247069180011749, -0.030579620972275734, 0.22726184129714966, -0.29395171999931335, -0.15108881890773773, 0.2248864620923996, 0.30082234740257263, -0.2780728042125702, -0.12545496225357056, -0.1338309943675995, 0.1244623139500618, 0.0377982035279274, 0.10663247853517532, -0.36269432306289673, -0.011278417892754078], 988 | [-0.03844957798719406, -0.07200898230075836, 0.17476718127727509, -0.13438276946544647, 0.0045249746181070805, -0.042801979929208755, 0.031422700732946396, 0.023107800632715225, 0.2904854416847229, -0.1530412882566452, -0.30277305841445923, 0.16146323084831238, -0.6766874194145203, 0.0561353974044323, -0.017431093379855156, 0.4477277100086212, ...], 989 | ... 990 | ] 991 | > 992 | }, 993 | "key" => %{ 994 | "bias" => #Nx.Tensor< 995 | f32[32] 996 | EXLA.Backend 997 | [5.339144845493138e-4, 0.0014086366863921285, -8.214047993533313e-4, -2.5494268629699945e-4, 1.6062534996308386e-4, -2.7622494962997735e-4, -5.117025575600564e-4, -4.4757919386029243e-4, 9.161723428405821e-4, -9.67931846389547e-5, 2.3132111527957022e-4, 2.641436003614217e-4, 0.0013516810722649097, 7.081329822540283e-4, 0.0016453240532428026, -3.157271712552756e-4, 3.824093146249652e-4, -6.019236170686781e-4, 8.673613774590194e-5, -8.229123777709901e-4, 2.3732471163384616e-4, -6.274062325246632e-4, -0.0019465215737000108, -3.8168931496329606e-5, -3.815832678810693e-5, -1.6878465248737484e-4, 3.705104973050766e-5, -1.038895788951777e-5, -6.699645891785622e-4, 6.461592274717987e-4, 2.1109878434799612e-4, 3.498121222946793e-4] 998 | >, 999 | "kernel" => #Nx.Tensor< 1000 | f32[32][32] 1001 | EXLA.Backend 1002 | [ 1003 | [0.08415064215660095, 0.2879759967327118, -0.24935944378376007, 0.007627225946635008, -0.3204685151576996, -0.899364709854126, 0.0017727279337123036, -0.40223392844200134, -0.3377895653247833, -0.24244733154773712, -0.22738304734230042, 0.22180818021297455, 0.22660820186138153, -0.4816625416278839, 0.23601596057415009, -0.11853193491697311, 0.3503369390964508, -0.18211157619953156, 0.4438214600086212, -0.8183256387710571, -0.298576682806015, 0.044904839247465134, -0.41670191287994385, 0.18856211006641388, -0.7515907883644104, 0.44034701585769653, 0.01015730295330286, 0.46664172410964966, -0.11723655462265015, -0.24692803621292114, 0.079217828810215, -0.2753612697124481], 1004 | [0.5019965767860413, 0.45832908153533936, -0.5738474726676941, -0.4480721056461334, 0.40961453318595886, -0.02529483661055565, -0.29848185181617737, -0.6949413418769836, 0.2270815372467041, 0.006744408048689365, 0.5465198755264282, 0.4061400890350342, 0.6404085159301758, 0.043757569044828415, ...], 1005 | ... 1006 | ] 1007 | > 1008 | }, 1009 | "language_modeling_head" => %{ 1010 | "bias" => #Nx.Tensor< 1011 | f32[65] 1012 | EXLA.Backend 1013 | [0.061188384890556335, 0.03874082490801811, -0.3440072238445282, -0.2977518141269684, -0.2497618943452835, -0.12782810628414154, -0.23234425485134125, -0.05885966122150421, -0.29599529504776, -0.2678046226501465, 0.2830163538455963, -0.29441750049591064, -0.3617924749851227, 0.015615391544997692, 0.05543321743607521, 0.059732530266046524, 2.2681929112877697e-4, 0.14945641160011292, -0.027915235608816147, 0.010921932756900787, 0.007053534034639597, 0.17028307914733887, -0.19982320070266724, -5.594325484707952e-4, 0.09194453060626984, 0.060529813170433044, -0.03193450719118118, 0.011559315957129002, -0.009883790276944637, -0.19082780182361603, 0.16168223321437836, -0.019084393978118896, 0.03505204990506172, 0.11881374567747116, -0.09905489534139633, 0.07346461713314056, -0.2048760950565338, 0.006363728549331427, -0.2438381165266037, 0.028238536790013313, -0.0257986132055521, -0.06388751417398453, 0.015318986028432846, 0.028752142563462257, 0.033384814858436584, -0.13916561007499695, ...] 1014 | >, 1015 | "kernel" => #Nx.Tensor< 1016 | f32[32][65] 1017 | EXLA.Backend 1018 | [ 1019 | [-0.22565065324306488, -0.16024763882160187, 0.14312513172626495, 0.7375252842903137, 0.6810034513473511, -0.21387973427772522, 0.09498835355043411, -0.16007646918296814, 0.3423072397708893, 0.7293732166290283, 0.4669094383716583, 0.25795578956604004, 0.3606616258621216, -0.15745210647583008, -0.2844543755054474, -0.35039520263671875, 0.06515390425920486, 0.18357262015342712, 0.35768988728523254, -0.12733156979084015, 0.02732028067111969, -0.48360756039619446, -0.03461068868637085, -0.05969618633389473, -0.14592662453651428, 0.060609374195337296, -0.05544782429933548, -0.2138560712337494, 0.11534108221530914, 0.1720675826072693, 0.20025299489498138, 0.125869020819664, 0.020521463826298714, 0.38288405537605286, -0.19809487462043762, -0.06728346645832062, 0.22307513654232025, -0.3827643394470215, 0.056872107088565826, -0.47563377022743225, 0.12778759002685547, -0.13875481486320496, -0.29747429490089417, -0.21030990779399872, -0.021548548713326454, ...], 1020 | ... 1021 | ] 1022 | > 1023 | }, 1024 | "position_embedding" => %{ 1025 | "kernel" => #Nx.Tensor< 1026 | f32[8][32] 1027 | EXLA.Backend 1028 | [ 1029 | [0.23624907433986664, -0.021182071417570114, 0.08677957952022552, -0.10049070417881012, -0.05136618763208389, -0.1526702344417572, 0.281761109828949, 0.14189034700393677, -0.02184602990746498, 0.3414129316806793, 0.007780917454510927, -0.17955872416496277, -4.2804042459465563e-4, -0.10944950580596924, -0.11458509415388107, 0.09803333878517151, -0.013304184190928936, 0.2933603525161743, -0.05242357775568962, 0.13128063082695007, -0.07299972325563431, -0.052072543650865555, -0.018173260614275932, 0.04962927848100662, -0.001993720419704914, -0.01199309527873993, 0.13491109013557434, -0.02931121736764908, -0.09115011245012283, -0.05176560580730438, 0.04005185514688492, 0.0767948254942894], 1030 | [0.17488150298595428, -0.037036482244729996, 0.05705530196428299, -0.11671741306781769, -0.039749953895807266, -0.11059848219156265, 0.2114427387714386, 0.04657561331987381, -0.020749395713210106, 0.289633572101593, 0.040833499282598495, -0.13450975716114044, 0.03456944227218628, ...], 1031 | ... 1032 | ] 1033 | > 1034 | }, 1035 | "query" => %{ 1036 | "bias" => #Nx.Tensor< 1037 | f32[32] 1038 | EXLA.Backend 1039 | [-0.034071750938892365, -0.28555038571357727, 0.29531151056289673, 0.00276755727827549, 0.4126724600791931, 0.354778915643692, -0.04076027125120163, 0.112078957259655, -0.01396627351641655, 0.010131004266440868, 0.09511833637952805, -0.334270179271698, -0.23874659836292267, 0.11469864100217819, -0.2345762699842453, -0.016118774190545082, -0.08340506255626678, -0.16744185984134674, -0.43535640835762024, 0.4058022201061249, 0.11186989396810532, 0.060189343988895416, 0.4751230776309967, -0.4158278703689575, 0.2699109613895416, 0.026102382689714432, 0.04701479151844978, -0.15357209742069244, 0.11870138347148895, 0.1022055521607399, -0.05558203160762787, 0.2189713716506958] 1040 | >, 1041 | "kernel" => #Nx.Tensor< 1042 | f32[32][32] 1043 | EXLA.Backend 1044 | [ 1045 | [0.18731792271137238, 0.2485392838716507, -0.1639232486486435, -0.11857520788908005, 0.3277408480644226, -0.2164953500032425, -0.061783526092767715, -0.12976013123989105, -0.07613621652126312, -0.22404593229293823, -0.1035078838467598, -0.23175814747810364, -0.5759645700454712, -0.31334614753723145, -0.2222266048192978, 0.0027542654424905777, -0.11652868986129761, -0.3812398314476013, -0.22661450505256653, 0.10356166958808899, 0.1277807503938675, -0.020671674981713295, -8.294901927001774e-4, -0.29352620244026184, -0.14790385961532593, -0.07956714928150177, -0.16293348371982574, 0.14328938722610474, 0.14654004573822021, 0.3405868113040924, -0.056984834372997284, 0.0869172215461731], 1046 | [0.15397128462791443, 0.11139403283596039, -0.14481408894062042, -0.0590481162071228, -0.21032772958278656, -0.25429296493530273, -0.03941408917307854, 0.05786168947815895, -0.37902113795280457, -0.13650669157505035, -0.0036288651172071695, ...], 1047 | ... 1048 | ] 1049 | > 1050 | }, 1051 | "value" => %{ 1052 | "bias" => #Nx.Tensor< 1053 | f32[32] 1054 | EXLA.Backend 1055 | [-0.21108631789684296, -9.746256982907653e-4, 0.12640845775604248, -0.02774987183511257, -0.172824889421463, 0.11715316772460938, 0.11779537796974182, 0.03865457698702812, 0.01370930578559637, 0.08419355005025864, 0.05815295875072479, -0.1265273094177246, -0.042062751948833466, -0.13843882083892822, 0.1025078222155571, -0.12344934791326523, -0.019868047907948494, 0.1059509813785553, 0.08318553864955902, 0.007491858210414648, 0.07789982110261917, -0.05841176211833954, -0.08457578718662262, -0.06597407907247543, -0.06657955795526505, -0.100681371986866, -0.2102871686220169, -0.0254219900816679, 0.034235186874866486, -0.20986993610858917, -0.18147148191928864, 4.758408176712692e-4] 1056 | >, 1057 | "kernel" => #Nx.Tensor< 1058 | f32[32][32] 1059 | EXLA.Backend 1060 | [ 1061 | [0.025911198928952217, -0.5517228841781616, -0.3961946666240692, -0.6812283396720886, -0.4301234483718872, 0.06663434207439423, 0.23885639011859894, -0.3583765923976898, 0.20131313800811768, 0.27871912717819214, 0.5112564563751221, -0.7567049264907837, 0.6266830563545227, -0.10098028928041458, -0.24578642845153809, 8.617832645541057e-5, 0.21004417538642883, -0.3550078272819519, -0.19582337141036987, 0.6274057030677795, 0.3725995421409607, -0.2865595817565918, -0.12431655079126358, -0.0691077709197998, -0.05854438990354538, -0.2952299416065216, 0.6136438250541687, 0.2543925344944, 0.37929919362068176, -0.436009019613266, -0.012473562732338905, 0.4972887635231018], 1062 | [0.034896500408649445, -0.011460673995316029, -0.30139535665512085, 0.4306725859642029, 0.3097185790538788, 0.2907136082649231, -0.33310818672180176, -0.06842639297246933, 0.3035053312778473, -0.5042290091514587, ...], 1063 | ... 1064 | ] 1065 | > 1066 | } 1067 | } 1068 | ``` 1069 | 1070 | ## Generating text w/ single-head attention model 1071 | 1072 | ```elixir 1073 | init_seq = Nx.broadcast(0, {1, 1}) 1074 | 1075 | TextGen.generate(single_head_model, params, init_seq, block_size, max_new_tokens: 1000) 1076 | |> IO.puts() 1077 | ``` 1078 | 1079 | 1080 | 1081 | ``` 1082 | 1083 | S: 1084 | Be tins te st ave ind goh bkew sis it ay dterimevend mou, titheap heakino gh wous in 1085 | Mou duteand bored sthom ome th Cor If thof win m d lso herave sanll istticaveerk fangs wirs hal rt poir mer! 1086 | 1087 | PAROMRUSgeadaveprals, tled ha bes dche the yan, wig, he tir go ait ch orsu, dry, 1088 | I: thillllorilay, knesea whe tou g B fiI br Spe thant de iss ouse, wofrorfe hen wot wacs noru k des hsf D wacer sse hikinpe V: 1089 | Anons se; 1090 | We Afe aythalpre, k tem five: my th nt gres ald 1091 | FO: 1092 | TOrer mers iMe s 1093 | Thesthe on prse or 1094 | Sde, iche whicome thay imererohth haghrs hant our wt tit ne Halldou, uncer Ly opiley cilane gn th on: 1095 | Loucrse ourm hes I lade be I Tof nechan: 1096 | Ce toh ish. 1097 | IINPENOERUCALER: 1098 | Wie ous st st on. 1099 | 1100 | DULADUCUCHol I a haerire 1101 | Ovif ft fimoturs, ss on we I th ys hakord he, 1102 | This me ssen mperima t hif ydsw od ghe houlth fhee yon ut, s ch rueler pavik g o-rt hthe susto I the hes owwe benes yseay whady hs whe ingr thur yos toucken whe faivite n hul nd burativit. 1103 | 1104 | Boun rd farist ae gh aned onkd, who fifa le, 1105 | ``` 1106 | 1107 | 1108 | 1109 | ``` 1110 | :ok 1111 | ``` 1112 | 1113 | ## Going from single head to multi head attention 1114 | 1115 | If you're following along with the video, we're currently at this [part.](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=4919s) 1116 | 1117 | Karpathy uses the following OOP code to create multiple heads of attention, but we can do the same thing by reshaping our original key, query, and value layer. 1118 | 1119 | ``` 1120 | self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) 1121 | ``` 1122 | 1123 | ### Reshaping our key, query, value layers 1124 | 1125 | Remember our single-head attention layer had k,q,v of size `{b, t, n_embd}`. In order to compute multi-head attention, we'll split this `n_embd` portion into `{n_head, head_size}`. (head_size is known as hidden_size in other places like Bumblebee). This lets us view the tensor as having multiple heads. For example, let x be some key/query/value tensor. 1126 | 1127 | 1. `{b,t,n_embd} = Nx.shape(x)` 1128 | 2. Reshape the last axis of x to become `{b,t,n_heads,head_size}` 1129 | 3. Transpose x so that the `t` and `n_heads` axes become swapped `{b, n_heads, t, head_size}` 1130 | 4. x is now ready to multiply with some other key, query, value tensor that has also undergone this reshaping. When `{b, n_heads, t, head_size}` @ `{b, n_heads, t, head_size}`, it'll produce a tensor of shape `{b, n_heads, t, t}`. This `{t,t}` portion is where we apply our attention mask. 1131 | 1132 | This is also how it's done in the Bumblebee library 1133 | 1134 | * https://github.com/elixir-nx/bumblebee/blob/main/lib/bumblebee/layers.ex#L525 1135 | * https://github.com/elixir-nx/bumblebee/blob/main/lib/bumblebee/layers.ex#L194 1136 | * https://github.com/elixir-nx/bumblebee/blob/main/lib/bumblebee/layers.ex#L220 1137 | 1138 | ```elixir 1139 | # simple multi-head attention without any optimizations like layernorm / feedforward 1140 | defmodule MultiAttention do 1141 | import Nx.Defn 1142 | 1143 | @doc """ 1144 | Modified from Bumblebee's transformer.ex 1145 | https://github.com/elixir-nx/bumblebee/blob/main/lib/bumblebee/layers.ex#L525 1146 | 1147 | Splits the hidden dimension into the given number of attention heads. 1148 | 1149 | In other words, the input with shape `{batch_size, sequence_length, hidden_size}` 1150 | is reshaped to `{batch_size, sequence_length, num_heads, hidden_size}`. 1151 | Then, transposed to `{batch_size, num_heads, sequence_length, *}` 1152 | """ 1153 | def split_heads(states, num_heads, opts \\ []) do 1154 | opts = Keyword.validate!(opts, name: "split_heads") 1155 | 1156 | Axon.nx( 1157 | states, 1158 | fn states -> 1159 | batch_size = Nx.axis_size(states, 0) 1160 | sequence_length = Nx.axis_size(states, 1) 1161 | new_shape = {batch_size, sequence_length, num_heads, :auto} 1162 | 1163 | states 1164 | |> Nx.reshape(new_shape) 1165 | |> Nx.transpose(axes: [0, 2, 1, 3]) 1166 | end, 1167 | name: opts[:name] 1168 | ) 1169 | end 1170 | 1171 | def multi_head_layer(%Axon{} = x, num_heads, head_size, opts \\ []) do 1172 | default_initializer = Axon.Initializers.he_uniform(scale: 0.5) 1173 | opts = Keyword.validate!(opts, kernel_initializer: default_initializer) 1174 | initializer = opts[:kernel_initializer] 1175 | 1176 | key = 1177 | x 1178 | |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "key") 1179 | |> split_heads(num_heads) 1180 | 1181 | query = 1182 | x 1183 | |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "query") 1184 | |> split_heads(num_heads) 1185 | 1186 | value = 1187 | x 1188 | |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "value") 1189 | |> split_heads(num_heads) 1190 | 1191 | Axon.layer(&multi_head_layer_impl/4, [key, query, value], name: "multi_head_attention") 1192 | end 1193 | 1194 | # Custom layers require the opts argument 1195 | # https://hexdocs.pm/axon/custom_layers.html#creating-custom-layers 1196 | defn multi_head_layer_impl(k, q, v, _opts \\ []) do 1197 | {b, h, t, c} = Nx.shape(k) 1198 | tensor_type = Nx.type(k) 1199 | 1200 | # {b, h, t, c} @ {b, h, c, t} -> {b, h, t, t} 1201 | wei = Nx.dot(q, [3], [0, 1], k, [3], [0, 1]) 1202 | 1203 | # Scaled attention 1204 | wei = wei * Nx.rsqrt(c) 1205 | 1206 | # Attention masking 1207 | tril = Tril.ones(shape: {t, t}) 1208 | tril = Nx.broadcast(tril, wei) 1209 | neg_inf = Nx.broadcast(Nx.Constants.neg_infinity(tensor_type), wei) 1210 | wei = Nx.select(tril, wei, neg_inf) 1211 | wei = Axon.Activations.softmax(wei, axis: -1) 1212 | 1213 | # {b, h, t, t} @ {b, h, t, head_size} -> {b, h, t, head_size} 1214 | out = Nx.dot(wei, [3], [0, 1], v, [2], [0, 1]) 1215 | 1216 | # Tranpose so we can stack the heads on top of each other 1217 | # {b, h, t, c} -> {b, t, h, c} 1218 | out = Nx.transpose(out, axes: [0, 2, 1, 3]) 1219 | 1220 | # Our output tensor is now enriched with attention information 1221 | # We shape it back to {b, t, c} 1222 | # This gives us the proper shape to add to our original input x 1223 | Nx.reshape(out, {b, t, h * c}) 1224 | end 1225 | end 1226 | ``` 1227 | 1228 | 1229 | 1230 | ``` 1231 | {:module, MultiAttention, <<70, 79, 82, 49, 0, 0, 26, ...>>, true} 1232 | ``` 1233 | 1234 | ## Multi-head attention model 1235 | 1236 | ```elixir 1237 | # We'll reuse the hyperparameters from single-head attention model but change head size 1238 | # batch_size = 32 1239 | # block_size = 8 1240 | # n_embd = 32 1241 | n_heads = 4 1242 | head_size = div(n_embd, n_heads) 1243 | 1244 | multi_head_model = 1245 | Axon.input("sequence") 1246 | |> then(fn input -> 1247 | # Create an embedding for the input data 1248 | token_embedding_table = Axon.embedding(input, vocab_size, n_embd, name: "token_embedding") 1249 | 1250 | # Generate positional encodings for the input sequence (hacky, couldn't find alternative) 1251 | positions = 1252 | Axon.nx(input, fn input -> 1253 | {_batch_size, sequence_length} = Nx.shape(input) 1254 | Nx.iota({sequence_length}) 1255 | end) 1256 | 1257 | # Positional encodings get mapped into @n_embd space 1258 | position_embedding_table = 1259 | Axon.embedding(positions, block_size, n_embd, name: "position_embedding") 1260 | 1261 | # Add the two layers above to produce tensors containing embedding + position info 1262 | Axon.add(token_embedding_table, position_embedding_table, name: "x_positional_encoding") 1263 | end) 1264 | |> MultiAttention.multi_head_layer(n_heads, head_size) 1265 | |> Axon.dense(vocab_size, kernel_initializer: :he_uniform, name: "language_modeling_head") 1266 | ``` 1267 | 1268 | 1269 | 1270 | ``` 1271 | #Axon< 1272 | inputs: %{"sequence" => nil} 1273 | outputs: "language_modeling_head" 1274 | nodes: 14 1275 | > 1276 | ``` 1277 | 1278 | ## Multi-head attention model graph 1279 | 1280 | Notice how the key, query, value layers split to become 4-dimensional. 1281 | 1282 | ```elixir 1283 | Axon.Display.as_graph(multi_head_model, Nx.template({batch_size, block_size}, :f32), 1284 | direction: :top_down 1285 | ) 1286 | ``` 1287 | 1288 | 1289 | 1290 | ```mermaid 1291 | graph TD; 1292 | 59[/"sequence (:input) {32, 8}"/]; 1293 | 60["token_embedding (:embedding) {32, 8, 32}"]; 1294 | 61["nx_0 (:nx) {8}"]; 1295 | 62["position_embedding (:embedding) {8, 32}"]; 1296 | 63["container_0 (:container) {{32, 8, 32}, {8, 32}}"]; 1297 | 64["x_positional_encoding (:add) {32, 8, 32}"]; 1298 | 65["key (:dense) {32, 8, 32}"]; 1299 | 66["split_heads (:nx) {32, 4, 8, 8}"]; 1300 | 67["query (:dense) {32, 8, 32}"]; 1301 | 68["split_heads (:nx) {32, 4, 8, 8}"]; 1302 | 69["value (:dense) {32, 8, 32}"]; 1303 | 70["split_heads (:nx) {32, 4, 8, 8}"]; 1304 | 71["multi_head_attention (:custom) {32, 8, 32}"]; 1305 | 72["language_modeling_head (:dense) {32, 8, 65}"]; 1306 | 71 --> 72; 1307 | 70 --> 71; 1308 | 68 --> 71; 1309 | 66 --> 71; 1310 | 69 --> 70; 1311 | 64 --> 69; 1312 | 67 --> 68; 1313 | 64 --> 67; 1314 | 65 --> 66; 1315 | 64 --> 65; 1316 | 63 --> 64; 1317 | 62 --> 63; 1318 | 60 --> 63; 1319 | 61 --> 62; 1320 | 59 --> 61; 1321 | 59 --> 60; 1322 | ``` 1323 | 1324 | ## Training the multi-head attention model 1325 | 1326 | ```elixir 1327 | {init_fn, predict_fn} = Axon.build(multi_head_model, mode: :train) 1328 | custom_predict_fn = &CommonTrain.custom_predict_fn(predict_fn, &1, &2) 1329 | custom_loss_fn = &CommonTrain.custom_loss_fn(&1, &2) 1330 | train_data_stream = get_batch_stream.(batch_size, block_size, :train) 1331 | 1332 | params = 1333 | {init_fn, custom_predict_fn} 1334 | |> Axon.Loop.trainer(custom_loss_fn, Axon.Optimizers.adamw()) 1335 | |> Axon.Loop.run(train_data_stream, %{}, epochs: 1, iterations: 3000, compiler: EXLA) 1336 | ``` 1337 | 1338 | 1339 | 1340 | ``` 1341 | 1342 | 15:05:05.341 [debug] Forwarding options: [compiler: EXLA] to JIT compiler 1343 | Epoch: 0, Batch: 2950, loss: 2.5504496 1344 | ``` 1345 | 1346 | 1347 | 1348 | ``` 1349 | %{ 1350 | "key" => %{ 1351 | "bias" => #Nx.Tensor< 1352 | f32[32] 1353 | EXLA.Backend 1354 | [-0.002009622985497117, 0.0030817119404673576, 0.002637566765770316, 8.874150225892663e-4, -0.0036697951145470142, 0.0026014624163508415, -8.77177866641432e-4, -8.539229747839272e-4, -0.010210449807345867, 0.0027200556360185146, -0.004561145324259996, 3.5793684219243005e-5, -0.0037038149312138557, -4.817320223082788e-5, 7.681822753511369e-4, -5.375476903282106e-4, -0.004867691546678543, 8.054355857893825e-4, 0.0029539004899561405, -0.0027021823916584253, 0.0010204947320744395, 0.00340187456458807, -0.0045007625594735146, -5.496059893630445e-4, -3.7173699820414186e-4, 0.001366788404993713, 0.0013142818352207541, 4.228678299114108e-4, -0.005013817455619574, 0.006275582127273083, -0.004706717096269131, -7.051487336866558e-4] 1355 | >, 1356 | "kernel" => #Nx.Tensor< 1357 | f32[32][32] 1358 | EXLA.Backend 1359 | [ 1360 | [0.6940197348594666, -0.09190584719181061, -0.7657436728477478, 0.13503895699977875, 0.44908562302589417, -0.8816263675689697, -0.05520651862025261, 0.15477292239665985, 0.569190263748169, -0.18743273615837097, 0.07919623702764511, -0.4803193509578705, 0.3279154300689697, 0.38622552156448364, -0.5753573775291443, -0.5578794479370117, -0.2834005355834961, 0.33334460854530334, 0.26077544689178467, -0.45049938559532166, 0.150202214717865, 0.3185662031173706, -0.030324669554829597, -0.3234182894229889, -0.1026887372136116, -0.19870352745056152, 0.3214745819568634, 0.5666719675064087, -0.3206053674221039, 0.23958717286586761, -0.0799722820520401, 0.4348626732826233], 1361 | [-0.1126786544919014, -0.09792473912239075, 0.9408704042434692, -0.6795709133148193, -0.19755850732326508, 0.2359877973794937, -0.01693005859851837, 0.10840979963541031, -0.250237375497818, -0.8673456907272339, 0.5119004845619202, 0.16578194499015808, -0.11029595881700516, -0.2691425681114197, 0.793876051902771, ...], 1362 | ... 1363 | ] 1364 | > 1365 | }, 1366 | "language_modeling_head" => %{ 1367 | "bias" => #Nx.Tensor< 1368 | f32[65] 1369 | EXLA.Backend 1370 | [0.08884071558713913, 0.06958547234535217, -0.32006266713142395, -0.30319324135780334, -0.26640912890434265, 0.11584974825382233, -0.14481262862682343, -0.17550241947174072, -0.323355495929718, -0.3057439625263214, 0.1047697365283966, -0.26720258593559265, -0.27166327834129333, 0.11763342469930649, 2.9158510733395815e-4, 0.07884888350963593, -0.03157912194728851, 0.19413268566131592, -0.15487632155418396, 0.010517282411456108, -0.04269542172551155, 0.2341887205839157, -0.1754981130361557, -0.11905181407928467, 0.13274210691452026, 0.0029528785962611437, 0.07284557819366455, 0.1600860059261322, -0.08310021460056305, -0.2266455441713333, 0.25225019454956055, 0.023659436032176018, -0.10211014747619629, 0.1650671660900116, -0.08954048156738281, -0.1098952367901802, -0.2547689974308014, -0.07372691482305527, -0.1818724125623703, 0.12592221796512604, -0.040537815541028976, -0.031494904309511185, 0.027326008304953575, 0.060592930763959885, 0.14337053894996643, -0.13098326325416565, -0.05378049239516258, ...] 1371 | >, 1372 | "kernel" => #Nx.Tensor< 1373 | f32[32][65] 1374 | EXLA.Backend 1375 | [ 1376 | [-0.7125491499900818, -0.46839943528175354, -0.2697147727012634, 0.42317184805870056, 0.1936987042427063, -0.38992878794670105, -0.06144484877586365, -0.1245880275964737, -0.14166134595870972, 0.33889245986938477, -0.11005877703428268, -0.20245349407196045, -0.36581820249557495, -0.20540089905261993, -0.21097524464130402, 0.3619268238544464, 0.3771439492702484, 0.06898072361946106, -0.18923085927963257, -0.019612403586506844, 0.2706209719181061, 0.316038578748703, 0.04826759919524193, 0.5845080018043518, -0.23023158311843872, -0.18634116649627686, -0.137278214097023, -0.08182564377784729, 0.5525726675987244, 0.2282651960849762, 0.29154810309410095, -0.18772584199905396, 0.1221810132265091, 0.006088509690016508, 0.20453433692455292, 0.10518457740545273, -0.23493321239948273, 0.36128759384155273, 0.3855167627334595, 0.2623154819011688, -0.2311507612466812, -0.2841350734233856, 0.05035056173801422, 0.12662070989608765, 0.13019651174545288, 0.32261139154434204, ...], 1377 | ... 1378 | ] 1379 | > 1380 | }, 1381 | "position_embedding" => %{ 1382 | "kernel" => #Nx.Tensor< 1383 | f32[8][32] 1384 | EXLA.Backend 1385 | [ 1386 | [-0.1628115475177765, 0.09981940686702728, -0.3936239182949066, 0.13354669511318207, 0.1114703118801117, -0.2746349275112152, 0.07311611622571945, -0.08590993285179138, 0.20255635678768158, 0.09707771986722946, -0.07302407920360565, -0.09127600491046906, -0.08702699840068817, -0.15735533833503723, -0.282254159450531, -0.240857794880867, -0.12304611504077911, -0.21165424585342407, 8.840659284032881e-4, -0.012054992839694023, -0.06108580157160759, -0.03755871579051018, 0.3343445956707001, -0.16978250443935394, 0.12578986585140228, -0.0284324511885643, -0.03782382979989052, -0.18365296721458435, 0.02894110605120659, -0.012323955073952675, -0.05722039192914963, -0.34137609601020813], 1387 | [-0.14392606914043427, 0.08419756591320038, -0.4101002812385559, 0.1324985772371292, -0.001000418676994741, -0.15875868499279022, 0.055551934987306595, 0.021683527156710625, 0.0783766582608223, 0.06663555651903152, -0.09983355551958084, -0.1321704238653183, -0.08909580856561661, -0.10228505730628967, ...], 1388 | ... 1389 | ] 1390 | > 1391 | }, 1392 | "query" => %{ 1393 | "bias" => #Nx.Tensor< 1394 | f32[32] 1395 | EXLA.Backend 1396 | [0.30744561553001404, -0.40035614371299744, -0.600159227848053, 0.14172907173633575, 0.5269912481307983, -0.4752598702907562, 0.04822060838341713, 0.27492329478263855, 1.0218167304992676, -0.03702900931239128, 0.2820287048816681, -0.2207508534193039, 0.18530860543251038, 0.32630085945129395, -0.33638694882392883, -0.1988237202167511, -0.17499062418937683, 0.6527702808380127, -0.07235650718212128, -0.09215840697288513, -0.11274632811546326, 0.10494064539670944, -0.2722526788711548, -0.5865284204483032, -0.26377955079078674, 0.22697702050209045, -0.48575764894485474, -0.28045541048049927, 0.23497842252254486, -0.31878966093063354, 0.04959775507450104, 0.42344802618026733] 1397 | >, 1398 | "kernel" => #Nx.Tensor< 1399 | f32[32][32] 1400 | EXLA.Backend 1401 | [ 1402 | [0.3361887037754059, 0.6659984588623047, -0.01123120915144682, 0.15361960232257843, -0.8631391525268555, -0.2230737805366516, -0.6785314083099365, 0.3036029636859894, -0.3403639793395996, -0.01705137826502323, -0.30506935715675354, -0.22626237571239471, 0.02393539436161518, 0.22523565590381622, 0.04351476952433586, -0.38311371207237244, 0.24531424045562744, 0.1956530660390854, -0.20927302539348602, 0.24935917556285858, -0.05999097228050232, -0.15577024221420288, 0.4829252362251282, 0.030667277052998543, -0.21252982318401337, -0.3893415629863739, 0.2501755654811859, 0.08658578991889954, -0.07877915352582932, -0.21714217960834503, 0.36494797468185425, 0.41473671793937683], 1403 | [0.6690793037414551, -0.5781891942024231, -0.45157963037490845, 0.06540126353502274, 0.2878626883029938, -0.34218332171440125, -0.06560727953910828, 0.5365062355995178, -0.005141665227711201, -0.9268913269042969, 0.36916038393974304, 0.6808081269264221, ...], 1404 | ... 1405 | ] 1406 | > 1407 | }, 1408 | "token_embedding" => %{ 1409 | "kernel" => #Nx.Tensor< 1410 | f32[65][32] 1411 | EXLA.Backend 1412 | [ 1413 | [0.26007893681526184, -0.34626492857933044, 0.02932133339345455, -0.20337288081645966, 0.3459436595439911, 0.04238921031355858, 0.12850745022296906, 0.1433255523443222, -0.14038851857185364, -0.2880373001098633, -0.049702197313308716, -0.06888625770807266, 0.3151685297489166, 0.34581026434898376, 0.05190388113260269, -0.0947747752070427, 0.18888181447982788, 0.1588648557662964, 0.005344870965927839, 0.31078749895095825, 0.35703420639038086, -0.4355168342590332, -0.315531849861145, 0.014955342747271061, 0.20505090057849884, 0.37664178013801575, 0.3216913640499115, -0.02835514210164547, -0.17066606879234314, -0.5279961228370667, -0.07544543594121933, 2.824735129252076e-4], 1414 | [0.19091172516345978, 0.30593302845954895, 0.02575504221022129, -0.3912547528743744, 0.3915358781814575, -0.08316769450902939, 0.23849405348300934, -0.31539011001586914, -0.13143017888069153, 0.014384283684194088, -0.014175496995449066, 0.1446840763092041, ...], 1415 | ... 1416 | ] 1417 | > 1418 | }, 1419 | "value" => %{ 1420 | "bias" => #Nx.Tensor< 1421 | f32[32] 1422 | EXLA.Backend 1423 | [-0.06650947034358978, -0.07498624920845032, -0.10041021555662155, -0.07517606765031815, -0.017928874120116234, -0.018348190933465958, 0.025849999859929085, 0.024769101291894913, 0.11090978980064392, 0.1781088411808014, 0.1638837605714798, -0.07431018352508545, 0.013822168111801147, 0.07023237645626068, 0.21814927458763123, 0.11975128948688507, 0.03838825970888138, -0.2842266857624054, -0.1748587042093277, 0.1080666333436966, -0.09363164007663727, 0.06730080395936966, -0.05205147713422775, 0.15445540845394135, 0.1409216672182083, 0.02535940520465374, 0.05307505652308464, -0.10651924461126328, 0.08264414966106415, 0.012388779781758785, -0.03457631170749664, 0.054180171340703964] 1424 | >, 1425 | "kernel" => #Nx.Tensor< 1426 | f32[32][32] 1427 | EXLA.Backend 1428 | [ 1429 | [0.1357257217168808, -0.3345215320587158, -0.0030786781571805477, 0.9356716871261597, 0.3679511547088623, 0.9085726141929626, -0.2707551121711731, 0.1832733303308487, -0.09216760098934174, 0.46946004033088684, 0.19642551243305206, 0.3595341444015503, -0.005219850689172745, -0.0030760210938751698, 0.14641784131526947, -0.3859373927116394, -0.5769176483154297, -0.4275362491607666, -0.4593873620033264, -0.3144780099391937, -0.08384311199188232, -0.05492587760090828, 0.3388808071613312, -0.40142741799354553, -0.3029492497444153, 0.04506917670369148, -0.2736053764820099, 0.21720239520072937, -0.19470778107643127, 0.4152255356311798, -0.5121222138404846, 0.19692230224609375], 1430 | [0.03740651533007622, -0.3748472332954407, 0.28940436244010925, 0.3966611325740814, 0.23191681504249573, -0.014880786649882793, -0.3704621493816376, -0.08446381241083145, 0.23995298147201538, 0.20541788637638092, ...], 1431 | ... 1432 | ] 1433 | > 1434 | } 1435 | } 1436 | ``` 1437 | 1438 | ## Generating text w/ multi-head attention model 1439 | 1440 | ```elixir 1441 | init_seq = Nx.broadcast(0, {1, 1}) 1442 | 1443 | TextGen.generate(multi_head_model, params, init_seq, block_size, max_new_tokens: 1000) 1444 | |> IO.puts() 1445 | ``` 1446 | 1447 | 1448 | 1449 | ``` 1450 | 1451 | TG: 1452 | Bath thod so aug is ngon bisu: kincs awimuoth my so mouctisss, mand dorn go wove in 1453 | Mth cuth in bot oferes bole tha pmabueres by. 1454 | 1455 | I 1456 | F hou helese s; tie. 1457 | Tqua ubmsil, 1458 | Mut wiss hal st queraper-MI? 1459 | 1460 | NHesctaquou, u foruss, be Mant de: 1461 | Hot by me, of he bush go ait cock: you tus itheepcllkire ewor: ourd 1462 | belathe ga De; 1463 | Gem 1464 | An'dd elu de iss othe trudngeeg heinton, abud thu le it houc hy alrd? 1465 | A RRES: 1466 | GLD, I tho se, the 1467 | Go ayur ford wher ne ith onw thars grhor gh! 1468 | 1469 | HNY wit ment ibe s 1470 | Thesthe of prrid m; horin lomy terar thes in. Sor tham bout he whis duvenes ol Galldss dre? 1471 | PLAved in,, con hin res, 1472 | Se omir tom prir her I k bre aice the. 1473 | 1474 | I chak bye'd the thir is ou Pear hou arribpreng ok I hin brat my baacan broul nut aly fillexrt: so pe, ge ith yran dror, he this me this moflila tane hy yu of gin, to wol hecexhadpredd 1475 | Bo rucker quus fich rramy matshrs I the her: yrin hist yount whe yous whe ingr thur you tor one whe faly-thar hon ne. 1476 | Nomsowoun hecle se farerve be! 1477 | 1478 | ARNK: 1479 | Horecwes fol, 1480 | I d 1481 | ``` 1482 | 1483 | 1484 | 1485 | ``` 1486 | :ok 1487 | ``` 1488 | 1489 | ## CheckpointHelper 1490 | 1491 | Training the final multi-layer, multi-head attention model is going to take some time. I got GPT4 to create this `CheckpointHelper` module to help resume training with the latest checkpoint. 1492 | 1493 | Checkpoints are stored by default in the "checkpoint" directory. In my case, `"/Users/[user]/checkpoint/"`. 1494 | 1495 | ```elixir 1496 | defmodule CheckpointHelper do 1497 | def load_last_checkpoint(%Axon.Loop{} = loop, checkpoint_path) do 1498 | with {:ok, last_checkpoint} <- get_most_recent_checkpoint(checkpoint_path) do 1499 | last_state = 1500 | (checkpoint_path <> "/" <> last_checkpoint) 1501 | |> IO.inspect(label: "Resuming training from this checkpoint") 1502 | |> File.read!() 1503 | |> Axon.Loop.deserialize_state() 1504 | 1505 | Axon.Loop.from_state(loop, last_state) 1506 | else 1507 | _ -> 1508 | IO.puts("Starting training from scratch") 1509 | loop 1510 | end 1511 | end 1512 | 1513 | defp get_most_recent_checkpoint(dir_path) do 1514 | {:ok, filenames} = File.ls(dir_path) 1515 | 1516 | filenames 1517 | |> Enum.filter(&String.starts_with?(&1, "gpt_checkpoint_")) 1518 | |> Enum.map(fn filename -> 1519 | [_, checkpoint1, checkpoint2] = Regex.run(~r/gpt_checkpoint_(\d+)_(\d+)/, filename) 1520 | {String.to_integer(checkpoint1), String.to_integer(checkpoint2), filename} 1521 | end) 1522 | |> Enum.max_by(fn {checkpoint1, checkpoint2, _} -> {checkpoint1, checkpoint2} end, fn -> 1523 | nil 1524 | end) 1525 | |> case do 1526 | nil -> 1527 | {:error, "No checkpoint file found"} 1528 | 1529 | {_, _, filename} -> 1530 | {:ok, filename} 1531 | end 1532 | end 1533 | end 1534 | 1535 | checkpoint_path = "checkpoint" 1536 | 1537 | checkpoint_file_pattern = fn %Axon.Loop.State{epoch: epoch, iteration: iter} -> 1538 | "gpt_checkpoint_#{epoch}_#{iter}" 1539 | end 1540 | ``` 1541 | 1542 | 1543 | 1544 | ``` 1545 | #Function<42.3316493/1 in :erl_eval.expr/6> 1546 | ``` 1547 | 1548 | ## Scaled up, multi-layer, multi-head attention model 1549 | 1550 | This is the final model w/ the following optimizations 1551 | 1552 | * feed forward layer 1553 | * residual connections 1554 | * layer norms 1555 | * multi layer blocks 1556 | 1557 | This section covers everything onwards from this point in the [video](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=5089s). 1558 | 1559 | ```elixir 1560 | defmodule Transformer do 1561 | import Nx.Defn 1562 | 1563 | def blocks(%Axon{} = x, n_blocks, n_embd, n_head, opts \\ []) do 1564 | opts = Keyword.validate!(opts, dropout_rate: 0.0) 1565 | 1566 | x = 1567 | for _ <- 1..n_blocks, reduce: x do 1568 | x -> block(x, n_embd, n_head, opts) 1569 | end 1570 | 1571 | # final layer norm 1572 | x |> Axon.layer_norm(name: "final_block_ln") 1573 | end 1574 | 1575 | def block(%Axon{} = x, n_embd, n_head, opts \\ []) do 1576 | head_size = div(n_embd, n_head) 1577 | 1578 | x = 1579 | Axon.add( 1580 | x, 1581 | x |> Axon.layer_norm(name: "block_ln_1") |> multi_head(n_head, head_size, opts), 1582 | name: "x_multihead_attention" 1583 | ) 1584 | 1585 | Axon.add( 1586 | x, 1587 | x |> Axon.layer_norm(name: "block_ln_2") |> feed_forward(n_embd, opts), 1588 | name: "x_feed_forward" 1589 | ) 1590 | end 1591 | 1592 | def feed_forward(%Axon{} = model, n_embd, opts \\ []) do 1593 | opts = Keyword.validate!(opts, dropout_rate: 0.0) 1594 | 1595 | dropout_rate = opts[:dropout_rate] 1596 | 1597 | model 1598 | |> Axon.dense(4 * n_embd, kernel_initializer: :he_uniform, name: "feed_forward_dense_1") 1599 | |> Axon.relu(name: "feed_forward_relu") 1600 | |> Axon.dense(n_embd, kernel_initializer: :he_uniform, name: "feed_forward_dense_2") 1601 | |> Axon.dropout(rate: dropout_rate, name: "feed_forward_dropout") 1602 | end 1603 | 1604 | @doc """ 1605 | Modified from Bumblebee's transformer.ex 1606 | https://github.com/elixir-nx/bumblebee/blob/main/lib/bumblebee/layers.ex#L525 1607 | 1608 | Splits the hidden dimension into the given number of attention heads. 1609 | 1610 | In other words, the input with shape `{batch_size, sequence_length, hidden_size}` 1611 | is reshaped to `{batch_size, sequence_length, num_heads, hidden_size}`. 1612 | Then, transposed to `{batch_size, num_heads, sequence_length, *}` 1613 | """ 1614 | def split_heads(states, num_heads, opts \\ []) do 1615 | opts = Keyword.validate!(opts, name: "split_heads") 1616 | 1617 | Axon.nx( 1618 | states, 1619 | fn states -> 1620 | batch_size = Nx.axis_size(states, 0) 1621 | sequence_length = Nx.axis_size(states, 1) 1622 | new_shape = {batch_size, sequence_length, num_heads, :auto} 1623 | 1624 | states 1625 | |> Nx.reshape(new_shape) 1626 | |> Nx.transpose(axes: [0, 2, 1, 3]) 1627 | end, 1628 | name: opts[:name] 1629 | ) 1630 | end 1631 | 1632 | def multi_head(%Axon{} = x, num_heads, head_size, opts \\ []) do 1633 | default_initializer = Axon.Initializers.he_uniform(scale: 0.5) 1634 | opts = Keyword.validate!(opts, kernel_initializer: default_initializer, dropout_rate: 0.0) 1635 | initializer = opts[:kernel_initializer] 1636 | dropout_rate = opts[:dropout_rate] 1637 | 1638 | key = 1639 | x 1640 | |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "key") 1641 | |> split_heads(num_heads) 1642 | 1643 | query = 1644 | x 1645 | |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "query") 1646 | |> split_heads(num_heads) 1647 | 1648 | value = 1649 | x 1650 | |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "value") 1651 | |> split_heads(num_heads) 1652 | 1653 | Axon.layer(&multi_head_layer_impl/4, [key, query, value], 1654 | name: "multi_head_attention", 1655 | dropout_rate: dropout_rate 1656 | ) 1657 | |> Axon.dense(num_heads * head_size, name: "multi_head_dense") 1658 | |> Axon.dropout(rate: dropout_rate, name: "multi_head_dropout") 1659 | end 1660 | 1661 | # Custom layers require the opts argument 1662 | # https://hexdocs.pm/axon/custom_layers.html#creating-custom-layers 1663 | defn multi_head_layer_impl(k, q, v, opts \\ []) do 1664 | opts = keyword!(opts, mode: :train, dropout_rate: 0.0) 1665 | dropout_rate = opts[:dropout_rate] 1666 | 1667 | {b, h, t, c} = Nx.shape(k) 1668 | tensor_type = Nx.type(k) 1669 | 1670 | # {b, h, t, c} @ {b, h, c, t} -> {b, h, t, t} 1671 | # 1672 | # Alternatively we could have done 1673 | # kT = Nx.transpose(k, axes: [0, 1, 3, 2]) 1674 | # wei = Nx.dot(q, [3], [0, 1], kT, [2], [0, 1]) 1675 | wei = Nx.dot(q, [3], [0, 1], k, [3], [0, 1]) 1676 | 1677 | # Scaled attention 1678 | wei = wei * Nx.rsqrt(c) 1679 | 1680 | # Attention masking 1681 | tril = Tril.ones(shape: {t, t}) 1682 | tril = Nx.broadcast(tril, wei) 1683 | neg_inf = Nx.broadcast(Nx.Constants.neg_infinity(tensor_type), wei) 1684 | # tril, wei, and neg_inf have the shape {b, h, t, t} 1685 | # Nx.select will look at tril, and if true it'll pick the value from wei, else -infinity 1686 | wei = Nx.select(tril, wei, neg_inf) 1687 | wei = Axon.Activations.softmax(wei, axis: -1) 1688 | wei = Axon.Layers.dropout(wei, Nx.Random.key(1337), rate: dropout_rate) 1689 | 1690 | # {b, h, t, t} @ {b, h, t, head_size} -> {b, h, t, head_size} 1691 | out = Nx.dot(wei, [3], [0, 1], v, [2], [0, 1]) 1692 | 1693 | # Tranpose so we can stack the heads on top of each other 1694 | # {b, h, t, c} -> {b, t, h, c} 1695 | out = Nx.transpose(out, axes: [0, 2, 1, 3]) 1696 | 1697 | # Our output tensor is now enriched with attention information 1698 | # We shape it back to {b, t, c} 1699 | # This gives us the proper shape to add to our original input x 1700 | Nx.reshape(out, {b, t, h * c}) 1701 | end 1702 | end 1703 | ``` 1704 | 1705 | 1706 | 1707 | ``` 1708 | {:module, Transformer, <<70, 79, 82, 49, 0, 0, 36, ...>>, true} 1709 | ``` 1710 | 1711 | Because our model is much larger now, `learning_rate` is lowered to 3.0e-4 . 1712 | 1713 | ```elixir 1714 | # Hyperparameters 1715 | n_embd = 384 1716 | n_heads = 6 1717 | n_layer = 6 1718 | batch_size = 64 1719 | block_size = 256 1720 | learning_rate = 3.0e-4 1721 | dropout_rate = 0.2 1722 | 1723 | final_model = 1724 | Axon.input("sequence") 1725 | |> then(fn input -> 1726 | # Create an embedding for the input data 1727 | token_embedding_table = Axon.embedding(input, vocab_size, n_embd, name: "token_embedding") 1728 | 1729 | # Generate positional encodings for the input sequence (hacky, couldn't find alternative) 1730 | positions = 1731 | Axon.nx(input, fn input -> 1732 | {_batch_size, sequence_length} = Nx.shape(input) 1733 | Nx.iota({sequence_length}) 1734 | end) 1735 | 1736 | # Positional encodings get mapped into @n_embd space 1737 | position_embedding_table = 1738 | Axon.embedding(positions, block_size, n_embd, name: "position_embedding") 1739 | 1740 | # Add the two layers above to produce tensors containing embedding + position info 1741 | Axon.add(token_embedding_table, position_embedding_table, name: "x_positional_encoding") 1742 | end) 1743 | |> Transformer.blocks(n_layer, n_embd, n_heads, dropout_rate: dropout_rate) 1744 | |> Axon.dense(vocab_size, kernel_initializer: :he_uniform, name: "language_modeling_head") 1745 | ``` 1746 | 1747 | 1748 | 1749 | ``` 1750 | #Axon< 1751 | inputs: %{"sequence" => nil} 1752 | outputs: "language_modeling_head" 1753 | nodes: 122 1754 | > 1755 | ``` 1756 | 1757 | ## Training the final model 1758 | 1759 | With the current hyperparameters, each checkpoint comes out to be ~35mb. 1760 | 1761 | ```elixir 1762 | {init_fn, predict_fn} = Axon.build(final_model, mode: :train) 1763 | custom_predict_fn = &CommonTrain.custom_predict_fn(predict_fn, &1, &2) 1764 | custom_loss_fn = &CommonTrain.custom_loss_fn(&1, &2) 1765 | train_data_stream = get_batch_stream.(batch_size, block_size, :train) 1766 | 1767 | params = 1768 | {init_fn, custom_predict_fn} 1769 | |> Axon.Loop.trainer(custom_loss_fn, Axon.Optimizers.adamw(learning_rate)) 1770 | |> CheckpointHelper.load_last_checkpoint(checkpoint_path) 1771 | |> Axon.Loop.checkpoint( 1772 | event: :iteration_completed, 1773 | filter: [every: 99], 1774 | path: checkpoint_path, 1775 | file_pattern: checkpoint_file_pattern 1776 | ) 1777 | |> Axon.Loop.run(train_data_stream, %{}, epochs: 1, iterations: 5000, compiler: EXLA) 1778 | ``` 1779 | 1780 | 1781 | 1782 | ``` 1783 | Resuming training from this checkpoint: "checkpoint/gpt_checkpoint_0_3018" 1784 | 1785 | 09:10:13.246 [debug] Forwarding options: [compiler: EXLA] to JIT compiler 1786 | Epoch: 0, Batch: 4998, loss: 1.6417795 1787 | ``` 1788 | 1789 | 1790 | 1791 | ``` 1792 | %{ 1793 | "block_ln_1" => %{ 1794 | "beta" => #Nx.Tensor< 1795 | f32[384] 1796 | EXLA.Backend 1797 | [-0.00848830584436655, 0.00920316856354475, -0.011059202253818512, 0.020641149953007698, -0.004101151134818792, -0.007295509800314903, 0.008314329199492931, -0.008258385583758354, -8.087782771326602e-4, -0.007050180342048407, -0.01070280559360981, -0.023406347259879112, -0.0016787968343123794, -0.02222963236272335, -0.018540432676672935, -0.014302549883723259, 0.0010177168296650052, -0.01171213947236538, -0.03551258519291878, 0.005983751267194748, -0.006882576737552881, 9.761084947967902e-5, -0.016758816316723824, -0.0011584166204556823, 0.0017746005905792117, 0.00811840035021305, 0.01556096225976944, 0.033429522067308426, 4.5731314457952976e-4, 0.03519672155380249, 0.006238017231225967, 0.009631684981286526, 4.665410378947854e-4, 0.01423694659024477, 0.002984068589285016, -0.015089771710336208, -0.009606064297258854, -0.00905038882046938, 0.0021043112501502037, -0.0035458712372928858, -0.00983726978302002, 2.740136696957052e-4, 0.017948541790246964, 0.0016870932886376977, 0.020347923040390015, 0.004460370168089867, 0.0073865666054189205, 0.004740085918456316, ...] 1798 | >, 1799 | "gamma" => #Nx.Tensor< 1800 | f32[384] 1801 | EXLA.Backend 1802 | [-0.9019647836685181, 0.2373071312904358, -1.074188232421875, 0.3899938464164734, 1.0010205507278442, -1.3070733547210693, 0.8236031532287598, -1.5869724750518799, 0.3163173794746399, 0.37490928173065186, -1.0106483697891235, -1.6686797142028809, -0.5751791000366211, -0.16168655455112457, -0.9891197085380554, 1.5283355712890625, 0.2562783658504486, -0.8124969601631165, 0.08245331048965454, 1.4095937013626099, 1.1405019760131836, -1.6745530366897583, 1.2952544689178467, 1.4876749515533447, 1.5021872520446777, 0.6464712023735046, -0.9438899755477905, 1.5004093647003174, -1.1250152587890625, 1.5461266040802002, 1.5317260026931763, 0.5766931772232056, 1.33409583568573, -0.8737964630126953, 1.491490364074707, 1.16155207157135, 1.2941070795059204, -0.34952089190483093, -1.407163143157959, 0.8760281205177307, 1.24746572971344, 0.38950851559638977, -0.05163421109318733, -0.3753258287906647, -0.7462378740310669, -0.33452585339546204, 0.5205916166305542, ...] 1803 | > 1804 | }, 1805 | "block_ln_2" => %{ 1806 | "beta" => #Nx.Tensor< 1807 | f32[384] 1808 | EXLA.Backend 1809 | [0.09962860494852066, -0.017158204689621925, 0.064473956823349, -0.02154291793704033, -0.1019430160522461, 0.013523696921765804, -0.012999583967030048, -0.03896573930978775, 0.1074555367231369, -0.1628476083278656, -0.05008017644286156, 0.0365133099257946, -0.07730609178543091, -0.02009524218738079, -0.03905002772808075, 0.09997482597827911, 0.02380087599158287, -0.10680756717920303, -0.07425318658351898, -0.004472446162253618, -0.1200273409485817, 0.006688457913696766, 0.06873513758182526, -0.06357701867818832, 0.2842353880405426, -0.07767970114946365, -0.0018610170809552073, 0.0703798159956932, 0.00796227715909481, -0.08657971024513245, -0.041982825845479965, -0.0027350035961717367, -0.022983459755778313, 0.009639413096010685, 0.013106227852404118, 0.07589122653007507, 0.04959215968847275, 0.011667381040751934, 0.04105760157108307, -0.00228147697634995, 0.003996665123850107, 0.009335983544588089, -0.03350013121962547, -0.03160631284117699, -0.07728160172700882, 0.018959442153573036, 0.0063129304908216, ...] 1810 | >, 1811 | "gamma" => #Nx.Tensor< 1812 | f32[384] 1813 | EXLA.Backend 1814 | [-0.3861040472984314, 1.3981317281723022, -0.6260970830917358, -0.6142591238021851, -0.8400639891624451, 1.2417875528335571, 1.3667564392089844, -1.4321521520614624, 0.6886146068572998, -0.6659373641014099, -1.0708991289138794, -1.5967121124267578, 0.6946384906768799, -1.511633038520813, -1.143491268157959, -0.28134071826934814, 1.3280084133148193, 0.4598590135574341, 1.032454013824463, 1.591091275215149, -1.516690969467163, -0.6980851292610168, 0.31910091638565063, 1.2580981254577637, 0.21601411700248718, 0.9868366122245789, -1.5262973308563232, -1.100521206855774, 1.4690897464752197, 1.5698038339614868, -1.5286346673965454, 1.5577421188354492, -0.9915542602539062, 1.2179222106933594, -0.8456666469573975, -1.5218002796173096, 0.7244951725006104, -1.3286036252975464, -0.7319031953811646, -0.41700422763824463, 1.5427435636520386, 0.3690219521522522, -1.1647601127624512, 0.3217754364013672, -0.486890971660614, -1.1719893217086792, ...] 1815 | > 1816 | }, 1817 | "feed_forward_dense_1" => %{ 1818 | "bias" => #Nx.Tensor< 1819 | f32[1536] 1820 | EXLA.Backend 1821 | [-0.01969454623758793, -0.03141949325799942, -0.025403395295143127, -0.02523775026202202, -0.033035438507795334, -0.018221471458673477, -0.027275601401925087, -0.02603720873594284, -0.014015775173902512, -0.024843944236636162, -0.03239646926522255, 0.005450837314128876, -0.006688214372843504, -0.023177066817879677, -0.033121153712272644, -0.012314669787883759, -0.022113187238574028, -0.014615594409406185, -0.023663213476538658, -0.031118979677557945, -0.033788520842790604, -0.027982935309410095, -0.03504092991352081, -0.006818012334406376, -0.018050672486424446, -0.021601490676403046, -0.02082575485110283, -0.011461683548986912, -0.03900982439517975, -0.023099983111023903, -0.03569479286670685, -0.016298258677124977, -0.01670851558446884, -0.031987082213163376, -0.027163656428456306, -0.03729814663529396, -0.02999606914818287, -0.03916657716035843, -0.025162482634186745, -0.0090325390920043, -0.04438036307692528, -0.013018188066780567, -0.029345678165555, -0.020096443593502045, -0.04386697709560394, -0.03509372100234032, ...] 1822 | >, 1823 | "kernel" => #Nx.Tensor< 1824 | f32[384][1536] 1825 | EXLA.Backend 1826 | [ 1827 | [-0.10871043801307678, -0.09207568317651749, -0.08655396848917007, -0.10224799066781998, 0.041993580758571625, 0.12063898146152496, -0.046707216650247574, -0.051344893872737885, -0.12153489142656326, -0.06452394276857376, 0.08251994103193283, 0.0873091071844101, -0.20266391336917877, 0.048171039670705795, -0.00336627708747983, 0.049141667783260345, -0.019730960950255394, -0.00961564015597105, -0.011603367514908314, 0.004130497574806213, 0.0031170370057225227, -0.16046537458896637, -0.08627355843782425, -0.08215445280075073, -0.006033711135387421, 0.06354783475399017, -0.046216171234846115, -0.011816021986305714, -0.12118562310934067, -0.14242969453334808, -0.1388159543275833, 0.045520272105932236, 0.014241612516343594, 0.01466137170791626, -0.1673451066017151, -0.007058283314108849, 0.06803500652313232, 0.0638517439365387, 0.08884952962398529, -0.07472492754459381, -0.14217063784599304, -0.007931041531264782, 0.0704694539308548, -0.08046635240316391, -0.07672581821680069, ...], 1828 | ... 1829 | ] 1830 | > 1831 | }, 1832 | "feed_forward_dense_2" => %{ 1833 | "bias" => #Nx.Tensor< 1834 | f32[384] 1835 | EXLA.Backend 1836 | [0.0016378882573917508, -0.004477905575186014, -1.8500685109756887e-4, -0.0018909699283540249, -0.0018954131519421935, -9.799797553569078e-4, 0.0021089836955070496, -0.0036012891214340925, -0.005732030142098665, -0.013728820718824863, -0.00945020467042923, 2.4839743855409324e-4, 0.005067904945462942, 6.146890227682889e-4, -0.004659520462155342, -0.0016149275470525026, 0.01016254909336567, 0.01562158390879631, -0.003967766184359789, 0.0025086686946451664, -0.003208652837201953, -0.00295314472168684, 8.695174474269152e-4, 0.003091400722041726, -0.020229540765285492, 0.0010084941750392318, -9.657987975515425e-4, -4.3867313070222735e-4, 3.0992846586741507e-4, 0.008426363579928875, -0.002736120019108057, 1.1287703091511503e-4, -6.791693158447742e-4, 4.408113891258836e-4, -0.001615250133909285, 0.007265539839863777, -0.004101641941815615, 0.0012738561490550637, 0.004863182548433542, 0.005860744044184685, -0.004074892494827509, -0.002494568470865488, 0.003334584180265665, -5.73037366848439e-4, -0.0058544655330479145, ...] 1837 | >, 1838 | "kernel" => #Nx.Tensor< 1839 | f32[1536][384] 1840 | EXLA.Backend 1841 | [ 1842 | [0.07813902199268341, 0.00755926501005888, -0.01507661771029234, -0.008303034119307995, -0.02480701357126236, -0.04982953518629074, 0.0688544288277626, -0.00950665958225727, 0.009715151973068714, 0.04819542542099953, 0.0017397699411958456, -0.027666132897138596, -1.672387879807502e-4, -0.02711641602218151, -0.02916313335299492, -0.004612234886735678, 0.04952307417988777, -0.013753768056631088, 0.0031526777893304825, -0.009431369602680206, -0.03181751072406769, 0.008109191432595253, 0.02372041903436184, -0.0030595185235142708, -0.019107308238744736, 0.035727448761463165, 0.03700413927435875, -0.020368410274386406, -0.012909585610032082, -0.004915925208479166, -0.003388757584616542, -0.03510107100009918, 0.010471213608980179, -0.030938591808080673, -0.010781565681099892, 0.05147552490234375, -0.004744974430650473, -0.05841310694813728, 0.05155558884143829, 0.02477145381271839, -0.06363201886415482, 0.06570444256067276, -0.043092936277389526, 0.0038584712892770767, ...], 1843 | ... 1844 | ] 1845 | > 1846 | }, 1847 | "feed_forward_dropout" => %{ 1848 | "key" => #Nx.Tensor< 1849 | u32[2] 1850 | EXLA.Backend 1851 | [722951111, 1788778939] 1852 | > 1853 | }, 1854 | "final_block_ln" => %{ 1855 | "beta" => #Nx.Tensor< 1856 | f32[384] 1857 | EXLA.Backend 1858 | [-0.06223426014184952, 0.020014429464936256, 0.04373729228973389, -0.02258773148059845, -0.07341495156288147, 0.038936272263526917, -0.008139075711369514, 0.05396490544080734, -0.08630179613828659, -0.03686339408159256, -0.027823925018310547, -0.03715555742383003, -0.019566871225833893, -0.05063457041978836, -0.0014520023250952363, 0.003898404538631439, 0.021072369068861008, -0.07472241669893265, -0.08343469351530075, 0.06747688353061676, -0.022054431959986687, -0.02674838900566101, -0.012110492214560509, 0.016023002564907074, 0.09339968115091324, -0.07446654140949249, 0.006534852087497711, 0.024841567501425743, -0.01911127381026745, -0.09875895082950592, 0.052235424518585205, -0.04918821528553963, -0.01366699393838644, 0.023073343560099602, 0.0037924626376479864, -0.05260546877980232, 0.007962469011545181, -0.011324295774102211, -0.029835710301995277, -0.025004083290696144, -0.07888715714216232, -0.0416419580578804, -0.020500805228948593, ...] 1859 | >, 1860 | "gamma" => #Nx.Tensor< 1861 | f32[384] 1862 | EXLA.Backend 1863 | [0.6214079856872559, 0.7886702418327332, -1.302459478378296, 1.4649631977081299, -0.5696701407432556, 0.525750458240509, -1.1974927186965942, -1.1313458681106567, -0.8776533603668213, -0.31107765436172485, 1.6163456439971924, -0.300855427980423, 0.7801929116249084, 1.1120021343231201, -1.1623111963272095, -1.131649136543274, -1.1163313388824463, 0.38161373138427734, 0.7284241914749146, -0.48665928840637207, -0.9602446556091309, 0.40392425656318665, -1.0368990898132324, -1.0550838708877563, 0.1489740014076233, 0.7694688439369202, -1.1132980585098267, -0.1496853530406952, -1.478727102279663, 1.6444339752197266, 1.2406306266784668, -0.6676476001739502, 1.1959826946258545, -0.7487120628356934, -0.44440358877182007, 0.6873772740364075, -0.852936863899231, 1.1269590854644775, 1.5348412990570068, 1.0658866167068481, -1.0694810152053833, -1.0851964950561523, ...] 1864 | > 1865 | }, 1866 | "key" => %{ 1867 | "bias" => #Nx.Tensor< 1868 | f32[384] 1869 | EXLA.Backend 1870 | [1.1622644524322823e-4, 6.414170638890937e-5, 8.051748591242358e-5, 1.2625248928088695e-4, -2.94942146865651e-5, -1.432016579201445e-4, 9.895324183162302e-6, -1.1200238986930344e-5, -4.871409691986628e-5, -7.922281656647101e-5, 4.16673174186144e-5, 1.8111472309101373e-4, 1.1267856461927295e-4, 1.5345354040618986e-4, 2.838678192347288e-4, 4.03459052904509e-5, 1.536956369818654e-5, 1.5929706569295377e-4, -1.9008279195986688e-4, -2.74067249847576e-4, 1.6437079466413707e-5, -5.9939222410321236e-5, -1.0931974247796461e-4, 2.542962320148945e-4, -1.5804167196620256e-4, -5.212277756072581e-5, 5.324201993062161e-5, 1.9963234080933034e-4, -2.5228006416000426e-4, -8.598788554081693e-5, 4.858178726863116e-5, 1.8963949696626514e-4, -1.971950987353921e-4, 6.570507684955373e-5, 1.5816971426829696e-4, 3.819910125457682e-5, 1.2098293518647552e-4, 3.7578868796117604e-4, -3.22144478559494e-4, -1.3360384036786854e-4, -1.2937923020217568e-4, 2.4096581910271198e-4, ...] 1871 | >, 1872 | "kernel" => #Nx.Tensor< 1873 | f32[384][384] 1874 | EXLA.Backend 1875 | [ 1876 | [-0.0012871787184849381, 0.05325806885957718, -0.0329965278506279, -9.909607470035553e-4, 0.010785852558910847, 0.009750726632773876, 0.04176070913672447, -0.054770056158304214, 0.008932334370911121, 0.022378653287887573, -0.04128154367208481, 0.05506803095340729, 0.0077568707056343555, -0.027852976694703102, 0.0058920662850141525, -0.0761546641588211, 0.017340153455734253, 0.022978365421295166, 0.025580445304512978, -0.018654203042387962, 0.04613232612609863, 0.04237061366438866, 0.012818633578717709, -0.011534139513969421, 0.07500981539487839, -0.05386031046509743, 0.03983011469244957, 0.06280945241451263, -0.05516345426440239, -0.007531187497079372, -0.01810508966445923, -0.05361419916152954, 0.04882905259728432, 0.06248072162270546, -0.028620455414056778, -0.03912936523556709, -0.039756715297698975, -0.004793907981365919, 0.03730688616633415, 0.031572699546813965, 0.03633453696966171, ...], 1877 | ... 1878 | ] 1879 | > 1880 | }, 1881 | "language_modeling_head" => %{ 1882 | "bias" => #Nx.Tensor< 1883 | f32[65] 1884 | EXLA.Backend 1885 | [-0.002946221036836505, 0.024192655459046364, -0.05523858591914177, -0.17861443758010864, -0.15369747579097748, 0.02711259014904499, -0.007696119602769613, -0.03897085785865784, -0.04317887872457504, -0.1140795648097992, -0.018143698573112488, -0.08973734080791473, -0.08171650767326355, -0.019000938162207603, -0.0456165112555027, -0.03262511268258095, -0.05419778451323509, -0.01939394138753414, -0.056382980197668076, -0.0516597144305706, -0.041264262050390244, -0.009129352867603302, -0.03057103417813778, -0.07836691290140152, -0.012997974641621113, -0.019970541819930077, -0.05059399455785751, -0.034343983978033066, -0.017009710893034935, -0.11167945712804794, -0.015476089902222157, -0.016781035810709, -0.0243771243840456, -0.07288465648889542, -0.102816142141819, -0.04766138643026352, -0.2592124044895172, -0.05409780517220497, -0.14828692376613617, 0.03934876248240471, 0.018621230497956276, ...] 1886 | >, 1887 | "kernel" => #Nx.Tensor< 1888 | f32[384][65] 1889 | EXLA.Backend 1890 | [ 1891 | [0.015758411958813667, -0.051811583340168, 0.1745842546224594, 0.03205735236406326, 0.14337486028671265, 0.06512445956468582, 0.08498262614011765, 0.01941721700131893, 0.04333372786641121, 0.11364075541496277, -0.043731216341257095, 0.09409287571907043, 0.06420766562223434, 0.0015034792013466358, 0.009646364487707615, 0.08266307413578033, 0.12922188639640808, 0.12079586833715439, -0.00773721095174551, 0.11592798680067062, 0.0960366353392601, 0.0038370585534721613, 0.2200704663991928, 0.09468023478984833, 0.15574730932712555, 0.14135101437568665, 0.05444718524813652, -0.049071311950683594, 0.19212448596954346, 0.2701287865638733, -0.014314115978777409, 0.1655369997024536, 0.1066943034529686, 0.11247802525758743, 0.1863957941532135, 0.13626587390899658, 0.15401582419872284, 0.10896278917789459, 0.12454552948474884, -0.09278301894664764, ...], 1892 | ... 1893 | ] 1894 | > 1895 | }, 1896 | "multi_head_dense" => %{ 1897 | "bias" => #Nx.Tensor< 1898 | f32[384] 1899 | EXLA.Backend 1900 | [-0.013091623783111572, -0.0025277109816670418, -0.015406734310090542, -0.0055812350474298, 0.0070198820903897285, 0.013889962807297707, 0.005282912403345108, -0.005262902472168207, 0.007651640567928553, 0.0013492725556716323, -6.565082585439086e-4, -0.009683044627308846, -0.020997583866119385, -0.0025598604697734118, 0.0017333345022052526, 1.242513917532051e-6, 0.0014633577084168792, -0.0045132143422961235, -0.011305914260447025, -0.0038131410256028175, -4.433437716215849e-4, -7.106841658242047e-4, 0.013996715657413006, -0.0048048049211502075, 0.016807060688734055, 0.005172553937882185, -0.008687887340784073, 0.011892814189195633, -0.0028974039014428854, 0.0022457861341536045, 0.007068112958222628, 0.002055160701274872, -0.002763577038422227, 0.014206371270120144, -0.005901937372982502, -0.0054986197501420975, -0.007799994666129351, 0.0023439626675099134, 0.009106948971748352, -0.012797119095921516, ...] 1901 | >, 1902 | "kernel" => #Nx.Tensor< 1903 | f32[384][384] 1904 | EXLA.Backend 1905 | [ 1906 | [0.07698854804039001, 0.009023042395710945, -0.026208246126770973, -0.006611100863665342, 0.036692507565021515, 0.05781245604157448, 0.04712959751486778, 0.024427007883787155, 0.05523138493299484, 0.01334020122885704, -0.014303312636911869, -0.029595771804451942, 0.032298460602760315, 0.036912381649017334, -0.09651049226522446, 0.061845824122428894, -0.05959019064903259, -0.02616739273071289, -0.07068630307912827, -0.06196698546409607, -0.05527180805802345, -0.05376365780830383, -0.06438268721103668, -0.06433121114969254, -0.017974428832530975, -0.008260016329586506, 0.03823176771402359, 0.09679777175188065, -0.03601941838860512, 0.017709065228700638, -0.07428305596113205, -0.047941479831933975, -0.027466345578432083, -0.042206089943647385, 0.010596836917102337, 0.048028383404016495, 0.06414726376533508, -0.028881270438432693, 0.06266971677541733, ...], 1907 | ... 1908 | ] 1909 | > 1910 | }, 1911 | "multi_head_dropout" => %{ 1912 | "key" => #Nx.Tensor< 1913 | u32[2] 1914 | EXLA.Backend 1915 | [3215365877, 2835071777] 1916 | > 1917 | }, 1918 | "position_embedding" => %{ 1919 | "kernel" => #Nx.Tensor< 1920 | f32[256][384] 1921 | EXLA.Backend 1922 | [ 1923 | [0.032874882221221924, 0.017140092328190804, 0.003328746184706688, -0.005748175550252199, 0.017624981701374054, -0.01301295030862093, -0.006106921937316656, -0.006924390327185392, -0.0665484294295311, -0.008325970731675625, -0.01804683916270733, 0.010912414640188217, 0.03563772514462471, 0.01872558705508709, -0.004042694810777903, 0.025369824841618538, 0.0023279483430087566, -0.035076919943094254, -0.005552679765969515, -0.004224944394081831, 0.017052853479981422, -0.009900875389575958, -0.0076743196696043015, -0.023836085572838783, 0.014323941431939602, -0.0012540258467197418, 0.00770062068477273, -0.05086039751768112, -0.0807521864771843, -0.024993926286697388, -0.011912805959582329, 0.00384755851700902, 0.03233299031853676, -0.015088425949215889, 0.030430546030402184, 0.0033960388973355293, 0.02266317792236805, -0.02247624099254608, ...], 1924 | ... 1925 | ] 1926 | > 1927 | }, 1928 | "query" => %{ 1929 | "bias" => #Nx.Tensor< 1930 | f32[384] 1931 | EXLA.Backend 1932 | [-0.019059835001826286, 0.007212303578853607, -0.0253163930028677, 0.02480863407254219, 0.030279699712991714, 0.044939201325178146, -0.010317720472812653, -0.03053799644112587, 0.01236018631607294, 0.01944728195667267, -0.05433334782719612, -0.1049325242638588, 0.021181730553507805, -0.026806436479091644, 0.004234259016811848, 0.00791642814874649, 0.0014239175943657756, 0.024933788925409317, 0.012788806110620499, 0.050823960453271866, 0.0026838139165192842, 2.7709786081686616e-4, -6.114842108217999e-7, -0.049711503088474274, -0.04625536501407623, 0.018391326069831848, -0.03626689687371254, -0.07920512557029724, 0.10758330672979355, -0.012113719247281551, -0.04403815045952797, -0.05166914314031601, 0.012584244832396507, -0.019165927544236183, 0.0462692454457283, 0.021622730419039726, 0.0011790632270276546, ...] 1933 | >, 1934 | "kernel" => #Nx.Tensor< 1935 | f32[384][384] 1936 | EXLA.Backend 1937 | [ 1938 | [0.0058362302370369434, 0.018301762640476227, -0.018645629286766052, -0.038221556693315506, 0.06795700639486313, 0.027598747983574867, -0.03910763934254646, -0.02521691471338272, -0.011773718520998955, 0.0824030414223671, -0.0028445073403418064, 0.006800362840294838, -0.02214733138680458, -0.09010455012321472, -0.032811108976602554, 0.08154575526714325, 0.025752056390047073, 0.08471326529979706, -0.008311114273965359, 0.052244335412979126, 0.036650706082582474, 0.04471902549266815, -0.023855609819293022, -0.004607589449733496, -0.011002322658896446, 0.0587620735168457, 0.005509072449058294, -0.01659712754189968, 0.026653604581952095, 0.08233808726072311, -9.325456921942532e-4, 0.03881856054067612, -0.035048775374889374, 0.032658956944942474, -0.035093218088150024, 0.0045005446299910545, ...], 1939 | ... 1940 | ] 1941 | > 1942 | }, 1943 | "token_embedding" => %{ 1944 | "kernel" => #Nx.Tensor< 1945 | f32[65][384] 1946 | EXLA.Backend 1947 | [ 1948 | [-0.020142383873462677, -0.014539672993123531, 0.0122062424197793, 0.039099209010601044, 0.03575901687145233, -0.00782727263867855, -0.01719738356769085, 0.07246894389390945, 0.02138376235961914, 0.032816097140312195, 0.012480519711971283, -0.03317803516983986, 0.027605293318629265, -0.017775360494852066, 4.3115095468237996e-4, 0.0056412131525576115, 0.036179959774017334, -0.010553312487900257, 5.860661040060222e-4, 0.03476720303297043, -0.00231738667935133, 0.005250687710940838, -0.014498109929263592, 0.010408789850771427, -0.016012923792004585, -0.012880049645900726, 0.02018360234797001, 0.007029877044260502, 0.00606964435428381, -0.0016668542521074414, 0.007597022689878941, -0.012955783866345882, 0.015751812607049942, 0.002011312637478113, -0.0200498066842556, 0.0019235415384173393, ...], 1949 | ... 1950 | ] 1951 | > 1952 | }, 1953 | "value" => %{ 1954 | "bias" => #Nx.Tensor< 1955 | f32[384] 1956 | EXLA.Backend 1957 | [-0.0030991090461611748, -0.002844809088855982, -0.009626881219446659, 4.500289505813271e-4, 0.00581051129847765, 1.5979257295839489e-4, 0.01220391783863306, -0.00541482400149107, 0.003090923186391592, -8.587435586377978e-4, 0.006976135075092316, 0.009952674619853497, -0.007886230014264584, 0.006909910123795271, 0.0038270035292953253, -0.003139507956802845, -0.0021739653311669827, 0.00220508617348969, -0.0015421948628500104, -0.004189274273812771, -0.007480615749955177, 0.004374688025563955, 0.0025047780945897102, -0.004456990864127874, -0.0062192585319280624, -0.009940357878804207, 0.007896310649812222, -0.008275106549263, 0.011483363807201385, 0.012841667048633099, -0.00428217276930809, 0.0053891753777861595, 0.0035085384733974934, 0.0021641673520207405, -5.1330698624951765e-6, ...] 1958 | >, 1959 | "kernel" => #Nx.Tensor< 1960 | f32[384][384] 1961 | EXLA.Backend 1962 | [ 1963 | [0.009906159713864326, 0.019326014444231987, 0.013990432024002075, 0.010629404336214066, -0.004779709968715906, -0.04365978389978409, 0.02479981817305088, 0.03628024086356163, -0.010248202830553055, 0.03433370590209961, -0.06440325081348419, 0.01333934348076582, 0.05220233276486397, -0.07441692799329758, 0.07505417615175247, -0.08211889117956161, -0.015163464471697807, -0.04528651386499405, -0.026584822684526443, 0.005662917625159025, 0.004619475454092026, 0.049791473895311356, 0.011024191975593567, -0.03539152815937996, 0.0017584519227966666, -0.016710760071873665, 0.07067379355430603, 0.006168636493384838, 0.041405051946640015, 0.024044036865234375, 0.04775714501738548, -0.04292962700128555, 0.010696981102228165, 0.015617704018950462, ...], 1964 | ... 1965 | ] 1966 | > 1967 | } 1968 | } 1969 | ``` 1970 | 1971 | ## Generating text with the final model 1972 | 1973 | ```elixir 1974 | init_seq = Nx.broadcast(0, {1, 1}) 1975 | 1976 | TextGen.generate(final_model, params, init_seq, block_size, max_new_tokens: 10000) 1977 | |> IO.puts() 1978 | ``` 1979 | 1980 | 1981 | 1982 | ``` 1983 | 1984 | Second Servingman: 1985 | Come, for a pull, and as as to my presence; 1986 | Vouchle? 1987 | 1988 | POMPEY: 1989 | Why, lady. 1990 | 1991 | ROMEO: 1992 | Come, sir? 1993 | 1994 | MERCUTIO: 1995 | Have speak, Make my considerer: therefore. 1996 | 1997 | ROMEO: 1998 | Why, the gate? 1999 | 2000 | MERCUTIO: 2001 | And to the pretty took'd doable. 2002 | 2003 | MERCUTIO: 2004 | Nay, be more, and after and you true, villain. 2005 | 2006 | ROMEO: 2007 | Why, Come, that Camillo, 2008 | As come valianted to my through in my power; 2009 | Nor, if our to God beggar, ever, call arm; 2010 | And is idle autor to loss. 2011 | 2012 | MERCUTIO: 2013 | What, do you have no with not make I nothing? 2014 | 2015 | ROMEO: 2016 | 'Tis that honour let honour? what do good he? 2017 | 2018 | ROMEO: 2019 | Thou ha? 2020 | 2021 | PETRUCHIO: 2022 | Have me lord? 2023 | 2024 | MERCUTIO: 2025 | I pray you pretty what bear be give so? 2026 | 2027 | MERCUTIO: 2028 | My lord? 2029 | 2030 | MERCUTIO: 2031 | All I pray you, poor lord, I clond a be done. 2032 | 2033 | ROMEO: 2034 | He throw with insife that I am speak and she consul. 2035 | 2036 | ROMEO: 2037 | I'll cut see not to you have: if your death soon. 2038 | If me sent my like you brother. 2039 | 2040 | TYRREL: 2041 | Why, my son your will my continue, to-morrow, 2042 | And you do my grace? and, what so do mightly be him 2043 | In he rests warm jame a sacre, his with his eEven being; 2044 | As blow's it a lack mine ere the court in he sill. 2045 | 2046 | KING RICHARD III: 2047 | 2048 | Keeper: 2049 | Troth, will and safety, if 'gain. 2050 | 2051 | QUEEN ELIZABETH: 2052 | Then and his comes myself, resign of the hands 2053 | In honour sit, have me othing 2054 | To six this at I am no beats, made and from the 2055 | In may betterfly tarry this the one: set himself 2056 | To look'd it of the great he dear: 2057 | Where, but a gates beseech none stain'd, 2058 | For shine is death very and come in honour'd; 2059 | I love as a welcomed your vain a play'd. 2060 | For, I claim to not stander the betwixt thousands? 2061 | 2062 | BUCKINGHAM: 2063 | Because by those good my true old I should 2064 | To with this my prince and have bite: 2065 | Thou wilt noble immed, I wouldering and itself 2066 | To have been in pence for young young by should be offend? 2067 | To the your a title be shall and drink you, 2068 | Too your being you thus next upon the gods wretch! 2069 | It is tell despatch you do you to look, 2070 | I pay forsake hope, my lord. Come, and for this you? 2071 | 2072 | ANGELO: 2073 | When this is is your amorous dry edies you: 2074 | I leave no banished my kind his with the purpose. 2075 | 2076 | ISABELLA: 2077 | The should you not a quarry command expose: 2078 | You not he when I need. 2079 | 2080 | ISABELLA: 2081 | Ay, sir, that have worship, so you, that for King 2082 | To tooch will perpeted passister him, 2083 | And just shall be king it for your fly; 2084 | No sorrow yours are of my business are weat 2085 | Is not attend in the king. You must need 2086 | To prove that is a garmy some a little 2087 | I am in thee. 2088 | 2089 | This prayers to they head; but ell to be 2090 | A be so pretty could d be roccursed it. 2091 | 2092 | TRANIO: 2093 | As elder amongsy: 2094 | Pardon father, sir; shall never hither's fack 2095 | Thus are head as that does the war from of he 2096 | in me possessioner's your love. 2097 | 2098 | MERCUTIO: 2099 | A monger, intent for this protector a knew me 2100 | too his sincely and fadies their 2101 | luke earn: end to you to coldier thanks 'll. 2102 | 2103 | ROMEO: 2104 | Come, and thy about my come violent. 2105 | 2106 | MERCUTIO: 2107 | Why, too: thou art that my why charge, 2108 | as I am date, by my destraight take with decrew. 2109 | 2110 | ROMEO: 2111 | Such I sprepare is the is dead! I dare me devision 2112 | solegiance the wish the live hence. 2113 | 2114 | ERCHIOLO: 2115 | Thou art shalt ne'er reason bosom of a stone. 2116 | 2117 | POMPEY: 2118 | He shalt be the gods him go. Do you villain. 2119 | Did I shall had ere insmies what many, 2120 | It is prince pursuits I have my namel. 2121 | 2122 | PETRUCHIO: 2123 | The so to my lord. 2124 | Your grace man arrow it now. 2125 | 2126 | EDWARD: 2127 | Not hate. 2128 | 2129 | BUCKINGHAM: 2130 | My lords, Lord Bohemia: he much have gone. 2131 | 2132 | GLOUCESTER: 2133 | 2134 | CLARENCE: 2135 | Do now, to him be flyman Salisbury, 2136 | And to comfort you this daughter lords your be not 2137 | How your kindtily to be caasters of this: 2138 | It warrance this like yours so bring: 2139 | Grow break like unfant to your as England? 2140 | 2141 | KING EDWARD IV: 2142 | Because to stroke a father. 2143 | 2144 | GLOUCESTER: 2145 | Now, so much as you? what salls. As we's lady? 2146 | 2147 | LADY GREY: 2148 | 'Tis was a trouble of thy good? 2149 | 2150 | KING EDWARD IV: 2151 | Shall thy calls? thou have drops thy for her? 2152 | 2153 | GLOUCESTER: 2154 | Whose come would Christ I be longs, thy lord! 2155 | 2156 | QUEEN MARGARET: 2157 | Why, thou wolf, what forgive rich me? 2158 | 2159 | CLARENCE: 2160 | Why this in Gaunt at being thy mother? 2161 | 2162 | QUEEN MARGARET: 2163 | And dispite revenge thou canst thou, royalty 2164 | Is thou both these plant thy hand, thy news: 2165 | I thoughts of thy hearts of lovest thy tongue 2166 | The from wounds purpose of and enough; 2167 | For thy blood this goble and again a prove? 2168 | 2169 | QUEEN ELIZABETH: 2170 | Therefore that I thanks, the were did o'er port, 2171 | Hath colour had made mightly humble pierceed 2172 | That I several this hath plagued to dispair, 2173 | Of this dust thou art thou wilt be matter. 2174 | But not speak what thou hast he truth rend 2175 | That horses Brottom'd! 2176 | Why, not I know in thou redeem, and what to thy love? 2177 | 2178 | CATESBY: 2179 | We is't? 2180 | 2181 | CATESBY: 2182 | He Pray most thou wast thou dost so know thy return'd? 2183 | 2184 | PRINCE EDWARD: 2185 | And I cannot ashe bring did thy counter. 2186 | Thou thy succester, not what my souls, 2187 | MI'll blood without to heart thy purposes agest by 2188 | To late come, thy scorn thy be giving me will 2189 | A kein a better, whose man's of thy wants-pert, 2190 | Who dares up thy with the and sugger things mights thee same. 2191 | Kill's a deceives often that bear, to friend, 2192 | To country's with slain my hand, 'Bring subject between in 2193 | That mistress heir put draws. 2194 | 2195 | Shepherd: 2196 | And thou not is the babe'st grace in yours. 2197 | 2198 | POMPEY: 2199 | I would be no encountenance your be no shall 2200 | question; sit out knows false to the greate 't. 2201 | 2202 | PERDITA: 2203 | Sir, your hated any of your dreamd. 2204 | 2205 | PETUS: 2206 | What, the Tush, pity 2207 | What, ortly thou may shalt her; I meal, for your son? 2208 | The slip wert thou this prince wish, who but jest? 2209 | 2210 | MENENIUS: 2211 | He's stay there? 2212 | 2213 | COMINIUS: 2214 | Sir, you speak tone spirit of this; I thoughts 2215 | Where not goest the should been 2216 | Be though welcome too sancture of Coriolanus, 2217 | Exposition, and discover number: your she common 2218 | Make done put you men counts prophecy be from father. 2219 | 2220 | OXFORD: 2221 | I think foul this beggar-conquer where against I am to 2222 | Lonf, but 'tiss to't: look'd way' humble. 2223 | 2224 | Second Servingman: 2225 | Why, withy cannot is goodly 'tis with the he, 2226 | daughter: you must entreat hath use no pale. 2227 | 2228 | LEONTES: 2229 | How? 2230 | 2231 | More harse! I'll men together carved be his guilty! 2232 | 2233 | PRINCE EDWARD: 2234 | O, why give made this it? what they are than's one? 2235 | Come is is country's you will tars, hast staint hire, 2236 | That have broughts for that would know dost forew. 2237 | 2238 | GLOUCESTER: 2239 | My heart, Decliff, my father some, and in three. 2240 | 2241 | ANGELO: 2242 | She hath all they lord; for I know my move not friend. 2243 | 2244 | LADY ANNE: 2245 | No, It will to thinks me to not for no love. 2246 | 2247 | LADY ANNE: 2248 | 2249 | GLOUCESTER: 2250 | An enever me with the king? 2251 | 2252 | GLOUCESTER: 2253 | I did, I mother? methough ofury, we by my heart 2254 | Whom the no more of much well as sendings; 2255 | It cannot live thinks by my with a peace, 2256 | And the right wings fair that we wilt the true.' 2257 | 2258 | LEONTES: 2259 | O Pray you, 2260 | To reclaim break, thou hast love. I was flesh all 2261 | Accompanion: many three lord and only to 2262 | As think. 2263 | 2264 | MENENIUS: 2265 | Take run'd to throne with is truth, to him. 2266 | 2267 | MENENIUS: 2268 | Methy knees: 2269 | Shall have him how walkly, that chequality one 2270 | And let of your his sondition, to I come go: 2271 | My breath-perform your for your good soul's to us. 2272 | 2273 | MENENIUS: 2274 | Must hold, therefore issues! you have destrance. 2275 | 2276 | MARCIUS: 2277 | The clamations' you, you stand we wit. 2278 | 2279 | VIMILIA: 2280 | Speak abhavours? pray none when I came gople. 2281 | 2282 | VOLUMNIA: 2283 | He hath bed me the srunds pluck'd the cause nothing? 2284 | 2285 | VOLUMNIA: 2286 | O my lord. 2287 | 2288 | PERDITA: 2289 | I pray, and that say you know not, you must dieath 2290 | shall forth my brother. 2291 | 2292 | MENENIUS: 2293 | Nay, now notha second from that die? 2294 | 2295 | MENENIUS: 2296 | The kill our say 'I' the word, I am,'tis 2297 | Of your tiding rest yours, I could scalutchy's my 2298 | I will stone. 2299 | 2300 | CORIOLANUS: 2301 | Hath senator possible. 2302 | 2303 | First Marcius; 2304 | will'd the to in't will not the and lawful give 2305 | Second all reputy the marting shall thine for die. 2306 | 2307 | COMINIUS: 2308 | Nay, not not he still unto this calarench a good. 2309 | 2310 | MARCIUS: 2311 | Villain, brail, with sin, blow the thee; 2312 | But name, let hollow, my lords; and will you have 2313 | I'll resign, it is the still I saw young him 2314 | some forswear of children: but for well be 2315 | My tedious should not to too, sistake too a. 2316 | What was not? thy tongue heart: sit but a France? 2317 | 2318 | CLIFFORD: 2319 | Mightier to be thverence! 2320 | 2321 | FRIAR LAURENCE: 2322 | Bid, he should that nor many book'd not in, 2323 | Who confesserved chamber'd my scient his gurden me? 2324 | 2325 | ROMEO: 2326 | A good More under of my lordship me, 2327 | For what she scratch will out lord words? 2328 | 2329 | BENVOLIO: 2330 | O, crying on, then set thanks being; not of, attrous 2331 | Of you are of you weaky, wherefore you to them? 2332 | 2333 | ROMEO: 2334 | Thou art shighness is your eye; not they soest approve 2335 | Of thy me in thy done sad done's a word interch, 2336 | Liewis is like a happy 'Twixt downos under, 2337 | Lest the mine enemy the head thy mock make thee. 2338 | 2339 | ROMEO: 2340 | Do there in the present daughter: 2341 | Madam, good night adversity; he's it my lord. 2342 | 2343 | HERMIONE: 2344 | Marry, when thou know me abstard, my lie? 2345 | 2346 | MERCUTIO: 2347 | To stay, never king, sir, by me, and that 2348 | Is love, my life. 2349 | 2350 | CAPULET: 2351 | My lord. 2352 | 2353 | ROMEO: 2354 | Auth hast is too, and served my lord. 2355 | 2356 | PRINCE EDWARD: 2357 | But hear it me harm. 2358 | 2359 | GREMIO: 2360 | Not very that forsaken thy mother, 2361 | my name not me thy cannot thy me, 2362 | Thy blood in that one sound not to theirs. 2363 | 2364 | DUKE OF AUMERLE: 2365 | Well, or did to be man; but thou love, 2366 | Tower it mistory, sdisconcil thou did his most 2367 | And more piercharge a tewell, therefore wile 2368 | Of so so. Both prison! what thy sir, but thou, 2369 | Does that; but ere which is is need, we be 2370 | At thou bears wert pergination a poor such to thy friends 2371 | So infactorse, here bearts and here is should before? 2372 | On plotting but thou see thy fair objected. 2373 | This is own confess upon-thou sift; but thy king, 2374 | To heard rocker had body I believer yet remember 2375 | As if thy war stand traitors: that death blush, 2376 | I would me consul! straight, let thee news. 2377 | 2378 | POLIXENES: 2379 | Why, I am so slain that? 2380 | Thy hear me, do my life ten'd to thy duty. 2381 | 2382 | HERMIONE: 2383 | The should I reple king! 2384 | No through is is a serves for think; to the world: 2385 | My lordship to thy king, now heart, Do not cheek nor hand, 2386 | And what thou diest thy this sound to pay thee have rank'd; 2387 | And their as herefore arms are match'd bring Warwick 2388 | O' the nuptain the executions with this thee: 2389 | Take candh prophecy me sun schange thou speech fraughters. 2390 | And in he on God's draction the kings, 2391 | And who can a wword the of king? not lacks, 2392 | God deed's daughter the king to dry grave: 2393 | H 2394 | ``` 2395 | 2396 | 2397 | 2398 | ``` 2399 | :ok 2400 | ``` 2401 | --------------------------------------------------------------------------------