├── CharRNN.py ├── LICENSE ├── README.md ├── data ├── shakespeare.txt └── sherlock.txt ├── preTrained ├── CharRNN_shakespeare.pth └── CharRNN_sherlock.pth └── test.py /CharRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Categorical 5 | import numpy as np 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | class RNN(nn.Module): 10 | def __init__(self, input_size, output_size, hidden_size, num_layers): 11 | super(RNN, self).__init__() 12 | self.embedding = nn.Embedding(input_size, input_size) 13 | self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) 14 | self.decoder = nn.Linear(hidden_size, output_size) 15 | 16 | def forward(self, input_seq, hidden_state): 17 | embedding = self.embedding(input_seq) 18 | output, hidden_state = self.rnn(embedding, hidden_state) 19 | output = self.decoder(output) 20 | return output, (hidden_state[0].detach(), hidden_state[1].detach()) 21 | 22 | def train(): 23 | ########### Hyperparameters ########### 24 | hidden_size = 512 # size of hidden state 25 | seq_len = 100 # length of LSTM sequence 26 | num_layers = 3 # num of layers in LSTM layer stack 27 | lr = 0.002 # learning rate 28 | epochs = 100 # max number of epochs 29 | op_seq_len = 200 # total num of characters in output test sequence 30 | load_chk = False # load weights from save_path directory to continue training 31 | save_path = "./preTrained/CharRNN_shakespeare.pth" 32 | data_path = "./data/shakespeare.txt" 33 | ####################################### 34 | 35 | # load the text file 36 | data = open(data_path, 'r').read() 37 | chars = sorted(list(set(data))) 38 | data_size, vocab_size = len(data), len(chars) 39 | print("----------------------------------------") 40 | print("Data has {} characters, {} unique".format(data_size, vocab_size)) 41 | print("----------------------------------------") 42 | 43 | # char to index and index to char maps 44 | char_to_ix = { ch:i for i,ch in enumerate(chars) } 45 | ix_to_char = { i:ch for i,ch in enumerate(chars) } 46 | 47 | # convert data from chars to indices 48 | data = list(data) 49 | for i, ch in enumerate(data): 50 | data[i] = char_to_ix[ch] 51 | 52 | # data tensor on device 53 | data = torch.tensor(data).to(device) 54 | data = torch.unsqueeze(data, dim=1) 55 | 56 | # model instance 57 | rnn = RNN(vocab_size, vocab_size, hidden_size, num_layers).to(device) 58 | 59 | # load checkpoint if True 60 | if load_chk: 61 | rnn.load_state_dict(torch.load(save_path)) 62 | print("Model loaded successfully !!") 63 | print("----------------------------------------") 64 | 65 | # loss function and optimizer 66 | loss_fn = nn.CrossEntropyLoss() 67 | optimizer = torch.optim.Adam(rnn.parameters(), lr=lr) 68 | 69 | # training loop 70 | for i_epoch in range(1, epochs+1): 71 | 72 | # random starting point (1st 100 chars) from data to begin 73 | data_ptr = np.random.randint(100) 74 | n = 0 75 | running_loss = 0 76 | hidden_state = None 77 | 78 | while True: 79 | input_seq = data[data_ptr : data_ptr+seq_len] 80 | target_seq = data[data_ptr+1 : data_ptr+seq_len+1] 81 | 82 | # forward pass 83 | output, hidden_state = rnn(input_seq, hidden_state) 84 | 85 | # compute loss 86 | loss = loss_fn(torch.squeeze(output), torch.squeeze(target_seq)) 87 | running_loss += loss.item() 88 | 89 | # compute gradients and take optimizer step 90 | optimizer.zero_grad() 91 | loss.backward() 92 | optimizer.step() 93 | 94 | # update the data pointer 95 | data_ptr += seq_len 96 | n +=1 97 | 98 | # if at end of data : break 99 | if data_ptr + seq_len + 1 > data_size: 100 | break 101 | 102 | # print loss and save weights after every epoch 103 | print("Epoch: {0} \t Loss: {1:.8f}".format(i_epoch, running_loss/n)) 104 | torch.save(rnn.state_dict(), save_path) 105 | 106 | # sample / generate a text sequence after every epoch 107 | data_ptr = 0 108 | hidden_state = None 109 | 110 | # random character from data to begin 111 | rand_index = np.random.randint(data_size-1) 112 | input_seq = data[rand_index : rand_index+1] 113 | 114 | print("----------------------------------------") 115 | while True: 116 | # forward pass 117 | output, hidden_state = rnn(input_seq, hidden_state) 118 | 119 | # construct categorical distribution and sample a character 120 | output = F.softmax(torch.squeeze(output), dim=0) 121 | dist = Categorical(output) 122 | index = dist.sample() 123 | 124 | # print the sampled character 125 | print(ix_to_char[index.item()], end='') 126 | 127 | # next input is current output 128 | input_seq[0][0] = index.item() 129 | data_ptr += 1 130 | 131 | if data_ptr > op_seq_len: 132 | break 133 | 134 | print("\n----------------------------------------") 135 | 136 | if __name__ == '__main__': 137 | train() 138 | 139 | 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Nikhil Barhate 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Char RNN PyTorch 2 | 3 | Minimalist code for character-level language modelling using Multi-layer Recurrent Neural Networks (LSTM) in PyTorch. The RNN is trained to predict next letter in a given text sequence. The trained model can then be used to generate a new text sequence resembling the original data. 4 | 5 | ## Requirements 6 | 7 | Trained and tested on: 8 | 9 | - `Python 3.6` 10 | - `PyTorch 1.0` 11 | - `NumPy 1.16.3` 12 | 13 | ## Usage 14 | 15 | ### Training 16 | To train a new network run `CharRNN.py`. If you are using custom data, change the `data_path` and `save_path` variables accordingly. To keep the code simple the batch size is one, so the training procedure is a bit slow. The average loss and a sample from the model is printed after every epoch. 17 | 18 | ### Testing 19 | To test a preTrained network (~15 epochs) run `test.py`. The training dataset is required for testing, to create vocabulary dictionary, and also for sampling a random small (10 letters) text sequence to begin generation. 20 | 21 | ## Samples 22 | 23 | **Shakespeare Dataset (~ 15 epochs) :** 24 | ``` 25 | DOCSER: 26 | What, will thy fond law? 27 | or that in all the chains that livinar? 28 | 29 | KING HENRY V: 30 | Come, come, I should our name answer'd for two mans 31 | To deafly upbrain, and broke him so our 32 | Master Athital. Mark ye, I say! 33 | 34 | B-CANSSIO: 35 | Come, let us die. 36 | 37 | Hostes: 38 | This was my prince of holy empress, 39 | That shalt thou save you in it with brave cap of heaven. 40 | Or is the digest and praud with their closets save of faitral'? 41 | 42 | KING HENRY V: 43 | Your treason follow Ncpius, Dout &ystermans' clent, 44 | On the pity can, when tell them 45 | Freely from direen prisoners town; and let us 46 | know the man of all. 47 | 48 | FLUELLEN: 49 | Go tell you. 50 | ``` 51 | 52 | ----------------------------------------------------------------- 53 | 54 | **Sherlock Holmes Dataset (~ 15 epochs) :** 55 | ``` 56 | Mr. Holmes had drawn up and again so brick, at west who closed upon 57 | the loud broken pallow and a cabmon ta the chair that we had fired 58 | out. 59 | 60 | "I wished in," said Holmes sobbily, "trust in the light. I said that you 61 | have to do with Gardens, come, you will pass you 62 | light, so you print?" 63 | 64 | "We are it is impossible." 65 | 66 | "I know that so submer a case here did he give you after I 67 | tell you?" 68 | 69 | "Ah, sir, I keep them, Watson," I said a tueler 70 | inspectoruded upon either way. "Home!" said Admirable 71 | Street. "But not considered a memory, which it was to complice him." 72 | 73 | I had so vallemed found me about this gloomy men. 74 | ``` 75 | 76 | 77 | ## Acknowledgements 78 | This code is based on the [char-rnn](https://github.com/karpathy/char-rnn) and [min-char-rnn](https://gist.github.com/karpathy/d4dee566867f8291f086) code by Andrej Karpathy, which is in turn based on Oxford University Machine Learning class [practical 6](https://github.com/oxford-cs-ml-2015/practical6), which is in turn based on [learning to execute](https://github.com/wojciechz/learning_to_execute) code from Wojciech Zaremba. 79 | 80 | -------------------------------------------------------------------------------- /preTrained/CharRNN_shakespeare.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Char-RNN-PyTorch/2cf836f236542fe4bc92593582dd66219bbceba4/preTrained/CharRNN_shakespeare.pth -------------------------------------------------------------------------------- /preTrained/CharRNN_sherlock.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Char-RNN-PyTorch/2cf836f236542fe4bc92593582dd66219bbceba4/preTrained/CharRNN_sherlock.pth -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.distributions import Categorical 4 | import numpy as np 5 | from CharRNN import RNN 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | def test(): 10 | ############ Hyperparameters ############ 11 | hidden_size = 512 # size of hidden state 12 | num_layers = 3 # num of layers in LSTM layer stack 13 | op_seq_len = 1000 # total num of characters in output test sequence 14 | 15 | load_path = "./preTrained/CharRNN_shakespeare.pth" 16 | data_path = "./data/shakespeare.txt" 17 | 18 | # load_path = "./preTrained/CharRNN_sherlock.pth" 19 | # data_path = "./data/sherlock.txt" 20 | ######################################### 21 | 22 | # load the text file 23 | data = open(data_path, 'r').read() 24 | chars = sorted(list(set(data))) 25 | data_size, vocab_size = len(data), len(chars) 26 | print("----------------------------------------") 27 | print("Data has {} characters, {} unique".format(data_size, vocab_size)) 28 | print("----------------------------------------") 29 | 30 | # char to index and idex to char maps 31 | char_to_ix = { ch:i for i,ch in enumerate(chars) } 32 | ix_to_char = { i:ch for i,ch in enumerate(chars) } 33 | 34 | # convert data from chars to indices 35 | data = list(data) 36 | for i, ch in enumerate(data): 37 | data[i] = char_to_ix[ch] 38 | 39 | # data tensor on device 40 | data = torch.tensor(data).to(device) 41 | data = torch.unsqueeze(data, dim=1) 42 | 43 | # create and load model instance 44 | rnn = RNN(vocab_size, vocab_size, hidden_size, num_layers).to(device) 45 | rnn.load_state_dict(torch.load(load_path)) 46 | print("Model loaded successfully !!") 47 | 48 | # initialize variables 49 | data_ptr = 0 50 | hidden_state = None 51 | 52 | # randomly select an initial string from the data 53 | rand_index = np.random.randint(data_size - 11) 54 | input_seq = data[rand_index : rand_index + 9] 55 | 56 | # compute last hidden state of the sequence 57 | _, hidden_state = rnn(input_seq, hidden_state) 58 | 59 | # next element is the input to rnn 60 | input_seq = data[rand_index + 9 : rand_index + 10] 61 | 62 | # generate remaining sequence 63 | print("----------------------------------------") 64 | while True: 65 | # forward pass 66 | output, hidden_state = rnn(input_seq, hidden_state) 67 | 68 | # construct categorical distribution and sample a character 69 | output = F.softmax(torch.squeeze(output), dim=0) 70 | dist = Categorical(output) 71 | index = dist.sample().item() 72 | 73 | # print the sampled character 74 | print(ix_to_char[index], end='') 75 | 76 | # next input is current output 77 | input_seq[0][0] = index 78 | data_ptr += 1 79 | 80 | if data_ptr > op_seq_len: 81 | break 82 | 83 | print("\n----------------------------------------") 84 | 85 | if __name__ == '__main__': 86 | test() 87 | 88 | --------------------------------------------------------------------------------