├── .gitignore ├── img ├── output.png └── output1.png ├── requirements.txt ├── docs └── Machine_translation.pdf ├── device.py ├── hyper_parameters.py ├── translate.py ├── README.md ├── data_processing.py ├── transformer.py ├── train.py ├── data_preparation.ipynb └── translation.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .vscode 3 | .__pycache__ 4 | ./Arab-Acquis -------------------------------------------------------------------------------- /img/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strifee/arabic2english/HEAD/img/output.png -------------------------------------------------------------------------------- /img/output1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strifee/arabic2english/HEAD/img/output1.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6.0 2 | torchtext==0.10.0 3 | spacy 4 | transformers 5 | nltk 6 | pandas 7 | -------------------------------------------------------------------------------- /docs/Machine_translation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strifee/arabic2english/HEAD/docs/Machine_translation.pdf -------------------------------------------------------------------------------- /device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | print(torch.cuda.get_device_name(device)) 9 | -------------------------------------------------------------------------------- /hyper_parameters.py: -------------------------------------------------------------------------------- 1 | from train import BATCH_SIZE 2 | 3 | 4 | class Hyperparam: 5 | """Hyper parameters""" 6 | # Training 7 | BATCH_SIZE = 16 8 | learning_rate = 0.0001 9 | num_epochs = 30 10 | 11 | # Model 12 | num_heads = 8 13 | num_encoder_layers = 3 14 | num_decoder_layers = 3 15 | 16 | max_len= 230 17 | dropout = 0.4 18 | embedding_size= 256 19 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | from data_processing import SRC, TRG, engTokenizer 2 | import torch 3 | import device 4 | from train import model 5 | 6 | def translate_sentence(model,sentence,srcField,targetField,srcTokenizer): 7 | model.eval() 8 | processed_sentence = srcField.process([srcTokenizer(sentence)]).to(device) 9 | trg = ["بداية"] 10 | 11 | for _ in range(60): 12 | trg_indecies = [targetField.vocab.stoi[word] for word in trg] 13 | trg_tensor = torch.LongTensor(trg_indecies).unsqueeze(1).to(device) 14 | outputs = model(processed_sentence,trg_tensor) 15 | 16 | if targetField.vocab.itos[outputs.argmax(2)[-1:].item()] == "": 17 | continue 18 | trg.append(targetField.vocab.itos[outputs.argmax(2)[-1:].item()]) 19 | if targetField.vocab.itos[outputs.argmax(2)[-1:].item()] == "نهاية": 20 | break 21 | return " ".join([word for word in trg if word != ""][1:-1]) 22 | 23 | 24 | if __name__ == '__main__': 25 | print("I'm home -> {}",translate_sentence(model,"I'm at home" ,SRC,TRG,engTokenizer)) 26 | print("I'm alone -> {}",translate_sentence(model,"I'm alone" ,SRC,TRG,engTokenizer)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Arabic2English - Arabic to English Translator 3 | 4 | **This is a PyTorch implementation of an Arabic to English Neural Machine Translation built using Transformers architecture ([Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf))** 5 | 6 | 7 | # Setup and Requirements 8 | **1. CUDA:** 9 |
10 | install [CUDA](https://developer.nvidia.com/cuda-downloads) before installing the required packages or check if it is already installed 11 |
12 |
13 | **2. Clone the Translate repo:** 14 | ``` 15 | $ git clone https://github.com/Strifee/arabic2english.git 16 | ``` 17 | **3. install requirements:** 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | `if you have problem with CUDA package try this:` 22 | ``` 23 | conda install -q pytorch torchvision cudatoolkit=11 -c pytorch-nightly 24 | ``` 25 | 26 | # Data 27 | 28 | **Arabic to English Translation Sentences :** 29 | 30 | [Arabic to English Translation Sentences](https://www.kaggle.com/samirmoustafa/arabic-to-english-translation-sentences) is a dataset for machine translation between English and Arabic. 31 | 32 | # Training 33 | 34 | **1. Clone the Translate repo:** 35 | ``` 36 | $ git clone clone https://github.com/Strifee/arabic2english.git 37 | ``` 38 | **2. Training** 39 | ``` 40 | $ python translate.py 41 | ``` 42 | **2. Regularization** 43 | ### Hyperparameters : 44 | ```python 45 | BATCH_SIZE = 16 46 | learning_rate = 0.0001 47 | num_epochs = 30 48 | 49 | num_heads = 8 50 | num_encoder_layers = 3 51 | num_decoder_layers = 3 52 | 53 | max_len= 230 54 | dropout = 0.4 55 | embedding_size= 256 56 | ``` 57 | ### Before regularization : 58 | ![image](img/output1.png) 59 |
60 | 61 | ### After regularization : 62 | ![image](img/output.png) 63 |
64 | 65 | # Results 66 | ``` 67 | "I'm ready" -> 'أنا مستعد' 68 | "i'm lucky" -> 'انا محظوظ' 69 | "I'm sad" -> 'أنا حزين' 70 | 71 | ``` 72 | 73 | -------------------------------------------------------------------------------- /data_processing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch.nn as nn 3 | import random 4 | import re 5 | import spacy 6 | from torchtext.legacy import data 7 | from spacy.tokenizer import Tokenizer 8 | from spacy.lang.ar import Arabic 9 | 10 | 11 | 12 | random.seed(0) 13 | df = pd.read_csv("data/arabic_english.txt",delimiter="\t",names=["eng","ar"]) 14 | 15 | ''' 16 | First : 17 | python -m spacy download en_core_web_sm 18 | ''' 19 | spacy_eng = spacy.load("en_core_web_sm") 20 | 21 | ar = Arabic() 22 | ar_Tokenizer = Tokenizer(ar.vocab) 23 | 24 | def engTokenizer(text): 25 | return [word.text for word in spacy_eng.tokenizer(text)] 26 | 27 | def arTokenizer(sentence): 28 | return [word.text for word in 29 | ar_Tokenizer(re.sub(r"\s+"," ",re.sub(r"[\.\'\"\n+]"," ",sentence)).strip())] 30 | 31 | SRC = data.Field(tokenize=engTokenizer,batch_first=False,init_token="",eos_token="") 32 | TRG = data.Field(tokenize=arTokenizer,batch_first=False,tokenizer_language="ar",init_token="ببدأ",eos_token="نهها") 33 | 34 | class TextDataset(data.Dataset): 35 | 36 | def __init__(self, df, src_field, target_field, is_test=False, **kwargs): 37 | fields = [('eng', src_field), ('ar',target_field)] 38 | samples = [] 39 | for i, row in df.iterrows(): 40 | eng = row.eng 41 | ar = row.ar 42 | samples.append(data.Example.fromlist([eng, ar], fields)) 43 | 44 | super().__init__(samples, fields, **kwargs) 45 | 46 | def __len__(self): 47 | return len(self.samples) 48 | 49 | def __getitem__(self, idx): 50 | return self.samples[idx] 51 | 52 | torchdataset = TextDataset(df,SRC,TRG) 53 | 54 | train_data, valid_data = torchdataset.split(split_ratio=0.8, random_state = random.seed(0)) 55 | 56 | SRC.build_vocab(train_data,min_freq=2) 57 | TRG.build_vocab(train_data,min_freq=2) 58 | 59 | if __name__=='__main__': 60 | print(TRG.vocab.freqs.most_common(50)) 61 | 62 | -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from device import device 4 | 5 | class Transformer(nn.Module): 6 | def __init__( 7 | self, 8 | embedding_size, 9 | src_vocab_size, 10 | trg_vocab_size, 11 | src_pad_idx, 12 | num_heads, 13 | num_encoder_layers, 14 | num_decoder_layers, 15 | forward_expansion, 16 | dropout, 17 | max_len, 18 | device, 19 | ): 20 | super(Transformer, self).__init__() 21 | self.src_embeddings = nn.Embedding(src_vocab_size,embedding_size) 22 | self.src_positional_embeddings= nn.Embedding(max_len,embedding_size) 23 | self.trg_embeddings= nn.Embedding(trg_vocab_size,embedding_size) 24 | self.trg_positional_embeddings= nn.Embedding(max_len,embedding_size) 25 | self.device = device 26 | self.transformer = nn.Transformer( 27 | embedding_size, 28 | num_heads, 29 | num_encoder_layers, 30 | num_decoder_layers, 31 | forward_expansion, 32 | dropout, 33 | ) 34 | 35 | self.fc_out = nn.Linear(embedding_size, trg_vocab_size) 36 | self.dropout = nn.Dropout(dropout) 37 | self.src_pad_idx = src_pad_idx 38 | 39 | def make_src_mask(self, src): 40 | src_mask = src.transpose(0,1) == self.src_pad_idx 41 | 42 | return src_mask 43 | 44 | def forward(self,src,trg): 45 | src_seq_length, S = src.shape 46 | trg_seq_length, S = trg.shape 47 | #adding zeros is an easy way 48 | src_positions = ( 49 | torch.arange(0, src_seq_length).unsqueeze(1).expand(src_seq_length, S).to(self.device) 50 | ) 51 | 52 | 53 | trg_positions = ( 54 | torch.arange(0, trg_seq_length).unsqueeze(1).expand(trg_seq_length, S).to(self.device) 55 | ) 56 | 57 | embed_src = self.dropout( 58 | ( self.src_embeddings(src) + self.src_positional_embeddings(src_positions) ) 59 | ) 60 | 61 | embed_trg = self.dropout( 62 | ( self.trg_embeddings(trg) + self.trg_positional_embeddings(trg_positions) ) 63 | ) 64 | 65 | src_padding_mask = self.make_src_mask(src) 66 | trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(device) 67 | 68 | 69 | out = self.transformer(embed_src,embed_trg, src_key_padding_mask=src_padding_mask,tgt_mask=trg_mask ) 70 | out= self.fc_out(out) 71 | 72 | return out 73 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import optim 4 | from torch import nn 5 | from torchtext.legacy import data 6 | from data_processing import SRC,TRG 7 | from transformer import Transformer 8 | from device import device 9 | from data_processing import train_data, valid_data 10 | from hyper_parameters import BATCH_SIZE, embedding_size, src_pad_idx, num_heads, num_encoder_layers, num_decoder_layers, forward_expansion, dropout, max_len, learning_rate, num_epochs 11 | 12 | train_iter, valid_iter = data.BucketIterator.splits( 13 | (train_data,valid_data), 14 | batch_size = BATCH_SIZE, 15 | sort=None, 16 | sort_within_batch=False, 17 | sort_key=lambda x: len(x.eng), 18 | device = device, 19 | shuffle=True 20 | ) 21 | 22 | src_vocab_size = len(SRC.vocab) 23 | print("Size of english vocabulary:",src_vocab_size) 24 | 25 | #No. of unique tokens in label 26 | trg_vocab_size =len(TRG.vocab) 27 | print("Size of arabic vocabulary:",trg_vocab_size) 28 | 29 | 30 | model = Transformer( 31 | embedding_size, 32 | src_vocab_size, 33 | trg_vocab_size, 34 | src_pad_idx, 35 | num_heads, 36 | num_encoder_layers, 37 | num_decoder_layers, 38 | forward_expansion, 39 | dropout, 40 | max_len, 41 | device, 42 | ).to(device) 43 | 44 | loss_track = [] 45 | loss_validation_track= [] 46 | 47 | 48 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 49 | 50 | pad_idx = SRC.vocab.stoi[""] 51 | criterion = nn.CrossEntropyLoss(ignore_index = pad_idx) 52 | for epoch in range(num_epochs): 53 | stepLoss=[] 54 | model.train() 55 | for batch in train_iter: 56 | input_data = batch.eng.to(device) 57 | target = batch.ar.to(device) 58 | 59 | output = model(input_data,target[:-1]) 60 | optimizer.zero_grad() 61 | 62 | output = output.reshape(-1,trg_vocab_size) 63 | target = target[1:].reshape(-1) 64 | 65 | loss = criterion(output,target) 66 | loss.backward() 67 | 68 | optimizer.step() 69 | stepLoss.append(loss.item()) 70 | 71 | loss_track.append(np.mean(stepLoss)) 72 | print(" Epoch {} | Train Cross Entropy Loss: ".format(epoch),np.mean(stepLoss)) 73 | with torch.no_grad(): 74 | stepValidLoss=[] 75 | model.eval() # the evaluation mode for the model (doesn't apply dropout and batchNorm) 76 | for i,batch in enumerate(valid_iter): 77 | input_sentence = batch.eng.to(device) 78 | target = batch.ar.to(device) 79 | optimizer.zero_grad() 80 | output = model(input_sentence,target[:-1]) 81 | output = output.reshape(-1,trg_vocab_size) 82 | target = target[1:].reshape(-1) 83 | loss = criterion(output,target) 84 | 85 | stepValidLoss.append(loss.item()) 86 | 87 | loss_validation_track.append(np.mean(stepValidLoss)) 88 | print(" Epoch {} | Validation Cross Entropy Loss: ".format(epoch),np.mean(stepValidLoss)) -------------------------------------------------------------------------------- /data_preparation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 24, 6 | "source": [ 7 | "import string\r\n", 8 | "import pandas as pd\r\n", 9 | "import regex as re\r\n", 10 | "import nltk\r\n", 11 | "from unicodedata import normalize\r\n", 12 | "from pickle import load\r\n", 13 | "from pickle import dump\r\n", 14 | "from collections import Counter" 15 | ], 16 | "outputs": [], 17 | "metadata": {} 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 25, 22 | "source": [ 23 | "# load doc into memory\r\n", 24 | "def load_doc(filename):\r\n", 25 | "\t# open the file as read only\r\n", 26 | "\tfile = open(filename, mode='rt', encoding='utf-8')\r\n", 27 | "\t# read all text\r\n", 28 | "\ttext = file.read()\r\n", 29 | "\t# close the file\r\n", 30 | "\tfile.close()\r\n", 31 | "\treturn text" 32 | ], 33 | "outputs": [], 34 | "metadata": {} 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 26, 39 | "source": [ 40 | "# split a loaded document into sentences\r\n", 41 | "def to_sentences(doc):\r\n", 42 | "\treturn doc.strip().split('\\n')" 43 | ], 44 | "outputs": [], 45 | "metadata": {} 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 27, 50 | "source": [ 51 | "# shortest and longest sentence lengths\r\n", 52 | "def sentence_lengths(sentences):\r\n", 53 | "\tlengths = [len(s.split()) for s in sentences]\r\n", 54 | "\treturn min(lengths), max(lengths)" 55 | ], 56 | "outputs": [], 57 | "metadata": {} 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 28, 62 | "source": [ 63 | "# load doc into memory\r\n", 64 | "def load_doc(filename):\r\n", 65 | "\t# open the file as read only\r\n", 66 | "\tfile = open(filename, mode='rt', encoding='utf-8')\r\n", 67 | "\t# read all text\r\n", 68 | "\ttext = file.read()\r\n", 69 | "\t# close the file\r\n", 70 | "\tfile.close()\r\n", 71 | "\treturn text\r\n", 72 | "\r\n", 73 | "# split a loaded document into sentences\r\n", 74 | "def to_sentences(doc):\r\n", 75 | "\treturn doc.strip().split('\\n')\r\n", 76 | "\r\n", 77 | "# shortest and longest sentence lengths\r\n", 78 | "def sentence_lengths(sentences):\r\n", 79 | "\tlengths = [len(s.split()) for s in sentences]\r\n", 80 | "\treturn min(lengths), max(lengths)" 81 | ], 82 | "outputs": [], 83 | "metadata": {} 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 29, 88 | "source": [ 89 | "# load 1st Arabic data\r\n", 90 | "arabic_filename = 'data/Arab-Acquis/Arabic-Translations/test.en_ref.ar'\r\n", 91 | "arabic_doc = load_doc(arabic_filename)\r\n", 92 | "arabic_sentences = to_sentences(arabic_doc)\r\n", 93 | "minlen1, maxlen1 = sentence_lengths(arabic_sentences)\r\n", 94 | "print('Arabic data: sentences=%d, min=%d, max=%d' % (len(arabic_sentences), minlen1, maxlen1))\r\n", 95 | "\r\n", 96 | "# load 1st English data\r\n", 97 | "english_filename = 'data/Arab-Acquis/JRC-ACQUIS/ac-test.en'\r\n", 98 | "english_doc = load_doc(english_filename)\r\n", 99 | "english_sentences = to_sentences(english_doc)\r\n", 100 | "minlen1, maxlen1 = sentence_lengths(english_sentences)\r\n", 101 | "print('English data: sentences=%d, min=%d, max=%d' % (len(english_sentences), minlen1, maxlen1))" 102 | ], 103 | "outputs": [ 104 | { 105 | "output_type": "stream", 106 | "name": "stdout", 107 | "text": [ 108 | "Arabic data: sentences=4107, min=1, max=246\n", 109 | "English data: sentences=4107, min=1, max=269\n" 110 | ] 111 | } 112 | ], 113 | "metadata": {} 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 30, 118 | "source": [ 119 | "arabic_sentences[0]" 120 | ], 121 | "outputs": [ 122 | { 123 | "output_type": "execute_result", 124 | "data": { 125 | "text/plain": [ 126 | "'مجلس الجماعة الاقتصادية الأوروبية'" 127 | ] 128 | }, 129 | "metadata": {}, 130 | "execution_count": 30 131 | } 132 | ], 133 | "metadata": {} 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 31, 138 | "source": [ 139 | "english_sentences[0]" 140 | ], 141 | "outputs": [ 142 | { 143 | "output_type": "execute_result", 144 | "data": { 145 | "text/plain": [ 146 | "'THE COUNCIL OF THE EUROPEAN ECONOMIC COMMUNITY,'" 147 | ] 148 | }, 149 | "metadata": {}, 150 | "execution_count": 31 151 | } 152 | ], 153 | "metadata": {} 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 32, 158 | "source": [ 159 | "import pandas as pd\r\n", 160 | "df1 = pd.read_csv(\"data/arabic_english.txt\",delimiter=\"\\t\",names=[\"eng\",\"ar\"])\r\n", 161 | "df1" 162 | ], 163 | "outputs": [ 164 | { 165 | "output_type": "execute_result", 166 | "data": { 167 | "text/plain": [ 168 | " eng \\\n", 169 | "0 Hi. \n", 170 | "1 Run! \n", 171 | "2 Help! \n", 172 | "3 Jump! \n", 173 | "4 Stop! \n", 174 | "... ... \n", 175 | "24633 rising voices promoting a more linguistically ... \n", 176 | "24634 following last year s successful campaign we i... \n", 177 | "24635 during last year s challenge we also met langu... \n", 178 | "24636 to take part just follow the simple steps outl... \n", 179 | "24637 you will also find links to some free web base... \n", 180 | "\n", 181 | " ar \n", 182 | "0 مرحبًا. \n", 183 | "1 اركض! \n", 184 | "2 النجدة! \n", 185 | "3 اقفز! \n", 186 | "4 قف! \n", 187 | "... ... \n", 188 | "24633 شاركنا تحدي ابداع ميم بلغتك الام تعزيزا للتنوع... \n", 189 | "24634 استكمالا لنجاح حملة العام السابق ندعوكم للمشار... \n", 190 | "24635 تعرفنا خلال تحدي العام الماضي على ابطال لغويين... \n", 191 | "24636 للمشاركة في التحدي اتبع الخطوات الموضحة على ال... \n", 192 | "24637 ستجد ايضا روابط لمجموعة من منصات ابداع الميم ا... \n", 193 | "\n", 194 | "[24638 rows x 2 columns]" 195 | ], 196 | "text/html": [ 197 | "
\n", 198 | "\n", 211 | "\n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | "
engar
0Hi.مرحبًا.
1Run!اركض!
2Help!النجدة!
3Jump!اقفز!
4Stop!قف!
.........
24633rising voices promoting a more linguistically ...شاركنا تحدي ابداع ميم بلغتك الام تعزيزا للتنوع...
24634following last year s successful campaign we i...استكمالا لنجاح حملة العام السابق ندعوكم للمشار...
24635during last year s challenge we also met langu...تعرفنا خلال تحدي العام الماضي على ابطال لغويين...
24636to take part just follow the simple steps outl...للمشاركة في التحدي اتبع الخطوات الموضحة على ال...
24637you will also find links to some free web base...ستجد ايضا روابط لمجموعة من منصات ابداع الميم ا...
\n", 277 | "

24638 rows × 2 columns

\n", 278 | "
" 279 | ] 280 | }, 281 | "metadata": {}, 282 | "execution_count": 32 283 | } 284 | ], 285 | "metadata": {} 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 33, 290 | "source": [ 291 | "import pandas as pd\r\n", 292 | "df2 = pd.DataFrame(list(zip(english_sentences,arabic_sentences)), columns=['eng','ar'])\r\n", 293 | "df2" 294 | ], 295 | "outputs": [ 296 | { 297 | "output_type": "execute_result", 298 | "data": { 299 | "text/plain": [ 300 | " eng \\\n", 301 | "0 THE COUNCIL OF THE EUROPEAN ECONOMIC COMMUNITY, \n", 302 | "1 Whereas the adoption of a common transport pol... \n", 303 | "2 Article 1 \n", 304 | "3 3. The types of carriage listed in Annex II sh... \n", 305 | "4 Member States shall inform the Commission of t... \n", 306 | "... ... \n", 307 | "4102 Having regard to the request made by Luxembour... \n", 308 | "4103 (2) Such derogations should be granted, at the... \n", 309 | "4104 Article 1 \n", 310 | "4105 (b) France is granted derogations for the prod... \n", 311 | "4106 After expiry of the transitional period, Austr... \n", 312 | "\n", 313 | " ar \n", 314 | "0 مجلس الجماعة الاقتصادية الأوروبية \n", 315 | "1 حيث أن اعتماد سياسة نقل مشتركة تنطوي من بين أم... \n", 316 | "2 المادة 1 \n", 317 | "3 3. لا تخضع أنواع النقل المدرجة في الملحق الثان... \n", 318 | "4 تبلغ الدول الأعضاء المفوضية الأوروبية بالتدابي... \n", 319 | "... ... \n", 320 | "4102 باعتبار الطلب الذي تقدمت به لوكسمبورغ في 25 تم... \n", 321 | "4103 (2) يجب أن يتم منح هذه الاستثناءات إلى النمسا ... \n", 322 | "4104 المادة 1 \n", 323 | "4105 (ب) تمنح فرنسا الاستثناءات للحصول على النتائج ... \n", 324 | "4106 بعد انتهاء الفترة الانتقالية، تقوم النمسا وفرن... \n", 325 | "\n", 326 | "[4107 rows x 2 columns]" 327 | ], 328 | "text/html": [ 329 | "
\n", 330 | "\n", 343 | "\n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | "
engar
0THE COUNCIL OF THE EUROPEAN ECONOMIC COMMUNITY,مجلس الجماعة الاقتصادية الأوروبية
1Whereas the adoption of a common transport pol...حيث أن اعتماد سياسة نقل مشتركة تنطوي من بين أم...
2Article 1المادة 1
33. The types of carriage listed in Annex II sh...3. لا تخضع أنواع النقل المدرجة في الملحق الثان...
4Member States shall inform the Commission of t...تبلغ الدول الأعضاء المفوضية الأوروبية بالتدابي...
.........
4102Having regard to the request made by Luxembour...باعتبار الطلب الذي تقدمت به لوكسمبورغ في 25 تم...
4103(2) Such derogations should be granted, at the...(2) يجب أن يتم منح هذه الاستثناءات إلى النمسا ...
4104Article 1المادة 1
4105(b) France is granted derogations for the prod...(ب) تمنح فرنسا الاستثناءات للحصول على النتائج ...
4106After expiry of the transitional period, Austr...بعد انتهاء الفترة الانتقالية، تقوم النمسا وفرن...
\n", 409 | "

4107 rows × 2 columns

\n", 410 | "
" 411 | ] 412 | }, 413 | "metadata": {}, 414 | "execution_count": 33 415 | } 416 | ], 417 | "metadata": {} 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 34, 422 | "source": [ 423 | "data = pd.concat([df1, df2], ignore_index=True)\r\n", 424 | "data" 425 | ], 426 | "outputs": [ 427 | { 428 | "output_type": "execute_result", 429 | "data": { 430 | "text/plain": [ 431 | " eng \\\n", 432 | "0 Hi. \n", 433 | "1 Run! \n", 434 | "2 Help! \n", 435 | "3 Jump! \n", 436 | "4 Stop! \n", 437 | "... ... \n", 438 | "28740 Having regard to the request made by Luxembour... \n", 439 | "28741 (2) Such derogations should be granted, at the... \n", 440 | "28742 Article 1 \n", 441 | "28743 (b) France is granted derogations for the prod... \n", 442 | "28744 After expiry of the transitional period, Austr... \n", 443 | "\n", 444 | " ar \n", 445 | "0 مرحبًا. \n", 446 | "1 اركض! \n", 447 | "2 النجدة! \n", 448 | "3 اقفز! \n", 449 | "4 قف! \n", 450 | "... ... \n", 451 | "28740 باعتبار الطلب الذي تقدمت به لوكسمبورغ في 25 تم... \n", 452 | "28741 (2) يجب أن يتم منح هذه الاستثناءات إلى النمسا ... \n", 453 | "28742 المادة 1 \n", 454 | "28743 (ب) تمنح فرنسا الاستثناءات للحصول على النتائج ... \n", 455 | "28744 بعد انتهاء الفترة الانتقالية، تقوم النمسا وفرن... \n", 456 | "\n", 457 | "[28745 rows x 2 columns]" 458 | ], 459 | "text/html": [ 460 | "
\n", 461 | "\n", 474 | "\n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | "
engar
0Hi.مرحبًا.
1Run!اركض!
2Help!النجدة!
3Jump!اقفز!
4Stop!قف!
.........
28740Having regard to the request made by Luxembour...باعتبار الطلب الذي تقدمت به لوكسمبورغ في 25 تم...
28741(2) Such derogations should be granted, at the...(2) يجب أن يتم منح هذه الاستثناءات إلى النمسا ...
28742Article 1المادة 1
28743(b) France is granted derogations for the prod...(ب) تمنح فرنسا الاستثناءات للحصول على النتائج ...
28744After expiry of the transitional period, Austr...بعد انتهاء الفترة الانتقالية، تقوم النمسا وفرن...
\n", 540 | "

28745 rows × 2 columns

\n", 541 | "
" 542 | ] 543 | }, 544 | "metadata": {}, 545 | "execution_count": 34 546 | } 547 | ], 548 | "metadata": {} 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 36, 553 | "source": [ 554 | "arabic_stopwords = set(nltk.corpus.stopwords.words(\"arabic\"))\r\n", 555 | "arabic_punctuations = '''`÷×؛<>_()*&^%][ـ،/:\"؟.,'{}~¦+|!”…“–ـ'''\r\n", 556 | "\r\n", 557 | "punctuations = arabic_punctuations + string.punctuation\r\n", 558 | "def remove_stopwords(text):\r\n", 559 | " filtered_sentence = [w for w in text.split() if not w in punctuations]\r\n", 560 | " return ' '.join(filtered_sentence)\r\n", 561 | "\r\n", 562 | "def clean_Data(line):\r\n", 563 | " if (isinstance(line, float)):\r\n", 564 | " return None\r\n", 565 | " line.replace('\\n', ' ')\r\n", 566 | " line = ' '.join(line)\r\n", 567 | " translator = str.maketrans('', '', punctuations)\r\n", 568 | " line = line.translate(translator)\r\n", 569 | " line = ' '.join(line)\r\n", 570 | " return line" 571 | ], 572 | "outputs": [], 573 | "metadata": {} 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": 22, 578 | "source": [ 579 | "data.eng = data.eng.apply(clean_Data)" 580 | ], 581 | "outputs": [], 582 | "metadata": {} 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": 18, 587 | "source": [ 588 | "data" 589 | ], 590 | "outputs": [ 591 | { 592 | "output_type": "execute_result", 593 | "data": { 594 | "text/plain": [ 595 | " eng \\\n", 596 | "0 H i \n", 597 | "1 R u n \n", 598 | "2 H e l p \n", 599 | "3 J u m p \n", 600 | "4 S t o p \n", 601 | "... ... \n", 602 | "28740 H a v i n g r e g a r ... \n", 603 | "28741 2 S u c h d e r o ... \n", 604 | "28742 A r t i c l e 1 \n", 605 | "28743 b F r a n c e i s ... \n", 606 | "28744 A f t e r e x p i r y ... \n", 607 | "\n", 608 | " ar \n", 609 | "0 مرحبًا. \n", 610 | "1 اركض! \n", 611 | "2 النجدة! \n", 612 | "3 اقفز! \n", 613 | "4 قف! \n", 614 | "... ... \n", 615 | "28740 باعتبار الطلب الذي تقدمت به لوكسمبورغ في 25 تم... \n", 616 | "28741 (2) يجب أن يتم منح هذه الاستثناءات إلى النمسا ... \n", 617 | "28742 المادة 1 \n", 618 | "28743 (ب) تمنح فرنسا الاستثناءات للحصول على النتائج ... \n", 619 | "28744 بعد انتهاء الفترة الانتقالية، تقوم النمسا وفرن... \n", 620 | "\n", 621 | "[28745 rows x 2 columns]" 622 | ], 623 | "text/html": [ 624 | "
\n", 625 | "\n", 638 | "\n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | "
engar
0H iمرحبًا.
1R u nاركض!
2H e l pالنجدة!
3J u m pاقفز!
4S t o pقف!
.........
28740H a v i n g r e g a r ...باعتبار الطلب الذي تقدمت به لوكسمبورغ في 25 تم...
287412 S u c h d e r o ...(2) يجب أن يتم منح هذه الاستثناءات إلى النمسا ...
28742A r t i c l e 1المادة 1
28743b F r a n c e i s ...(ب) تمنح فرنسا الاستثناءات للحصول على النتائج ...
28744A f t e r e x p i r y ...بعد انتهاء الفترة الانتقالية، تقوم النمسا وفرن...
\n", 704 | "

28745 rows × 2 columns

\n", 705 | "
" 706 | ] 707 | }, 708 | "metadata": {}, 709 | "execution_count": 18 710 | } 711 | ], 712 | "metadata": {} 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": 21, 717 | "source": [ 718 | "data.to_csv('data/data.txt')" 719 | ], 720 | "outputs": [], 721 | "metadata": {} 722 | }, 723 | { 724 | "cell_type": "code", 725 | "execution_count": 22, 726 | "source": [ 727 | "df = pd.read_csv(\"data/data.txt\")\r\n", 728 | "df = df.drop(df.columns[0], axis=1)\r\n", 729 | "df" 730 | ], 731 | "outputs": [ 732 | { 733 | "output_type": "execute_result", 734 | "data": { 735 | "text/plain": [ 736 | " eng \\\n", 737 | "0 Hi. \n", 738 | "1 Run! \n", 739 | "2 Help! \n", 740 | "3 Jump! \n", 741 | "4 Stop! \n", 742 | "... ... \n", 743 | "28740 Having regard to the request made by Luxembour... \n", 744 | "28741 (2) Such derogations should be granted, at the... \n", 745 | "28742 Article 1 \n", 746 | "28743 (b) France is granted derogations for the prod... \n", 747 | "28744 After expiry of the transitional period, Austr... \n", 748 | "\n", 749 | " ar \n", 750 | "0 مرحبًا. \n", 751 | "1 اركض! \n", 752 | "2 النجدة! \n", 753 | "3 اقفز! \n", 754 | "4 قف! \n", 755 | "... ... \n", 756 | "28740 باعتبار الطلب الذي تقدمت به لوكسمبورغ في 25 تم... \n", 757 | "28741 (2) يجب أن يتم منح هذه الاستثناءات إلى النمسا ... \n", 758 | "28742 المادة 1 \n", 759 | "28743 (ب) تمنح فرنسا الاستثناءات للحصول على النتائج ... \n", 760 | "28744 بعد انتهاء الفترة الانتقالية، تقوم النمسا وفرن... \n", 761 | "\n", 762 | "[28745 rows x 2 columns]" 763 | ], 764 | "text/html": [ 765 | "
\n", 766 | "\n", 779 | "\n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | "
engar
0Hi.مرحبًا.
1Run!اركض!
2Help!النجدة!
3Jump!اقفز!
4Stop!قف!
.........
28740Having regard to the request made by Luxembour...باعتبار الطلب الذي تقدمت به لوكسمبورغ في 25 تم...
28741(2) Such derogations should be granted, at the...(2) يجب أن يتم منح هذه الاستثناءات إلى النمسا ...
28742Article 1المادة 1
28743(b) France is granted derogations for the prod...(ب) تمنح فرنسا الاستثناءات للحصول على النتائج ...
28744After expiry of the transitional period, Austr...بعد انتهاء الفترة الانتقالية، تقوم النمسا وفرن...
\n", 845 | "

28745 rows × 2 columns

\n", 846 | "
" 847 | ] 848 | }, 849 | "metadata": {}, 850 | "execution_count": 22 851 | } 852 | ], 853 | "metadata": {} 854 | } 855 | ], 856 | "metadata": { 857 | "orig_nbformat": 4, 858 | "language_info": { 859 | "name": "python", 860 | "version": "3.7.11", 861 | "mimetype": "text/x-python", 862 | "codemirror_mode": { 863 | "name": "ipython", 864 | "version": 3 865 | }, 866 | "pygments_lexer": "ipython3", 867 | "nbconvert_exporter": "python", 868 | "file_extension": ".py" 869 | }, 870 | "kernelspec": { 871 | "name": "python3", 872 | "display_name": "Python 3.7.11 64-bit ('torch': conda)" 873 | }, 874 | "interpreter": { 875 | "hash": "4bb0fe8ced3cf0716ac3718fe834e829af40e8ba0fef1c4cadecb390da29a017" 876 | } 877 | }, 878 | "nbformat": 4, 879 | "nbformat_minor": 2 880 | } -------------------------------------------------------------------------------- /translation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# English to Arabic Translation " 7 | ], 8 | "metadata": {} 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "source": [ 13 | "## Imports" 14 | ], 15 | "metadata": {} 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "source": [ 21 | "import pandas as pd\r\n", 22 | "import numpy as np\r\n", 23 | "import matplotlib.pyplot as plt \r\n", 24 | "import torch\r\n", 25 | "import random\r\n", 26 | "import re\r\n", 27 | "import os\r\n", 28 | "\r\n", 29 | "import spacy\r\n", 30 | "from spacy.tokenizer import Tokenizer\r\n", 31 | "from spacy.lang.ar import Arabic\r\n", 32 | "\r\n", 33 | "\r\n", 34 | "import torch\r\n", 35 | "import torch.nn as nn\r\n", 36 | "from torch import optim\r\n", 37 | "from torch.utils.tensorboard import SummaryWriter\r\n", 38 | "\r\n", 39 | "from torchtext import data\r\n", 40 | "from torchtext.legacy import data\r\n", 41 | "\r\n", 42 | "\r\n", 43 | "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'" 44 | ], 45 | "outputs": [], 46 | "metadata": {} 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "source": [ 51 | "## Data Processing" 52 | ], 53 | "metadata": {} 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "source": [ 59 | "df = pd.read_csv(\"data/arabic_english.txt\",delimiter=\"\\t\",names=[\"eng\",\"ar\"])\r\n", 60 | "df" 61 | ], 62 | "outputs": [ 63 | { 64 | "output_type": "execute_result", 65 | "data": { 66 | "text/html": [ 67 | "
\n", 68 | "\n", 81 | "\n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | "
engar
0Hi.مرحبًا.
1Run!اركض!
2Help!النجدة!
3Jump!اقفز!
4Stop!قف!
.........
24633rising voices promoting a more linguistically ...شاركنا تحدي ابداع ميم بلغتك الام تعزيزا للتنوع...
24634following last year s successful campaign we i...استكمالا لنجاح حملة العام السابق ندعوكم للمشار...
24635during last year s challenge we also met langu...تعرفنا خلال تحدي العام الماضي على ابطال لغويين...
24636to take part just follow the simple steps outl...للمشاركة في التحدي اتبع الخطوات الموضحة على ال...
24637you will also find links to some free web base...ستجد ايضا روابط لمجموعة من منصات ابداع الميم ا...
\n", 147 | "

24638 rows × 2 columns

\n", 148 | "
" 149 | ], 150 | "text/plain": [ 151 | " eng \\\n", 152 | "0 Hi. \n", 153 | "1 Run! \n", 154 | "2 Help! \n", 155 | "3 Jump! \n", 156 | "4 Stop! \n", 157 | "... ... \n", 158 | "24633 rising voices promoting a more linguistically ... \n", 159 | "24634 following last year s successful campaign we i... \n", 160 | "24635 during last year s challenge we also met langu... \n", 161 | "24636 to take part just follow the simple steps outl... \n", 162 | "24637 you will also find links to some free web base... \n", 163 | "\n", 164 | " ar \n", 165 | "0 مرحبًا. \n", 166 | "1 اركض! \n", 167 | "2 النجدة! \n", 168 | "3 اقفز! \n", 169 | "4 قف! \n", 170 | "... ... \n", 171 | "24633 شاركنا تحدي ابداع ميم بلغتك الام تعزيزا للتنوع... \n", 172 | "24634 استكمالا لنجاح حملة العام السابق ندعوكم للمشار... \n", 173 | "24635 تعرفنا خلال تحدي العام الماضي على ابطال لغويين... \n", 174 | "24636 للمشاركة في التحدي اتبع الخطوات الموضحة على ال... \n", 175 | "24637 ستجد ايضا روابط لمجموعة من منصات ابداع الميم ا... \n", 176 | "\n", 177 | "[24638 rows x 2 columns]" 178 | ] 179 | }, 180 | "metadata": {}, 181 | "execution_count": 3 182 | } 183 | ], 184 | "metadata": {} 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 33, 189 | "source": [ 190 | "!python -m spacy download en_core_web_sm" 191 | ], 192 | "outputs": [ 193 | { 194 | "output_type": "stream", 195 | "name": "stdout", 196 | "text": [ 197 | "Collecting en-core-web-sm==3.1.0\n", 198 | " Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.1.0/en_core_web_sm-3.1.0-py3-none-any.whl (13.6 MB)\n", 199 | "Requirement already satisfied: spacy<3.2.0,>=3.1.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from en-core-web-sm==3.1.0) (3.1.2)\n", 200 | "Requirement already satisfied: srsly<3.0.0,>=2.4.1 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (2.4.1)\n", 201 | "Requirement already satisfied: thinc<8.1.0,>=8.0.8 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (8.0.8)\n", 202 | "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (3.0.5)\n", 203 | "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (1.0.5)\n", 204 | "Requirement already satisfied: catalogue<2.1.0,>=2.0.4 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (2.0.5)\n", 205 | "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.7 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (3.0.8)\n", 206 | "Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (0.8.2)\n", 207 | "Requirement already satisfied: packaging>=20.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (21.0)\n", 208 | "Requirement already satisfied: pathy>=0.3.5 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (0.6.0)\n", 209 | "Requirement already satisfied: requests<3.0.0,>=2.13.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (2.26.0)\n", 210 | "Requirement already satisfied: jinja2 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (3.0.1)\n", 211 | "Requirement already satisfied: typing-extensions<4.0.0.0,>=3.7.4 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (3.10.0.0)\n", 212 | "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (4.62.1)\n", 213 | "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (2.0.5)\n", 214 | "Requirement already satisfied: numpy>=1.15.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (1.21.2)\n", 215 | "Requirement already satisfied: typer<0.4.0,>=0.3.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (0.3.2)\n", 216 | "Requirement already satisfied: blis<0.8.0,>=0.4.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (0.7.4)\n", 217 | "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (1.8.2)\n", 218 | "Requirement already satisfied: setuptools in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (52.0.0.post20210125)\n", 219 | "Requirement already satisfied: zipp>=0.5 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from catalogue<2.1.0,>=2.0.4->spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (3.5.0)\n", 220 | "Requirement already satisfied: pyparsing>=2.0.2 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from packaging>=20.0->spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (2.4.7)\n", 221 | "Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from pathy>=0.3.5->spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (5.2.0)\n", 222 | "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (2021.5.30)\n", 223 | "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (2.10)\n", 224 | "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (2.0.4)\n", 225 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (1.26.6)\n", 226 | "Requirement already satisfied: colorama in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from tqdm<5.0.0,>=4.38.0->spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (0.4.4)\n", 227 | "Requirement already satisfied: click<7.2.0,>=7.1.1 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from typer<0.4.0,>=0.3.0->spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (7.1.2)\n", 228 | "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\ultrapc\\anaconda3\\envs\\torch\\lib\\site-packages (from jinja2->spacy<3.2.0,>=3.1.0->en-core-web-sm==3.1.0) (2.0.1)\n", 229 | "✔ Download and installation successful\n", 230 | "You can now load the package via spacy.load('en_core_web_sm')\n" 231 | ] 232 | } 233 | ], 234 | "metadata": {} 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "source": [ 239 | "### tokenizers" 240 | ], 241 | "metadata": {} 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 5, 246 | "source": [ 247 | "seed=32\r\n", 248 | "\r\n", 249 | "spacy_eng = spacy.load(\"en_core_web_sm\")\r\n", 250 | "\r\n", 251 | "arab = Arabic()\r\n", 252 | "ar_Tokenizer = Tokenizer(arab.vocab)\r\n", 253 | "\r\n", 254 | "def engTokenizer(text):\r\n", 255 | " return [word.text for word in spacy_eng.tokenizer(text)] \r\n", 256 | "\r\n", 257 | "def arTokenizer(sentence):\r\n", 258 | " return [word.text for word in \r\n", 259 | " ar_Tokenizer(re.sub(r\"\\s+\",\" \",re.sub(r\"[\\.\\'\\\"\\n+]\",\" \",sentence)).strip())]\r\n", 260 | "\r\n", 261 | "SRC = data.Field(tokenize=engTokenizer,batch_first=False,init_token=\"\",eos_token=\"\")\r\n", 262 | "TRG = data.Field(tokenize=arTokenizer,batch_first=False,tokenizer_language=\"ar\",init_token=\"بداية\",eos_token=\"نهاية\")\r\n", 263 | "\r\n", 264 | "class TextDataset(data.Dataset):\r\n", 265 | "\r\n", 266 | " def __init__(self, df, src_field, target_field, is_test=False, **kwargs):\r\n", 267 | " fields = [('eng', src_field), ('ar',target_field)]\r\n", 268 | " samples = []\r\n", 269 | " for i, row in df.iterrows():\r\n", 270 | " eng = row.eng \r\n", 271 | " ar = row.ar\r\n", 272 | " samples.append(data.Example.fromlist([eng, ar], fields))\r\n", 273 | "\r\n", 274 | " super().__init__(samples, fields, **kwargs)\r\n", 275 | " def __len__(self):\r\n", 276 | " return len(self.samples)\r\n", 277 | " \r\n", 278 | " def __getitem__(self, idx):\r\n", 279 | " return self.samples[idx]\r\n", 280 | "\r\n", 281 | "torchdataset = TextDataset(df,SRC,TRG)\r\n", 282 | "\r\n", 283 | "train_data, valid_data = torchdataset.split(split_ratio=0.8, random_state = random.seed(32))\r\n", 284 | "\r\n", 285 | "SRC.build_vocab(train_data,min_freq=2)\r\n", 286 | "TRG.build_vocab(train_data,min_freq=2)\r\n", 287 | "\r\n", 288 | "print(train_data[1].__dict__)\r\n" 289 | ], 290 | "outputs": [ 291 | { 292 | "output_type": "stream", 293 | "name": "stdout", 294 | "text": [ 295 | "{'eng': ['I', 'was', 'delayed', 'by', 'a', 'traffic', 'jam', '.'], 'ar': ['أخّرني', 'زحام', 'السير']}\n" 296 | ] 297 | } 298 | ], 299 | "metadata": {} 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "source": [ 304 | "### seting up the device" 305 | ], 306 | "metadata": {} 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 6, 311 | "source": [ 312 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n", 313 | "#device = torch.device(\"cpu\")\r\n", 314 | "print(device)\r\n", 315 | "print(torch.cuda.get_device_name(0))\r\n", 316 | "# full infos\r\n", 317 | "# !nvidia-smi" 318 | ], 319 | "outputs": [ 320 | { 321 | "output_type": "stream", 322 | "name": "stdout", 323 | "text": [ 324 | "cuda\n", 325 | "NVIDIA GeForce RTX 3060 Ti\n" 326 | ] 327 | } 328 | ], 329 | "metadata": {} 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "source": [ 334 | "## Transformer Class" 335 | ], 336 | "metadata": {} 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 7, 341 | "source": [ 342 | "class Transformer(nn.Module):\r\n", 343 | " def __init__(\r\n", 344 | " self,\r\n", 345 | " embedding_size,\r\n", 346 | " src_vocab_size,\r\n", 347 | " trg_vocab_size,\r\n", 348 | " src_pad_idx,\r\n", 349 | " num_heads,\r\n", 350 | " num_encoder_layers,\r\n", 351 | " num_decoder_layers,\r\n", 352 | " max_len,\r\n", 353 | " ):\r\n", 354 | " super(Transformer, self).__init__()\r\n", 355 | " self.src_embeddings = nn.Embedding(src_vocab_size,embedding_size)\r\n", 356 | " self.src_positional_embeddings= nn.Embedding(max_len,embedding_size)\r\n", 357 | " self.trg_embeddings= nn.Embedding(trg_vocab_size,embedding_size)\r\n", 358 | " self.trg_positional_embeddings= nn.Embedding(max_len,embedding_size)\r\n", 359 | " self.device = device\r\n", 360 | " self.transformer = nn.Transformer(\r\n", 361 | " embedding_size,\r\n", 362 | " num_heads,\r\n", 363 | " num_encoder_layers,\r\n", 364 | " num_decoder_layers,\r\n", 365 | " )\r\n", 366 | "\r\n", 367 | " self.fc_out = nn.Linear(embedding_size, trg_vocab_size)\r\n", 368 | " self.dropout = nn.Dropout(dropout)\r\n", 369 | " self.src_pad_idx = src_pad_idx\r\n", 370 | " \r\n", 371 | " def make_src_mask(self, src):\r\n", 372 | " src_mask = src.transpose(0,1) == self.src_pad_idx\r\n", 373 | "\r\n", 374 | " return src_mask.to(device)\r\n", 375 | "\r\n", 376 | " def forward(self,src,trg) :\r\n", 377 | " src_seq_length, S = src.shape\r\n", 378 | " trg_seq_length, S = trg.shape\r\n", 379 | " #adding zeros is an easy way\r\n", 380 | " src_positions = (\r\n", 381 | " torch.arange(0, src_seq_length).unsqueeze(1).expand(src_seq_length, S).to(self.device)\r\n", 382 | " )\r\n", 383 | " \r\n", 384 | " \r\n", 385 | " trg_positions = (\r\n", 386 | " torch.arange(0, trg_seq_length).unsqueeze(1).expand(trg_seq_length, S).to(self.device)\r\n", 387 | " )\r\n", 388 | "\r\n", 389 | " embed_src = self.dropout(\r\n", 390 | " ( self.src_embeddings(src) + self.src_positional_embeddings(src_positions) )\r\n", 391 | " )\r\n", 392 | "\r\n", 393 | " embed_trg = self.dropout(\r\n", 394 | " ( self.trg_embeddings(trg) + self.trg_positional_embeddings(trg_positions) )\r\n", 395 | " )\r\n", 396 | " \r\n", 397 | " src_padding_mask = self.make_src_mask(src)\r\n", 398 | " trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(device)\r\n", 399 | " \r\n", 400 | " \r\n", 401 | " out = self.transformer(embed_src,embed_trg, src_key_padding_mask=src_padding_mask,tgt_mask=trg_mask )\r\n", 402 | " out= self.fc_out(out)\r\n", 403 | "\r\n", 404 | " return out" 405 | ], 406 | "outputs": [], 407 | "metadata": {} 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "source": [ 412 | "## Model and Parameters " 413 | ], 414 | "metadata": {} 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 8, 419 | "source": [ 420 | "BATCH_SIZE = 16\r\n", 421 | "\r\n", 422 | "train_iter, valid_iter = data.BucketIterator.splits(\r\n", 423 | " (train_data,valid_data), \r\n", 424 | " batch_size = BATCH_SIZE,\r\n", 425 | " sort=None,\r\n", 426 | " sort_within_batch=False,\r\n", 427 | " sort_key=lambda x: len(x.eng),\r\n", 428 | " device=device,\r\n", 429 | " shuffle=True\r\n", 430 | ")" 431 | ], 432 | "outputs": [], 433 | "metadata": {} 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 14, 438 | "source": [ 439 | "load_model = False\r\n", 440 | "save_model = True\r\n", 441 | "\r\n", 442 | "num_epochs = 30\r\n", 443 | "learning_rate = 0.0001\r\n", 444 | "\r\n", 445 | "num_heads = 8\r\n", 446 | "num_encoder_layers = 3\r\n", 447 | "num_decoder_layers = 3\r\n", 448 | "\r\n", 449 | "max_len= 230\r\n", 450 | "dropout = 0.4\r\n", 451 | "embedding_size= 256\r\n", 452 | "src_pad_idx = SRC.vocab.stoi[\"\"]\r\n", 453 | "\r\n", 454 | "\r\n", 455 | "src_vocab_size = len(SRC.vocab)\r\n", 456 | "print(\"Size of english vocabulary:\",src_vocab_size)\r\n", 457 | "\r\n", 458 | "trg_vocab_size =len(TRG.vocab)\r\n", 459 | "print(\"Size of arabic vocabulary:\",trg_vocab_size)\r\n", 460 | "\r\n", 461 | "\r\n", 462 | "model = Transformer( \r\n", 463 | " embedding_size,\r\n", 464 | " src_vocab_size,\r\n", 465 | " trg_vocab_size,\r\n", 466 | " src_pad_idx,\r\n", 467 | " num_heads,\r\n", 468 | " num_encoder_layers,\r\n", 469 | " num_decoder_layers,\r\n", 470 | " max_len,\r\n", 471 | ").to(device)\r\n", 472 | "\r\n" 473 | ], 474 | "outputs": [ 475 | { 476 | "output_type": "stream", 477 | "name": "stdout", 478 | "text": [ 479 | "Size of english vocabulary: 12812\n", 480 | "Size of arabic vocabulary: 22067\n" 481 | ] 482 | } 483 | ], 484 | "metadata": {} 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 15, 489 | "source": [ 490 | "print(model)" 491 | ], 492 | "outputs": [ 493 | { 494 | "output_type": "stream", 495 | "name": "stdout", 496 | "text": [ 497 | "Transformer(\n", 498 | " (src_embeddings): Embedding(12812, 256)\n", 499 | " (src_positional_embeddings): Embedding(230, 256)\n", 500 | " (trg_embeddings): Embedding(22067, 256)\n", 501 | " (trg_positional_embeddings): Embedding(230, 256)\n", 502 | " (transformer): Transformer(\n", 503 | " (encoder): TransformerEncoder(\n", 504 | " (layers): ModuleList(\n", 505 | " (0): TransformerEncoderLayer(\n", 506 | " (self_attn): MultiheadAttention(\n", 507 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 508 | " )\n", 509 | " (linear1): Linear(in_features=256, out_features=2048, bias=True)\n", 510 | " (dropout): Dropout(p=0.1, inplace=False)\n", 511 | " (linear2): Linear(in_features=2048, out_features=256, bias=True)\n", 512 | " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 513 | " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 514 | " (dropout1): Dropout(p=0.1, inplace=False)\n", 515 | " (dropout2): Dropout(p=0.1, inplace=False)\n", 516 | " )\n", 517 | " (1): TransformerEncoderLayer(\n", 518 | " (self_attn): MultiheadAttention(\n", 519 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 520 | " )\n", 521 | " (linear1): Linear(in_features=256, out_features=2048, bias=True)\n", 522 | " (dropout): Dropout(p=0.1, inplace=False)\n", 523 | " (linear2): Linear(in_features=2048, out_features=256, bias=True)\n", 524 | " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 525 | " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 526 | " (dropout1): Dropout(p=0.1, inplace=False)\n", 527 | " (dropout2): Dropout(p=0.1, inplace=False)\n", 528 | " )\n", 529 | " (2): TransformerEncoderLayer(\n", 530 | " (self_attn): MultiheadAttention(\n", 531 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 532 | " )\n", 533 | " (linear1): Linear(in_features=256, out_features=2048, bias=True)\n", 534 | " (dropout): Dropout(p=0.1, inplace=False)\n", 535 | " (linear2): Linear(in_features=2048, out_features=256, bias=True)\n", 536 | " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 537 | " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 538 | " (dropout1): Dropout(p=0.1, inplace=False)\n", 539 | " (dropout2): Dropout(p=0.1, inplace=False)\n", 540 | " )\n", 541 | " )\n", 542 | " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 543 | " )\n", 544 | " (decoder): TransformerDecoder(\n", 545 | " (layers): ModuleList(\n", 546 | " (0): TransformerDecoderLayer(\n", 547 | " (self_attn): MultiheadAttention(\n", 548 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 549 | " )\n", 550 | " (multihead_attn): MultiheadAttention(\n", 551 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 552 | " )\n", 553 | " (linear1): Linear(in_features=256, out_features=2048, bias=True)\n", 554 | " (dropout): Dropout(p=0.1, inplace=False)\n", 555 | " (linear2): Linear(in_features=2048, out_features=256, bias=True)\n", 556 | " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 557 | " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 558 | " (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 559 | " (dropout1): Dropout(p=0.1, inplace=False)\n", 560 | " (dropout2): Dropout(p=0.1, inplace=False)\n", 561 | " (dropout3): Dropout(p=0.1, inplace=False)\n", 562 | " )\n", 563 | " (1): TransformerDecoderLayer(\n", 564 | " (self_attn): MultiheadAttention(\n", 565 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 566 | " )\n", 567 | " (multihead_attn): MultiheadAttention(\n", 568 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 569 | " )\n", 570 | " (linear1): Linear(in_features=256, out_features=2048, bias=True)\n", 571 | " (dropout): Dropout(p=0.1, inplace=False)\n", 572 | " (linear2): Linear(in_features=2048, out_features=256, bias=True)\n", 573 | " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 574 | " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 575 | " (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 576 | " (dropout1): Dropout(p=0.1, inplace=False)\n", 577 | " (dropout2): Dropout(p=0.1, inplace=False)\n", 578 | " (dropout3): Dropout(p=0.1, inplace=False)\n", 579 | " )\n", 580 | " (2): TransformerDecoderLayer(\n", 581 | " (self_attn): MultiheadAttention(\n", 582 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 583 | " )\n", 584 | " (multihead_attn): MultiheadAttention(\n", 585 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 586 | " )\n", 587 | " (linear1): Linear(in_features=256, out_features=2048, bias=True)\n", 588 | " (dropout): Dropout(p=0.1, inplace=False)\n", 589 | " (linear2): Linear(in_features=2048, out_features=256, bias=True)\n", 590 | " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 591 | " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 592 | " (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 593 | " (dropout1): Dropout(p=0.1, inplace=False)\n", 594 | " (dropout2): Dropout(p=0.1, inplace=False)\n", 595 | " (dropout3): Dropout(p=0.1, inplace=False)\n", 596 | " )\n", 597 | " )\n", 598 | " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 599 | " )\n", 600 | " )\n", 601 | " (fc_out): Linear(in_features=256, out_features=22067, bias=True)\n", 602 | " (dropout): Dropout(p=0.3, inplace=False)\n", 603 | ")\n" 604 | ] 605 | } 606 | ], 607 | "metadata": {} 608 | }, 609 | { 610 | "cell_type": "markdown", 611 | "source": [ 612 | "## Training" 613 | ], 614 | "metadata": {} 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": 16, 619 | "source": [ 620 | "torch.cuda.empty_cache()" 621 | ], 622 | "outputs": [], 623 | "metadata": {} 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 17, 628 | "source": [ 629 | "loss_track = []\r\n", 630 | "loss_validation_track= []\r\n", 631 | "\r\n", 632 | "\r\n", 633 | "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\r\n", 634 | "\r\n", 635 | "pad_idx = SRC.vocab.stoi[\"\"]\r\n", 636 | "criterion = nn.CrossEntropyLoss(ignore_index = pad_idx)\r\n", 637 | "for epoch in range(num_epochs):\r\n", 638 | " stepLoss=[]\r\n", 639 | " model.train()\r\n", 640 | " for batch in train_iter:\r\n", 641 | " input_data = batch.eng.to(device)\r\n", 642 | " target = batch.ar.to(device)\r\n", 643 | "\r\n", 644 | " output = model(input_data,target[:-1])\r\n", 645 | " optimizer.zero_grad()\r\n", 646 | " \r\n", 647 | " output = output.reshape(-1,trg_vocab_size)\r\n", 648 | " target = target[1:].reshape(-1)\r\n", 649 | "\r\n", 650 | " loss = criterion(output,target)\r\n", 651 | " loss.backward()\r\n", 652 | "\r\n", 653 | " optimizer.step()\r\n", 654 | " stepLoss.append(loss.item())\r\n", 655 | "\r\n", 656 | " loss_track.append(np.mean(stepLoss))\r\n", 657 | " print(\" Epoch {} | Train Cross Entropy Loss: \".format(epoch),np.mean(stepLoss))\r\n", 658 | " with torch.no_grad(): \r\n", 659 | " stepValidLoss=[]\r\n", 660 | " model.eval() # the evaluation mode for the model (doesn't apply dropout and batchNorm)\r\n", 661 | " for i,batch in enumerate(valid_iter):\r\n", 662 | " input_sentence = batch.eng.to(device)\r\n", 663 | " target = batch.ar.to(device)\r\n", 664 | " optimizer.zero_grad()\r\n", 665 | " output = model(input_sentence,target[:-1])\r\n", 666 | " output = output.reshape(-1,trg_vocab_size)\r\n", 667 | " target = target[1:].reshape(-1)\r\n", 668 | " loss = criterion(output,target)\r\n", 669 | " \r\n", 670 | " stepValidLoss.append(loss.item())\r\n", 671 | " \r\n", 672 | " loss_validation_track.append(np.mean(stepValidLoss))\r\n", 673 | " print(\" Epoch {} | Validation Cross Entropy Loss: \".format(epoch),np.mean(stepValidLoss)) " 674 | ], 675 | "outputs": [ 676 | { 677 | "output_type": "stream", 678 | "name": "stdout", 679 | "text": [ 680 | " Epoch 0 | Train Cross Entropy Loss: 7.435049949141292\n", 681 | " Epoch 0 | Validation Cross Entropy Loss: 5.988370576462188\n", 682 | " Epoch 1 | Train Cross Entropy Loss: 6.977232069164128\n", 683 | " Epoch 1 | Validation Cross Entropy Loss: 5.755757348103956\n", 684 | " Epoch 2 | Train Cross Entropy Loss: 6.727246981162529\n", 685 | " Epoch 2 | Validation Cross Entropy Loss: 5.524308341664153\n", 686 | " Epoch 3 | Train Cross Entropy Loss: 6.47047579946456\n", 687 | " Epoch 3 | Validation Cross Entropy Loss: 5.395510543476451\n", 688 | " Epoch 4 | Train Cross Entropy Loss: 6.229757437845329\n", 689 | " Epoch 4 | Validation Cross Entropy Loss: 5.26321325131825\n", 690 | " Epoch 5 | Train Cross Entropy Loss: 5.9832683666185895\n", 691 | " Epoch 5 | Validation Cross Entropy Loss: 5.190949131142009\n", 692 | " Epoch 6 | Train Cross Entropy Loss: 5.732032602870619\n", 693 | " Epoch 6 | Validation Cross Entropy Loss: 5.094015572752271\n", 694 | " Epoch 7 | Train Cross Entropy Loss: 5.481263580647382\n", 695 | " Epoch 7 | Validation Cross Entropy Loss: 5.061300001361153\n", 696 | " Epoch 8 | Train Cross Entropy Loss: 5.221863861207838\n", 697 | " Epoch 8 | Validation Cross Entropy Loss: 5.003460148712257\n", 698 | " Epoch 9 | Train Cross Entropy Loss: 4.972142156068381\n", 699 | " Epoch 9 | Validation Cross Entropy Loss: 4.966425876338761\n", 700 | " Epoch 10 | Train Cross Entropy Loss: 4.713377501283373\n", 701 | " Epoch 10 | Validation Cross Entropy Loss: 4.975224611047026\n", 702 | " Epoch 11 | Train Cross Entropy Loss: 4.459344664177337\n", 703 | " Epoch 11 | Validation Cross Entropy Loss: 4.963081742648955\n", 704 | " Epoch 12 | Train Cross Entropy Loss: 4.216912533942756\n", 705 | " Epoch 12 | Validation Cross Entropy Loss: 4.948643896099809\n", 706 | " Epoch 13 | Train Cross Entropy Loss: 3.9796205166872447\n", 707 | " Epoch 13 | Validation Cross Entropy Loss: 5.041126395587797\n", 708 | " Epoch 14 | Train Cross Entropy Loss: 3.754485915427084\n", 709 | " Epoch 14 | Validation Cross Entropy Loss: 5.053339765830473\n", 710 | " Epoch 15 | Train Cross Entropy Loss: 3.545525817127971\n", 711 | " Epoch 15 | Validation Cross Entropy Loss: 5.0739604056655585\n", 712 | " Epoch 16 | Train Cross Entropy Loss: 3.3388276191113833\n", 713 | " Epoch 16 | Validation Cross Entropy Loss: 5.128433674960942\n", 714 | " Epoch 17 | Train Cross Entropy Loss: 3.1556254357293057\n", 715 | " Epoch 17 | Validation Cross Entropy Loss: 5.156803880032006\n", 716 | " Epoch 18 | Train Cross Entropy Loss: 2.9846578783222606\n", 717 | " Epoch 18 | Validation Cross Entropy Loss: 5.189165875896231\n", 718 | " Epoch 19 | Train Cross Entropy Loss: 2.8305458025885866\n", 719 | " Epoch 19 | Validation Cross Entropy Loss: 5.24085449514451\n", 720 | " Epoch 20 | Train Cross Entropy Loss: 2.6819157435909493\n", 721 | " Epoch 20 | Validation Cross Entropy Loss: 5.320800462713489\n", 722 | " Epoch 21 | Train Cross Entropy Loss: 2.5431851465980726\n", 723 | " Epoch 21 | Validation Cross Entropy Loss: 5.328162451843163\n", 724 | " Epoch 22 | Train Cross Entropy Loss: 2.4220719888999866\n", 725 | " Epoch 22 | Validation Cross Entropy Loss: 5.437457458539442\n", 726 | " Epoch 23 | Train Cross Entropy Loss: 2.3040674579414455\n", 727 | " Epoch 23 | Validation Cross Entropy Loss: 5.465106003470235\n", 728 | " Epoch 24 | Train Cross Entropy Loss: 2.1977922209091\n", 729 | " Epoch 24 | Validation Cross Entropy Loss: 5.51173290803835\n", 730 | " Epoch 25 | Train Cross Entropy Loss: 2.098390280619844\n", 731 | " Epoch 25 | Validation Cross Entropy Loss: 5.609959487017099\n", 732 | " Epoch 26 | Train Cross Entropy Loss: 2.0048216544575506\n", 733 | " Epoch 26 | Validation Cross Entropy Loss: 5.693122586646638\n", 734 | " Epoch 27 | Train Cross Entropy Loss: 1.9181972871256339\n", 735 | " Epoch 27 | Validation Cross Entropy Loss: 5.720419735103459\n", 736 | " Epoch 28 | Train Cross Entropy Loss: 1.8346951236198474\n", 737 | " Epoch 28 | Validation Cross Entropy Loss: 5.777838484033362\n", 738 | " Epoch 29 | Train Cross Entropy Loss: 1.7577206774481706\n", 739 | " Epoch 29 | Validation Cross Entropy Loss: 5.902188576661147\n" 740 | ] 741 | } 742 | ], 743 | "metadata": {} 744 | }, 745 | { 746 | "cell_type": "code", 747 | "execution_count": 39, 748 | "source": [ 749 | "#the train loss after 50 epoch\r\n", 750 | "plt.figure(figsize=(10,5))\r\n", 751 | "plt.plot(range(30),loss_track,label=\"train loss\")\r\n", 752 | "plt.plot(range(30),loss_validation_track,label=\"valiadtion loss\")\r\n", 753 | "plt.legend()\r\n", 754 | "plt.show()" 755 | ], 756 | "outputs": [ 757 | { 758 | "output_type": "display_data", 759 | "data": { 760 | "image/png": "", 761 | "text/plain": [ 762 | "
" 763 | ] 764 | }, 765 | "metadata": {} 766 | } 767 | ], 768 | "metadata": {} 769 | }, 770 | { 771 | "cell_type": "markdown", 772 | "source": [ 773 | "## Translation" 774 | ], 775 | "metadata": {} 776 | }, 777 | { 778 | "cell_type": "code", 779 | "execution_count": 19, 780 | "source": [ 781 | "def translate_sentence(model,sentence,srcField,targetField,srcTokenizer):\r\n", 782 | " model.eval()\r\n", 783 | " processed_sentence = srcField.process([srcTokenizer(sentence)]).to(device)\r\n", 784 | " trg = [\"بداية\"]\r\n", 785 | "\r\n", 786 | " for _ in range(60):\r\n", 787 | " trg_indecies = [targetField.vocab.stoi[word] for word in trg]\r\n", 788 | " trg_tensor = torch.LongTensor(trg_indecies).unsqueeze(1).to(device)\r\n", 789 | " outputs = model(processed_sentence,trg_tensor)\r\n", 790 | " \r\n", 791 | " if targetField.vocab.itos[outputs.argmax(2)[-1:].item()] == \"\":\r\n", 792 | " continue \r\n", 793 | " trg.append(targetField.vocab.itos[outputs.argmax(2)[-1:].item()])\r\n", 794 | " if targetField.vocab.itos[outputs.argmax(2)[-1:].item()] == \"نهاية\":\r\n", 795 | " break\r\n", 796 | " return \" \".join([word for word in trg if word != \"\"][1:-1])\r\n" 797 | ], 798 | "outputs": [], 799 | "metadata": {} 800 | }, 801 | { 802 | "cell_type": "code", 803 | "execution_count": 20, 804 | "source": [ 805 | "translate_sentence(model,\"I'm ready\" ,SRC,TRG,engTokenizer)" 806 | ], 807 | "outputs": [ 808 | { 809 | "output_type": "execute_result", 810 | "data": { 811 | "text/plain": [ 812 | "'أنا مستعد'" 813 | ] 814 | }, 815 | "metadata": {}, 816 | "execution_count": 20 817 | } 818 | ], 819 | "metadata": {} 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": 23, 824 | "source": [ 825 | "translate_sentence(model,\"i'm lucky\" ,SRC,TRG,engTokenizer)" 826 | ], 827 | "outputs": [ 828 | { 829 | "output_type": "execute_result", 830 | "data": { 831 | "text/plain": [ 832 | "'انا محظوظ'" 833 | ] 834 | }, 835 | "metadata": {}, 836 | "execution_count": 23 837 | } 838 | ], 839 | "metadata": {} 840 | }, 841 | { 842 | "cell_type": "code", 843 | "execution_count": 24, 844 | "source": [ 845 | "translate_sentence(model,\"I'm sad\" ,SRC,TRG,engTokenizer)" 846 | ], 847 | "outputs": [ 848 | { 849 | "output_type": "execute_result", 850 | "data": { 851 | "text/plain": [ 852 | "'أنا حزين'" 853 | ] 854 | }, 855 | "metadata": {}, 856 | "execution_count": 24 857 | } 858 | ], 859 | "metadata": {} 860 | } 861 | ], 862 | "metadata": { 863 | "interpreter": { 864 | "hash": "4bb0fe8ced3cf0716ac3718fe834e829af40e8ba0fef1c4cadecb390da29a017" 865 | }, 866 | "kernelspec": { 867 | "name": "python3", 868 | "display_name": "Python 3.7.11 64-bit ('torch': conda)" 869 | }, 870 | "language_info": { 871 | "codemirror_mode": { 872 | "name": "ipython", 873 | "version": 3 874 | }, 875 | "file_extension": ".py", 876 | "mimetype": "text/x-python", 877 | "name": "python", 878 | "nbconvert_exporter": "python", 879 | "pygments_lexer": "ipython3", 880 | "version": "3.7.11" 881 | } 882 | }, 883 | "nbformat": 4, 884 | "nbformat_minor": 2 885 | } --------------------------------------------------------------------------------