├── .gitignore ├── models └── 2025-07-20_10-38-57 │ ├── config.json │ └── log.txt ├── layers ├── __init__.py ├── attention_block.py ├── feedforward_block.py ├── embedding.py ├── feedforward.py ├── positional_encoding.py ├── layer_norm.py ├── linear.py └── attention.py ├── README.md ├── generate.py ├── train.py ├── data ├── utils.py ├── little_char_idx.json ├── tokenizer.py └── char_idx.json └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | text_video.py 2 | __pycache__ 3 | models -------------------------------------------------------------------------------- /models/2025-07-20_10-38-57/config.json: -------------------------------------------------------------------------------- 1 | {"vocab_size": 80, "embed_dim": 24, "context_len": 64} -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import Embedding 2 | from .positional_encoding import PositionalEncoding 3 | from .attention_block import AttentionBlock 4 | from .feedforward_block import FeedForwardBlock 5 | from .linear import Linear -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementação NumPy de um Mini GPT-2 2 | 3 | Este repositório apresenta uma implementação educacional de um modelo Transformer do tipo **Decoder-Only**, inspirado no **GPT-2**, utilizando **apenas NumPy**. O objetivo é proporcionar uma compreensão mais profunda do funcionamento interno de um modelo de linguagem neural, camada por camada. 4 | 5 | O modelo foi treinado em um corpus baseado nas obras **"Dom Casmurro"**, **"O Alienista"**, **"Memórias Póstumas de Brás Cubas"** e **"Quincas Borba"**, com uma tokenização em nível de caractere (*char-level*), o que permite gerar texto com vocabulário limitado de maneira didática. 6 | ## Uso 7 | 8 | Primeiramente deve ser feito o treinamento do modelo rodando o arquivo `train.py`.Depois, o modelo pode ser utilizado rodando o arquivo `generate.py` e o contexto pode ser alterado na prompt=. -------------------------------------------------------------------------------- /layers/attention_block.py: -------------------------------------------------------------------------------- 1 | from .attention import Attention 2 | from .layer_norm import LayerNorm 3 | 4 | class AttentionBlock: 5 | def __init__(self, embed_dim): 6 | self.attn = Attention(embed_dim, 2) 7 | self.norm = LayerNorm(embed_dim) 8 | 9 | def forward(self, x): 10 | self.input = x 11 | self.attn_out = self.attn.forward(x) 12 | self.out = self.norm.forward(x + self.attn_out) 13 | return self.out 14 | 15 | def backward(self, grad_out): 16 | grad_norm = self.norm.backward(grad_out) 17 | grad_attn = self.attn.backward(grad_norm) 18 | return grad_attn + grad_norm 19 | 20 | def update_params(self, lr): 21 | self.attn.update_params(lr) 22 | self.norm.update_params(lr) 23 | 24 | def get_params(self): 25 | return { 26 | 'attn': self.attn.get_params(), 27 | 'norm': self.norm.get_params() 28 | } 29 | 30 | def set_params(self, params): 31 | self.attn.set_params(params['attn']) 32 | self.norm.set_params(params['norm']) -------------------------------------------------------------------------------- /layers/feedforward_block.py: -------------------------------------------------------------------------------- 1 | from .feedforward import FeedForward 2 | from .layer_norm import LayerNorm 3 | 4 | class FeedForwardBlock: 5 | def __init__(self, embed_dim, hidden_dim): 6 | self.ff = FeedForward(embed_dim, hidden_dim) 7 | self.norm = LayerNorm(embed_dim) 8 | 9 | def forward(self, x): 10 | self.input = x 11 | self.ff_out = self.ff.forward(x) 12 | self.out = self.norm.forward(x + self.ff_out) 13 | return self.out 14 | 15 | def backward(self, grad_output): 16 | grad_norm = self.norm.backward(grad_output) 17 | grad_ff = self.ff.backward(grad_norm) 18 | return grad_ff + grad_norm 19 | 20 | def update_params(self, lr): 21 | self.ff.update_params(lr) 22 | self.norm.update_params(lr) 23 | 24 | def get_params(self): 25 | return { 26 | 'ff': self.ff.get_params(), 27 | 'norm': self.norm.get_params() 28 | } 29 | 30 | def set_params(self, params): 31 | self.ff.set_params(params['ff']) 32 | self.norm.set_params(params['norm']) -------------------------------------------------------------------------------- /layers/embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Embedding: 4 | def __init__(self, num_embed: int, embed_dim: int) -> None: 5 | self.num_embed = num_embed 6 | self.embed_dim = embed_dim 7 | self.weight = np.random.randn(num_embed, embed_dim).astype(np.float32) * 0.01 8 | self.inputs: np.ndarray | None = None 9 | self.dW = np.zeros_like(self.weight, dtype=np.float32) 10 | 11 | def forward(self, tokens: np.ndarray) -> np.ndarray: 12 | self.inputs = tokens 13 | 14 | return self.weight[tokens] 15 | 16 | def backward(self, dout: np.ndarray) -> None: 17 | self.dW.fill(0) 18 | np.add.at(self.dW, self.inputs, dout) 19 | 20 | def update_params(self, learning_rate: float) -> None: 21 | self.weight -= learning_rate * self.dW 22 | 23 | def get_params(self): 24 | return { 25 | 'weight': self.weight.tolist() 26 | } 27 | 28 | def set_params(self, params): 29 | self.weight = np.array(params['weight'], dtype=np.float32) 30 | self.dW = np.zeros_like(self.weight, dtype=np.float32) -------------------------------------------------------------------------------- /layers/feedforward.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .linear import Linear 3 | 4 | class FeedForward: 5 | def __init__(self, embed_dim, hidden_dim): 6 | self.linear1 = Linear(embed_dim, hidden_dim) 7 | self.linear2 = Linear(hidden_dim, embed_dim) 8 | 9 | def forward(self, x): 10 | self.input = x 11 | self.hidden = np.maximum(0, self.linear1.forward(x)) 12 | return self.linear2.forward(self.hidden) 13 | 14 | def backward(self, grad_output): 15 | grad_hidden = self.linear2.backward(grad_output) 16 | 17 | relu_grad = grad_hidden * (self.hidden > 0) 18 | 19 | return self.linear1.backward(relu_grad) 20 | 21 | def update_params(self, lr): 22 | self.linear1.update_params(lr) 23 | self.linear2.update_params(lr) 24 | 25 | def get_params(self): 26 | return { 27 | 'linear1': self.linear1.get_params(), 28 | 'linear2': self.linear2.get_params() 29 | } 30 | 31 | def set_params(self, params): 32 | self.linear1.set_params(params['linear1']) 33 | self.linear2.set_params(params['linear2']) -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | from model import Model 2 | from data.tokenizer import CharTokenizer 3 | import numpy as np 4 | 5 | tokenizer = CharTokenizer.load("data/little_char_idx.json") 6 | 7 | def generate_text_streaming(model, tokenizer, prompt, max_new_tokens=100): 8 | model_input = np.array([tokenizer.encode(prompt)], dtype=int) 9 | 10 | print(prompt, end="", flush=True) 11 | 12 | for _ in range(max_new_tokens): 13 | context = model_input[:, -model.context_len:] 14 | 15 | next_token = model.predict(context, top_k=4, temperature=0.5) 16 | model_input = np.concatenate([model_input, next_token[:, None]], axis=1) 17 | att_scores = model.layers[2].attn.forward(model.layers[1].forward(model.layers[0].forward(context))) 18 | next_char = tokenizer.decode(next_token[0:1]) 19 | print(next_char, end="",flush=True) 20 | 21 | print() 22 | 23 | 24 | model = Model.load("models/2025-07-20_10-38-57") 25 | 26 | prompt = "Uma noite destas, vindo da cidade para o Engenho Novo, encontrei" 27 | generate_text_streaming(model, tokenizer, prompt, max_new_tokens=422) -------------------------------------------------------------------------------- /layers/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class PositionalEncoding: 4 | def __init__(self, max_len: int, embedding_dim: int): 5 | self.max_len = max_len 6 | self.embedding_dim = embedding_dim 7 | self.pe = self._create_positional_encoding() 8 | 9 | def _create_positional_encoding(self) -> np.ndarray: 10 | pe = np.zeros((self.max_len, self.embedding_dim), dtype=np.float32) 11 | position = np.arange(0, self.max_len, dtype=np.float32).reshape(-1, 1) 12 | div_term = np.exp(np.arange(0, self.embedding_dim, 2) * -(np.log(10000.0) / self.embedding_dim), dtype=np.float32) 13 | 14 | pe[:, 0::2] = np.sin(position * div_term, dtype=np.float32) 15 | pe[:, 1::2] = np.cos(position * div_term, dtype=np.float32) 16 | 17 | return pe 18 | 19 | def forward(self, x: np.ndarray) -> np.ndarray: 20 | seq_len = x.shape[1] 21 | return x + self.pe[:seq_len] 22 | 23 | def backward(self, dout: np.ndarray) -> np.ndarray: 24 | return dout 25 | 26 | def update_params(self, learning_rate: float): 27 | pass 28 | 29 | def get_params(self): 30 | return {} 31 | 32 | def set_params(self, params): 33 | pass -------------------------------------------------------------------------------- /models/2025-07-20_10-38-57/log.txt: -------------------------------------------------------------------------------- 1 | Start time: 23:11:23 2 | Epoch | Train Loss | Train Acc | Val Acc | Time 3 | --------------------------------------------- 4 | 1 | 2.3473 | 0.3773 | 0.3773 | 23:38:43 5 | 2 | 1.9565 | 0.4223 | 0.4223 | 00:06:10 6 | 3 | 1.8650 | 0.4405 | 0.4405 | 00:33:40 7 | 4 | 1.8208 | 0.4507 | 0.4507 | 01:01:03 8 | 5 | 1.7941 | 0.4564 | 0.4563 | 01:28:24 9 | 6 | 1.7743 | 0.4590 | 0.4591 | 01:55:46 10 | 7 | 1.7593 | 0.4634 | 0.4635 | 02:23:08 11 | 8 | 1.7476 | 0.4671 | 0.4671 | 02:50:29 12 | 9 | 1.7383 | 0.4693 | 0.4692 | 03:18:00 13 | 10 | 1.7308 | 0.4718 | 0.4717 | 03:45:27 14 | 11 | 1.7241 | 0.4731 | 0.4730 | 04:12:48 15 | 12 | 1.7182 | 0.4771 | 0.4770 | 04:40:13 16 | 13 | 1.7132 | 0.4788 | 0.4786 | 05:07:34 17 | 14 | 1.7087 | 0.4803 | 0.4802 | 05:35:00 18 | 15 | 1.7047 | 0.4815 | 0.4815 | 06:02:22 19 | 16 | 1.7009 | 0.4824 | 0.4824 | 06:29:44 20 | 17 | 1.6975 | 0.4837 | 0.4837 | 06:57:14 21 | 18 | 1.6944 | 0.4848 | 0.4847 | 07:26:01 22 | 19 | 1.6917 | 0.4854 | 0.4854 | 07:53:32 23 | 20 | 1.6893 | 0.4863 | 0.4863 | 08:21:07 24 | 21 | 1.6871 | 0.4867 | 0.4867 | 08:48:42 25 | 22 | 1.6850 | 0.4874 | 0.4873 | 09:16:18 26 | 23 | 1.6829 | 0.4873 | 0.4872 | 09:44:05 27 | 24 | 1.6810 | 0.4880 | 0.4880 | 10:11:31 28 | 25 | 1.6792 | 0.4884 | 0.4884 | 10:38:57 29 | Model saved to models/2025-07-20_10-38-57 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model import Model 3 | from data.tokenizer import CharTokenizer 4 | from data.utils import calculate_accuracy, load_text, create_dataset 5 | from datetime import datetime 6 | 7 | def main(): 8 | raw_text = load_text("data.txt") 9 | 10 | tokenizer = CharTokenizer.load('data/little_char_idx.json') 11 | 12 | encoded = np.array(tokenizer.encode(raw_text), dtype=np.int32) 13 | 14 | context_len = 8 15 | vocab_size = tokenizer.vocab_size 16 | embed_dim = 8 17 | 18 | x_train, y_train, x_val, y_val = create_dataset(encoded, context_len, train_split=0.8) 19 | 20 | model = Model(vocab_size, embed_dim, context_len) 21 | 22 | epochs = 1 23 | batch_size = 32 24 | learning_rate = 0.05 25 | print(f"Start time: {datetime.now().strftime('%H:%M:%S')}") 26 | print("Epoch | Train Loss | Train Acc | Val Acc | Time") 27 | print("-" * 45) 28 | 29 | for epoch in range(epochs): 30 | total_loss = 0 31 | num_batches = x_train.shape[0] // batch_size 32 | for i in range(num_batches): 33 | x_batch = x_train[i*batch_size:(i+1)*batch_size] 34 | y_batch = y_train[i*batch_size:(i+1)*batch_size] 35 | loss = model.train_step(x_batch, y_batch, learning_rate) 36 | total_loss += loss 37 | 38 | avg_loss = total_loss / num_batches 39 | train_accuracy = calculate_accuracy(model, x_train, y_train, batch_size) 40 | val_accuracy = calculate_accuracy(model, x_val, y_val, batch_size) 41 | 42 | now = datetime.now().strftime("%H:%M:%S") 43 | print(f"{epoch+1:5d} | {avg_loss:10.4f} | {train_accuracy:9.4f} | {val_accuracy:9.4f} | {now}") 44 | 45 | model.save("models") 46 | 47 | if __name__ == "__main__": 48 | main() -------------------------------------------------------------------------------- /layers/layer_norm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class LayerNorm: 4 | def __init__(self, dim, eps=1e-5): 5 | self.eps = eps 6 | self.gamma = np.ones(dim, dtype=np.float32) 7 | self.beta = np.zeros(dim, dtype=np.float32) 8 | 9 | def forward(self, x): 10 | self.input = x 11 | self.mean = x.mean(axis=-1, keepdims=True) 12 | self.var = x.var(axis=-1, keepdims=True) 13 | self.norm = (x - self.mean) / np.sqrt(self.var + self.eps) 14 | return self.gamma * self.norm + self.beta 15 | 16 | def backward(self, grad_output): 17 | N = self.input.shape[-1] 18 | 19 | std_inv = 1.0 / np.sqrt(self.var + self.eps) 20 | x_mu = self.input - self.mean 21 | 22 | dnorm = grad_output * self.gamma 23 | 24 | dvar = np.sum(dnorm * x_mu * -0.5 * std_inv**3, axis=-1, keepdims=True) 25 | dmean = np.sum(dnorm * -std_inv, axis=-1, keepdims=True) + dvar * np.mean(-2.0 * x_mu, axis=-1, keepdims=True) 26 | 27 | grad_input = dnorm * std_inv + dvar * 2 * x_mu / N + dmean / N 28 | 29 | axes_to_sum = tuple(range(len(grad_output.shape) - 1)) 30 | 31 | self.grad_gamma = np.sum(grad_output * self.norm, axis=axes_to_sum) 32 | self.grad_beta = np.sum(grad_output, axis=axes_to_sum) 33 | 34 | return grad_input 35 | 36 | def update_params(self, lr): 37 | self.gamma -= lr * self.grad_gamma 38 | self.beta -= lr * self.grad_beta 39 | 40 | def get_params(self): 41 | return { 42 | 'gamma': self.gamma.tolist(), 43 | 'beta': self.beta.tolist() 44 | } 45 | 46 | def set_params(self, params): 47 | self.gamma = np.array(params['gamma'], dtype=np.float32) 48 | self.beta = np.array(params['beta'], dtype=np.float32) 49 | self.grad_gamma = np.zeros_like(self.gamma, dtype=np.float32) 50 | self.grad_beta = np.zeros_like(self.beta, dtype=np.float32) -------------------------------------------------------------------------------- /layers/linear.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Linear: 4 | def __init__(self, in_features: int, out_features: int, bias: bool = True): 5 | self.in_features = in_features 6 | self.out_features = out_features 7 | self.use_bias = bias 8 | 9 | limit = np.sqrt(6 / (in_features + out_features), dtype=np.float32) 10 | self.weight = np.random.uniform(-limit, limit, (in_features, out_features)).astype(np.float32) 11 | 12 | if self.use_bias: 13 | self.bias = np.zeros(out_features, dtype=np.float32) 14 | self.grad_bias = np.zeros_like(self.bias, dtype=np.float32) 15 | else: 16 | self.bias = None 17 | self.grad_bias = None 18 | 19 | self.grad_weight = np.zeros_like(self.weight, dtype=np.float32) 20 | 21 | def forward(self, x): 22 | self.input = x 23 | out = np.matmul(x, self.weight) 24 | 25 | if self.use_bias: 26 | out += self.bias 27 | 28 | return out 29 | 30 | def backward(self, grad_output): 31 | self.grad_weight = np.einsum('bij,bik->jk', self.input, grad_output) 32 | 33 | if self.use_bias: 34 | self.grad_bias = np.sum(grad_output, axis=(0, 1)) 35 | 36 | grad_input = np.matmul(grad_output, self.weight.T) 37 | return grad_input 38 | 39 | def update_params(self, lr): 40 | self.weight -= lr * self.grad_weight 41 | 42 | if self.use_bias: 43 | self.bias -= lr * self.grad_bias 44 | 45 | def get_params(self): 46 | params = {'weight': self.weight.tolist()} 47 | 48 | if self.use_bias: 49 | params['bias'] = self.bias.tolist() 50 | 51 | return params 52 | 53 | def set_params(self, params): 54 | self.weight = np.array(params['weight'], dtype=np.float32) 55 | self.grad_weight = np.zeros_like(self.weight, dtype=np.float32) 56 | 57 | if self.use_bias: 58 | self.bias = np.array(params['bias'], dtype=np.float32) 59 | self.grad_bias = np.zeros_like(self.bias, dtype=np.float32) -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def load_text(path): 4 | with open(path, 'r', encoding='utf-8') as f: 5 | return f.read() 6 | 7 | def create_dataset(encoded, context_len, train_split=0.8): 8 | x, y = [], [] 9 | 10 | for i in range(len(encoded) - context_len): 11 | x.append(encoded[i:i+context_len]) 12 | y.append(encoded[i+1:i+context_len+1]) 13 | 14 | x = np.array(x) 15 | y = np.array(y) 16 | 17 | indices = np.arange(x.shape[0]) 18 | np.random.shuffle(indices) 19 | x = x[indices] 20 | y = y[indices] 21 | 22 | split_idx = int(len(x) * train_split) 23 | 24 | x_train, x_val = x[:split_idx], x[split_idx:] 25 | y_train, y_val = y[:split_idx], y[split_idx:] 26 | 27 | return x_train, y_train, x_val, y_val 28 | 29 | def calculate_accuracy(model, x_data, y_data, batch_size=32): 30 | correct_predictions = 0 31 | total_predictions = 0 32 | 33 | num_batches = x_data.shape[0] // batch_size 34 | 35 | for i in range(num_batches): 36 | x_batch = x_data[i*batch_size:(i+1)*batch_size] 37 | y_batch = y_data[i*batch_size:(i+1)*batch_size] 38 | 39 | logits = model.forward(x_batch) 40 | 41 | logits_shifted = logits - np.max(logits, axis=-1, keepdims=True) 42 | exp_logits = np.exp(logits_shifted) 43 | probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) 44 | 45 | predictions = np.argmax(probs, axis=-1) 46 | 47 | correct_predictions += np.sum(predictions == y_batch) 48 | total_predictions += y_batch.size 49 | 50 | remaining = x_data.shape[0] % batch_size 51 | if remaining > 0: 52 | x_batch = x_data[-remaining:] 53 | y_batch = y_data[-remaining:] 54 | logits = model.forward(x_batch) 55 | 56 | logits_shifted = logits - np.max(logits, axis=-1, keepdims=True) 57 | exp_logits = np.exp(logits_shifted) 58 | probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) 59 | 60 | predictions = np.argmax(probs, axis=-1) 61 | correct_predictions += np.sum(predictions == y_batch) 62 | total_predictions += y_batch.size 63 | 64 | return correct_predictions / total_predictions -------------------------------------------------------------------------------- /data/little_char_idx.json: -------------------------------------------------------------------------------- 1 | { 2 | "char_to_idx": { 3 | "": 0, 4 | " ": 1, 5 | "a": 2, 6 | "e": 3, 7 | "o": 4, 8 | "s": 5, 9 | "r": 6, 10 | "i": 7, 11 | "m": 8, 12 | "u": 9, 13 | "d": 10, 14 | "n": 11, 15 | "t": 12, 16 | "c": 13, 17 | "l": 14, 18 | ",": 15, 19 | "p": 16, 20 | ".": 17, 21 | "v": 18, 22 | "h": 19, 23 | "q": 20, 24 | "\n": 21, 25 | "g": 22, 26 | "ã": 23, 27 | "b": 24, 28 | "f": 25, 29 | "-": 26, 30 | "é": 27, 31 | "ç": 28, 32 | ";": 29, 33 | "z": 30, 34 | "á": 31, 35 | "C": 32, 36 | "A": 33, 37 | "E": 34, 38 | "í": 35, 39 | "j": 36, 40 | "x": 37, 41 | "O": 38, 42 | "P": 39, 43 | "ó": 40, 44 | "—": 41, 45 | "N": 42, 46 | "S": 43, 47 | "ê": 44, 48 | "D": 45, 49 | "T": 46, 50 | "M": 47, 51 | "à": 48, 52 | "?": 49, 53 | "R": 50, 54 | "V": 51, 55 | "I": 52, 56 | "L": 53, 57 | "!": 54, 58 | "U": 55, 59 | "Q": 56, 60 | "B": 57, 61 | "ú": 58, 62 | "J": 59, 63 | "F": 60, 64 | "õ": 61, 65 | "X": 62, 66 | "Í": 63, 67 | ":": 64, 68 | "\"": 65, 69 | "ô": 66, 70 | "H": 67, 71 | "â": 68, 72 | "G": 69, 73 | "1": 70, 74 | "É": 71, 75 | "ü": 72, 76 | "(": 73, 77 | ")": 74, 78 | "8": 75, 79 | "0": 76, 80 | "4": 77, 81 | "”": 78, 82 | "Ã": 79 83 | }, 84 | "idx_to_char": { 85 | "0": "", 86 | "1": " ", 87 | "2": "a", 88 | "3": "e", 89 | "4": "o", 90 | "5": "s", 91 | "6": "r", 92 | "7": "i", 93 | "8": "m", 94 | "9": "u", 95 | "10": "d", 96 | "11": "n", 97 | "12": "t", 98 | "13": "c", 99 | "14": "l", 100 | "15": ",", 101 | "16": "p", 102 | "17": ".", 103 | "18": "v", 104 | "19": "h", 105 | "20": "q", 106 | "21": "\n", 107 | "22": "g", 108 | "23": "ã", 109 | "24": "b", 110 | "25": "f", 111 | "26": "-", 112 | "27": "é", 113 | "28": "ç", 114 | "29": ";", 115 | "30": "z", 116 | "31": "á", 117 | "32": "C", 118 | "33": "A", 119 | "34": "E", 120 | "35": "í", 121 | "36": "j", 122 | "37": "x", 123 | "38": "O", 124 | "39": "P", 125 | "40": "ó", 126 | "41": "—", 127 | "42": "N", 128 | "43": "S", 129 | "44": "ê", 130 | "45": "D", 131 | "46": "T", 132 | "47": "M", 133 | "48": "à", 134 | "49": "?", 135 | "50": "R", 136 | "51": "V", 137 | "52": "I", 138 | "53": "L", 139 | "54": "!", 140 | "55": "U", 141 | "56": "Q", 142 | "57": "B", 143 | "58": "ú", 144 | "59": "J", 145 | "60": "F", 146 | "61": "õ", 147 | "62": "X", 148 | "63": "Í", 149 | "64": ":", 150 | "65": "\"", 151 | "66": "ô", 152 | "67": "H", 153 | "68": "â", 154 | "69": "G", 155 | "70": "1", 156 | "71": "É", 157 | "72": "ü", 158 | "73": "(", 159 | "74": ")", 160 | "75": "8", 161 | "76": "0", 162 | "77": "4", 163 | "78": "”", 164 | "79": "Ã" 165 | }, 166 | "vocab_size": 80, 167 | "is_fitted": true 168 | } -------------------------------------------------------------------------------- /data/tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List, Optional 3 | from collections import Counter 4 | import re 5 | 6 | class CharTokenizer: 7 | def __init__(self, special_tokens: Optional[List[str]] = None, max_vocab_size: int = 65): 8 | self.char_to_idx = {} 9 | self.idx_to_char = {} 10 | self.vocab_size = 0 11 | self._is_fitted = False 12 | self.max_vocab_size = max_vocab_size 13 | 14 | if special_tokens is None: 15 | special_tokens = [''] 16 | 17 | self.add_special_tokens(special_tokens) 18 | 19 | def fit(self, texts: List[str]) -> None: 20 | char_counter = Counter() 21 | for text in texts: 22 | char_counter.update(text) 23 | 24 | available_slots = self.max_vocab_size - len(self.char_to_idx) 25 | most_frequent_chars = [char for char, count in char_counter.most_common(available_slots) 26 | if char not in self.char_to_idx] 27 | 28 | for char in most_frequent_chars: 29 | self.char_to_idx[char] = self.vocab_size 30 | self.idx_to_char[self.vocab_size] = char 31 | self.vocab_size += 1 32 | 33 | self._is_fitted = True 34 | 35 | def encode(self, text: str) -> List[int]: 36 | if not self._is_fitted: 37 | raise ValueError("Tokenizer must be fitted before encoding") 38 | 39 | unk_token = self.char_to_idx.get('') 40 | if unk_token is None: 41 | raise ValueError("No token found in vocabulary") 42 | 43 | return [self.char_to_idx.get(char, unk_token) for char in text] 44 | 45 | def decode(self, tokens: List[int]) -> str: 46 | return ''.join([self.idx_to_char.get(idx, '') for idx in tokens]) 47 | 48 | def add_special_tokens(self, tokens: List[str]) -> None: 49 | for token in tokens: 50 | if token not in self.char_to_idx: 51 | self.char_to_idx[token] = self.vocab_size 52 | self.idx_to_char[self.vocab_size] = token 53 | self.vocab_size += 1 54 | 55 | def is_fitted(self) -> bool: 56 | return self._is_fitted 57 | 58 | def get_vocab(self) -> dict: 59 | return self.char_to_idx.copy() 60 | 61 | def get_vocab_size(self) -> int: 62 | return self.vocab_size 63 | 64 | def save(self, filepath: str) -> None: 65 | data = { 66 | 'char_to_idx': self.char_to_idx, 67 | 'idx_to_char': {str(k): v for k, v in self.idx_to_char.items()}, 68 | 'vocab_size': self.vocab_size, 69 | 'is_fitted': self._is_fitted 70 | } 71 | 72 | with open(filepath, 'w', encoding='utf-8') as f: 73 | json.dump(data, f, ensure_ascii=False, indent=2) 74 | 75 | @staticmethod 76 | def load(filepath: str, max_vocab_size: int = 65) -> None: 77 | with open(filepath, 'r', encoding='utf-8') as f: 78 | data = json.load(f) 79 | 80 | tokenizer = CharTokenizer(max_vocab_size=max_vocab_size) 81 | tokenizer.char_to_idx = data['char_to_idx'] 82 | tokenizer.idx_to_char = {int(k): v for k, v in data['idx_to_char'].items()} 83 | tokenizer.vocab_size = data['vocab_size'] 84 | tokenizer._is_fitted = data.get('is_fitted', True) 85 | 86 | return tokenizer 87 | 88 | def __len__(self) -> int: 89 | return self.vocab_size 90 | 91 | def __repr__(self) -> str: 92 | return f"CharTokenizer(vocab_size={self.vocab_size}, fitted={self._is_fitted})" -------------------------------------------------------------------------------- /data/char_idx.json: -------------------------------------------------------------------------------- 1 | { 2 | "char_to_idx": { 3 | "": 0, 4 | " ": 1, 5 | "a": 2, 6 | "e": 3, 7 | "o": 4, 8 | "s": 5, 9 | "r": 6, 10 | "i": 7, 11 | "m": 8, 12 | "u": 9, 13 | "d": 10, 14 | "n": 11, 15 | "t": 12, 16 | "c": 13, 17 | "l": 14, 18 | ",": 15, 19 | "p": 16, 20 | ".": 17, 21 | "v": 18, 22 | "h": 19, 23 | "q": 20, 24 | "\n": 21, 25 | "g": 22, 26 | "ã": 23, 27 | "b": 24, 28 | "f": 25, 29 | "-": 26, 30 | "é": 27, 31 | "ç": 28, 32 | ";": 29, 33 | "z": 30, 34 | "á": 31, 35 | "C": 32, 36 | "A": 33, 37 | "E": 34, 38 | "í": 35, 39 | "j": 36, 40 | "x": 37, 41 | "O": 38, 42 | "P": 39, 43 | "ó": 40, 44 | "—": 41, 45 | "N": 42, 46 | "S": 43, 47 | "ê": 44, 48 | "D": 45, 49 | "T": 46, 50 | "M": 47, 51 | "à": 48, 52 | "?": 49, 53 | "R": 50, 54 | "V": 51, 55 | "I": 52, 56 | "L": 53, 57 | "!": 54, 58 | "U": 55, 59 | "Q": 56, 60 | "B": 57, 61 | "ú": 58, 62 | "J": 59, 63 | "F": 60, 64 | "õ": 61, 65 | "X": 62, 66 | "Í": 63, 67 | ":": 64, 68 | "\"": 65, 69 | "ô": 66, 70 | "H": 67, 71 | "â": 68, 72 | "G": 69, 73 | "1": 70, 74 | "É": 71, 75 | "ü": 72, 76 | "(": 73, 77 | ")": 74, 78 | "8": 75, 79 | "0": 76, 80 | "4": 77, 81 | "”": 78, 82 | "Ã": 79, 83 | "5": 80, 84 | "2": 81, 85 | "3": 82, 86 | "6": 83, 87 | "7": 84, 88 | "“": 85, 89 | "9": 86, 90 | "Z": 87, 91 | "Á": 88, 92 | "À": 89, 93 | "Ç": 90, 94 | "Ó": 91, 95 | "Ú": 92, 96 | "k": 93, 97 | "'": 94, 98 | "’": 95, 99 | "y": 96, 100 | "°": 97, 101 | "è": 98, 102 | "*": 99 103 | }, 104 | "idx_to_char": { 105 | "0": "", 106 | "1": " ", 107 | "2": "a", 108 | "3": "e", 109 | "4": "o", 110 | "5": "s", 111 | "6": "r", 112 | "7": "i", 113 | "8": "m", 114 | "9": "u", 115 | "10": "d", 116 | "11": "n", 117 | "12": "t", 118 | "13": "c", 119 | "14": "l", 120 | "15": ",", 121 | "16": "p", 122 | "17": ".", 123 | "18": "v", 124 | "19": "h", 125 | "20": "q", 126 | "21": "\n", 127 | "22": "g", 128 | "23": "ã", 129 | "24": "b", 130 | "25": "f", 131 | "26": "-", 132 | "27": "é", 133 | "28": "ç", 134 | "29": ";", 135 | "30": "z", 136 | "31": "á", 137 | "32": "C", 138 | "33": "A", 139 | "34": "E", 140 | "35": "í", 141 | "36": "j", 142 | "37": "x", 143 | "38": "O", 144 | "39": "P", 145 | "40": "ó", 146 | "41": "—", 147 | "42": "N", 148 | "43": "S", 149 | "44": "ê", 150 | "45": "D", 151 | "46": "T", 152 | "47": "M", 153 | "48": "à", 154 | "49": "?", 155 | "50": "R", 156 | "51": "V", 157 | "52": "I", 158 | "53": "L", 159 | "54": "!", 160 | "55": "U", 161 | "56": "Q", 162 | "57": "B", 163 | "58": "ú", 164 | "59": "J", 165 | "60": "F", 166 | "61": "õ", 167 | "62": "X", 168 | "63": "Í", 169 | "64": ":", 170 | "65": "\"", 171 | "66": "ô", 172 | "67": "H", 173 | "68": "â", 174 | "69": "G", 175 | "70": "1", 176 | "71": "É", 177 | "72": "ü", 178 | "73": "(", 179 | "74": ")", 180 | "75": "8", 181 | "76": "0", 182 | "77": "4", 183 | "78": "”", 184 | "79": "Ã", 185 | "80": "5", 186 | "81": "2", 187 | "82": "3", 188 | "83": "6", 189 | "84": "7", 190 | "85": "“", 191 | "86": "9", 192 | "87": "Z", 193 | "88": "Á", 194 | "89": "À", 195 | "90": "Ç", 196 | "91": "Ó", 197 | "92": "Ú", 198 | "93": "k", 199 | "94": "'", 200 | "95": "’", 201 | "96": "y", 202 | "97": "°", 203 | "98": "è", 204 | "99": "*" 205 | }, 206 | "vocab_size": 100, 207 | "is_fitted": true 208 | } -------------------------------------------------------------------------------- /layers/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .linear import Linear 3 | 4 | class Attention: 5 | def __init__(self, embed_dim, num_heads): 6 | assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" 7 | self.embed_dim = embed_dim 8 | self.num_heads = num_heads 9 | self.head_dim = embed_dim // num_heads 10 | 11 | self.q_proj = Linear(embed_dim, embed_dim, bias=False) 12 | self.k_proj = Linear(embed_dim, embed_dim, bias=False) 13 | self.v_proj = Linear(embed_dim, embed_dim, bias=False) 14 | self.out_proj = Linear(embed_dim, embed_dim, bias=False) 15 | 16 | self._causal_mask = None 17 | self._mask_size = 0 18 | 19 | def forward(self, x): 20 | self.x = x 21 | batch_size, seq_len, _ = x.shape 22 | 23 | Q = self.q_proj.forward(x) 24 | K = self.k_proj.forward(x) 25 | V = self.v_proj.forward(x) 26 | 27 | Q = Q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) 28 | K = K.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) 29 | V = V.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) 30 | 31 | self.Q, self.K, self.V = Q, K, V 32 | 33 | scores = (Q @ K.transpose(0, 1, 3, 2)) / np.sqrt(self.head_dim) 34 | scores = self._apply_causal_mask(scores, seq_len) 35 | 36 | self.attn_weights = self._softmax(scores) 37 | attn_output = self.attn_weights @ V 38 | 39 | attn_output = attn_output.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, self.embed_dim) 40 | self.attn_output = attn_output 41 | 42 | return self.out_proj.forward(attn_output) 43 | 44 | def backward(self, grad_out): 45 | batch_size, seq_len, _ = grad_out.shape 46 | 47 | grad_attn_output = self.out_proj.backward(grad_out) 48 | 49 | grad_attn_output = grad_attn_output.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) 50 | 51 | grad_attn_weights = grad_attn_output @ self.V.transpose(0, 1, 3, 2) 52 | grad_V = self.attn_weights.transpose(0, 1, 3, 2) @ grad_attn_output 53 | 54 | grad_scores = self._softmax_backward(self.attn_weights, grad_attn_weights) 55 | causal_mask = self._get_causal_mask(seq_len) 56 | grad_scores = grad_scores * causal_mask[None, None, :, :] 57 | 58 | grad_Q = grad_scores @ self.K / np.sqrt(self.head_dim) 59 | grad_K = grad_scores.transpose(0, 1, 3, 2) @ self.Q / np.sqrt(self.head_dim) 60 | 61 | grad_Q = grad_Q.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, self.embed_dim) 62 | grad_K = grad_K.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, self.embed_dim) 63 | grad_V = grad_V.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, self.embed_dim) 64 | 65 | grad_input_Q = self.q_proj.backward(grad_Q) 66 | grad_input_K = self.k_proj.backward(grad_K) 67 | grad_input_V = self.v_proj.backward(grad_V) 68 | 69 | return grad_input_Q + grad_input_K + grad_input_V 70 | 71 | def update_params(self, lr): 72 | self.q_proj.update_params(lr) 73 | self.k_proj.update_params(lr) 74 | self.v_proj.update_params(lr) 75 | self.out_proj.update_params(lr) 76 | 77 | def _get_causal_mask(self, seq_len): 78 | if self._causal_mask is None or self._mask_size < seq_len: 79 | self._causal_mask = np.tril(np.ones((seq_len, seq_len))) 80 | self._mask_size = seq_len 81 | return self._causal_mask[:seq_len, :seq_len] 82 | 83 | def _apply_causal_mask(self, scores, seq_len): 84 | causal_mask = self._get_causal_mask(seq_len) 85 | return np.where(causal_mask[None, None, :, :], scores, -1e9) 86 | 87 | def _softmax(self, x): 88 | x = x - np.max(x, axis=-1, keepdims=True) 89 | exp = np.exp(x) 90 | return exp / np.sum(exp, axis=-1, keepdims=True) 91 | 92 | def _softmax_backward(self, softmax_out, grad_output): 93 | return grad_output * softmax_out - softmax_out * np.sum(grad_output * softmax_out, axis=-1, keepdims=True) 94 | 95 | def get_params(self): 96 | return { 97 | 'q_proj': self.q_proj.get_params(), 98 | 'k_proj': self.k_proj.get_params(), 99 | 'v_proj': self.v_proj.get_params(), 100 | 'out_proj': self.out_proj.get_params() 101 | } 102 | 103 | def set_params(self, params): 104 | self.q_proj.set_params(params['q_proj']) 105 | self.k_proj.set_params(params['k_proj']) 106 | self.v_proj.set_params(params['v_proj']) 107 | self.out_proj.set_params(params['out_proj']) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import layers 3 | import os 4 | import json 5 | from datetime import datetime 6 | 7 | class Model: 8 | def __init__(self, vocab_size, embed_dim, context_len): 9 | self.vocab_size = vocab_size 10 | self.embed_dim = embed_dim 11 | self.context_len = context_len 12 | 13 | self.layers = [ 14 | layers.Embedding(vocab_size, embed_dim), 15 | layers.PositionalEncoding(context_len, embed_dim), 16 | layers.AttentionBlock(embed_dim), 17 | layers.FeedForwardBlock(embed_dim, 4 * embed_dim), 18 | layers.AttentionBlock(embed_dim), 19 | layers.FeedForwardBlock(embed_dim, 4 * embed_dim), 20 | layers.Linear(embed_dim, vocab_size) 21 | ] 22 | 23 | def forward(self, x): 24 | out = x 25 | for layer in self.layers: 26 | out = layer.forward(out) 27 | return out 28 | 29 | def backward(self, d_logits): 30 | grad = d_logits 31 | for layer in reversed(self.layers): 32 | grad = layer.backward(grad) 33 | 34 | def compute_loss(self, logits, y_true): 35 | batch_size, context_len, vocab_size = logits.shape 36 | 37 | logits_flat = logits.reshape(-1, vocab_size) 38 | y_true_flat = y_true.flatten() 39 | 40 | logits_shifted = logits_flat - np.max(logits_flat, axis=1, keepdims=True) 41 | exp_logits = np.exp(logits_shifted) 42 | probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True) 43 | 44 | correct_logprobs = -np.log(probs[np.arange(len(y_true_flat)), y_true_flat] + 1e-9) 45 | loss = np.mean(correct_logprobs) 46 | 47 | d_logits = probs 48 | d_logits[np.arange(len(y_true_flat)), y_true_flat] -= 1 49 | d_logits /= batch_size * context_len 50 | d_logits = d_logits.reshape(batch_size, context_len, vocab_size) 51 | 52 | return loss, d_logits 53 | 54 | def update_params(self, learning_rate): 55 | for layer in self.layers: 56 | layer.update_params(learning_rate) 57 | 58 | def train_step(self, x_batch, y_batch, learning_rate): 59 | logits = self.forward(x_batch) 60 | loss, d_logits = self.compute_loss(logits, y_batch) 61 | self.backward(d_logits) 62 | self.update_params(learning_rate) 63 | return loss 64 | 65 | def predict(self, x: np.ndarray, temperature: float = 1.0, top_k: int = None) -> np.ndarray: 66 | logits = self.forward(x) 67 | last_logits = logits[:, -1, :] 68 | 69 | if temperature <= 0: 70 | raise ValueError("Temperature must be greater than 0") 71 | 72 | scaled_logits = last_logits / temperature 73 | 74 | if top_k is not None and top_k > 0: 75 | top_k_indices = np.argsort(scaled_logits, axis=-1)[:, -top_k:] 76 | mask = np.full_like(scaled_logits, -np.inf) 77 | batch_indices = np.arange(scaled_logits.shape[0])[:, None] 78 | mask[batch_indices, top_k_indices] = scaled_logits[batch_indices, top_k_indices] 79 | scaled_logits = mask 80 | 81 | exp_logits = np.exp(scaled_logits - np.max(scaled_logits, axis=-1, keepdims=True)) 82 | probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) 83 | 84 | return np.array([ 85 | np.random.choice(probs.shape[-1], p=probs[i]) 86 | for i in range(probs.shape[0]) 87 | ]) 88 | 89 | def get_config(self): 90 | return { 91 | 'vocab_size': self.vocab_size, 92 | 'embed_dim': self.embed_dim, 93 | 'context_len': self.context_len 94 | } 95 | 96 | def get_params(self): 97 | return [layer.get_params() for layer in self.layers] 98 | 99 | def set_params(self, params): 100 | for layer, param in zip(self.layers, params): 101 | layer.set_params(param) 102 | 103 | def save(self, path): 104 | os.makedirs(path, exist_ok=True) 105 | 106 | timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 107 | 108 | timestamped_path = os.path.join(path, timestamp) 109 | os.makedirs(timestamped_path, exist_ok=True) 110 | 111 | config_path = os.path.join(timestamped_path, "config.json") 112 | with open(config_path, 'w') as f: 113 | json.dump(self.get_config(), f) 114 | 115 | params_path = os.path.join(timestamped_path, "weights.json") 116 | with open(params_path, 'w') as f: 117 | json.dump(self.get_params(), f) 118 | 119 | print(f"Model saved to {timestamped_path}") 120 | 121 | @staticmethod 122 | def load(path): 123 | config_path = os.path.join(path, "config.json") 124 | with open(config_path, 'r') as f: 125 | config = json.load(f) 126 | 127 | model = Model(**config) 128 | 129 | params_path = os.path.join(path, "weights.json") 130 | with open(params_path, 'r') as f: 131 | params = json.load(f) 132 | 133 | model.set_params(params) 134 | print(f"Model loaded from {path}") 135 | return model --------------------------------------------------------------------------------