├── .gitignore ├── README.md ├── data ├── data.tar.gz └── translation │ ├── eng-fra.txt │ ├── test.csv │ ├── train.csv │ └── val.csv ├── models ├── luong_attention │ └── luong_attention.py ├── luong_attention_batch │ └── luong_attention_batch.py └── luong_attention_manual_mask │ ├── luong_attention_manual_mask.py │ └── masked_rnn.py ├── train_luong_attention.py └── utils ├── __pycache__ ├── batches.cpython-36.pyc ├── embeddings.cpython-36.pyc ├── masked_cross_entropy.cpython-36.pyc └── tokens.cpython-36.pyc ├── batches.py ├── embeddings.py ├── load_and_preprocessing └── translation.py ├── masked_cross_entropy.py └── tokens.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | data/clickbait_coarse 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PytorchLuongAttention 2 | This is batched implementation of Luong Attention. This code does batch multiplication to calculate the attention scores, instead of calculating the score one by one 3 | 4 | To run: 5 | `train_luong_attention.py --train_dir data/translation --dataset_module translation --log_level INFO --batch_size 50 --use_cuda --hidden_size 500 --input_size 500 --different_vocab` -------------------------------------------------------------------------------- /data/data.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rawmarshmellows/pytorch-batch-luong-attention/768c5b809e7cd0a5dee97c91a2c534f8f3d66f69/data/data.tar.gz -------------------------------------------------------------------------------- /data/translation/test.csv: -------------------------------------------------------------------------------- 1 | source,target 2 | That bridge is anything but safe.,Ce pont est tout sauf sûr. 3 | That meant no.,Ça signifiait non. 4 | I'll phone you as soon as I get to the airport.,Je t'appellerai dès que j'arrive à l'aéroport. 5 | It's not your style.,Ce n'est pas ton style. 6 | I guarantee I'll get you a job.,Je t'assure que je te trouverai un travail. 7 | You're very curious.,Vous êtes fort curieuses. 8 | "I made a deposit of $1,000 at the bank.",J'ai fait un dépôt de 1000 dollars à la banque. 9 | I have no home to return to.,Je n'ai pas de maison à laquelle rentrer. 10 | Creationism is a pseudo-science.,Le créationnisme est une pseudo-science. 11 | "I slept all day yesterday, because it was Sunday.",Hier j'ai dormi toute la journée car c'était dimanche. 12 | You can take your time.,Tu peux prendre ton temps. 13 | He is in love with her.,Il est amoureux d'elle. 14 | Do you remember what to do?,Te rappelles-tu quoi faire ? 15 | He came to Tokyo in search of employment.,Il vint à Tokyo à la recherche d'un emploi. 16 | Tom is studying in Boston.,Tom étudie à Boston. 17 | I have always considered you a close friend.,Je t'ai toujours considéré comme un ami intime. 18 | You'll find the book in the library.,Vous trouverez le livre à la bibliothèque. 19 | I didn't think I'd meet you here.,Je ne pensais pas vous rencontrer ici. 20 | "Look forward, please.","Regardez devant vous, je vous prie." 21 | This makes me curious.,Ça excite ma curiosité. 22 | Stop being so nice.,Arrêtez d'être si gentilles ! 23 | She needed the entire afternoon to complete the work.,Elle eut besoin de toute l'après-midi pour accomplir le travail. 24 | What a bunch of idiots!,Quelle bande d'idiots ! 25 | I prefer walking.,Je préfère marcher. 26 | Let me show you an example.,Laissez-moi vous montrer un exemple. 27 | He is too smart not to know it.,Il est trop intelligent pour ne pas le savoir. 28 | He arrived as soon as he could.,Il est arrivé aussitôt qu'il a pu. 29 | I'm good at science.,Je suis bon en Sciences. 30 | It's crawling with spiders.,Ça grouille d'araignées. 31 | She started for Kyoto yesterday.,Elle est partie hier pour Kyoto. 32 | Could you please tell me again where you put the key?,Pourrais-tu me répéter où tu as mis la clé ? 33 | Get in the car.,Grimpez dans la voiture ! 34 | It only takes about fifteen minutes to walk to the station from here.,Cela ne prend qu'environ quinze minutes d'aller d'ici à la gare à pied. 35 | How long will I have to wait?,Combien de temps vais-je devoir attendre ? 36 | She finished her errand and returned home.,Elle a terminé ses courses et elle est rentrée chez elle. 37 | Ignore Tom's request.,Ignore la demande de Tom. 38 | I eat lunch here two or three times a week.,Je déjeune ici deux ou trois fois par semaine. 39 | The soldiers were ready to die for their country.,Les soldats étaient prêts à mourir pour leur pays. 40 | I hope this data is wrong.,J'espère que ces données sont fausses. 41 | Is it my turn?,Est-ce mon tour ? 42 | I find this difficult to believe.,Je trouve ceci difficile à croire. 43 | I know you don't like me very much.,Je sais que tu ne m'apprécies guère. 44 | We're both wrong.,Nous avons tous deux tort. 45 | We have been waiting here for hours.,Ça fait des heures qu'on attend ici. 46 | Can you believe it?,Pouvez-vous le croire ? 47 | Give me a hammer.,Donnez-moi un marteau. 48 | Were you even tempted?,As-tu même été tenté ? 49 | Would someone please wake me up at 2:30?,"Quelqu'un peut-il me réveiller à 2:30, s'il vous plaît ?" 50 | The disagreement between the union and management could lead to a strike.,Le désaccord entre le syndicat et la direction pourrait mener à la grève. 51 | I was a little put out by this.,J'en fus un peu irrité. 52 | Would you like to play with us?,Voudriez-vous jouer avec nous ? 53 | Do you like Wagner?,Appréciez-vous Wagner ? 54 | I must speak with you alone.,Il me faut vous parler seul. 55 | I no longer have the energy to talk.,Je n'ai plus la force de discuter. 56 | "I have a lot to do today, so if you don't mind, I'd like to have this discussion at another time.","J'ai fort à faire aujourd'hui alors, si ça ne vous dérange pas, j'aimerais avoir cette discussion une autre fois." 57 | Is this seat vacant?,Cette place est-elle libre ? 58 | "We have to do this now. ""There's no time.""",« Nous devons faire ça maintenant. » « Il n'y a pas le temps. » 59 | Don't leave town.,Ne quitte pas la ville. 60 | It's a very difficult situation.,C'est une situation très difficile. 61 | Does he like China?,Apprécie-t-il la Chine ? 62 | Did anyone else help you?,Qui que ce soit d'autre t'a-t-il aidé ? 63 | He saved a hundred dollars.,Il économisa cent dollars. 64 | "Can I talk to you for a second, please?","Puis-je vous parler une seconde, s'il vous plaît ?" 65 | Bring me a moist towel.,Apporte-moi une serviette humide. 66 | No need to worry.,Pas de quoi s'inquiéter. 67 | "It was a dry year, and many animals starved.",C'était une année de sécheresse et de nombreux animaux furent affamés. 68 | This is the restaurant that I often eat at.,J'ai l'habitude de manger dans ce restaurant. 69 | It is now necessary to add another rule.,Il est désormais nécessaire de rajouter une nouvelle règle. 70 | You look happy.,Tu as l'air heureux. 71 | I don't really understand what just happened.,Je ne comprends pas vraiment ce qui vient de se passer. 72 | It smells good!,Ça sent bon ! 73 | Spend your time wisely and you'll always have enough of it.,Emploie ton temps judicieusement et tu en auras toujours assez. 74 | Is there something in particular that you want?,Y a-t-il quelque chose de particulier que vous vouliez ? 75 | Beethoven was a great musician.,Beethoven était un grand musicien. 76 | She took him to the store.,Elle l'a emmené au magasin. 77 | I think we can make it on time.,Je pense que nous pouvons le faire à temps. 78 | How much time did we lose?,Combien de temps avons-nous perdu ? 79 | Have you been mistreated?,As-tu été maltraité ? 80 | Routine exercise is great for your health.,Un exercice régulier est excellent pour la santé. 81 | Do you have a retirement plan?,Avez-vous une retraite ? 82 | He quit.,Il laissa tomber. 83 | I feel wonderful.,Je me sens à merveille. 84 | You're not too late.,Tu n'es pas trop en retard. 85 | I agreed with her.,Je fus d'accord avec elle. 86 | She gave me good advice.,Elle m'a donné des conseils judicieux. 87 | I caught them in the act.,Je les ai attrapés la main dans le sac. 88 | Who was the book written by?,Par qui le livre fut-il écrit ? 89 | I don't know what I want to do with my life.,J'ignore ce que je veux faire de ma vie. 90 | Everyone is here now.,Tout le monde est désormais là. 91 | I want this matter taken care of immediately.,Je veux qu'on prenne immédiatement soin de cette affaire. 92 | This bread smells really good.,Ce pain sent vraiment très bon. 93 | I sure wish you would leave.,J'espère certainement que tu partes. 94 | Minnesota's state bird is the mosquito.,L'oiseau symbole de l'État du Minnesota est le moustique. 95 | "All things considered, I think you should go back home and take care of your parents.","Tout bien considéré, je pense que tu devrais retourner chez toi pour t'occuper de tes parents." 96 | How many times have you lied to me?,Combien de fois m'avez-vous menti ? 97 | Our school was reduced to ashes.,Notre école a été réduite en cendres. 98 | I don't know much about that.,Je n'en sais pas grand-chose. 99 | He left Japan for Europe.,Il a quitté le Japon pour l'Europe. 100 | I can't believe you did this by yourself.,Je n'arrive pas à croire que tu aies fait cela par toi-même. 101 | Tom ate up all the cookies.,Tom a mangé tous les biscuits. 102 | What have you bought your girlfriend for Christmas?,"Qu'avez-vous acheté à votre petite-amie, pour Noël ?" 103 | What's your favorite sport?,Quel est ton sport préféré ? 104 | Why are you doing this to me?,Pourquoi me faites-vous ça ? 105 | I'd be delighted.,J'en serais ravi. 106 | I fully support your proposal.,Je soutiens complètement votre proposition. 107 | Your book is here.,Ton livre est ici. 108 | We've got a bigger problem.,Nous avons un plus gros problème. 109 | Hurry up! I don't have all day.,Dépêchez-vous ! Je n'ai pas toute la journée ! 110 | You said you were happy.,Tu as dit que vous étiez heureux. 111 | You're annoying.,Tu es chiante. 112 | We can't go there.,Nous ne pouvons pas aller là-bas. 113 | She went to the museum by taxi.,Elle se rendit au musée en taxi. 114 | This house is too big for us.,Cette maison est trop grande pour nous. 115 | I want to buy them all.,Je veux tous les acheter. 116 | She is more pretty than beautiful.,Elle est davantage mignonne que belle. 117 | Do you know what this is all about?,Sais-tu de quoi tout ça retourne ? 118 | I'll get rid of them.,Je m'en débarrasserai. 119 | It's not as cold today as it was yesterday.,Il ne fait pas aussi froid aujourd'hui qu'hier. 120 | I should've worn a coat.,J'aurai dû mettre un manteau. 121 | Not everyone agrees with you.,Tout le monde n'est pas de ton avis. 122 | He raised his arm.,Il leva le bras. 123 | Do you know what that means?,Sais-tu ce que cela signifie ? 124 | Drug addiction is a cancer in modern society.,La toxicomanie est un cancer au sein de la société moderne. 125 | I need a pair of scissors to cut this paper.,J'ai besoin d'une paire de ciseaux pour couper ce papier. 126 | Don't tell me to go home.,Ne me dis pas de m'en aller chez moi ! 127 | Someone stole something from my bag.,Quelqu'un a dérobé quelque chose de mon sac. 128 | I live upstate.,Je réside au nord de l'état. 129 | When would it be convenient for you?,Quand est-ce que ça vous arrangera ? 130 | Are you dating anyone?,Sors-tu avec quelqu'un ? 131 | Ignore Tom.,Ignore Tom. 132 | Try to make it last a little longer.,Essayez de le faire durer un peu plus longtemps. 133 | We play football every Saturday.,Nous jouons au football chaque samedi. 134 | "Mary looks like her mother, but her personality is different.",Marie ressemble à sa mère mais elle a une personnalité différente. 135 | I hate talking about politics.,Je déteste parler de politique. 136 | I thought you were taller.,Je pensais que vous étiez plus grand. 137 | Don't expect too much.,N'attends pas trop. 138 | -------------------------------------------------------------------------------- /data/translation/val.csv: -------------------------------------------------------------------------------- 1 | source,target 2 | What the devil are you doing?,Mais que diable es-tu en train de faire ? 3 | They smiled at each other.,Ils se sourirent l'un à l'autre. 4 | How many hours a day does she spend in the kitchen?,Combien d'heures passe-t-elle quotidiennement à la cuisine ? 5 | Did you catch the first train?,Avez-vous eu le premier train ? 6 | No one could see us.,Personne ne pouvait nous voir. 7 | His soldiers feared and respected him.,Ses soldats le craignaient et le respectaient. 8 | "When I was in school, I really hated writing essays.","Lorsque j'étais à l'école, j'avais vraiment horreur d'écrire des dissertations." 9 | I wish you'd never been born.,J'aimerais que tu ne sois jamais née. 10 | Please tell me your opinion.,Donnez-moi votre opinion s'il vous plaît. 11 | The huge tanker has just left the dock.,Le pétrolier géant vient de quitter le bassin. 12 | The cows are eating grass.,Les vaches paissent. 13 | I got my son to repair the door.,J'ai fait réparer la porte par mon fils. 14 | The crowd applauded for several minutes.,La foule a applaudi pendant plusieurs minutes. 15 | She caught me by the arm and stopped me from going home.,Elle m'attrapa par le bras et me retint d'aller chez moi. 16 | He went to America to study American literature.,Il est allé en Amérique pour étudier la littérature américaine. 17 | Take a walk.,Va marcher ! 18 | He told me to meet him at his apartment.,Il m'a dit de le rencontrer à son appartement. 19 | You can't just come in here and start ordering people around.,Tu ne peux pas juste venir ici et commencer à donner des ordres aux gens alentour. 20 | Are you sure you're up for it?,Êtes-vous sûr d'y être prêt ? 21 | All is well that ends well.,Tout est bien qui finit bien. 22 | Tom seems courteous.,Tom a l'air courtois. 23 | I'm going to tell you a secret.,Je vais vous conter un secret. 24 | Would you like to be my apprentice?,Aimeriez-vous être mon apprenti ? 25 | I don't regret a thing.,Je ne regrette rien. 26 | I will always love Mary.,Je serai toujours amoureux de Mary. 27 | I believe things will get better.,Je crois que les choses vont s'améliorer. 28 | Tom didn't miss a thing.,Tom n'a rien manqué. 29 | You may take this book as long as you keep it clean.,Tu pourras prendre ce livre tant que tu le gardes propre. 30 | I wonder whose car this is.,Je me demande à qui est cette voiture. 31 | He wants to sell his old car to a man in Kobe.,Il veut vendre sa vieille voiture à un homme à Kobe. 32 | Give me some more tea.,Donnez-moi un peu plus de thé. 33 | They are Russian.,Elles sont russes. 34 | Basketball is a lot of fun.,Le basket-ball est très distrayant. 35 | I must go alone.,Il faut que je m'y rende seule. 36 | I don't have any evidence.,Je ne dispose d'aucune preuve. 37 | We owe part of our success to luck.,Nous devons une part de notre succès à la chance. 38 | "This looks pretty interesting, Hiroshi says.",Ca a l'air plutôt intéressant dit Hiroshi. 39 | Keep next Sunday free.,Garde dimanche prochain de libre. 40 | What don't you want me to do?,Que ne voulez-vous pas que je fasse ? 41 | "If you're going to do this, you'd better hurry.","Si tu vas faire ça, tu ferais mieux de te grouiller !" 42 | It's not enough.,Ce n'est pas suffisant. 43 | Tom certainly appeared to be enjoying himself.,Tom avait l'air de bien s'amuser. 44 | I heard coughing.,J'ai entendu tousser. 45 | I lost my camera.,J'ai perdu mon appareil photo. 46 | Is this pure gold?,Est-ce de l'or pur ? 47 | I wrote to her last month.,Je lui ai écrit le mois dernier. 48 | Tom has already finished the book he started reading last night.,Tom a déjà fini le livre qu'il avait commencé à lire la nuit dernière. 49 | I used to think no one cared.,Je pensais que personne ne s'en souciait. 50 | I've had a very busy morning.,J'ai été très occupée toute la matinée. 51 | I didn't mean to surprise you.,Je n'ai pas eu l'intention de vous surprendre. 52 | Tom plays the trombone.,Tom joue du trombone. 53 | Take off your cap.,Ôte ta casquette. 54 | Don't make me regret this.,Ne me le fais pas regretter. 55 | I can't get the car to start.,Je n'arrive pas à faire démarrer la voiture. 56 | She disguised herself as him.,Elle se déguisa en lui. 57 | It was very painful.,Ça a été très douloureux. 58 | I wouldn't want anything to happen to you.,Je ne voudrais pas qu'il vous arrive quelque chose. 59 | Tom is on leave.,Tom est en congé. 60 | How long will the train stop here?,Combien de temps le train s'arrête-t-il ici ? 61 | I think we're even.,Je pense que nous sommes à égalité. 62 | You must take care of the dog.,Tu dois t'occuper du chien. 63 | A walk before breakfast is refreshing.,Une promenade avant le petit déjeuner est rafraîchissante. 64 | "As a rule of thumb, you should plan on one pound of beef for every two guests.","À grosse maille, vous devriez prévoir une livre de bœuf pour deux invités." 65 | Is that why they died?,Est-ce là pourquoi ils sont morts ? 66 | It's almost seven. We have to go to school.,Il est presque sept heures. Il nous faut aller à l'école. 67 | I'm not uncomfortable.,Je ne me sens pas mal à l'aise. 68 | I pretended that I was sleeping.,J'ai fait semblant d'être en train de dormir. 69 | I'm sorry to be so late. The meeting completely slipped my mind.,Je suis désolée d'être tellement en retard. J'avais complètement oublié le rendez-vous. 70 | He is photogenic.,Il est photogénique. 71 | Give a man a fish and you feed him for a day. Teach a man to fish and you feed him for the rest of his life.,"Donne un poisson à un homme, il mangera un jour. Apprends-lui à pêcher, il mangera toute sa vie." 72 | I've finally got the whole set!,Je possède enfin un set complet ! 73 | "If you can digest them, raw vegetables are good for your health.","Si on peut les digérer, les légumes crus font du bien à la santé." 74 | Both brothers are musicians.,Les deux frères sont musiciens. 75 | How are you today?,"Comment vous portez-vous, aujourd'hui ?" 76 | He parties too much.,Il fait trop la bamboula. 77 | She advised him to go there.,Elle lui conseilla d'y aller. 78 | I think you might be overreacting.,Je pense qu'il se pourrait que tu dramatises. 79 | He inherited an old wooden chest.,Il a hérité d'un vieux coffre en bois. 80 | I like French food very much.,La cuisine française me plaît énormément. 81 | I plan to try reading some other books.,J'ai l'intention de lire quelques autres livres. 82 | He rejected our offer.,Il a rejeté notre proposition. 83 | She lives on a small pension.,Elle vit avec une petite pension. 84 | Are you aware that Okinawa is closer to China than to Honshu?,Sais-tu qu'Okinawa est plus près de la Chine que de Honshu ? 85 | Reading is to the mind what exercise is to the body.,La lecture est à l'esprit ce que l'exercice physique est au corps. 86 | The cup has a crack.,La tasse a une fêlure. 87 | We're on schedule.,On est dans les clous. 88 | She'd just begun to read the book when someone knocked on the door.,Elle venait de commencer à lire le livre quand quelqu'un frappa à la porte. 89 | My bicycle needs to be repaired.,Mon vélo a besoin d'être réparé. 90 | He's an author.,C'est un auteur. 91 | I didn't think you'd be coming.,Je ne pensais pas que tu viendrais. 92 | Do you think your parents spent enough time with you when you were in your teens?,Pensez-vous que vos parents ont passé suffisamment de temps avec vous lorsque vous étiez adolescent ? 93 | What do you want to do about it?,Que veux-tu qu'on y fasse ? 94 | You have a message.,Tu as un message. 95 | You're too trusting.,Vous êtes trop confiante. 96 | What you're saying isn't logical.,Ce que vous dites n'est pas logique. 97 | Exercise is to the body what thinking is to the brain.,Le sport est au corps ce que la pensée est au cerveau. 98 | They all have drinks.,Elles ont toutes des boissons. 99 | I'm between jobs.,Je suis en train de changer de poste. 100 | She's in the garden planting roses.,Elle plante des roses dans le jardin. 101 | Piercings can be a form of rebellion for teenagers.,Les piercings peuvent être une forme de rébellion pour les adolescents. 102 | He had the right idea.,Il a eu la bonne idée. 103 | Thank you for calling.,Merci d'avoir appelé. 104 | Do you have a book written in English?,As-tu un livre écrit en anglais ? 105 | Lincoln died in 1865.,Lincoln est mort en 1865. 106 | His hobby is painting pictures.,Son passe-temps est de peindre des tableaux. 107 | You're very stylish.,Vous êtes très élégante. 108 | Can you still remember the time we first met?,Arrives-tu à te rappeler le moment où nous nous sommes rencontrés la première fois ? 109 | Who put that there?,Qui a mis ça là ? 110 | Was it cloudy in Tokyo yesterday?,Le ciel était-il couvert à Tokyo hier ? 111 | Reading is to the mind what food is to the body.,La lecture est à l'esprit ce que la nourriture est au corps. 112 | There is a very old temple in the town.,Il y a un très vieux temple dans la ville. 113 | I've been misunderstood.,J'ai été mal comprise. 114 | This is the latest fashion.,C'est la dernière mode. 115 | Tom won.,Tom a gagné. 116 | I think watching TV is a waste of time.,Je pense que regarder la télévision est une perte de temps. 117 | He stared at me from head to foot.,Il m'observa de la tête aux pieds. 118 | What was it that we were told to do?,C'est quoi qu'on nous a dit de faire ? 119 | I can't protect you.,Je ne peux vous protéger. 120 | I gave the beggar what money I had.,J'ai donné au mendiant tout l'argent que j'avais. 121 | You have some books.,Vous avez quelques livres. 122 | You said you were happy.,Tu as dit que vous étiez heureuses. 123 | Where is Tom working now?,Où travaille Tom maintenant? 124 | Are you people lost?,Êtes-vous perdus ? 125 | She goes to the bookstore once a week.,Elle va à la librairie une fois par semaine. 126 | I don't mind hot weather.,Le temps chaud ne me dérange pas. 127 | We're sloshed.,Nous sommes bourrées. 128 | What do you need exactly?,De quoi as-tu besoin exactement ? 129 | He did what he promised to do.,Il a fait ce qu'il a promis. 130 | He might come tomorrow.,Il se peut qu'il vienne demain. 131 | She steered our efforts in the right direction.,Elle a dirigé nos efforts dans la bonne direction. 132 | That a boy!,T'es un bon garçon ! 133 | I think you must be getting tired.,Je pense que tu dois être fatigué. 134 | I don't speak fast.,Je ne parle pas rapidement. 135 | It used to be nearly impossible.,C'était presque impossible. 136 | We go to the movies together once in a while.,Nous allons ensemble au cinéma de temps en temps. 137 | "I want to play, too.","Je veux jouer, moi aussi." 138 | -------------------------------------------------------------------------------- /models/luong_attention/luong_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 7 | 8 | 9 | class EncoderRNN(nn.Module): 10 | def __init__(self, hidden_size, input_size, n_layers, dropout, word_embedding_matrix, rnn_cell, use_cuda): 11 | super().__init__() 12 | self.hidden_size = hidden_size 13 | self.input_size = input_size 14 | self.n_layers = n_layers 15 | self.dropout = dropout 16 | self.embedding = word_embedding_matrix 17 | self.use_cuda = use_cuda 18 | 19 | if rnn_cell == "GRU": 20 | self.rnn = nn.GRU(input_size, hidden_size, n_layers, dropout=dropout, bidirectional=True) 21 | 22 | def forward(self, input_seqs, hidden, input_lengths): 23 | """ 24 | input_seqs : (Max input length, batch_size) 25 | input_lengths: (batch_size) 26 | """ 27 | 28 | # Max input length, batch size, hidden_size 29 | embedded = self.embedding(input_seqs) 30 | packed = pack_padded_sequence(embedded, input_lengths) 31 | outputs, hidden = self.rnn(packed, hidden) 32 | outputs, output_lengths = pad_packed_sequence(outputs) 33 | 34 | # Max input length, batch_size, hidden_size, we add the backward and forward 35 | # hidden states together 36 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] 37 | 38 | # Get the forwards and backwards hidden states 39 | hidden_layers = [] 40 | for i in range(self.n_layers): 41 | hidden_layers.append((hidden[i * 2, :, :] + hidden[(i * 2) + 1, :, :]).unsqueeze(0)) 42 | return outputs, torch.cat(hidden_layers, 0) 43 | 44 | def init_hidden(self, batch_size): 45 | hidden = Variable(torch.zeros(self.n_layers * 2, batch_size, self.hidden_size)) 46 | if self.use_cuda: hidden = hidden.cuda() 47 | return hidden 48 | 49 | class Attn(nn.Module): 50 | def __init__(self, method, hidden_size, use_cuda=False): 51 | super().__init__() 52 | self.method = method 53 | self.hidden_size = hidden_size 54 | self.use_cuda = use_cuda 55 | 56 | if self.method == 'general': 57 | self.attn = nn.Linear(self.hidden_size, hidden_size) 58 | elif self.method == 'concat': 59 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size) 60 | self.v = nn.Parameter(torch.FloatTensor(1, hidden_size)) 61 | 62 | def forward(self, hidden, encoder_outputs): 63 | """ 64 | hidden : 1, batch_size, hidden_size 65 | encoder_outputs : max input length, batch_size, hidden_size 66 | """ 67 | # max_len = encoder_outputs.size(0) 68 | # this_batch_size = encoder_outputs.size(1) 69 | 70 | attn_energies = torch.bmm(self.attn(hidden).transpose(1, 0), encoder_outputs.permute(1, 2, 0)) 71 | 72 | # Batch size, 1, max input length 73 | return F.softmax(attn_energies) 74 | 75 | def score(self, hidden, encoder_output): 76 | 77 | if self.method == 'general': 78 | energy = self.attn(encoder_output).view(-1) 79 | energy = hidden.view(-1).dot(energy) 80 | return energy 81 | 82 | 83 | 84 | class LuongAttnDecoderRNN(nn.Module): 85 | def __init__(self, score_function, hidden_size, input_size, output_size, 86 | n_layers, dropout , word_embedding_matrix, use_cuda): 87 | super(LuongAttnDecoderRNN, self).__init__() 88 | 89 | # Keep for reference 90 | self.score_function = score_function 91 | self.hidden_size = hidden_size 92 | self.input_size = input_size 93 | self.output_size = output_size 94 | self.n_layers = n_layers 95 | self.dropout = dropout 96 | self.use_cuda = use_cuda 97 | 98 | # Define layers 99 | self.embedding = word_embedding_matrix 100 | self.gru = nn.GRU(self.input_size, self.hidden_size, n_layers, dropout=self.dropout) 101 | self.concat = nn.Linear(self.hidden_size * 2, self.hidden_size) 102 | self.out = nn.Linear(self.hidden_size, self.output_size) 103 | 104 | # Choose attention models 105 | if score_function != 'none': 106 | self.attn = Attn(score_function, hidden_size, use_cuda=use_cuda) 107 | 108 | def forward(self, input_seq, last_hidden, encoder_outputs): 109 | """ 110 | input_seq : batch_size 111 | hidden : hidden_size, batch_size 112 | encoder_outputs : max input length, batch_size, hidden_size 113 | """ 114 | # Note: we run this one step at a time 115 | 116 | # logging.debug(f"input_seq:\n{input_seq}") 117 | # logging.debug(f"last_hidden:\n{last_hidden}") 118 | # logging.debug(f"encoder_outputs:\n{encoder_outputs}") 119 | 120 | batch_size = input_seq.size(0) 121 | 122 | # (batch size, hidden_size) 123 | embedded = self.embedding(input_seq) 124 | 125 | # (1, batch size, input_size) add another dimension so that it works with 126 | # the GRU 127 | embedded = embedded.view(1, batch_size, self.input_size) # S=1 x B x N 128 | 129 | logging.debug(f"embedded:\n{embedded}") 130 | # Get current hidden state from input word and last hidden state 131 | rnn_output, hidden = self.gru(embedded, last_hidden) 132 | logging.debug(f"rnn_output:\n{rnn_output}") 133 | logging.debug(f"hidden:\n{hidden}") 134 | 135 | 136 | # Calculate attention from current RNN state and all encoder outputs; 137 | # apply to encoder outputs to get weighted average 138 | 139 | # batch size, max input length 140 | attn_weights = self.attn(rnn_output, encoder_outputs) 141 | logging.debug(f"attn_weights:\n{attn_weights}") 142 | 143 | # (batch_size, 1, max input length) @ batch_size, max input length, hidden size 144 | # note that we use this convention here to take advantage of the bmm function 145 | context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # B x S=1 x N 146 | 147 | # Attentional vector using the RNN hidden state and context vector 148 | # concatenated together (Luong eq. 5) 149 | rnn_output = rnn_output.squeeze(0) # S=1 x B x N -> B x N 150 | context = context.squeeze(1) # B x S=1 x N -> B x N 151 | concat_input = torch.cat((rnn_output, context), 1) 152 | concat_output = F.tanh(self.concat(concat_input)) 153 | 154 | # Finally predict next token (Luong eq. 6, without fftmax) 155 | output = self.out(concat_output) 156 | 157 | # Return final output, hidden state, and attention weights (for visualization) 158 | return output, hidden, attn_weights -------------------------------------------------------------------------------- /models/luong_attention_batch/luong_attention_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import torch.nn as nn 4 | import numpy as np 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 8 | from utils.masked_cross_entropy import sequence_mask 9 | import sys 10 | 11 | 12 | class EncoderRNN(nn.Module): 13 | def __init__(self, hidden_size, input_size, n_layers, dropout, word_embedding_matrix, rnn_cell, use_cuda): 14 | super().__init__() 15 | self.hidden_size = hidden_size 16 | self.input_size = input_size 17 | self.n_layers = n_layers 18 | self.dropout = dropout 19 | self.embedding = word_embedding_matrix 20 | if rnn_cell == "GRU": 21 | self.rnn = nn.GRU(input_size, hidden_size, n_layers, dropout, bidirectional=True) 22 | 23 | self.use_cuda = use_cuda 24 | 25 | def forward(self, input_seqs, hidden, input_lengths): 26 | """ 27 | input_seqs : (Max input length, batch_size) 28 | input_lengths: (batch_size) 29 | """ 30 | 31 | # # Max input length, batch size, hidden_size 32 | embedded = self.embedding(input_seqs) 33 | packed = pack_padded_sequence(embedded, input_lengths) 34 | outputs, hidden = self.rnn(packed, hidden) 35 | outputs, output_lengths = pad_packed_sequence(outputs) 36 | 37 | # Max input length, batch_size, hidden_size, we add the backward and forward 38 | # hidden states together 39 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] 40 | 41 | hidden_layers = [] 42 | for i in range(self.n_layers): 43 | hidden_layers.append((hidden[i*2, :, :] + hidden[(i*2)+1, :, :]).unsqueeze(0)) 44 | return outputs, torch.cat(hidden_layers, 0) 45 | 46 | def init_hidden(self, batch_size): 47 | hidden = Variable(torch.zeros(self.n_layers * 2, batch_size, self.hidden_size)) 48 | if self.use_cuda: hidden = hidden.cuda() 49 | return hidden 50 | 51 | 52 | class LuongAttnDecoderRNN(nn.Module): 53 | def __init__(self, 54 | hidden_size, 55 | input_size, 56 | output_size, 57 | n_layers, 58 | word_embedding_matrix, 59 | dropout, 60 | rnn_cell, 61 | use_cuda): 62 | super(LuongAttnDecoderRNN, self).__init__() 63 | 64 | # Keep for reference 65 | self.hidden_size = hidden_size 66 | self.input_size = input_size 67 | self.output_size = output_size 68 | self.n_layers = n_layers 69 | self.use_cuda = use_cuda 70 | 71 | # Define layers 72 | self.embedding = word_embedding_matrix 73 | if rnn_cell == "GRU": 74 | self.rnn = nn.GRU(input_size, hidden_size, n_layers, dropout) 75 | self.attn = LuongAttention(hidden_size, use_cuda) 76 | self.concat = nn.Linear(hidden_size * 2, hidden_size) 77 | self.out = nn.Linear(hidden_size, output_size) 78 | 79 | def forward(self, 80 | input_seq, 81 | input_lengths, 82 | encoder_hidden, 83 | encoder_outputs, 84 | encoder_outputs_length): 85 | """ 86 | input_seq : batch_size 87 | hidden : hidden_size, batch_size 88 | encoder_outputs : max input length, batch_size, hidden_size 89 | """ 90 | # Note: we run this one step at a time 91 | 92 | # logging.debug(f"input_seq:\n{input_seq}") 93 | # logging.debug(f"last_hidden:\n{last_hidden}") 94 | # logging.debug(f"encoder_outputs:\n{encoder_outputs}") 95 | 96 | # sort the input by descending order, but now we also need to sort the encoder 97 | # outputs w/ the same index 98 | sorted_index = np.argsort(input_lengths).tolist()[::-1] 99 | unsorted_index = np.argsort(sorted_index) 100 | sorted_input_seq = input_seq[:, sorted_index] 101 | sorted_input_lengths = np.array(input_lengths)[sorted_index] 102 | sorted_encoder_hidden = encoder_hidden[:, sorted_index, :] 103 | sorted_encoder_outputs = encoder_outputs[:, sorted_index, :] 104 | sorted_encoder_outputs_length = np.array(encoder_outputs_length)[sorted_index].tolist() 105 | 106 | # decoder input: (batch size, hidden_size) 107 | embedded = self.embedding(sorted_input_seq) 108 | 109 | packed = pack_padded_sequence(embedded, sorted_input_lengths) 110 | decoder_outputs, decoder_hidden = self.rnn(packed, sorted_encoder_hidden) 111 | decoder_outputs, decoder_outputs_length = pad_packed_sequence(decoder_outputs) 112 | 113 | # Calculate attention from current RNN state and all encoder outputs; 114 | # apply to encoder outputs to get weighted average 115 | 116 | # batch size, max input length. if we use batch we would need to feed the attention 117 | # the sorted encoder outputs and sorted encoder output lengths 118 | attn_weights = self.attn(sorted_encoder_outputs, 119 | sorted_encoder_outputs_length, 120 | decoder_outputs, 121 | decoder_outputs_length) 122 | logging.debug(f"attn_weights:\n{attn_weights}") 123 | 124 | # attn_weights is in (batch_size, max_decoder_length, max_encoder_length) 125 | # encoder_outputs is (max_encoder_length, batch_size, hidden_size) 126 | # but transposed would be (batch_size, max_encoder_length, hidden_size) 127 | context = torch.bmm(attn_weights, encoder_outputs.transpose(1, 0)) 128 | 129 | 130 | concat_rep = torch.cat([decoder_outputs, context.transpose(0, 1)], -1) 131 | concat_output = self.concat(F.tanh(concat_rep)) 132 | 133 | final_decoder_output = self.out(concat_output) 134 | # Why do we need to use .tolist() for `output` but not for `final_decoder_output` 135 | original_position_outputs = final_decoder_output[:, unsorted_index.tolist(), :] 136 | 137 | # Return final output, hidden state, and attention weights (for visualization) 138 | return original_position_outputs, decoder_hidden, attn_weights 139 | 140 | 141 | class LuongAttention(nn.Module): 142 | """ 143 | Note here that we are only implementing the 'general' method as denoted in the paper 144 | """ 145 | 146 | def __init__(self, hidden_size, use_cuda): 147 | super().__init__() 148 | self.hidden_size = hidden_size 149 | self.general_weights = Variable(torch.randn(hidden_size, hidden_size)) 150 | self.use_cuda = use_cuda 151 | if use_cuda: 152 | self.general_weights = self.general_weights.cuda() 153 | 154 | def forward(self, 155 | encoder_outputs, 156 | encoder_outputs_length, 157 | decoder_outputs, 158 | decoder_outputs_length): 159 | """ 160 | 161 | :param encoder_outputs: max_encoder_length, batch_size, hidden_size 162 | :param encoder_outputs_length: batch_size 163 | :param decoder_outputs: max_decoder_length, batch_size, hidden_size 164 | :param decoder_outputs_length: batch_size 165 | :return: attention_aware_output 166 | """ 167 | 168 | # (batch_size, max_decoder_length, hidden_size) 169 | decoder_outputs = torch.transpose(decoder_outputs, 0, 1) 170 | 171 | # (batch_size, hidden_size, max_encoder_length) 172 | encoder_outputs = encoder_outputs.permute(1, 2, 0) 173 | 174 | # (batch_size, max_encoder_length, max_decoder_length 175 | score = torch.bmm(decoder_outputs @ self.general_weights, encoder_outputs) 176 | 177 | (attention_mask, 178 | max_enc_outputs_length, 179 | max_dec_outputs_length) = self.attention_sequence_mask(encoder_outputs_length, decoder_outputs_length) 180 | masked_score = score + attention_mask 181 | weights_flat = F.softmax(masked_score.view(-1, max_enc_outputs_length)) 182 | weights = weights_flat.view(-1, max_dec_outputs_length, max_enc_outputs_length) 183 | 184 | return weights 185 | 186 | def attention_sequence_mask(self, encoder_outputs_length, decoder_outputs_length): 187 | batch_size = len(encoder_outputs_length) 188 | max_encoder_outputs_length = max(encoder_outputs_length) 189 | max_decoder_outputs_length = max(decoder_outputs_length) 190 | 191 | encoder_sequence_mask = sequence_mask(encoder_outputs_length, use_cuda=self.use_cuda) 192 | encoder_sequence_mask_expand = (encoder_sequence_mask 193 | .unsqueeze(1) 194 | .expand(batch_size, 195 | max_decoder_outputs_length, 196 | max_encoder_outputs_length)) 197 | 198 | decoder_sequence_mask = sequence_mask(decoder_outputs_length, use_cuda=self.use_cuda) 199 | decoder_sequence_mask_expand = (decoder_sequence_mask 200 | .unsqueeze(2) 201 | .expand(batch_size, 202 | max_decoder_outputs_length, 203 | max_encoder_outputs_length)) 204 | attention_mask = (encoder_sequence_mask_expand * 205 | decoder_sequence_mask_expand).float() 206 | attention_mask = (attention_mask - 1) * sys.maxsize 207 | return (attention_mask, 208 | max_encoder_outputs_length, 209 | max_decoder_outputs_length) 210 | -------------------------------------------------------------------------------- /models/luong_attention_manual_mask/luong_attention_manual_mask.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 10 | 11 | from utils.masked_cross_entropy import sequence_mask 12 | 13 | 14 | class EncoderRNN(nn.Module): 15 | def __init__(self, hidden_size, input_size, n_layers, dropout, word_embedding_matrix, rnn_cell, should_mask, use_cuda): 16 | super().__init__() 17 | self.hidden_size = hidden_size 18 | self.input_size = input_size 19 | self.n_layers = n_layers 20 | self.dropout = dropout 21 | self.embedding = word_embedding_matrix 22 | if rnn_cell == "GRU": 23 | self.rnn = nn.GRU(input_size, hidden_size, n_layers, dropout, bidirectional=True) 24 | 25 | self.use_cuda = use_cuda 26 | 27 | def forward(self, input_seqs, hidden, input_lengths): 28 | """ 29 | input_seqs : (Max input length, batch_size) 30 | input_lengths: (batch_size) 31 | """ 32 | # 33 | # embedded = self.embedding(input_seqs) 34 | # outputs = embedded 35 | # hidden_layers = [] 36 | # for _ in range(self.n_layers): 37 | # outputs, hidden = masked_rnn_cell(self.gru, outputs, input_lengths, None, True, self.use_cuda) 38 | # hidden_layers.append(hidden) 39 | # return outputs, torch.cat(hidden_layers, 0) 40 | 41 | # # Max input length, batch size, hidden_size 42 | embedded = self.embedding(input_seqs) 43 | packed = pack_padded_sequence(embedded, input_lengths) 44 | outputs, hidden = self.rnn(packed, hidden) 45 | outputs, output_lengths = pad_packed_sequence(outputs) 46 | 47 | # Max input length, batch_size, hidden_size, we add the backward and forward 48 | # hidden states together 49 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] 50 | 51 | hidden_layers = [] 52 | for i in range(self.n_layers): 53 | hidden_layers.append((hidden[i*2, :, :] + hidden[(i*2)+1, :, :]).unsqueeze(0)) 54 | return outputs, torch.cat(hidden_layers, 0) 55 | 56 | 57 | def init_hidden(self, batch_size): 58 | hidden = Variable(torch.zeros(self.n_layers * 2, batch_size, self.hidden_size)) 59 | if self.use_cuda: hidden = hidden.cuda() 60 | return hidden 61 | 62 | 63 | class LuongAttnDecoderRNN(nn.Module): 64 | def __init__(self, 65 | hidden_size, 66 | input_size, 67 | output_size, 68 | n_layers, 69 | word_embedding_matrix, 70 | dropout, 71 | rnn_cell, 72 | use_cuda): 73 | super(LuongAttnDecoderRNN, self).__init__() 74 | 75 | # Keep for reference 76 | self.hidden_size = hidden_size 77 | self.input_size = input_size 78 | self.output_size = output_size 79 | self.n_layers = n_layers 80 | self.use_cuda = use_cuda 81 | 82 | # Define layers 83 | self.embedding = word_embedding_matrix 84 | if rnn_cell == "GRU": 85 | self.rnn = nn.GRU(input_size, hidden_size, n_layers, dropout) 86 | elif rnn_cell == "SRU": 87 | self.rnn = SRU(input_size, hidden_size, n_layers, dropout, use_tanh=1) 88 | self.attn = LuongAttention(hidden_size, use_cuda) 89 | self.concat = nn.Linear(hidden_size * 2, hidden_size) 90 | self.out = nn.Linear(hidden_size, output_size) 91 | 92 | def forward(self, 93 | input_seq, 94 | input_lengths, 95 | encoder_hidden, 96 | encoder_outputs, 97 | encoder_outputs_length): 98 | """ 99 | input_seq : batch_size 100 | hidden : hidden_size, batch_size 101 | encoder_outputs : max input length, batch_size, hidden_size 102 | """ 103 | # Note: we run this one step at a time 104 | 105 | # logging.debug(f"input_seq:\n{input_seq}") 106 | # logging.debug(f"last_hidden:\n{last_hidden}") 107 | # logging.debug(f"encoder_outputs:\n{encoder_outputs}") 108 | 109 | # sort the input by descending order, but now we also need to sort the encoder 110 | # outputs w/ the same index 111 | sorted_index = np.argsort(input_lengths).tolist()[::-1] 112 | unsorted_index = np.argsort(sorted_index) 113 | sorted_input_seq = input_seq[:, sorted_index] 114 | sorted_input_lengths = np.array(input_lengths)[sorted_index] 115 | sorted_encoder_hidden = encoder_hidden[:, sorted_index, :] 116 | sorted_encoder_outputs = encoder_outputs[:, sorted_index, :] 117 | sorted_encoder_outputs_length = np.array(encoder_outputs_length)[sorted_index].tolist() 118 | 119 | # decoder input: (batch size, hidden_size) 120 | embedded = self.embedding(sorted_input_seq) 121 | 122 | packed = pack_padded_sequence(embedded, sorted_input_lengths) 123 | decoder_outputs, decoder_hidden = self.rnn(packed, sorted_encoder_hidden) 124 | decoder_outputs, decoder_outputs_length = pad_packed_sequence(decoder_outputs) 125 | 126 | # Calculate attention from current RNN state and all encoder outputs; 127 | # apply to encoder outputs to get weighted average 128 | 129 | # batch size, max input length. if we use batch we would need to feed the attention 130 | # the sorted encoder outputs and sorted encoder output lengths 131 | attn_weights = self.attn(sorted_encoder_outputs, 132 | sorted_encoder_outputs_length, 133 | decoder_outputs, 134 | decoder_outputs_length) 135 | logging.debug(f"attn_weights:\n{attn_weights}") 136 | 137 | # attn_weights is in (batch_size, max_decoder_length, max_encoder_length) 138 | # encoder_outputs is (max_encoder_length, batch_size, hidden_size) 139 | context = torch.bmm(attn_weights, encoder_outputs.transpose(1, 0)) 140 | concat_rep = torch.cat([decoder_outputs, context.transpose(0, 1)], -1) 141 | concat_output = self.concat(F.tanh(concat_rep)) 142 | 143 | final_decoder_output = self.out(concat_output) 144 | # Why do we need to use .tolist() for `output` but not for `final_decoder_output` 145 | original_position_outputs = final_decoder_output[:, unsorted_index.tolist(), :] 146 | 147 | # Return final output, hidden state, and attention weights (for visualization) 148 | return original_position_outputs, decoder_hidden, attn_weights 149 | 150 | 151 | class LuongAttention(nn.Module): 152 | 153 | def __init__(self, hidden_size, use_cuda): 154 | super().__init__() 155 | self.hidden_size = hidden_size 156 | self.general_weights = Variable(torch.randn(hidden_size, hidden_size)) 157 | self.use_cuda = use_cuda 158 | if use_cuda: 159 | self.general_weights = self.general_weights.cuda() 160 | 161 | def forward(self, 162 | encoder_outputs, 163 | encoder_outputs_length, 164 | decoder_outputs, 165 | decoder_outputs_length): 166 | """ 167 | 168 | :param encoder_outputs: max_encoder_length, batch_size, hidden_size 169 | :param encoder_outputs_length: batch_size 170 | :param decoder_outputs: max_decoder_length, batch_size, hidden_size 171 | :param decoder_outputs_length: batch_size 172 | :return: attention_aware_output 173 | """ 174 | 175 | # (batch_size, max_decoder_length, hidden_size) 176 | decoder_outputs = torch.transpose(decoder_outputs, 0, 1) 177 | 178 | # (batch_size, hidden_size, max_encoder_length) 179 | encoder_outputs = encoder_outputs.permute(1, 2, 0) 180 | 181 | # (batch_size, max_encoder_length, max_decoder_length 182 | score = torch.bmm(decoder_outputs @ self.general_weights, encoder_outputs) 183 | 184 | (attention_mask, 185 | max_enc_outputs_length, 186 | max_dec_outputs_length) = self.attention_sequence_mask(encoder_outputs_length, decoder_outputs_length) 187 | masked_score = score + attention_mask 188 | weights_flat = F.softmax(masked_score.view(-1, max_enc_outputs_length)) 189 | weights = weights_flat.view(-1, max_dec_outputs_length, max_enc_outputs_length) 190 | 191 | return weights 192 | 193 | def attention_sequence_mask(self, encoder_outputs_length, decoder_outputs_length): 194 | batch_size = len(encoder_outputs_length) 195 | max_encoder_outputs_length = max(encoder_outputs_length) 196 | max_decoder_outputs_length = max(decoder_outputs_length) 197 | 198 | encoder_sequence_mask = sequence_mask(encoder_outputs_length, use_cuda=self.use_cuda) 199 | encoder_sequence_mask_expand = (encoder_sequence_mask 200 | .unsqueeze(1) 201 | .expand(batch_size, 202 | max_decoder_outputs_length, 203 | max_encoder_outputs_length)) 204 | 205 | decoder_sequence_mask = sequence_mask(decoder_outputs_length, use_cuda=self.use_cuda) 206 | decoder_sequence_mask_expand = (decoder_sequence_mask 207 | .unsqueeze(2) 208 | .expand(batch_size, 209 | max_decoder_outputs_length, 210 | max_encoder_outputs_length)) 211 | attention_mask = (encoder_sequence_mask_expand * 212 | decoder_sequence_mask_expand).float() 213 | attention_mask = (attention_mask - 1) * sys.maxsize 214 | return (attention_mask, 215 | max_encoder_outputs_length, 216 | max_decoder_outputs_length) 217 | -------------------------------------------------------------------------------- /models/luong_attention_manual_mask/masked_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.masked_cross_entropy import sequence_mask 4 | 5 | 6 | def masked_rnn_cell(rnn_cell, embedded_seq, input_lengths, hidden, bidirectional, use_cuda): 7 | 8 | outputs, _ = rnn_cell(embedded_seq, hidden) 9 | current_batch_size = len(input_lengths) 10 | if bidirectional: 11 | hidden_size = int(outputs.size()[-1]/2) 12 | else: 13 | hidden_size = outputs.size()[-1] 14 | max_length = max(input_lengths) 15 | mask = sequence_mask(input_lengths, use_cuda=use_cuda).float() 16 | mask = mask.transpose(0, 1).unsqueeze(-1).expand(max_length, 17 | current_batch_size, 18 | hidden_size * 2 if bidirectional else hidden_size) 19 | # output_ret = outputs * mask 20 | # last_time_step_indices = torch.from_numpy(np.array(input_lengths) - 1).long() 21 | 22 | output_ret = outputs 23 | last_time_step_indices = torch.from_numpy(np.ones(current_batch_size) * (max(input_lengths) - 1)).long() 24 | batch_extractor_indices = torch.from_numpy(np.arange(current_batch_size)).long() 25 | if use_cuda: 26 | last_time_step_indices = last_time_step_indices.cuda() 27 | batch_extractor_indices = batch_extractor_indices.cuda() 28 | hidden_ret = output_ret[last_time_step_indices, batch_extractor_indices, :] 29 | if bidirectional: 30 | hidden_ret = hidden_ret[:, hidden_size:] + hidden_ret[:, :hidden_size] 31 | output_ret = output_ret[:, :, hidden_size:] + output_ret[:, :,:hidden_size] 32 | return output_ret, hidden_ret.unsqueeze(0) 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /train_luong_attention.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import random 4 | import time 5 | import importlib 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from torch import optim 10 | from torch.autograd import Variable 11 | from tqdm import tqdm 12 | 13 | from models.luong_attention import luong_attention 14 | from utils.batches import batches, data_from_batch 15 | from utils.embeddings import create_embedding_maps 16 | from utils.masked_cross_entropy import masked_cross_entropy 17 | from utils.tokens import Tokens 18 | 19 | cudnn.benchmark = True 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--train_dir", required=True) 25 | parser.add_argument("--batch_size", type=int, default=512) 26 | parser.add_argument("--learning_rate", default=0.0001) 27 | parser.add_argument("--input_size", default=128, type=int) 28 | parser.add_argument("--hidden_size", default=128, type=int) 29 | parser.add_argument("--eval_every", default=50, type=int) 30 | parser.add_argument("--eval_batch_size", default=10, type=int) 31 | parser.add_argument("--n_layers", default=2, type=int) 32 | parser.add_argument("--n_epochs", default=10, type=int) 33 | parser.add_argument("--dropout", default=0.1) 34 | parser.add_argument("--score_function", default="general") 35 | parser.add_argument("--teacher_forcing_ratio", default=1) 36 | parser.add_argument("--decoder_learning_ratio", default=5.) 37 | parser.add_argument("--clip_norm", default=5.0) 38 | parser.add_argument("--log_level", default="INFO") 39 | parser.add_argument("--debug_restrict_data", type=int) 40 | parser.add_argument("--dataset_module", required=True) 41 | parser.add_argument("--different_vocab", action="store_true") 42 | parser.add_argument("--rnn_cell", default="GRU") 43 | parser.add_argument("--use_cuda", action="store_true") 44 | parser.add_argument("--use_batch_attention", action="store_true") 45 | args = parser.parse_args() 46 | 47 | log_level = logging.getLevelName(args.log_level) 48 | logging.basicConfig(level=log_level) 49 | return args 50 | 51 | def load_data(dataset_module, train_dir, debug_restrict_data): 52 | dataset_module_path = f"utils.load_and_preprocessing.{dataset_module}" 53 | dataset_module = importlib.import_module(dataset_module_path) 54 | train, val = dataset_module.load_data(train_dir, debug_restrict_data) 55 | return train, val 56 | 57 | def main(): 58 | args = parse_args() 59 | train, val = load_data(args.dataset_module, args.train_dir, args.debug_restrict_data) 60 | 61 | logging.info("Creating embedding maps") 62 | encoder_embedding_map, \ 63 | decoder_embedding_map, \ 64 | encoder_embedding_matrix, \ 65 | decoder_embedding_matrix = create_embedding_maps(train, val, args.input_size, args.different_vocab) 66 | 67 | encoder = luong_attention.EncoderRNN(args.hidden_size, 68 | args.input_size, 69 | args.n_layers, 70 | args.dropout, 71 | encoder_embedding_matrix, 72 | args.rnn_cell, 73 | args.use_cuda) 74 | 75 | decoder = luong_attention.LuongAttnDecoderRNN(args.score_function, 76 | args.hidden_size, 77 | args.input_size, 78 | decoder_embedding_map.n_words, 79 | args.n_layers, 80 | args.dropout, 81 | decoder_embedding_matrix, 82 | args.use_cuda) 83 | 84 | if args.use_cuda: 85 | encoder = encoder.cuda() 86 | decoder = decoder.cuda() 87 | 88 | encoder_optimizer = optim.Adam(encoder.parameters(), lr=args.learning_rate) 89 | decoder_optimizer = optim.Adam(decoder.parameters(), lr=args.learning_rate * args.decoder_learning_ratio) 90 | 91 | logging.info("Starting training") 92 | 93 | run_train(args.n_epochs, 94 | args.batch_size, 95 | args.eval_batch_size, 96 | args.eval_every, 97 | train, 98 | val, 99 | encoder_embedding_map, 100 | decoder_embedding_map, 101 | encoder, 102 | decoder, 103 | encoder_optimizer, 104 | decoder_optimizer, 105 | args.clip_norm, 106 | args.teacher_forcing_ratio, 107 | args.use_cuda, 108 | ) 109 | 110 | 111 | def run_train(n_epochs, 112 | batch_size, 113 | eval_batch_size, 114 | eval_every, 115 | train, 116 | val, 117 | encoder_embedding_map, 118 | decoder_embedding_map, 119 | encoder, 120 | decoder, 121 | encoder_optimizer, 122 | decoder_optimizer, 123 | clip_norm, 124 | teacher_forcing_ratio, 125 | use_cuda, 126 | ): 127 | for i in range(n_epochs): 128 | logging.info(f"EPOCH: {i+1}") 129 | for j, batch in enumerate(tqdm(batches(train, 130 | encoder_embedding_map, 131 | decoder_embedding_map, 132 | use_cuda=use_cuda, 133 | batch_size=batch_size))): 134 | encoder.train() 135 | decoder.train() 136 | encoder_optimizer.zero_grad() 137 | decoder_optimizer.zero_grad() 138 | 139 | source_var, source_lengths, target_var, target_lengths = data_from_batch(batch) 140 | 141 | current_batch_size = len(target_lengths) 142 | 143 | # encoder_outputs: max input length, batch size, hidden size 144 | # encoder_hidden: num_layers, batch size, hidden size 145 | logging.debug("Encoding") 146 | tic = time.time() 147 | encoder_outputs, encoder_hidden = encoder(source_var, 148 | encoder.init_hidden(current_batch_size), 149 | source_lengths) 150 | toc = time.time() 151 | logging.debug(f"Seconds take to encode: {round(toc-tic,2)}") 152 | decoder_input = Variable(torch.LongTensor([Tokens.SOS_token] * current_batch_size)) 153 | decoder_hidden = encoder_hidden 154 | if use_cuda: 155 | decoder_input = decoder_input.cuda() 156 | decoder_hidden = decoder_hidden.cuda() 157 | 158 | logging.debug("Decoding") 159 | tic = time.time() 160 | 161 | max_target_length = max(target_lengths) 162 | all_decoder_outputs = Variable(torch.zeros(max_target_length, current_batch_size, decoder.output_size)) 163 | 164 | if use_cuda: 165 | all_decoder_outputs = all_decoder_outputs.cuda() 166 | 167 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 168 | 169 | if use_teacher_forcing: 170 | logging.debug("Using teacher forcing") 171 | for t in range(max_target_length): 172 | decoder_output, decoder_hidden, _ = decoder(decoder_input, 173 | decoder_hidden, 174 | encoder_outputs) 175 | all_decoder_outputs[t] = decoder_output 176 | decoder_input = target_var[t] 177 | else: 178 | ## TODO: how do I do non-teacher forcing for batch inputs? 179 | logging.debug("Not using teacher forcing") 180 | for t in range(max_target_length): 181 | decoder_output, decoder_hidden, _ = decoder(decoder_input, 182 | decoder_hidden, 183 | encoder_outputs) 184 | all_decoder_outputs[t] = decoder_output 185 | _, top_i = decoder_output.data.topk(1) 186 | decoder_input = Variable(top_i).squeeze(1) 187 | 188 | toc = time.time() 189 | loss = masked_cross_entropy(all_decoder_outputs.transpose(0, 1).contiguous(), 190 | target_var.transpose(0, 1).contiguous(), 191 | target_lengths, 192 | use_cuda=use_cuda) 193 | logging.debug(f"Time taken for 1 decode step: {round(toc-tic, 2)}") 194 | logging.debug("Backpropagating") 195 | 196 | tic = time.time() 197 | loss.backward() 198 | toc = time.time() 199 | logging.debug(f"Seconds taken for backpropagation: {round(toc-tic, 2)}") 200 | 201 | # Clip gradients 202 | logging.debug("Clipping Gradients") 203 | _ = torch.nn.utils.clip_grad_norm(encoder.parameters(), clip_norm) 204 | _ = torch.nn.utils.clip_grad_norm(decoder.parameters(), clip_norm) 205 | 206 | logging.debug("Updating Weights") 207 | encoder_optimizer.step() 208 | decoder_optimizer.step() 209 | 210 | if j % eval_every == 0: 211 | run_eval(encoder_embedding_map, 212 | decoder_embedding_map, 213 | encoder, 214 | decoder, 215 | val, 216 | eval_batch_size, 217 | use_cuda=use_cuda) 218 | 219 | run_eval(encoder_embedding_map, 220 | decoder_embedding_map, 221 | encoder, 222 | decoder, 223 | train, 224 | batch_size, 225 | use_cuda=use_cuda) 226 | 227 | logging.info(f"LOSS: {loss.data[0]}") 228 | 229 | 230 | def run_eval(encoder_embedding_map, 231 | decoder_embedding_map, 232 | encoder, 233 | decoder, 234 | val, 235 | batch_size, 236 | use_cuda=False): 237 | 238 | batch = next(batches(val, 239 | encoder_embedding_map, 240 | decoder_embedding_map, 241 | bucket=False, 242 | batch_size=batch_size, 243 | use_cuda=use_cuda)) 244 | 245 | # Disable to avoid dropout 246 | encoder.eval() 247 | decoder.eval() 248 | 249 | source_var, source_lengths, target_var, target_lengths = data_from_batch(batch) 250 | 251 | current_batch_size = len(target_lengths) 252 | 253 | encoder_outputs, encoder_hidden = encoder(source_var, 254 | encoder.init_hidden(current_batch_size), 255 | source_lengths) 256 | 257 | # (1, eval_batch_size) 258 | decoder_inputs = Variable(torch.LongTensor([Tokens.SOS_token] * current_batch_size)) 259 | decoder_hidden = encoder_hidden 260 | max_target_length = max(target_lengths) 261 | all_decoder_outputs = Variable(torch.zeros(max_target_length, current_batch_size, decoder.output_size)) 262 | 263 | if use_cuda: 264 | decoder_inputs = decoder_inputs.cuda() 265 | all_decoder_outputs = all_decoder_outputs.cuda() 266 | 267 | # decoder_attentions = torch.zeros(eval_batch_size, max_target_length, max_target_length) 268 | 269 | logging.info("Decoding") 270 | 271 | for t in tqdm(range(max_target_length)): 272 | decoder_output, decoder_hidden, attn_weights = decoder(decoder_inputs, 273 | decoder_hidden, 274 | encoder_outputs) 275 | 276 | # num_time_step (will be 1), batch_size, output_vocab_size 277 | all_decoder_outputs[t] = decoder_output 278 | 279 | # num_time_step (will be 1), batch_size, 1 280 | _, top_i = decoder_output.data.topk(1) 281 | decoder_inputs = Variable(top_i).squeeze(1) 282 | 283 | format_eval_output(encoder_embedding_map, 284 | decoder_embedding_map, 285 | source_var, 286 | target_var, 287 | all_decoder_outputs) 288 | 289 | 290 | def format_eval_output(encoder_embedding_map, 291 | decoder_embedding_map, 292 | inp, 293 | target, 294 | all_decoder_outputs): 295 | inp = inp.cpu() 296 | target = target.cpu() 297 | _, batch_top_i = all_decoder_outputs.topk(1) 298 | batch_top_i = batch_top_i.cpu().squeeze(-1).transpose(1, 0).data.numpy() 299 | 300 | input_sentences = [] 301 | target_sentences = [] 302 | decoded_sentences = [] 303 | 304 | for var in inp.transpose(1, 0).data.numpy(): 305 | input_sentences.append(encoder_embedding_map.get_sentence_from_indexes(var)) 306 | for var in target.transpose(1, 0).data.numpy(): 307 | target_sentences.append(decoder_embedding_map.get_sentence_from_indexes(var)) 308 | for indexes in batch_top_i: 309 | sentence = decoder_embedding_map.get_sentence_from_indexes(indexes) 310 | trunc_sentence = [] 311 | for word in sentence: 312 | if word != "": 313 | trunc_sentence.append(word) 314 | else: 315 | break 316 | decoded_sentences.append(trunc_sentence) 317 | 318 | for source, target, decode in zip(input_sentences, target_sentences, decoded_sentences): 319 | logging.info(f"\nSource:{' '.join(source)}\nTarget:{' '.join(target)}\nDecode:{' '.join(decode)}") 320 | 321 | if __name__ == "__main__": 322 | main() 323 | -------------------------------------------------------------------------------- /utils/__pycache__/batches.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rawmarshmellows/pytorch-batch-luong-attention/768c5b809e7cd0a5dee97c91a2c534f8f3d66f69/utils/__pycache__/batches.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/embeddings.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rawmarshmellows/pytorch-batch-luong-attention/768c5b809e7cd0a5dee97c91a2c534f8f3d66f69/utils/__pycache__/embeddings.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/masked_cross_entropy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rawmarshmellows/pytorch-batch-luong-attention/768c5b809e7cd0a5dee97c91a2c534f8f3d66f69/utils/__pycache__/masked_cross_entropy.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tokens.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rawmarshmellows/pytorch-batch-luong-attention/768c5b809e7cd0a5dee97c91a2c534f8f3d66f69/utils/__pycache__/tokens.cpython-36.pyc -------------------------------------------------------------------------------- /utils/batches.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | from torch.autograd import Variable 5 | from utils.tokens import Tokens 6 | 7 | 8 | def indexes_from_sentence(embedding_map, sentence): 9 | return [embedding_map.get_index_from_word(word) 10 | for word in sentence.split(' ')] + [Tokens.EOS_token] 11 | 12 | 13 | def pad_seq(seq, max_length): 14 | seq += [Tokens.PAD_token for _ in range(max_length - len(seq))] 15 | return seq 16 | 17 | 18 | def batches(data, 19 | encoder_embedding_map, 20 | decoder_embedding_map, 21 | batch_size, 22 | bucket=False, 23 | use_cuda=False): 24 | 25 | source = data["source"] 26 | target = data["target"] 27 | 28 | # Bucket to make training faster 29 | if bucket: 30 | sorted_source_target = sorted(zip(source, target), key=lambda p: len(p[0]), reverse=True) 31 | source, target = zip(*sorted_source_target) 32 | else: 33 | shuffled_source_target = np.random.permutation(list(zip(source, target))) 34 | source, target = zip(*shuffled_source_target) 35 | 36 | n_samples = len(source) 37 | 38 | for i in range(0, n_samples, batch_size): 39 | source_seqs = [] 40 | target_seqs = [] 41 | source_batch = source[i:i+batch_size] 42 | target_batch = target[i:i+batch_size] 43 | logging.debug(f"Source batch:\n{source_batch}") 44 | logging.debug(f"Target batch:\n{target_batch}") 45 | 46 | for source_, target_ in zip(source_batch, target_batch): 47 | source_seqs.append(indexes_from_sentence(encoder_embedding_map, source_)) 48 | target_seqs.append(indexes_from_sentence(decoder_embedding_map, target_)) 49 | 50 | seq_pairs = sorted(zip(source_seqs, target_seqs), key=lambda p: len(p[0]), reverse=True) 51 | source_seqs, target_seqs = zip(*seq_pairs) 52 | 53 | source_lengths = [len(s) for s in source_seqs] 54 | source_padded = [pad_seq(seq, max(source_lengths)) for seq in source_seqs] 55 | target_lengths = [len(t) for t in target_seqs] 56 | target_padded = [pad_seq(seq, max(target_lengths)) for seq in target_seqs] 57 | 58 | source_var = Variable(torch.LongTensor(source_padded)).transpose(0, 1) 59 | target_var = Variable(torch.LongTensor(target_padded)).transpose(0, 1) 60 | 61 | if use_cuda: 62 | source_var = source_var.cuda() 63 | target_var = target_var.cuda() 64 | 65 | yield {"source_var": source_var, 66 | "source_lengths": source_lengths, 67 | "target_var": target_var, 68 | "target_lengths": target_lengths} 69 | 70 | def data_from_batch(batch): 71 | 72 | # max input length, batch size 73 | source_var = batch["source_var"] 74 | 75 | # batch size 76 | source_lengths = batch["source_lengths"] 77 | 78 | # max target length, batch size 79 | target_var = batch["target_var"] 80 | 81 | # batch size 82 | target_lengths = batch["target_lengths"] 83 | 84 | return source_var, source_lengths, target_var, target_lengths 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /utils/embeddings.py: -------------------------------------------------------------------------------- 1 | from utils.tokens import Tokens 2 | from tqdm import tqdm 3 | import torch.nn as nn 4 | import logging 5 | 6 | 7 | class EmbeddingMap(object): 8 | def __init__(self): 9 | self.word2index = {"": Tokens.PAD_token, 10 | "": Tokens.UNK_token, 11 | "": Tokens.SOS_token, 12 | "": Tokens.EOS_token} 13 | self.word2count = {} 14 | self.index2word = {Tokens.PAD_token: "", 15 | Tokens.UNK_token: "", 16 | Tokens.SOS_token: "", 17 | Tokens.EOS_token: ""} 18 | self.n_words = 4 19 | 20 | def index_words(self, sentence): 21 | for word in sentence.split(): 22 | self.index_word(word) 23 | 24 | def index_word(self, word): 25 | if word not in self.word2index: 26 | self.word2index[word] = self.n_words 27 | self.word2count[word] = 1 28 | self.index2word[self.n_words] = word 29 | self.n_words += 1 30 | else: 31 | self.word2count[word] += 1 32 | 33 | def get_index_from_word(self, word): 34 | if word not in self.word2index: 35 | return self.word2index[""] 36 | else: 37 | return self.word2index[word] 38 | 39 | def get_indexes_from_sentences(self, sentence): 40 | indexes = [] 41 | for word in sentence: 42 | indexes.append(self.get_index_from_word(word)) 43 | return indexes 44 | 45 | def get_sentence_from_indexes(self, indexes): 46 | words = [] 47 | for index in indexes: 48 | words.append(self.index2word[index]) 49 | return words 50 | 51 | 52 | def create_embedding_map(data): 53 | embedding_map = EmbeddingMap() 54 | for d in tqdm(data): 55 | for row in tqdm(d): 56 | embedding_map.index_words(row) 57 | return embedding_map 58 | 59 | 60 | def create_embedding_maps(train, val, hidden_size, different_vocab=False): 61 | train_source = train["source"].tolist() 62 | train_target = train["target"].tolist() 63 | val_source = val["source"].tolist() 64 | val_target = val["target"].tolist() 65 | 66 | if different_vocab: 67 | encoder_embedding_map = create_embedding_map([train_source, val_source]) 68 | decoder_embedding_map = create_embedding_map([train_target, val_target]) 69 | encoder_embedding_matrix = nn.Embedding(encoder_embedding_map.n_words, hidden_size) 70 | decoder_embedding_matrix = nn.Embedding(decoder_embedding_map.n_words, hidden_size) 71 | return encoder_embedding_map, decoder_embedding_map, \ 72 | encoder_embedding_matrix, decoder_embedding_matrix 73 | else: 74 | embedding_map = create_embedding_map([train_source, 75 | train_target, 76 | val_source, 77 | val_target]) 78 | embedding_matrix = nn.Embedding(embedding_map.n_words, hidden_size) 79 | return embedding_map, embedding_map, embedding_matrix, embedding_matrix 80 | -------------------------------------------------------------------------------- /utils/load_and_preprocessing/translation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from tqdm import tqdm 3 | from os.path import join as pjoin 4 | import re 5 | import unicodedata 6 | 7 | eng_prefixes = ( 8 | "i am ", "i m ", 9 | "he is", "he s ", 10 | "she is", "she s", 11 | "you are", "you re ", 12 | "we are", "we re ", 13 | "they are", "they re " 14 | ) 15 | 16 | 17 | def load_data(train_dir, debug_restrict_data, max_length=10, reverse=True): 18 | train = load_and_preprocess_file(pjoin(train_dir, "train.csv"), debug_restrict_data, max_length, reverse) 19 | val = load_and_preprocess_file(pjoin(train_dir, "val.csv"), debug_restrict_data, max_length, reverse) 20 | return train, val 21 | 22 | 23 | def load_and_preprocess_file(fpath, debug_restrict_data, max_length, reverse): 24 | data = pd.read_csv(fpath) 25 | if debug_restrict_data is not None: 26 | pairs = list(zip(data["source"][:debug_restrict_data], data["target"][:debug_restrict_data])) 27 | else: 28 | pairs = list(zip(data["source"], data["target"])) 29 | 30 | pairs = [[normalize_string(p[0]), normalize_string(p[1])] for p in pairs] 31 | pairs = filter_pairs(pairs, max_length) 32 | 33 | output_data = [] 34 | for source, target in tqdm(pairs): 35 | if reverse: 36 | output_data.append({"source": target, "target": source}) 37 | else: 38 | output_data.append({"source": source, "target": target}) 39 | return pd.DataFrame(output_data) 40 | 41 | 42 | def unicode_to_ascii(s): 43 | return ''.join( 44 | c for c in unicodedata.normalize('NFD', s) 45 | if unicodedata.category(c) != 'Mn' 46 | ) 47 | 48 | 49 | def normalize_string(s): 50 | s = unicode_to_ascii(s.lower().strip()) 51 | s = re.sub(r"([.!?])", r" \1", s) 52 | s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) 53 | return s 54 | 55 | 56 | def filter_pair(p, max_length): 57 | return len(p[0].split(' ')) < max_length and \ 58 | len(p[1].split(' ')) < max_length and \ 59 | p[0].startswith(eng_prefixes) 60 | 61 | 62 | def filter_pairs(pairs, max_length): 63 | return [pair for pair in pairs if filter_pair(pair, max_length)] 64 | 65 | -------------------------------------------------------------------------------- /utils/masked_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | def sequence_mask(sequence_length, max_len=None, use_cuda=False): 8 | if isinstance(sequence_length, np.ndarray): 9 | sequence_length = Variable(torch.from_numpy(sequence_length)) 10 | elif isinstance(sequence_length, list): 11 | sequence_length = Variable(torch.from_numpy(np.array(sequence_length))) 12 | 13 | if max_len is None: 14 | max_len = sequence_length.data.max() 15 | 16 | batch_size = sequence_length.size(0) 17 | seq_range = torch.arange(0, max_len).long() 18 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 19 | seq_range_expand = Variable(seq_range_expand) 20 | if sequence_length.is_cuda: 21 | seq_range_expand = seq_range_expand.cuda() 22 | seq_length_expand = (sequence_length.unsqueeze(1) 23 | .expand_as(seq_range_expand)) 24 | mask = seq_range_expand < seq_length_expand 25 | if use_cuda: 26 | mask = mask.cuda() 27 | return mask 28 | 29 | 30 | def masked_cross_entropy(logits, target, length, use_cuda=False): 31 | 32 | length = Variable(torch.LongTensor(length)) 33 | 34 | if use_cuda: 35 | length = length.cuda() 36 | 37 | """ 38 | Args: 39 | logits: A Variable containing a FloatTensor of size 40 | (batch, max_len, num_classes) which contains the 41 | unnormalized probability for each class. 42 | target: A Variable containing a LongTensor of size 43 | (batch, max_len) which contains the index of the true 44 | class for each corresponding step. 45 | length: A Variable containing a LongTensor of size (batch,) 46 | which contains the length of each data in a batch. 47 | Returns: 48 | loss: An average loss value masked by the length. 49 | """ 50 | 51 | # logits_flat: (batch * max_len, num_classes) 52 | logits_flat = logits.view(-1, logits.size(-1)) 53 | # log_probs_flat: (batch * max_len, num_classes) 54 | log_probs_flat = F.log_softmax(logits_flat) 55 | # target_flat: (batch * max_len, 1) 56 | target_flat = target.view(-1, 1) 57 | # losses_flat: (batch * max_len, 1) 58 | losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) 59 | # losses: (batch, max_len) 60 | losses = losses_flat.view(*target.size()) 61 | # mask: (batch, max_len) 62 | mask = sequence_mask(sequence_length=length, max_len=target.size(1), use_cuda=use_cuda) 63 | losses = losses * mask.float() 64 | loss = losses.sum() / length.float().sum() 65 | return loss 66 | -------------------------------------------------------------------------------- /utils/tokens.py: -------------------------------------------------------------------------------- 1 | class Tokens: 2 | PAD_token = 0 3 | UNK_token = 1 4 | SOS_token = 2 5 | EOS_token = 3 --------------------------------------------------------------------------------