├── README.md └── eng-hebrew-transformer-pytorch.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # 🌍 English ↔ Hebrew Transformer (PyTorch) 2 | 3 | A transformer-based deep learning model that **translates English sentences into Hebrew** using the `opus100` bilingual dataset. The notebook builds everything from data preparation to training and inference. 4 | 5 | --- 6 | 7 | ## 📦 What This Project Does 8 | 9 | * Loads bilingual English-Hebrew data. 10 | * Uses tokenizer files to convert text into numbers. 11 | * Feeds data into a transformer model. 12 | * Trains the model to learn translation. 13 | * Uses the trained model to translate new English sentences into Hebrew. 14 | 15 | --- 16 | 17 | ## 🧠 How It Works (Simple Flow) 18 | 19 | ```mermaid 20 | flowchart TD 21 | A[Start] --> B[Load English-Hebrew Sentences] 22 | B --> C[Tokenize Sentences] 23 | C --> D[Prepare Inputs for Model] 24 | D --> E[Train Transformer Model] 25 | E --> F[Save Trained Model] 26 | F --> G[Translate New Sentences] 27 | ``` 28 | 29 | --- 30 | 31 | ## 🔍 What You Need 32 | 33 | * **Tokenizers** for English and Hebrew (pre-built). 34 | * The **opus100** dataset from HuggingFace. 35 | * A GPU is helpful for training but not necessary for inference. 36 | 37 | --- 38 | 39 | ## 💡 Key Concepts (Explained Simply) 40 | 41 | ### 1. Tokenization 42 | 43 | Breaking text into units (tokens) and assigning each one a number. This makes it readable for the model. 44 | 45 | ### 2. Encoder & Decoder 46 | 47 | * **Encoder** reads the English sentence. 48 | * **Decoder** learns to produce the Hebrew sentence, one word at a time. 49 | 50 | ### 3. Masks 51 | 52 | Help the model: 53 | 54 | * Ignore padding. 55 | * Prevent the decoder from "cheating" by seeing future words. 56 | 57 | --- 58 | 59 | ## 🔄 A Simple Example 60 | 61 | **English:** 62 | `"How are you?"` 63 | 64 | **Model translates it to Hebrew:** 65 | `"מה שלומך?"` 66 | 67 | **Another Example:** 68 | 69 | | English | Hebrew | 70 | | ------------------- | -------------- | 71 | | Good morning | בוקר טוב | 72 | | I love learning | אני אוהב ללמוד | 73 | | Where is the hotel? | איפה המלון? | 74 | 75 | --- 76 | 77 | ## ✅ What Happens Internally 78 | 79 | 1. Each sentence is wrapped with `[START]` and `[END]` tokens. 80 | 2. Padded to make all inputs the same length. 81 | 3. The model learns from many such examples. 82 | 4. Translations improve after every training cycle (epoch). 83 | 5. Saved checkpoints let you re-use the trained model. 84 | 85 | --- 86 | 87 | ## 🎯 Final Outcome 88 | 89 | You’ll have: 90 | 91 | * A trained model that can translate English into Hebrew. 92 | * The ability to reuse the model to translate any English sentence you input. 93 | 94 | --- 95 | 96 | ## 👤 Author 97 | 98 | For any questions or issues, please open an issue on GitHub: [@Siddharth Mishra](https://github.com/Sid3503) 99 | 100 | --- 101 | 102 |

103 | Made with ❤️ and lots of ☕ 104 |

105 | -------------------------------------------------------------------------------- /eng-hebrew-transformer-pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "kernelspec": { 4 | "name": "python3", 5 | "display_name": "Python 3", 6 | "language": "python" 7 | }, 8 | "language_info": { 9 | "name": "python", 10 | "version": "3.11.11", 11 | "mimetype": "text/x-python", 12 | "codemirror_mode": { 13 | "name": "ipython", 14 | "version": 3 15 | }, 16 | "pygments_lexer": "ipython3", 17 | "nbconvert_exporter": "python", 18 | "file_extension": ".py" 19 | }, 20 | "colab": { 21 | "provenance": [], 22 | "gpuType": "T4" 23 | }, 24 | "accelerator": "GPU", 25 | "kaggle": { 26 | "accelerator": "gpu", 27 | "dataSources": [], 28 | "dockerImageVersionId": 31011, 29 | "isInternetEnabled": true, 30 | "language": "python", 31 | "sourceType": "notebook", 32 | "isGpuEnabled": true 33 | } 34 | }, 35 | "nbformat_minor": 4, 36 | "nbformat": 4, 37 | "cells": [ 38 | { 39 | "cell_type": "code", 40 | "source": "import torch\nimport torch.nn as nn\nfrom torch.utils.data import Dataset\n\nclass BilingualDataset(Dataset):\n def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):\n super().__init__()\n self.seq_len = seq_len\n self.ds = ds\n self.tokenizer_src = tokenizer_src\n self.tokenizer_tgt = tokenizer_tgt\n self.src_lang = src_lang\n self.tgt_lang = tgt_lang\n\n self.sos_token = torch.tensor([tokenizer_tgt.token_to_id(\"[SOS]\")], dtype=torch.int64)\n self.eos_token = torch.tensor([tokenizer_tgt.token_to_id(\"[EOS]\")], dtype=torch.int64)\n self.pad_token = torch.tensor([tokenizer_tgt.token_to_id(\"[PAD]\")], dtype=torch.int64)\n\n def __len__(self):\n return len(self.ds)\n\n def __getitem__(self, index):\n src_target_pair = self.ds[index]\n src_text = src_target_pair['translation'][self.src_lang]\n tgt_text = src_target_pair['translation'][self.tgt_lang]\n\n enc_input_tokens = self.tokenizer_src.encode(src_text).ids\n dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids\n\n enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2\n dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1\n\n if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:\n raise ValueError('Sentence is Too Long!')\n\n # adding SOS and EOS and PADDING to the Source Text\n encoder_input = torch.cat(\n [\n self.sos_token,\n torch.tensor(enc_input_tokens, dtype=torch.int64),\n self.eos_token,\n torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)\n ]\n )\n\n # adding SOS and PADDING to the target Text(as decoder will predict EOS token, we do not give EOS token as input here)\n decoder_input = torch.cat(\n [\n self.sos_token,\n torch.tensor(dec_input_tokens, dtype=torch.int64),\n torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)\n ]\n )\n\n # adding EOS and PADDING to the label(which is what we expect as output from decoder)\n label = torch.cat(\n [\n torch.tensor(dec_input_tokens, dtype=torch.int64),\n self.eos_token,\n torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)\n ]\n )\n\n assert encoder_input.size(0) == self.seq_len\n assert decoder_input.size(0) == self.seq_len\n assert label.size(0) == self.seq_len\n\n return {\n \"encoder_input\": encoder_input, # (seq_len)\n \"decoder_input\": decoder_input, # (seq_len)\n # mask the padded tokens\n \"encoder_mask\": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)\n \"decoder_mask\": (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len)\n \"label\": label, # (seq_len)\n \"src_text\": src_text,\n \"tgt_text\": tgt_text,\n\n }\n\ndef causal_mask(size):\n mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)\n return mask == 0", 41 | "metadata": { 42 | "id": "g7UcnvaX5bEJ", 43 | "trusted": true, 44 | "execution": { 45 | "iopub.status.busy": "2025-05-02T14:53:26.576123Z", 46 | "iopub.execute_input": "2025-05-02T14:53:26.576771Z", 47 | "iopub.status.idle": "2025-05-02T14:53:26.587136Z", 48 | "shell.execute_reply.started": "2025-05-02T14:53:26.576751Z", 49 | "shell.execute_reply": "2025-05-02T14:53:26.586304Z" 50 | } 51 | }, 52 | "outputs": [], 53 | "execution_count": 30 54 | }, 55 | { 56 | "cell_type": "code", 57 | "source": "from pathlib import Path\n\ndef get_config():\n return {\n \"batch_size\": 32,\n \"num_epochs\": 25,\n \"lr\": 3e-4,\n \"seq_len\": 350,\n \"d_model\": 512,\n \"lang_src\": \"en\",\n \"lang_tgt\": \"he\",\n \"model_folder\": \"/kaggle/working/weights\", # Updated path\n \"model_basename\": \"tmodel_\",\n \"preload\": None,\n \"tokenizer_file\": \"/kaggle/working/tokenizer_{0}.json\", # Updated path\n \"experiment_name\": \"/kaggle/working/runs/tmodel\", # Updated path\n \"dataset_name\": \"opus100\",\n \"dataset_max_samples\": 10000,\n \"train_size\": 0.9,\n }\n\ndef get_weights_file_path(config, epoch: str):\n model_folder = config[\"model_folder\"]\n model_basename = config[\"model_basename\"]\n model_filename = f\"{model_basename}{epoch}.pt\"\n return str(Path(model_folder) / model_filename) # Proper path joining", 58 | "metadata": { 59 | "id": "WAH2Tm8m5nWe", 60 | "trusted": true, 61 | "execution": { 62 | "iopub.status.busy": "2025-05-02T14:53:28.286957Z", 63 | "iopub.execute_input": "2025-05-02T14:53:28.287613Z", 64 | "iopub.status.idle": "2025-05-02T14:53:28.292765Z", 65 | "shell.execute_reply.started": "2025-05-02T14:53:28.287581Z", 66 | "shell.execute_reply": "2025-05-02T14:53:28.291969Z" 67 | } 68 | }, 69 | "outputs": [], 70 | "execution_count": 31 71 | }, 72 | { 73 | "cell_type": "code", 74 | "source": "import torch\nimport torch.nn as nn\nimport math\n\n\nclass InputEmbeddings(nn.Module):\n def __init__(self, d_model: int, vocab_size: int):\n super().__init__()\n self.d_model = d_model\n self.vocab_size = vocab_size\n self.embedding = nn.Embedding(vocab_size, d_model)\n\n def forward(self, x):\n return self.embedding(x) * math.sqrt(self.d_model)\n\n\nclass PositionalEncoding(nn.Module):\n def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:\n super().__init__()\n self.d_model = d_model\n self.seq_len = seq_len\n self.dropout = nn.Dropout(dropout)\n\n # creating a matrix of shape (seq_len, d_model)\n pe = torch.zeros(seq_len, d_model)\n\n\n # creating a position index vector for every word in seq of size (seq_len, 1)\n position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)\n\n div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n\n # apply sin to even pos\n pe[:, 0::2] = torch.sin(position * div_term)\n\n # apply cos to odd pos\n pe[:, 1::2] = torch.cos(position * div_term)\n\n # include batch_size dimensions\n pe = pe.unsqueeze(0) # (1, seq_len, d_model)\n\n self.register_buffer('pe', pe)\n\n def forward(self, x):\n x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)\n return self.dropout(x)\n\n\nclass LayerNormalization(nn.Module):\n def __init__(self, eps: float = 1e-6) -> None:\n super().__init__()\n self.eps = eps\n self.alpha = nn.Parameter(torch.ones(1)) # multiplied\n self.bias = nn.Parameter(torch.zeros(1)) # added\n\n def forward(self, x):\n mean = x.mean(dim = -1, keepdim=True)\n std = x.std(dim = -1, keepdim=True)\n return self.alpha * (x-mean) / (std + self.eps) + self.bias\n\n\nclass FeedForwardBlock(nn.Module):\n def __init__(self, d_model: int, d_ff: int, dropout: float):\n super().__init__()\n self.linear_1 = nn.Linear(d_model, d_ff)\n self.dropout = nn.Dropout(dropout)\n self.linear_2 = nn.Linear(d_ff, d_model)\n\n def forward(self, x):\n # (Batch, seq_len, d_model) -> (Batch, seq_len, d_ff) -> (Batch, seq_len, d_model)\n return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))\n\n\nclass MultiHeadAttentionBlock(nn.Module):\n def __init__(self, d_model: int, h: int, dropout: float):\n super().__init__()\n self.d_model = d_model\n self.h = h\n assert d_model % h == 0, \"d_model is not divisible by h\"\n\n # dimension of each head\n self.d_k = d_model // h\n\n # defining query, key and val matrices\n self.w_q = nn.Linear(d_model, d_model) #Wq\n self.w_k = nn.Linear(d_model, d_model) #Wk\n self.w_v = nn.Linear(d_model, d_model) #Wv\n\n # defining Wo\n self.w_o = nn.Linear(d_model, d_model)\n\n self.dropout = nn.Dropout(dropout)\n\n @staticmethod\n def attention(query, key, value, mask, dropout: nn.Dropout):\n d_k = query.shape[-1]\n\n # (Batch, num_heads, seq_len, head_dim) x (Batch, num_heads, head_dim, seq_len) --> (Batch, num_heads, seq_len, seq_len)\n attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)\n\n if mask is not None:\n attention_scores.masked_fill_(mask == 0, -1e9)\n\n attention_scores = attention_scores.softmax(dim = -1)\n\n if dropout is not None:\n attention_scores = dropout(attention_scores)\n\n return (attention_scores @ value), attention_scores\n\n\n\n\n def forward(self, q, k, v, mask):\n query = self.w_q(q) # (Batch, seq_len, d_model) x (Batch, d_model, d_model) --> (Batch, seq_len, d_model)\n key = self.w_k(k) # (Batch, seq_len, d_model) x (Batch, d_model, d_model) --> (Batch, seq_len, d_model)\n value = self.w_v(v) # (Batch, seq_len, d_model) x (Batch, d_model, d_model) --> (Batch, seq_len, d_model)\n\n # splitting (Batch, seq_len, d_model) --> (Batch, seq_len, num_heads, head_dim) --> (Batch, num_heads, seq_len, head_dim)\n query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)\n key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)\n value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)\n\n x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)\n\n # (Batch, num_heads, seq_len, head_dim) --> (Batch, seq_len, num_heads, head_dim) --> (Batch, seq_len, d_model)\n x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)\n\n return self.w_o(x) # (Batch, seq_len, d_model)\n\n\nclass ResidualConnection(nn.Module):\n def __init__(self, dropout: float):\n super().__init__()\n self.dropout = nn.Dropout(dropout)\n self.norm = LayerNormalization()\n\n def forward(self, x, sublayer):\n return x + self.dropout(sublayer(self.norm(x)))\n\n\nclass EncoderBlock(nn.Module):\n def __init__(self, self_attention_block: MultiHeadAttentionBlock, feeed_forward_block: FeedForwardBlock, dropout: float):\n super().__init__()\n self.self_attention_block = self_attention_block\n self.feed_forward_block = feeed_forward_block\n self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])\n\n\n def forward(self, x, src_mask):\n x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))\n x = self.residual_connections[1](x, self.feed_forward_block)\n\n return x\n\n\nclass Encoder(nn.Module):\n def __init__(self, layers: nn.ModuleList):\n super().__init__()\n self.layers = layers\n self.norm = LayerNormalization()\n\n def forward(self, x, mask):\n for layer in self.layers:\n x = layer(x, mask)\n\n return self.norm(x)\n\n\nclass DecoderBlock(nn.Module):\n def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float):\n super().__init__()\n self.self_attention_block = self_attention_block\n self.cross_attention_block = cross_attention_block\n self.feed_forward_block = feed_forward_block\n\n self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])\n\n def forward(self, x, encoder_output, src_mask, tgt_mask):\n x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))\n x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))\n x = self.residual_connections[2](x, self.feed_forward_block)\n return x\n\n\nclass Decoder(nn.Module):\n def __init__(self, layers: nn.ModuleList):\n super().__init__()\n self.layers = layers\n self.norm = LayerNormalization()\n\n def forward(self, x, encoder_output, src_mask, tgt_mask):\n for layer in self.layers:\n x = layer(x, encoder_output, src_mask, tgt_mask)\n\n return self.norm(x)\n\n\nclass ProjectionLayer(nn.Module):\n def __init__(self, d_model: int, vocab_size: int):\n super().__init__()\n self.proj = nn.Linear(d_model, vocab_size)\n\n def forward(self, x):\n # (Batch, seq_len, d_model) --> (Batch, seq_len, vocab_size)\n return torch.log_softmax(self.proj(x), dim = -1)\n\n\nclass Transformer(nn.Module):\n def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer):\n super().__init__()\n self.encoder = encoder\n self.decoder = decoder\n self.src_embed = src_embed\n self.tgt_embed = tgt_embed\n self.src_pos = src_pos\n self.tgt_pos = tgt_pos\n self.projection_layer = projection_layer\n\n\n def encode(self, src, src_mask):\n src = self.src_embed(src)\n src = self.src_pos(src)\n return self.encoder(src, src_mask)\n\n def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):\n tgt = self.tgt_embed(tgt)\n tgt = self.tgt_pos(tgt)\n return self.decoder(tgt, encoder_output, src_mask, tgt_mask)\n\n def project(self, x):\n return self.projection_layer(x)\n\n\ndef build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> Transformer:\n # create the embedding layers\n src_embed = InputEmbeddings(d_model, src_vocab_size)\n tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)\n\n # create positional encoding layers\n src_pos = PositionalEncoding(d_model, src_seq_len, dropout)\n tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)\n\n # create encoder blocks\n encoder_blocks = []\n for _ in range(N):\n encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)\n feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)\n encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)\n encoder_blocks.append(encoder_block)\n\n # create decoder blocks\n decoder_blocks = []\n for _ in range(N):\n decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)\n decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)\n feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)\n decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)\n decoder_blocks.append(decoder_block)\n\n # create encoder and decoder\n encoder = Encoder(nn.ModuleList(encoder_blocks))\n decoder = Decoder(nn.ModuleList(decoder_blocks))\n\n # create projection layer\n projection_layer = ProjectionLayer(d_model, tgt_vocab_size)\n\n # create the transformer\n transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)\n\n\n\n # initializing PARAMETERS\n for p in transformer.parameters():\n if p.dim() > 1:\n nn.init.xavier_uniform_(p)\n\n return transformer", 75 | "metadata": { 76 | "id": "tKr_mS7y5tgl", 77 | "trusted": true, 78 | "execution": { 79 | "iopub.status.busy": "2025-05-02T14:53:30.205270Z", 80 | "iopub.execute_input": "2025-05-02T14:53:30.205523Z", 81 | "iopub.status.idle": "2025-05-02T14:53:30.234929Z", 82 | "shell.execute_reply.started": "2025-05-02T14:53:30.205505Z", 83 | "shell.execute_reply": "2025-05-02T14:53:30.234087Z" 84 | } 85 | }, 86 | "outputs": [], 87 | "execution_count": 32 88 | }, 89 | { 90 | "cell_type": "code", 91 | "source": "!pip install datasets", 92 | "metadata": { 93 | "colab": { 94 | "base_uri": "https://localhost:8080/" 95 | }, 96 | "id": "prn2EzAG7SPT", 97 | "outputId": "5ba68a0a-56ab-4c3a-bb47-ff4522c96a5b", 98 | "trusted": true, 99 | "execution": { 100 | "iopub.status.busy": "2025-05-02T12:57:23.901417Z", 101 | "iopub.execute_input": "2025-05-02T12:57:23.902129Z", 102 | "iopub.status.idle": "2025-05-02T12:57:26.953553Z", 103 | "shell.execute_reply.started": "2025-05-02T12:57:23.902106Z", 104 | "shell.execute_reply": "2025-05-02T12:57:26.952577Z" 105 | } 106 | }, 107 | "outputs": [ 108 | { 109 | "name": "stdout", 110 | "text": "Requirement already satisfied: datasets in /usr/local/lib/python3.11/dist-packages (3.5.0)\nRequirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from datasets) (3.18.0)\nRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from datasets) (1.26.4)\nRequirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (19.0.1)\nRequirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (0.3.8)\nRequirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from datasets) (2.2.3)\nRequirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.11/dist-packages (from datasets) (2.32.3)\nRequirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.11/dist-packages (from datasets) (4.67.1)\nRequirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from datasets) (3.5.0)\nRequirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.11/dist-packages (from datasets) (0.70.16)\nRequirement already satisfied: fsspec<=2024.12.0,>=2023.1.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets) (2024.12.0)\nRequirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from datasets) (3.11.16)\nRequirement already satisfied: huggingface-hub>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (0.30.2)\nRequirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from datasets) (24.2)\nRequirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from datasets) (6.0.2)\nRequirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (2.6.1)\nRequirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.3.2)\nRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (25.3.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.5.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (6.2.0)\nRequirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (0.3.1)\nRequirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.19.0)\nRequirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.24.0->datasets) (4.13.1)\nRequirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17->datasets) (1.3.8)\nRequirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17->datasets) (1.2.4)\nRequirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17->datasets) (0.1.1)\nRequirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17->datasets) (2025.1.0)\nRequirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17->datasets) (2022.1.0)\nRequirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17->datasets) (2.4.1)\nRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (3.4.1)\nRequirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (3.10)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (2.3.0)\nRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (2025.1.31)\nRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2.9.0.post0)\nRequirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2025.2)\nRequirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2025.2)\nRequirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\nRequirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy>=1.17->datasets) (2024.2.0)\nRequirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy>=1.17->datasets) (2022.1.0)\nRequirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy>=1.17->datasets) (1.2.0)\nRequirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy>=1.17->datasets) (2024.2.0)\nRequirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy>=1.17->datasets) (2024.2.0)\n", 111 | "output_type": "stream" 112 | } 113 | ], 114 | "execution_count": 10 115 | }, 116 | { 117 | "cell_type": "code", 118 | "source": "import torch\nimport torch.nn as nn\nfrom torch.utils.data import Dataset, DataLoader, random_split\nimport tqdm\nimport warnings\n\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom datasets import load_dataset\nfrom tokenizers import Tokenizer\nfrom tokenizers.models import WordLevel\nfrom tokenizers.trainers import WordLevelTrainer\nfrom tokenizers.pre_tokenizers import Whitespace\nfrom tokenizers.models import BPE\nfrom tokenizers.trainers import BpeTrainer\nfrom pathlib import Path\n\ndef get_all_sentences(ds, lang):\n for item in ds:\n yield item['translation'][lang]\n\ndef get_or_build_tokenizer(config, ds, lang):\n tokenizer_path = Path(config['tokenizer_file'].format(lang))\n if not tokenizer_path.exists():\n # Use BPE for Hebrew/English\n tokenizer = Tokenizer(BPE(unk_token=\"[UNK]\"))\n tokenizer.pre_tokenizer = Whitespace()\n trainer = BpeTrainer(\n special_tokens=[\"[UNK]\", \"[PAD]\", \"[SOS]\", \"[EOS]\"],\n min_frequency=2\n )\n tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)\n tokenizer.save(str(tokenizer_path))\n else:\n tokenizer = Tokenizer.from_file(str(tokenizer_path))\n return tokenizer\n\n\ndef greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n\n # Precompute the encoder output and reuse it for every step\n encoder_output = model.encode(source, source_mask)\n # Initialize the decoder input with the sos token\n decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n while True:\n if decoder_input.size(1) == max_len:\n break\n\n # build mask for target\n decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)\n\n # calculate output\n out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)\n\n # get next token\n prob = model.project(out[:, -1])\n _, next_word = torch.max(prob, dim=1)\n decoder_input = torch.cat(\n [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1\n )\n\n if next_word == eos_idx:\n break\n\n return decoder_input.squeeze(0)\n\n\n\ndef run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):\n model.eval()\n count = 0\n\n console_width = 80\n\n with torch.no_grad():\n for batch in validation_ds:\n count += 1\n encoder_input = batch['encoder_input'].to(device)\n encoder_mask = batch['encoder_mask'].to(device)\n\n assert encoder_input.size(0) == 1, \"Batch Size must be 1 for validation\"\n\n model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n\n source_text = batch['src_text'][0]\n target_text = batch['tgt_text'][0]\n\n model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())\n\n print_msg('-'*console_width)\n print_msg(f'SOURCE: {source_text}')\n print_msg(f'TARGET: {target_text}')\n print_msg(f'PREDICTED: {model_out_text}')\n\n if count == num_examples:\n break\n\ndef get_ds(config):\n # Load the dataset (OPUS100 or JW300)\n try:\n ds_raw = load_dataset(config[\"dataset_name\"], f'{config[\"lang_src\"]}-{config[\"lang_tgt\"]}', split=f'train[:{config[\"dataset_max_samples\"]}]')\n except ValueError as e:\n print(f\"Error loading dataset: {e}\")\n print(f\"Available datasets: {get_config(config['dataset_name'])}\")\n raise\n\n # Build tokenizers\n tokenizer_src = get_or_build_tokenizer(config, ds_raw, config[\"lang_src\"])\n tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config[\"lang_tgt\"])\n\n # Split into train/val\n train_size = int(config[\"train_size\"] * len(ds_raw))\n val_size = len(ds_raw) - train_size\n train_ds_raw, val_ds_raw = random_split(ds_raw, [train_size, val_size])\n\n # Create datasets\n train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config[\"lang_src\"], config[\"lang_tgt\"], config[\"seq_len\"])\n val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config[\"lang_src\"], config[\"lang_tgt\"], config[\"seq_len\"])\n\n # Check max sentence lengths\n max_len_src = max(len(tokenizer_src.encode(item[\"translation\"][config['lang_src']]).ids) for item in ds_raw)\n max_len_tgt = max(len(tokenizer_tgt.encode(item[\"translation\"][config['lang_tgt']]).ids) for item in ds_raw)\n print(f'Max source length: {max_len_src}, Max target length: {max_len_tgt}')\n\n # Create dataloaders\n train_dataloader = DataLoader(train_ds, batch_size=config[\"batch_size\"], shuffle=True)\n val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)\n\n return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt\n\n\ndef get_model(config, vocab_src_len, vocab_tgt_len):\n model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])\n return model", 119 | "metadata": { 120 | "id": "Iuj4Oc645zWm", 121 | "trusted": true, 122 | "execution": { 123 | "iopub.status.busy": "2025-05-02T14:53:34.986679Z", 124 | "iopub.execute_input": "2025-05-02T14:53:34.987380Z", 125 | "iopub.status.idle": "2025-05-02T14:53:35.012283Z", 126 | "shell.execute_reply.started": "2025-05-02T14:53:34.987349Z", 127 | "shell.execute_reply": "2025-05-02T14:53:35.011449Z" 128 | } 129 | }, 130 | "outputs": [], 131 | "execution_count": 33 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "source": "## Training", 136 | "metadata": { 137 | "id": "X0EJGo56_k_d" 138 | } 139 | }, 140 | { 141 | "cell_type": "code", 142 | "source": "def train_model(config):\n device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n print(f'Using Device: {device}')\n\n # Create model directory\n Path(config['model_folder']).mkdir(parents=True, exist_ok=True)\n\n # Load dataset and tokenizers\n try:\n train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n except ValueError as e:\n print(f\"Failed to load dataset: {e}\")\n print(\"Trying JW300 as fallback...\")\n config[\"dataset_name\"] = \"jw300\" # Fallback to JW300\n train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n\n model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n writer = SummaryWriter(config['experiment_name'])\n\n optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)\n loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_tgt.token_to_id('[PAD]'), label_smoothing=0.1) # Note: Use tgt tokenizer for PAD\n\n # Resume training if preload=True\n initial_epoch = 0\n global_step = 0\n if config['preload']:\n model_filename = get_weights_file_path(config, config['preload'])\n print(f\"Preloading model: {model_filename}\")\n state = torch.load(model_filename, map_location=device)\n initial_epoch = state['epoch'] + 1\n optimizer.load_state_dict(state['optimizer_state_dict'])\n global_step = state['global_step']\n model.load_state_dict(state['model_state_dict'])\n\n # Training loop\n for epoch in range(initial_epoch, config['num_epochs']):\n model.train()\n batch_iterator = tqdm.tqdm(train_dataloader, desc=f\"Epoch {epoch:02d}\")\n\n for batch in batch_iterator:\n encoder_input = batch['encoder_input'].to(device) # (B, seq_len)\n decoder_input = batch['decoder_input'].to(device) # (B, seq_len)\n encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)\n decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)\n\n # Forward pass\n encoder_output = model.encode(encoder_input, encoder_mask)\n decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)\n proj_output = model.project(decoder_output) # (B, seq_len, tgt_vocab_size)\n\n # Compute loss\n label = batch['label'].to(device) # (B, seq_len)\n loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))\n\n # Backpropagation\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n\n # Logging\n batch_iterator.set_postfix({\"loss\": f\"{loss.item():6.3f}\"})\n writer.add_scalar('train_loss', loss.item(), global_step)\n global_step += 1\n\n # Validation after each epoch\n run_validation(\n model, val_dataloader, tokenizer_src, tokenizer_tgt,\n config['seq_len'], device,\n lambda msg: batch_iterator.write(msg), global_step, writer\n )\n\n # Save checkpoint\n model_filename = get_weights_file_path(config, f\"{epoch:02d}\")\n torch.save({\n 'epoch': epoch,\n 'model_state_dict': model.state_dict(),\n 'optimizer_state_dict': optimizer.state_dict(),\n 'global_step': global_step,\n }, model_filename)", 143 | "metadata": { 144 | "id": "Xyeoa7Cc53I9", 145 | "trusted": true, 146 | "execution": { 147 | "iopub.status.busy": "2025-05-02T14:53:38.486842Z", 148 | "iopub.execute_input": "2025-05-02T14:53:38.487243Z", 149 | "iopub.status.idle": "2025-05-02T14:53:38.498170Z", 150 | "shell.execute_reply.started": "2025-05-02T14:53:38.487221Z", 151 | "shell.execute_reply": "2025-05-02T14:53:38.497399Z" 152 | } 153 | }, 154 | "outputs": [], 155 | "execution_count": 34 156 | }, 157 | { 158 | "cell_type": "code", 159 | "source": "config = get_config()\ntrain_model(config)", 160 | "metadata": { 161 | "colab": { 162 | "base_uri": "https://localhost:8080/", 163 | "height": 1000, 164 | "referenced_widgets": [ 165 | "e4014b4765df4c62a1862056fece2702", 166 | "e3375860219b4085984dd78640c0901c", 167 | "e3256d366d96418da443c101a3c46b48", 168 | "8f9a148d1e3543c08161ac29d8e196f4", 169 | "bbba318124aa47168c01ab6adcb6a823", 170 | "feb3e2c1ceb24944b8c78ca02db08e60", 171 | "e93a895118474610ad13c69845a8a2bd", 172 | "6014bbc0c4de45ec91e077a8c9d5a3c0", 173 | "f3607f3b154a42199a506b5a2870b6ae", 174 | "aa4c3e880ccc41cfa8fce987fe7995ac", 175 | "187ab74e550a4649b7f87c1162e17e28", 176 | "d810cf72ab414d08a616a66e3bbcdb55", 177 | "8a5364bb9d424b688a134b9821880b1c", 178 | "7e8979e9a7e54a76bc371d4cf7a7f3e1", 179 | "ddd565646c314a41a1926cfb1624bb1f", 180 | "7f54f242b1764963823be7f35a9fdce7", 181 | "bbe14cc40f7d416a9a5a4b1117a28641", 182 | "bf606add805049c39a547938721f4f19", 183 | "5ed3b22c11434a66b35d88da2f1c7c76", 184 | "eda79d51fab844249a3a4d7c0a688c3c", 185 | "ad59a86fa2b04e4fa7cd698d68037233", 186 | "0528a44b80bf46d18a9a4218b90917e3", 187 | "b7a2928afa8e473da22a8c2df4851451", 188 | "1889cc4d0ef44da1b3b1cebaca12c828", 189 | "902c609b94dd493ab60065ab8bec2f44", 190 | "4d3168886c2242eea99cf8e40d973db9", 191 | "75d10e72f0bd45a58fe4fe030fd1f831", 192 | "0335e7e630d6404bba099d7f1d36ac25", 193 | "37e2b96b61ce44608b2c3ffa1a575147", 194 | "4514f7e9734e46e88fcf6df4bb48ef2d", 195 | "36502cb912534fb6a4dbb858dfc29408", 196 | "e69a13da112a4cd4b04b758001ee9519", 197 | "6c0f8957814649bda80ff23711227795", 198 | "01040e8e9bdc45beb663100424eb4d62", 199 | "b86d67d72fd84c4d811f180a98753a17", 200 | "0dc6816095f648b1a1693fb3330b7a27", 201 | "18d38760c0b14a32bc91dcab976db588", 202 | "46ce8829f4bd45c7a766caa094b3c3c8", 203 | "528e6f3ebfe44547a2d947e66cfcf4b7", 204 | "4635ab530ea147019f649ba6f4c49e4a", 205 | "f5982a0bb4b04ed7a2958f0d6458e39f", 206 | "505a83d67ff6499285544fa7c36a06f0", 207 | "fdef06a95c6940db9d1e8bca51b0212f", 208 | "d30b4d0122cc4c469ad7186096221667", 209 | "843d19b6e0f74604b4f90f16b46e5afa", 210 | "d22d261660df46c6bf55c1ef5664754d", 211 | "314f82483fbb4fd186cec77a2afe596d", 212 | "d09c93e95d0f4deb966b1fcbfac0afae", 213 | "614bc215e2654c65bd4f38ca716472b8", 214 | "e3414886074545cf80af844df92f972d", 215 | "b51b1846ad0d48d198e993e7fe4d1f7d", 216 | "6887d1cd84f64419ac65dd521a8cf70a", 217 | "8ef3b0b2928b4055b4e13a98061f386d", 218 | "c04da2586e5a48b3a8082d560fe0169f", 219 | "7b1c68f2e46d47d287577263cc942de3", 220 | "ca82472a3d564259b25730077aee9d98", 221 | "a6c6d7d31d464005be04c3fe8835594d", 222 | "bc105a2d3c9c4c4ea8c58933cd59cce7", 223 | "41c8d47d376545bf9192801e40622173", 224 | "1a517388f01745b89c57238a12ec267d", 225 | "2eb2edc38a6f40a2a442c576374d8b86", 226 | "35b8dc696c254ab2950e352d21681e10", 227 | "b83ff38304be4beeac402ba8ef66768c", 228 | "8d50e7cc7bf7472e8a94841c3ca918e7", 229 | "fb756cf08b9c496ea19937efc83c8971", 230 | "9eaab1364e1a476d800b16343251fd6c", 231 | "a8eb547d732e4307a532209bf66226ad", 232 | "d8017044d85c421b9fc1c6ea0f311b21", 233 | "4cf8a64b95df454fa8862222ebb49584", 234 | "41554fe4df1c440890a3247ed0195392", 235 | "516a4c5a80554d578927c7c00ecb95ed", 236 | "18fb657a8d4f46df90987162aa48245b", 237 | "86511e63634b4f539108c4550b17ec0f", 238 | "b53bd3dfea934e548eaaa67dde8a11bd", 239 | "253b57a6f1c64972a3c4d29c7fa82452", 240 | "0af8a556045b44cb867464d01f3a3f7c", 241 | "93813e840a5044c0aa19bd11bdef27f5" 242 | ] 243 | }, 244 | "id": "8WjMqo766T0y", 245 | "outputId": "8f589ea3-984b-4a9c-bd5a-7d301e7537ba", 246 | "trusted": true, 247 | "execution": { 248 | "iopub.status.busy": "2025-05-02T12:57:46.443764Z", 249 | "iopub.execute_input": "2025-05-02T12:57:46.444342Z", 250 | "iopub.status.idle": "2025-05-02T14:36:56.297383Z", 251 | "shell.execute_reply.started": "2025-05-02T12:57:46.444321Z", 252 | "shell.execute_reply": "2025-05-02T14:36:56.296536Z" 253 | } 254 | }, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "text": "Using Device: cuda\n", 259 | "output_type": "stream" 260 | }, 261 | { 262 | "output_type": "display_data", 263 | "data": { 264 | "text/plain": "README.md: 0%| | 0.00/65.4k [00:00\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;31m# load pretrained weigts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mmodel_filename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_weights_file_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_filename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'model_state_dict'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 640 | "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[1;32m 1317\u001b[0m \u001b[0mpickle_load_args\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"encoding\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"utf-8\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1318\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1319\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1320\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_is_zipfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1321\u001b[0m \u001b[0;31m# The zipfile reader is going to advance the current file position.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 641 | "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m 657\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 658\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_is_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 659\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 660\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 661\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m\"w\"\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 642 | "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/serialization.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m 638\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_opener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 640\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 641\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 642\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 643 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/kaggle/working/weights/tmodel_.pt'" 644 | ], 645 | "ename": "FileNotFoundError", 646 | "evalue": "[Errno 2] No such file or directory: '/kaggle/working/weights/tmodel_.pt'", 647 | "output_type": "error" 648 | } 649 | ], 650 | "execution_count": 35 651 | }, 652 | { 653 | "cell_type": "code", 654 | "source": "def test_model(config, epoch_to_load=\"09\"):\n device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n print(f'Using Device: {device}')\n\n # Create directories if they don't exist\n Path(config['model_folder']).mkdir(parents=True, exist_ok=True)\n Path(config['experiment_name']).mkdir(parents=True, exist_ok=True)\n\n # Load dataset and tokenizers\n try:\n _, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n except ValueError as e:\n print(f\"Failed to load dataset: {e}\")\n print(\"Trying JW300 as fallback...\")\n config[\"dataset_name\"] = \"jw300\"\n _, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n\n # Initialize model\n model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n\n # Load pretrained weights\n model_filename = get_weights_file_path(config, epoch_to_load)\n \n if not Path(model_filename).exists():\n available_files = list(Path(config['model_folder']).glob(\"*.pt\"))\n raise FileNotFoundError(\n f\"Model file {model_filename} not found.\\n\"\n f\"Available files: {[f.name for f in available_files]}\"\n )\n\n state = torch.load(model_filename, map_location=device)\n model.load_state_dict(state['model_state_dict'])\n model.eval()\n\n # Run validation\n def print_msg(msg):\n print(msg)\n\n print(\"\\n\" + \"=\"*50)\n print(f\"Testing model from epoch {epoch_to_load}\")\n print(\"=\"*50 + \"\\n\")\n \n run_validation(\n model, val_dataloader, tokenizer_src, tokenizer_tgt,\n config['seq_len'], device, print_msg, 0, None, num_examples=10\n )\n\n# Usage\nif __name__ == \"__main__\":\n config = get_config()\n \n # Example: test epoch 09\n test_model(config, epoch_to_load=\"24\")", 655 | "metadata": { 656 | "trusted": true, 657 | "execution": { 658 | "iopub.status.busy": "2025-05-02T14:54:39.678003Z", 659 | "iopub.execute_input": "2025-05-02T14:54:39.678660Z", 660 | "iopub.status.idle": "2025-05-02T14:54:45.409184Z", 661 | "shell.execute_reply.started": "2025-05-02T14:54:39.678637Z", 662 | "shell.execute_reply": "2025-05-02T14:54:45.408547Z" 663 | } 664 | }, 665 | "outputs": [ 666 | { 667 | "name": "stdout", 668 | "text": "Using Device: cuda\nMax source length: 81, Max target length: 114\n", 669 | "output_type": "stream" 670 | }, 671 | { 672 | "name": "stderr", 673 | "text": "/tmp/ipykernel_31/1329649009.py:31: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n state = torch.load(model_filename, map_location=device)\n", 674 | "output_type": "stream" 675 | }, 676 | { 677 | "name": "stdout", 678 | "text": "\n==================================================\nTesting model from epoch 24\n==================================================\n\n--------------------------------------------------------------------------------\nSOURCE: What's going on between you and Claudia?\nTARGET: \u05de\u05d4 \u05e7\u05d5\u05e8\u05d4 \u05d1\u05d9\u05e0\u05da \u05dc\u05d1\u05d9\u05df \u05e7\u05dc\u05d0\u05d5\u05d3\u05d9\u05d4?\nPREDICTED: \u05de\u05d4 \u05e7\u05d5\u05e8\u05d4 \u05d1\u05d9\u05e0\u05da \u05dc\u05d1\u05d9\u05df \u05e7\u05dc\u05d0\u05d5\u05d3\u05d9\u05d4 ?\n--------------------------------------------------------------------------------\nSOURCE: I mean, somebody is sending those texts-- call him Moriarty, call him whatever you want.\nTARGET: \u05d0\u05e0\u05d9 \u05de\u05ea\u05db\u05d5\u05d5\u05df, \u05de\u05d9\u05e9\u05d4\u05d5 \u05e9\u05d5\u05dc\u05d7 \u05d4\u05d8\u05e7\u05e1\u05d8\u05d9\u05dd \u05d4\u05d0\u05dc\u05d4 - \u05e7\u05d5\u05e8\u05d0 \u05dc\u05d5 \u05de\u05d5\u05e8\u05d9\u05d0\u05e8\u05d8\u05d9, \u05e7\u05d5\u05e8\u05d0 \u05dc\u05d5 \u05de\u05d4 \u05e9\u05d0\u05ea\u05d4 \u05e8\u05d5\u05e6\u05d4.\nPREDICTED: \u05d6\u05d0\u05ea \u05d0\u05d5\u05de\u05e8\u05ea , \u05d0\u05ea\u05dd \u05d3\u05d5 \u05de\u05d6\u05dc ... \u05d0\u05ea \u05e8\u05d5\u05e6\u05d4 , \" \u05dc\u05d4\u05e8\u05d0\u05d5\u05ea \u05dc\u05d5 \u05de\u05d4 \u05e9\u05d0\u05ea\u05d4 \u05e8\u05d5\u05e6\u05d4 \u05e9\u05ea \u05e0\u05e1\u05d4 \u05dc\u05e4 \u05e7\u05d9 \u05e2\u05ea \u05d1\u05e9\u05d1\u05d9\u05dc\u05da .\n--------------------------------------------------------------------------------\nSOURCE: Thats Aprilss...\nTARGET: \u05d6\u05d5 \u05d0\u05e4\u05e8\u05d9\u05dc...\nPREDICTED: \u05d6\u05d5 \u05d0\u05e4\u05e8\u05d9\u05dc ...\n--------------------------------------------------------------------------------\nSOURCE: I didn't ask for it.\nTARGET: \u05d0\u05e0\u05d9 \u05dc\u05d0 \u05d1\u05d9\u05e7\u05e9\u05ea\u05d9 \u05d0\u05ea \u05d6\u05d4.\nPREDICTED: \u05d0\u05e0\u05d9 \u05dc\u05d0 \u05d1\u05d9\u05e7\u05e9\u05ea\u05d9 \u05d0\u05ea \u05d6\u05d4 .\n--------------------------------------------------------------------------------\nSOURCE: I don't need a tutorial.\nTARGET: \u05d0\u05e0\u05d9 \u05dc\u05d0 \u05e6\u05e8\u05d9\u05db\u05d4 \u05e9\u05d9\u05e2\u05d5\u05e8 \u05e4\u05e8\u05d8\u05d9.\nPREDICTED: \u05d0\u05e0\u05d9 \u05dc\u05d0 \u05e6\u05e8\u05d9\u05db\u05d4 \u05e9\u05d9\u05e2\u05d5\u05e8 \u05e4\u05e8\u05d8\u05d9 .\n--------------------------------------------------------------------------------\nSOURCE: - Jesus, Dad, stop!\nTARGET: - \u05d9\u05e9\u05d5, \u05d0\u05d1\u05d0, \u05e2\u05e6\u05d5\u05e8!\nPREDICTED: - \u05d9\u05e9\u05d5 , \u05d0\u05d1\u05d0 , \u05e2\u05e6\u05d5\u05e8 !\n--------------------------------------------------------------------------------\nSOURCE: Without her cooperation, I doubt if it'll go forward.\nTARGET: \u05d1\u05dc\u05d9 \u05e9\u05d9\u05ea\u05d5\u05e3 \u05d4\u05e4\u05e2\u05d5\u05dc\u05d4 \u05e9\u05dc\u05d4 \u05e1\u05e4\u05e7 \u05d0\u05dd \u05d6\u05d4 \u05d9\u05ea\u05e7\u05d3\u05dd.\nPREDICTED: \u05d1\u05dc\u05d9 \u05e9\u05d9\u05ea\u05d5\u05e3 \u05d4\u05e4 \u05e2\u05d5\u05dc\u05d4 \u05e9\u05dc\u05d4 \u05e1\u05e4\u05e7 \u05d0\u05dd \u05d6\u05d4 \u05d9\u05ea\u05e7\u05d3\u05dd .\n--------------------------------------------------------------------------------\nSOURCE: Mom... taking your car out is gonna seem like the least of it... because whatever you're thinking right now... it's worse.\nTARGET: \u05d0\u05d9\u05de\u05d0... \u05dc\u05e7\u05d7\u05ea \u05d0\u05ea \u05d4\u05de\u05db\u05d5\u05e0\u05d9\u05ea \u05e9\u05dc\u05da \u05d9\u05d9\u05e8\u05d0\u05d4 \u05dc\u05da \u05d7\u05e1\u05e8 \u05de\u05e9\u05de\u05e2\u05d5\u05ea... \u05db\u05d9 \u05dc\u05d0 \u05de\u05e9\u05e0\u05d4 \u05de\u05d4 \u05d0\u05ea \u05d7\u05d5\u05e9\u05d1\u05ea \u05db\u05e2\u05ea... \u05d6\u05d4 \u05d9\u05d5\u05ea\u05e8 \u05d2\u05e8\u05d5\u05e2.\nPREDICTED: \u05d0\u05d9\u05de\u05d0 ... \u05dc\u05e7\u05d7\u05ea \u05d0\u05ea \u05d4\u05de\u05db\u05d5\u05e0\u05d9\u05ea \u05e9\u05dc\u05da \u05d9\u05d9\u05e8\u05d0\u05d4 \u05dc\u05da \u05d7\u05e1\u05e8 \u05de\u05e9\u05de\u05e2\u05d5\u05ea ... \u05db\u05d9 \u05dc\u05d0 \u05de\u05e9\u05e0\u05d4 \u05de\u05d4 \u05d0\u05ea \u05d7\u05d5\u05e9\u05d1\u05ea \u05db\u05e2\u05ea ... \u05d6\u05d4 \u05d9\u05d5\u05ea\u05e8 \u05d2\u05e8\u05d5\u05e2 .\n--------------------------------------------------------------------------------\nSOURCE: You know anything about the hunter's curse?\nTARGET: \u05d0\u05ea\u05d4 \u05d9\u05d5\u05d3\u05e2 \u05de\u05e9\u05d4\u05d5 \u05e2\u05dc \u05e7\u05dc\u05dc\u05ea \u05d4\u05e6\u05d9\u05d9\u05d3?\nPREDICTED: \u05d0\u05ea\u05d4 \u05d9\u05d5\u05d3\u05e2 \u05de\u05e9\u05d4\u05d5 \u05e2\u05dc \u05e7\u05dc \u05dc\u05ea \u05d4\u05e6\u05d9\u05d9\u05d3 ?\n--------------------------------------------------------------------------------\nSOURCE: Have you looked?\nTARGET: \u05d7\u05d9\u05e4\u05e9\u05ea?\nPREDICTED: \u05d7\u05d9\u05e4 \u05e9\u05ea ?\n", 679 | "output_type": "stream" 680 | } 681 | ], 682 | "execution_count": 37 683 | }, 684 | { 685 | "cell_type": "code", 686 | "source": "print(\"hi\")", 687 | "metadata": { 688 | "trusted": true, 689 | "execution": { 690 | "iopub.status.busy": "2025-05-02T14:57:08.322595Z", 691 | "iopub.execute_input": "2025-05-02T14:57:08.323097Z", 692 | "iopub.status.idle": "2025-05-02T14:57:08.327047Z", 693 | "shell.execute_reply.started": "2025-05-02T14:57:08.323074Z", 694 | "shell.execute_reply": "2025-05-02T14:57:08.326282Z" 695 | } 696 | }, 697 | "outputs": [ 698 | { 699 | "name": "stdout", 700 | "text": "hi\n", 701 | "output_type": "stream" 702 | } 703 | ], 704 | "execution_count": 39 705 | }, 706 | { 707 | "cell_type": "code", 708 | "source": "", 709 | "metadata": { 710 | "trusted": true 711 | }, 712 | "outputs": [], 713 | "execution_count": null 714 | } 715 | ] 716 | } --------------------------------------------------------------------------------