├── .gitignore ├── LICENSE ├── README.md ├── code ├── data_set.py ├── gpt_model.py └── train.py ├── note ├── 1-tokenizer.ipynb ├── 2-embedding.ipynb ├── 4-attention.ipynb ├── data_set.ipynb └── gpt_model.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llm-hero 2 | 3 | 因为考虑到可能会出现的版权问题,所以语料文件并没有上传,有用到的地方大家可以替换成自己的文件. 4 | 5 | ## note (jupter notebook) 6 | 7 | 包含如下几个笔记 8 | 9 | * tokenizer 的使用 10 | * embedding 对语料中的词进行关系分析 11 | * self-attention 的详细计算过程 12 | * data_set 用于生成模型io数据 13 | * gpt_model 详细的模型结构搭建 14 | 15 | ## code (python src_code) 16 | 17 | 具体的代码使用: 18 | 19 | * 安装 requirements.txt 里面的库 20 | * 准备好语料文件,修改 data_set.py 里面对应的文件名,然后运行 `python data_set.py` 21 | * 再根据自己的实际情况,修改 gpt_model.py 里面的 config 参数,最后运行 `python train.py` 进行模型训练 22 | -------------------------------------------------------------------------------- /code/data_set.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sentencepiece as spm 4 | import sys 5 | import torch 6 | 7 | 8 | def train_model(fname, prefix): 9 | spm.SentencePieceTrainer.train( 10 | input=fname, model_prefix=prefix, vocab_size=16000,) 11 | # user_defined_symbols=['侠之大者', '为国为民']) 12 | 13 | def load_file_into_splits(text_file, split_ratio): 14 | with open(text_file, 'r') as file: 15 | data = file.read() 16 | split_idx = int(len(data) * split_ratio) 17 | return data[:split_idx], data[split_idx:] 18 | 19 | def load_tokenizer(model_file): 20 | sp = spm.SentencePieceProcessor() 21 | if not sp.load(model_file=model_file): 22 | return False, None 23 | else: 24 | return True, sp 25 | 26 | def encode_and_save(sp, content, prefix): 27 | token_ids = sp.encode(content, out_type=int) 28 | print(f"data split of {prefix} has {len(token_ids)} tokens") 29 | token_ids = np.array(token_ids, dtype=np.int32) 30 | token_ids.tofile(os.path.join(os.path.dirname(__file__), "{}.dat".format(prefix))) 31 | 32 | def gen_dataset(text_file, model_file): 33 | flag, sp = load_tokenizer(model_file) 34 | if not flag: 35 | print(f"load tokenizer model from: {model_file} failed") 36 | sys.exit(1) 37 | 38 | split_ratio = 0.9 39 | train_text, test_text = load_file_into_splits(text_file, split_ratio) 40 | encode_and_save(sp, train_text, "train") 41 | encode_and_save(sp, test_text, "test") 42 | 43 | def get_batch(data): 44 | batch_size = 4 45 | block_size = 16 46 | ix = torch.randint(len(data) - block_size, (batch_size,)) 47 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int32)) for i in ix]) 48 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int32)) for i in ix]) 49 | return x, y 50 | 51 | def test_samples(): 52 | train_data = np.memmap(os.path.join("./", 'train.dat'), dtype=np.int32, mode='r') 53 | x, y = get_batch(train_data) 54 | 55 | model_file = "bird_shooter.model" 56 | flag, sp = load_tokenizer(model_file) 57 | if not flag: 58 | print(f"load tokenizer model from: {model_file} failed") 59 | sys.exit(1) 60 | 61 | # print(x, y) 62 | for features, targets in zip(x, y): 63 | print("feature:", sp.decode(features.tolist())) 64 | print("target:", sp.decode(targets.tolist())) 65 | 66 | 67 | if __name__ == '__main__': 68 | train_model("bird_shooter.txt", "bird_shooter") 69 | gen_dataset("bird_shooter.txt", "bird_shooter.model") 70 | # test_samples() -------------------------------------------------------------------------------- /code/gpt_model.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torchinfo import summary 8 | 9 | 10 | class GPTConfig: 11 | vocab_size: int = 16000 12 | seq_len: int = 128 13 | d_model: int = 128 # d_model 14 | n_layer: int = 4 15 | n_head: int = 4 16 | bias: bool = True 17 | dropout: float = 0.0 18 | 19 | 20 | class SinusoidPE(nn.Module): 21 | """ sin/cos position encoding """ 22 | 23 | def __init__(self, config): 24 | super().__init__() 25 | d_model, seq_len = config.d_model, config.seq_len 26 | pe = torch.zeros(seq_len, d_model) 27 | position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) 28 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 29 | pe[:, 0::2] = torch.sin(position * div_term) 30 | pe[:, 1::2] = torch.cos(position * div_term) 31 | pe = pe.unsqueeze(0) 32 | self.register_buffer('sinusoid_pe', pe) 33 | 34 | def forward(self, x): 35 | return self.sinusoid_pe[:, :x.shape[1], :] 36 | 37 | 38 | class SelfAttention(nn.Module): 39 | """ multi-head attention """ 40 | 41 | def __init__(self, config): 42 | super().__init__() 43 | assert config.d_model % config.n_head == 0 44 | # key, query, value projections for all heads, but in a batch 45 | self.attn = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias) 46 | # output projection 47 | self.proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) 48 | # regularization 49 | self.attn_dropout = nn.Dropout(config.dropout) 50 | # self.resid_dropout = nn.Dropout(config.dropout) 51 | self.n_head = config.n_head 52 | self.d_model = config.d_model 53 | self.dropout = config.dropout 54 | # causal mask to ensure that attention is only applied to the left in the input sequence 55 | self.register_buffer("mask", torch.tril(torch.ones(config.seq_len, config.seq_len)) 56 | .view(1, 1, config.seq_len, config.seq_len)) 57 | 58 | def forward(self, x): 59 | B, C, E = x.size() 60 | 61 | q, k, v = self.attn(x).split(self.d_model, dim=2) 62 | q = q.view(B, C, self.n_head, E // self.n_head).transpose(1, 2) # (B, nh, C, hs) 63 | k = k.view(B, C, self.n_head, E // self.n_head).transpose(1, 2) # (B, nh, C, hs) 64 | v = v.view(B, C, self.n_head, E // self.n_head).transpose(1, 2) # (B, nh, C, hs) 65 | 66 | # self-attention: (B, nh, C, hs) x (B, nh, hs, C) -> (B, nh, C, C) 67 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 68 | att = att.masked_fill(self.mask[:,:,:C,:C] == 0, float('-inf')) 69 | att = F.softmax(att, dim=-1) 70 | att = self.attn_dropout(att) 71 | y = att @ v # (B, nh, C, C) x (B, nh, C, hs) -> (B, nh, C, hs) 72 | y = y.transpose(1, 2).contiguous().view(B, C, E) 73 | 74 | return self.proj(y) 75 | 76 | 77 | class FeedFoward(nn.Module): 78 | """ a two-layers mlp """ 79 | 80 | def __init__(self, config): 81 | super().__init__() 82 | d_model = config.d_model 83 | self.net = nn.Sequential( 84 | nn.Linear(d_model, 4 * d_model), 85 | nn.GELU(), 86 | nn.Linear(4 * d_model, d_model), 87 | nn.Dropout(config.dropout), 88 | ) 89 | 90 | def forward(self, x): 91 | return self.net(x) 92 | 93 | 94 | class Block(nn.Module): 95 | """ Decoder Block """ 96 | 97 | def __init__(self, config): 98 | # n_embd: embedding dimension, n_head: the number of heads we'd like 99 | super().__init__() 100 | self.ln1 = nn.LayerNorm(config.d_model, bias=config.bias) 101 | self.attn = SelfAttention(config) 102 | self.ln2 = nn.LayerNorm(config.d_model, bias=config.bias) 103 | self.ffn = FeedFoward(config) 104 | 105 | def forward(self, x): 106 | x = x + self.attn(self.ln1(x)) 107 | x = x + self.ffn(self.ln2(x)) 108 | return x 109 | 110 | 111 | class GPTModel(nn.Module): 112 | 113 | def __init__(self, config): 114 | super().__init__() 115 | self.config = config 116 | self.tok_embed_table = nn.Embedding(config.vocab_size, config.d_model) 117 | self.pos_embed_table = SinusoidPE(config) 118 | self.decoder_blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 119 | self.layer_norm = nn.LayerNorm(config.d_model, bias=config.bias) 120 | self.final_linear = nn.Linear(config.d_model, config.vocab_size, bias=False) 121 | 122 | # init all weights 123 | self.apply(self._init_weights) 124 | for pn, p in self.named_parameters(): 125 | if pn.endswith('proj.weight'): 126 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 127 | 128 | def _init_weights(self, module): 129 | if isinstance(module, nn.Linear): 130 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 131 | if module.bias is not None: 132 | torch.nn.init.zeros_(module.bias) 133 | elif isinstance(module, nn.Embedding): 134 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 135 | 136 | def forward(self, features, targets=None): 137 | tok_emb = self.tok_embed_table(features) # (B,C,E) 138 | pos_emb = self.pos_embed_table(tok_emb) 139 | x = tok_emb + pos_emb # (B,C,E) 140 | x = self.decoder_blocks(x) 141 | 142 | x = self.layer_norm(x) 143 | if targets is not None: 144 | logits = self.final_linear(x) # (B,C,V) 145 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 146 | else: 147 | logits = self.final_linear(x[:, [-1], :]) 148 | loss = None 149 | return logits, loss 150 | 151 | @torch.no_grad() 152 | def generate(self, seq, max_new_tokens): 153 | for _ in range(max_new_tokens): 154 | seq = seq[:, -self.config.seq_len:] 155 | logits, _ = self(seq) 156 | # focus only on the last time step 157 | logits = logits[:, -1, :] # becomes (B, V) 158 | # apply softmax to get probabilities 159 | probs = F.softmax(logits, dim=-1) # (B, V) 160 | # sample from the distribution 161 | seq_next = torch.multinomial(probs, num_samples=1) # (B, 1) 162 | seq = torch.cat((seq, seq_next), dim=1) 163 | return seq 164 | 165 | 166 | def main(): 167 | config = GPTConfig() 168 | model = GPTModel(config) 169 | summary(model, input_size=[(100, config.seq_len), (100, config.seq_len)], 170 | dtypes=[torch.long, torch.long]) 171 | 172 | 173 | if __name__ == '__main__': 174 | main() -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | from gpt_model import GPTConfig, GPTModel 2 | 3 | import numpy as np 4 | import sentencepiece as spm 5 | import sys 6 | import torch 7 | 8 | 9 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 10 | 11 | learning_rate = 1e-3 12 | max_iters = 12000 13 | 14 | train_data = np.memmap('train.dat', dtype=np.int32, mode='r') 15 | test_data = np.memmap('test.dat', dtype=np.int32, mode='r') 16 | 17 | def get_batch(split, config): 18 | data = train_data if split == 'train' else test_data 19 | ix = torch.randint(len(data) - config.seq_len, (config.batch_size,)) 20 | x = torch.stack([torch.from_numpy((data[i:i+config.seq_len]).astype(np.int32)) for i in ix]) 21 | y = torch.stack([torch.from_numpy((data[i+1:i+1+config.seq_len]).astype(np.int64)) for i in ix]) 22 | if device == 'cuda': 23 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 24 | else: 25 | x, y = x.to(device), y.to(device) 26 | return x, y 27 | 28 | def train(config, model): 29 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) 30 | 31 | for iter_num in range(max_iters): 32 | optimizer.zero_grad() 33 | 34 | xb, yb = get_batch('train', config) 35 | 36 | # forward and loss caculation 37 | _, loss = model(xb, yb) 38 | if (iter_num + 1) % 100 == 0: 39 | print(f"[train_info] iter:{iter_num+1:5d}, loss:{loss.item():5.3f}") 40 | 41 | # backward and gradient descent 42 | loss.backward() 43 | 44 | # update weights 45 | optimizer.step() 46 | 47 | print(f"final loss: {loss.item()}") 48 | 49 | 50 | def main(): 51 | config = GPTConfig() 52 | config.batch_size = 32 53 | config.dropout = 0.1 54 | 55 | model = GPTModel(config).to(device) 56 | 57 | train(config, model) 58 | 59 | # load tokenizer 60 | from data_set import load_tokenizer 61 | model_file = "bird_shooter.model" 62 | flag, sp = load_tokenizer(model_file) 63 | if not flag: 64 | print(f"load tokenizer model from: {model_file} failed") 65 | sys.exit(1) 66 | 67 | # generate from the model 68 | user_inputs = ["郭靖一掌挥出", "黄蓉突然想到", "周伯通好奇心大起", "洪七公哈哈大笑"] 69 | for user_input in user_inputs: 70 | context = torch.tensor([sp.encode(user_input)], dtype=torch.int32, device=device) 71 | gpt_output = model.generate(context, max_new_tokens=50)[0].tolist() 72 | print(f"gpt({user_input}) => {sp.decode(gpt_output)}") 73 | 74 | if __name__ == '__main__': 75 | main() -------------------------------------------------------------------------------- /note/1-tokenizer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "1b032014-b3c1-4ea4-ac5a-7646b1c79fdf", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "# def print_format_table():\n", 13 | "# \"\"\"\n", 14 | "# prints table of formatted text format options\n", 15 | "# \"\"\"\n", 16 | "# for style in range(8):\n", 17 | "# for fg in range(30, 38):\n", 18 | "# s1 = ''\n", 19 | "# for bg in range(40, 48):\n", 20 | "# format = ';'.join([str(style), str(fg), str(bg)])\n", 21 | "# s1 += '\\x1b[%sm %s \\x1b[0m' % (format, format)\n", 22 | "# print(s1)\n", 23 | "# print('\\n')\n", 24 | "# print_format_table()\n", 25 | "\n", 26 | "# for fg in range(31, 38):\n", 27 | "# print('\\x1b[2;%s;40m %s \\x1b[0m' % (fg, \"hello\"))\n", 28 | " \n", 29 | "class RainbowPrinter:\n", 30 | " def __init__(self):\n", 31 | " self.idx = 0\n", 32 | " self.format_str = '\\x1b[1;%s;48m%s \\x1b[0m'\n", 33 | "\n", 34 | " def print_word(self, word):\n", 35 | " self.idx += 1\n", 36 | " if self.idx == 7:\n", 37 | " self.idx = 1\n", 38 | " print(self.format_str % (30+self.idx, word), end='')\n", 39 | "\n", 40 | " def print_words(self, words):\n", 41 | " \"\"\" print sentence made up of tokenwords \"\"\"\n", 42 | " if isinstance(words,list) or isinstance(words,tuple):\n", 43 | " for token_word in words:\n", 44 | " self.print_word(token_word)\n", 45 | " print('\\n')\n", 46 | " else:\n", 47 | " raise" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 63, 53 | "id": "83a7952c-2859-4621-990e-0a93407236db", 54 | "metadata": { 55 | "tags": [] 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "import re\n", 60 | "from collections import Counter\n", 61 | "\n", 62 | "\"\"\"\n", 63 | "SentencePiece treats the input text just as a sequence of Unicode characters. \n", 64 | "Whitespace is also handled as a normal symbol. \n", 65 | "To handle the whitespace as a basic token explicitly, SentencePiece first \n", 66 | "escapes the whitespace with a meta symbol \"▁\" (U+2581) as follows.\n", 67 | "\"\"\"\n", 68 | "\n", 69 | "#\n", 70 | "# corpus = {\n", 71 | "# word_0: [token0, token1, ..., tokenm],\n", 72 | "# word_1: [token0, token1, ..., tokenm],\n", 73 | "# ...\n", 74 | "# word_n: [token0, token1, ..., tokenm],\n", 75 | "# }\n", 76 | "#\n", 77 | "# vocab {\n", 78 | "# token_0: count_0,\n", 79 | "# token_1: count_1,\n", 80 | "# ...\n", 81 | "# token_m: count_m,\n", 82 | "# }\n", 83 | "#\n", 84 | "\n", 85 | "class BytePairEncoder:\n", 86 | " \n", 87 | " def __init__(self):\n", 88 | " self.ws_token = '▁'\n", 89 | " self.unk_token = ''\n", 90 | " \n", 91 | " self.corpus = {}\n", 92 | " self.word_count = {}\n", 93 | " self.vocab = Counter()\n", 94 | " \n", 95 | " self.id_tokens = {}\n", 96 | " self.token_ids = {}\n", 97 | " \n", 98 | " \n", 99 | " def init_state(self, content):\n", 100 | " # init corpus and wordcnt\n", 101 | " for line in content:\n", 102 | " sentence = self.preprocess(line.strip())\n", 103 | " self.process_sentence(sentence)\n", 104 | " \n", 105 | " alphabet = {}\n", 106 | " for word, chrs in self.corpus.items():\n", 107 | " for ch in chrs:\n", 108 | " alphabet[ch] = alphabet.get(ch, 0) + self.word_count[word]\n", 109 | " self.vocab.update(alphabet)\n", 110 | " \n", 111 | " # for debug\n", 112 | " self._dump_init()\n", 113 | " \n", 114 | " \n", 115 | " def process_sentence(self, sentence):\n", 116 | " words = sentence.split()\n", 117 | " for word in words:\n", 118 | " word = self.ws_token + word\n", 119 | " if word not in self.corpus:\n", 120 | " self.corpus[word] = [ch for ch in word]\n", 121 | " self.word_count[word] = 1\n", 122 | " else:\n", 123 | " self.word_count[word] += 1\n", 124 | " \n", 125 | " \n", 126 | " def preprocess(self, text):\n", 127 | " return re.sub('\\s+', ' ', text)\n", 128 | " \n", 129 | " \n", 130 | " def _dump_init(self):\n", 131 | " print(\"=\" * 12 + \" dump initial state \" + \"=\" * 12)\n", 132 | " print(\"==> dump corpus <==\")\n", 133 | " for word, text in self.corpus.items():\n", 134 | " print(f\"{word} => {text}\")\n", 135 | " print('-.' * 20)\n", 136 | " print(\"==> dump wordcnt <==\")\n", 137 | " for word, count in self.word_count.items():\n", 138 | " print(f\"{word} => {count}\")\n", 139 | " print('-.' * 20)\n", 140 | " print(\"==> dump vocab <==\")\n", 141 | " for token, count in self.vocab.items():\n", 142 | " print(f\"{token} => {count}\")\n", 143 | " print(\"-\" * 40)\n", 144 | " \n", 145 | " \n", 146 | " def gen_bigrams(self):\n", 147 | " bigram_counter = Counter()\n", 148 | " for word, text in self.corpus.items():\n", 149 | " for i in range(len(text) - 1):\n", 150 | " # NOTE: use '+' instead of (l,r) to deal with the case\n", 151 | " # a,aa is same as aa,a when generate bigram.\n", 152 | " bigram = text[i] + text[i+1]\n", 153 | " bigram_counter[bigram] += self.word_count[word]\n", 154 | " \n", 155 | " # for debug\n", 156 | " # print(\"==> dump bigram counter <==\")\n", 157 | " # for symbol, counter in bigram_counter.most_common(5):\n", 158 | " # print(f\"{symbol} => {counter}\")\n", 159 | " return bigram_counter\n", 160 | "\n", 161 | " \n", 162 | " def merge_pair(self):\n", 163 | " top_bigram, top_count = self.gen_bigrams().most_common(1)[0]\n", 164 | " print(f\"=> top_bigram:{top_bigram}, top_count:{top_count}\")\n", 165 | " if top_count == 1:\n", 166 | " return\n", 167 | " for word, text in self.corpus.items():\n", 168 | " merged = False\n", 169 | " for i in range(len(text) - 1): \n", 170 | " if (text[i] + text[i+1] == top_bigram):\n", 171 | " self.update_vocab(text[i], -self.word_count[word])\n", 172 | " self.update_vocab(text[i+1], -self.word_count[word])\n", 173 | " text[i] = top_bigram\n", 174 | " text[i+1] = ''\n", 175 | " merged = True\n", 176 | " if merged:\n", 177 | " self.corpus[word] = [token for token in text if token]\n", 178 | " self.update_vocab(top_bigram, top_count)\n", 179 | " \n", 180 | " \n", 181 | " def update_vocab(self, symbol, count):\n", 182 | " if symbol in self.vocab:\n", 183 | " self.vocab[symbol] += count\n", 184 | " # NOTE: must comment off, will cut off the way to combine tokenwords\n", 185 | " # if self.vocab[symbol] == 0:\n", 186 | " # del self.vocab[symbol]\n", 187 | " else:\n", 188 | " self.vocab[symbol] = count\n", 189 | " \n", 190 | " \n", 191 | " def train(self, text, steps=3):\n", 192 | " self.init_state(text)\n", 193 | " \n", 194 | " for step in range(steps):\n", 195 | " print(\"=\" * 12 + f\" step:{step} \" + \"=\" * 12)\n", 196 | " self.merge_pair()\n", 197 | " # for debug\n", 198 | " # self._dump_merge()\n", 199 | " \n", 200 | " print(\"==> dump final vocab <==\")\n", 201 | " for token, count in sorted(self.vocab.items(), key=lambda x:x[1], reverse=True):\n", 202 | " print(f\"{token} => {count}\")\n", 203 | " self.gen_id_token_map()\n", 204 | " \n", 205 | "\n", 206 | " def _dump_merge(self):\n", 207 | " print(\"-\" * 40)\n", 208 | " print(\"==> dump vocab <==\")\n", 209 | " for token, count in sorted(self.vocab.items(), key=lambda x:x[1], reverse=True):\n", 210 | " print(f\"{token} => {count}\")\n", 211 | " print('-' * 40)\n", 212 | " print(\"==> dump corpus <==\")\n", 213 | " for word, tokens in self.corpus.items():\n", 214 | " print(f\"[{self.word_count[word]:3d}] * {word} => {tokens}\")\n", 215 | " print(\"-\" * 40) \n", 216 | "\n", 217 | "\n", 218 | " def gen_id_token_map(self):\n", 219 | " # descent order\n", 220 | " self.id_tokens[0] = self.unk_token\n", 221 | " self.token_ids[self.unk_token] = 0\n", 222 | " \n", 223 | " idx = 1\n", 224 | " for token, _ in self.vocab.most_common():\n", 225 | " self.id_tokens[idx] = token\n", 226 | " self.token_ids[token] = idx\n", 227 | " idx += 1\n", 228 | " \n", 229 | " \n", 230 | " def encode(self, text):\n", 231 | " if not text: return\n", 232 | " text = self.preprocess(text)\n", 233 | " text = self.ws_token + re.sub(' ', self.ws_token, text.strip())\n", 234 | " seg_txt = self.segment(text)\n", 235 | " seg_ids = [self.token_ids[token] if token in self.token_ids else 0 for token in seg_txt]\n", 236 | " return (seg_txt, seg_ids)\n", 237 | " \n", 238 | " \n", 239 | " def segment(self, text):\n", 240 | " if len(text) == 1:\n", 241 | " return text if text in self.vocab else self.unk_token\n", 242 | " \n", 243 | " segments = [ch for ch in text]\n", 244 | " merge_rules = Counter()\n", 245 | " \n", 246 | " # iter over merge segments [i, i+1]\n", 247 | " for i in range(len(segments)-1):\n", 248 | " token_word = segments[i] + segments[i+1]\n", 249 | " if token_word in self.vocab:\n", 250 | " # print(f\"* update rule of combine {segments[i]} and {segments[i+1]} into {token_word}\")\n", 251 | " merge_rules.update({(i, token_word):self.vocab[i]})\n", 252 | "\n", 253 | " while merge_rules:\n", 254 | " (i, token_word), _ = merge_rules.most_common(1)[0]\n", 255 | " # eg: a,b,c first merge (b,c); then (a,b) is no longer exist\n", 256 | " if i >= len(segments)-1 or segments[i] + segments[i+1] != token_word:\n", 257 | " # print(f\"! discard rule of combine {segments[i]} and {segments[i+1]} into {token_word}, i={i}\")\n", 258 | " merge_rules.pop((i, token_word))\n", 259 | " continue\n", 260 | " # print(f\"> apply rule of combine {segments[i]} and {segments[i+1]} into {token_word}\")\n", 261 | " for i in range(len(segments)-1):\n", 262 | " if segments[i] + segments[i+1] == token_word:\n", 263 | " segments[i] = token_word\n", 264 | " segments[i+1] = ''\n", 265 | " # print(\"before merge: \", segments)\n", 266 | " segments = [seg for seg in segments if seg]\n", 267 | " # print(\"after merge: \", segments)\n", 268 | " if len(segments) <= 1:\n", 269 | " break\n", 270 | " for i in range(len(segments)-1):\n", 271 | " token_word = segments[i] + segments[i+1]\n", 272 | " if token_word in self.vocab:\n", 273 | " merge_rules.update({(i, token_word): self.vocab[i]})\n", 274 | " \n", 275 | " return segments\n", 276 | " \n", 277 | " \n", 278 | " def decode(self, ids):\n", 279 | " text = ''.join([self.id_tokens[idx] for idx in ids]).replace(self.ws_token, ' ')\n", 280 | " return text\n", 281 | " \n" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 64, 287 | "id": "08e3bb27-4acf-4179-9eca-7bb3016a8523", 288 | "metadata": { 289 | "tags": [] 290 | }, 291 | "outputs": [ 292 | { 293 | "name": "stdout", 294 | "output_type": "stream", 295 | "text": [ 296 | "============ dump initial state ============\n", 297 | "==> dump corpus <==\n", 298 | "▁这是OpenAI => ['▁', '这', '是', 'O', 'p', 'e', 'n', 'A', 'I']\n", 299 | "▁团队前一段时间放出来的预印版论文。 => ['▁', '团', '队', '前', '一', '段', '时', '间', '放', '出', '来', '的', '预', '印', '版', '论', '文', '。']\n", 300 | "▁他们的目标是学习一个通用的表示,能够在大量任务上进行应用。 => ['▁', '他', '们', '的', '目', '标', '是', '学', '习', '一', '个', '通', '用', '的', '表', '示', ',', '能', '够', '在', '大', '量', '任', '务', '上', '进', '行', '应', '用', '。']\n", 301 | "▁这篇论文的亮点主要在于, => ['▁', '这', '篇', '论', '文', '的', '亮', '点', '主', '要', '在', '于', ',']\n", 302 | "▁他们利用了Transformer网络代替了LSTM作为语言模型来更好的捕获长距离语言结构。 => ['▁', '他', '们', '利', '用', '了', 'T', 'r', 'a', 'n', 's', 'f', 'o', 'r', 'm', 'e', 'r', '网', '络', '代', '替', '了', 'L', 'S', 'T', 'M', '作', '为', '语', '言', '模', '型', '来', '更', '好', '的', '捕', '获', '长', '距', '离', '语', '言', '结', '构', '。']\n", 303 | "▁然后在进行具体任务有监督微调时, => ['▁', '然', '后', '在', '进', '行', '具', '体', '任', '务', '有', '监', '督', '微', '调', '时', ',']\n", 304 | "▁使用了模型作为附属任务训练目标。 => ['▁', '使', '用', '了', '模', '型', '作', '为', '附', '属', '任', '务', '训', '练', '目', '标', '。']\n", 305 | "-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.\n", 306 | "==> dump wordcnt <==\n", 307 | "▁这是OpenAI => 1\n", 308 | "▁团队前一段时间放出来的预印版论文。 => 1\n", 309 | "▁他们的目标是学习一个通用的表示,能够在大量任务上进行应用。 => 1\n", 310 | "▁这篇论文的亮点主要在于, => 1\n", 311 | "▁他们利用了Transformer网络代替了LSTM作为语言模型来更好的捕获长距离语言结构。 => 1\n", 312 | "▁然后在进行具体任务有监督微调时, => 1\n", 313 | "▁使用了模型作为附属任务训练目标。 => 1\n", 314 | "-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.\n", 315 | "==> dump vocab <==\n", 316 | "▁ => 7\n", 317 | "这 => 2\n", 318 | "是 => 2\n", 319 | "O => 1\n", 320 | "p => 1\n", 321 | "e => 2\n", 322 | "n => 2\n", 323 | "A => 1\n", 324 | "I => 1\n", 325 | "团 => 1\n", 326 | "队 => 1\n", 327 | "前 => 1\n", 328 | "一 => 2\n", 329 | "段 => 1\n", 330 | "时 => 2\n", 331 | "间 => 1\n", 332 | "放 => 1\n", 333 | "出 => 1\n", 334 | "来 => 2\n", 335 | "的 => 5\n", 336 | "预 => 1\n", 337 | "印 => 1\n", 338 | "版 => 1\n", 339 | "论 => 2\n", 340 | "文 => 2\n", 341 | "。 => 4\n", 342 | "他 => 2\n", 343 | "们 => 2\n", 344 | "目 => 2\n", 345 | "标 => 2\n", 346 | "学 => 1\n", 347 | "习 => 1\n", 348 | "个 => 1\n", 349 | "通 => 1\n", 350 | "用 => 4\n", 351 | "表 => 1\n", 352 | "示 => 1\n", 353 | ", => 2\n", 354 | "能 => 1\n", 355 | "够 => 1\n", 356 | "在 => 3\n", 357 | "大 => 1\n", 358 | "量 => 1\n", 359 | "任 => 3\n", 360 | "务 => 3\n", 361 | "上 => 1\n", 362 | "进 => 2\n", 363 | "行 => 2\n", 364 | "应 => 1\n", 365 | "篇 => 1\n", 366 | "亮 => 1\n", 367 | "点 => 1\n", 368 | "主 => 1\n", 369 | "要 => 1\n", 370 | "于 => 1\n", 371 | "利 => 1\n", 372 | "了 => 3\n", 373 | "T => 2\n", 374 | "r => 3\n", 375 | "a => 1\n", 376 | "s => 1\n", 377 | "f => 1\n", 378 | "o => 1\n", 379 | "m => 1\n", 380 | "网 => 1\n", 381 | "络 => 1\n", 382 | "代 => 1\n", 383 | "替 => 1\n", 384 | "L => 1\n", 385 | "S => 1\n", 386 | "M => 1\n", 387 | "作 => 2\n", 388 | "为 => 2\n", 389 | "语 => 2\n", 390 | "言 => 2\n", 391 | "模 => 2\n", 392 | "型 => 2\n", 393 | "更 => 1\n", 394 | "好 => 1\n", 395 | "捕 => 1\n", 396 | "获 => 1\n", 397 | "长 => 1\n", 398 | "距 => 1\n", 399 | "离 => 1\n", 400 | "结 => 1\n", 401 | "构 => 1\n", 402 | "然 => 1\n", 403 | "后 => 1\n", 404 | "具 => 1\n", 405 | "体 => 1\n", 406 | "有 => 1\n", 407 | "监 => 1\n", 408 | "督 => 1\n", 409 | "微 => 1\n", 410 | "调 => 1\n", 411 | ", => 1\n", 412 | "使 => 1\n", 413 | "附 => 1\n", 414 | "属 => 1\n", 415 | "训 => 1\n", 416 | "练 => 1\n", 417 | "----------------------------------------\n", 418 | "============ step:0 ============\n", 419 | "=> top_bigram:任务, top_count:3\n", 420 | "============ step:1 ============\n", 421 | "=> top_bigram:▁这, top_count:2\n", 422 | "============ step:2 ============\n", 423 | "=> top_bigram:论文, top_count:2\n", 424 | "============ step:3 ============\n", 425 | "=> top_bigram:▁他, top_count:2\n", 426 | "============ step:4 ============\n", 427 | "=> top_bigram:▁他们, top_count:2\n", 428 | "============ step:5 ============\n", 429 | "=> top_bigram:目标, top_count:2\n", 430 | "============ step:6 ============\n", 431 | "=> top_bigram:进行, top_count:2\n", 432 | "============ step:7 ============\n", 433 | "=> top_bigram:用了, top_count:2\n", 434 | "============ step:8 ============\n", 435 | "=> top_bigram:作为, top_count:2\n", 436 | "============ step:9 ============\n", 437 | "=> top_bigram:语言, top_count:2\n", 438 | "============ step:10 ============\n", 439 | "=> top_bigram:模型, top_count:2\n", 440 | "============ step:11 ============\n", 441 | "=> top_bigram:▁这是, top_count:1\n", 442 | "============ step:12 ============\n", 443 | "=> top_bigram:▁这是, top_count:1\n", 444 | "============ step:13 ============\n", 445 | "=> top_bigram:▁这是, top_count:1\n", 446 | "============ step:14 ============\n", 447 | "=> top_bigram:▁这是, top_count:1\n", 448 | "============ step:15 ============\n", 449 | "=> top_bigram:▁这是, top_count:1\n", 450 | "============ step:16 ============\n", 451 | "=> top_bigram:▁这是, top_count:1\n", 452 | "============ step:17 ============\n", 453 | "=> top_bigram:▁这是, top_count:1\n", 454 | "============ step:18 ============\n", 455 | "=> top_bigram:▁这是, top_count:1\n", 456 | "============ step:19 ============\n", 457 | "=> top_bigram:▁这是, top_count:1\n", 458 | "==> dump final vocab <==\n", 459 | "的 => 5\n", 460 | "。 => 4\n", 461 | "▁ => 3\n", 462 | "在 => 3\n", 463 | "r => 3\n", 464 | "任务 => 3\n", 465 | "是 => 2\n", 466 | "e => 2\n", 467 | "n => 2\n", 468 | "一 => 2\n", 469 | "时 => 2\n", 470 | "来 => 2\n", 471 | "用 => 2\n", 472 | ", => 2\n", 473 | "T => 2\n", 474 | "▁这 => 2\n", 475 | "论文 => 2\n", 476 | "▁他们 => 2\n", 477 | "目标 => 2\n", 478 | "进行 => 2\n", 479 | "用了 => 2\n", 480 | "作为 => 2\n", 481 | "语言 => 2\n", 482 | "模型 => 2\n", 483 | "O => 1\n", 484 | "p => 1\n", 485 | "A => 1\n", 486 | "I => 1\n", 487 | "团 => 1\n", 488 | "队 => 1\n", 489 | "前 => 1\n", 490 | "段 => 1\n", 491 | "间 => 1\n", 492 | "放 => 1\n", 493 | "出 => 1\n", 494 | "预 => 1\n", 495 | "印 => 1\n", 496 | "版 => 1\n", 497 | "学 => 1\n", 498 | "习 => 1\n", 499 | "个 => 1\n", 500 | "通 => 1\n", 501 | "表 => 1\n", 502 | "示 => 1\n", 503 | "能 => 1\n", 504 | "够 => 1\n", 505 | "大 => 1\n", 506 | "量 => 1\n", 507 | "上 => 1\n", 508 | "应 => 1\n", 509 | "篇 => 1\n", 510 | "亮 => 1\n", 511 | "点 => 1\n", 512 | "主 => 1\n", 513 | "要 => 1\n", 514 | "于 => 1\n", 515 | "利 => 1\n", 516 | "了 => 1\n", 517 | "a => 1\n", 518 | "s => 1\n", 519 | "f => 1\n", 520 | "o => 1\n", 521 | "m => 1\n", 522 | "网 => 1\n", 523 | "络 => 1\n", 524 | "代 => 1\n", 525 | "替 => 1\n", 526 | "L => 1\n", 527 | "S => 1\n", 528 | "M => 1\n", 529 | "更 => 1\n", 530 | "好 => 1\n", 531 | "捕 => 1\n", 532 | "获 => 1\n", 533 | "长 => 1\n", 534 | "距 => 1\n", 535 | "离 => 1\n", 536 | "结 => 1\n", 537 | "构 => 1\n", 538 | "然 => 1\n", 539 | "后 => 1\n", 540 | "具 => 1\n", 541 | "体 => 1\n", 542 | "有 => 1\n", 543 | "监 => 1\n", 544 | "督 => 1\n", 545 | "微 => 1\n", 546 | "调 => 1\n", 547 | ", => 1\n", 548 | "使 => 1\n", 549 | "附 => 1\n", 550 | "属 => 1\n", 551 | "训 => 1\n", 552 | "练 => 1\n", 553 | "这 => 0\n", 554 | "论 => 0\n", 555 | "文 => 0\n", 556 | "他 => 0\n", 557 | "们 => 0\n", 558 | "目 => 0\n", 559 | "标 => 0\n", 560 | "任 => 0\n", 561 | "务 => 0\n", 562 | "进 => 0\n", 563 | "行 => 0\n", 564 | "作 => 0\n", 565 | "为 => 0\n", 566 | "语 => 0\n", 567 | "言 => 0\n", 568 | "模 => 0\n", 569 | "型 => 0\n", 570 | "▁他 => 0\n" 571 | ] 572 | } 573 | ], 574 | "source": [ 575 | "bpe = BytePairEncoder()\n", 576 | "corpus = [\n", 577 | " # \"Alice is running faster than Bob\",\n", 578 | " # \"Bob run slower than Alice\",\n", 579 | " # \"FloydHub is the fastest way to build, train and deploy deep learning models. Build deep learning models in the cloud. Train deep learning models.\"\n", 580 | " # \"old \" * 7 + \"older \" * 3 + \"finest \" * 9 + \"lowest \" * 4\n", 581 | " # \"hug \" * 10 + \"pug \" * 5 + \"pun \" * 12 + \"bun \" * 4 + \"hugs \" * 5\n", 582 | " \"这是OpenAI 团队前一段时间放出来的预印版论文。 他们的目标是学习一个通用的表示,能够在大量任务上进行应用。\",\n", 583 | " \"这篇论文的亮点主要在于, 他们利用了Transformer网络代替了LSTM作为语言模型来更好的捕获长距离语言结构。\",\n", 584 | " \"然后在进行具体任务有监督微调时, 使用了模型作为附属任务训练目标。\"\n", 585 | "]\n", 586 | "bpe.train(corpus, 20)\n", 587 | "# bpe.init_state('\\n'.join(corpus))" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": null, 593 | "id": "e3f71ed8-12a0-4fe1-9116-86cea4e217c2", 594 | "metadata": { 595 | "tags": [] 596 | }, 597 | "outputs": [ 598 | { 599 | "name": "stdout", 600 | "output_type": "stream", 601 | "text": [ 602 | "\u001b[1;31;48m▁他们 \u001b[0m\u001b[1;32;48m论文 \u001b[0m\u001b[1;33;48m的 \u001b[0m\u001b[1;34;48m亮 \u001b[0m\u001b[1;35;48m点 \u001b[0m\u001b[1;36;48m是 \u001b[0m\u001b[1;31;48m用 \u001b[0m\u001b[1;32;48m语言 \u001b[0m\u001b[1;33;48m模型 \u001b[0m\u001b[1;34;48m完 \u001b[0m\u001b[1;35;48m成 \u001b[0m\u001b[1;36;48m对 \u001b[0m\u001b[1;31;48m应 \u001b[0m\u001b[1;32;48m的 \u001b[0m\u001b[1;33;48m目标 \u001b[0m\u001b[1;34;48m任务 \u001b[0m\n", 603 | "\n", 604 | " 论文的亮点是用模型应的目标任务\n" 605 | ] 606 | } 607 | ], 608 | "source": [ 609 | "printer = RainbowPrinter()\n", 610 | "# segments, seg_ids = bpe.encode(\"huggpnun what ugg is haasnb\")\n", 611 | "seg_txt, seg_ids = bpe.encode(\"他们论文的亮点是用语言模型完成对应的目标任务\")\n", 612 | "printer.print_words(seg_txt)\n", 613 | "print(bpe.decode(seg_ids))" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": null, 619 | "id": "e8d67db3-47ea-4693-b790-6e058e1a7b09", 620 | "metadata": { 621 | "tags": [] 622 | }, 623 | "outputs": [], 624 | "source": [ 625 | "bpe.train('\\n'.join(corpus))" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": null, 631 | "id": "f3c2b0f7-6f74-4f7b-a287-5b1d1b8ecd94", 632 | "metadata": { 633 | "tags": [] 634 | }, 635 | "outputs": [], 636 | "source": [ 637 | "bpe.merge_pair()" 638 | ] 639 | } 640 | ], 641 | "metadata": { 642 | "kernelspec": { 643 | "display_name": "Python 3 (ipykernel)", 644 | "language": "python", 645 | "name": "python3" 646 | }, 647 | "language_info": { 648 | "codemirror_mode": { 649 | "name": "ipython", 650 | "version": 3 651 | }, 652 | "file_extension": ".py", 653 | "mimetype": "text/x-python", 654 | "name": "python", 655 | "nbconvert_exporter": "python", 656 | "pygments_lexer": "ipython3", 657 | "version": "3.10.12" 658 | } 659 | }, 660 | "nbformat": 4, 661 | "nbformat_minor": 5 662 | } 663 | -------------------------------------------------------------------------------- /note/4-attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | } 12 | }, 13 | "cells": [ 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "id": "0XQ6NsIuDtgr" 18 | }, 19 | "source": [ 20 | "# Illustrated: Self-Attention\n", 21 | "Step-by-step guide to self-attention with illustrations and code\n", 22 | "\n", 23 | "[medium article](https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a)\n", 24 | "\n", 25 | "[Article author](https://towardsdatascience.com/@remykarem)\n", 26 | "\n", 27 | "> Colab made by: [Manuel Romero](https://twitter.com/mrm8488)" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "U76qWlrbOmx7" 34 | }, 35 | "source": [ 36 | "![texto alternativo](https://miro.medium.com/max/1973/1*_92bnsMJy8Bl539G4v93yg.gif)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "id": "wOkXKd60Q_Iu" 43 | }, 44 | "source": [ 45 | "What do *BERT, RoBERTa, ALBERT, SpanBERT, DistilBERT, SesameBERT, SemBERT, MobileBERT, TinyBERT and CamemBERT* all have in common? And I’m not looking for the answer “BERT” 🤭.\n", 46 | "Answer: **self-attention** 🤗. We are not only talking about architectures bearing the name “BERT’, but more correctly **Transformer-based architectures**. Transformer-based architectures, which are primarily used in modelling language understanding tasks, eschew the use of recurrence in neural network (RNNs) and instead trust entirely on self-attention mechanisms to draw global dependencies between inputs and outputs. But what’s the math behind this?" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": { 52 | "id": "yozzTBjBRbAA" 53 | }, 54 | "source": [ 55 | "The main content of this kernel is to walk you through the mathematical operations involved in a self-attention module." 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": { 61 | "id": "atUYzU3TSD9z" 62 | }, 63 | "source": [ 64 | "### Step 0. What is self-attention?\n", 65 | "\n", 66 | "If you’re thinking if self-attention is similar to attention, then the answer is yes! They fundamentally share the same concept and many common mathematical operations.\n", 67 | "A self-attention module takes in n inputs, and returns n outputs. What happens in this module? In layman’s terms, the self-attention mechanism allows the inputs to interact with each other (“self”) and find out who they should pay more attention to (“attention”). The outputs are aggregates of these interactions and attention scores." 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "id": "SDMmHAaSTE6P" 74 | }, 75 | "source": [ 76 | "Following, we are going to explain and implement:\n", 77 | "- Prepare inputs\n", 78 | "- Initialise weights\n", 79 | "- Derive key, query and value\n", 80 | "- Calculate attention scores for Input 1\n", 81 | "- Calculate softmax\n", 82 | "- Multiply scores with values\n", 83 | "- Sum weighted values to get Output 1\n", 84 | "- Repeat steps 4–7 for Input 2 & Input 3" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "metadata": { 90 | "id": "u1UxPJlHBVmS" 91 | }, 92 | "source": [ 93 | "import torch" 94 | ], 95 | "execution_count": null, 96 | "outputs": [] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "id": "ENdzUZqSBsiB" 102 | }, 103 | "source": [ 104 | "### Step 1: Prepare inputs\n", 105 | "\n", 106 | "For this tutorial, for the shake of simplicity, we start with 3 inputs, each with dimension 4.\n", 107 | "\n", 108 | "![texto alternativo](https://miro.medium.com/max/1973/1*hmvdDXrxhJsGhOQClQdkBA.png)\n" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "metadata": { 114 | "id": "jKYrJsljBhnv", 115 | "colab": { 116 | "base_uri": "https://localhost:8080/", 117 | "height": 70 118 | }, 119 | "outputId": "7b865905-2151-4a6a-a899-5439aa429af4" 120 | }, 121 | "source": [ 122 | "x = [\n", 123 | " [1, 0, 1, 0], # Input 1\n", 124 | " [0, 2, 0, 2], # Input 2\n", 125 | " [1, 1, 1, 1] # Input 3\n", 126 | " ]\n", 127 | "x = torch.tensor(x, dtype=torch.float32)\n", 128 | "x" 129 | ], 130 | "execution_count": null, 131 | "outputs": [ 132 | { 133 | "output_type": "execute_result", 134 | "data": { 135 | "text/plain": [ 136 | "tensor([[1., 0., 1., 0.],\n", 137 | " [0., 2., 0., 2.],\n", 138 | " [1., 1., 1., 1.]])" 139 | ] 140 | }, 141 | "metadata": { 142 | "tags": [] 143 | }, 144 | "execution_count": 2 145 | } 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": { 151 | "id": "DZ96EoE1Bvat" 152 | }, 153 | "source": [ 154 | "### Step 2: Initialise weights\n", 155 | "\n", 156 | "Every input must have three representations (see diagram below). These representations are called **key** (orange), **query** (red), and **value** (purple). For this example, let’s take that we want these representations to have a dimension of 3. Because every input has a dimension of 4, this means each set of the weights must have a shape of 4×3.\n", 157 | "\n", 158 | "![texto del enlace](https://miro.medium.com/max/1975/1*VPvXYMGjv0kRuoYqgFvCag.gif)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "metadata": { 164 | "id": "jUTNr15JBkSG", 165 | "outputId": "baa4c379-6174-4990-8cd2-51191e904550", 166 | "colab": { 167 | "base_uri": "https://localhost:8080/", 168 | "height": 284 169 | } 170 | }, 171 | "source": [ 172 | "w_key = [\n", 173 | " [0, 0, 1],\n", 174 | " [1, 1, 0],\n", 175 | " [0, 1, 0],\n", 176 | " [1, 1, 0]\n", 177 | "]\n", 178 | "w_query = [\n", 179 | " [1, 0, 1],\n", 180 | " [1, 0, 0],\n", 181 | " [0, 0, 1],\n", 182 | " [0, 1, 1]\n", 183 | "]\n", 184 | "w_value = [\n", 185 | " [0, 2, 0],\n", 186 | " [0, 3, 0],\n", 187 | " [1, 0, 3],\n", 188 | " [1, 1, 0]\n", 189 | "]\n", 190 | "w_key = torch.tensor(w_key, dtype=torch.float32)\n", 191 | "w_query = torch.tensor(w_query, dtype=torch.float32)\n", 192 | "w_value = torch.tensor(w_value, dtype=torch.float32)\n", 193 | "\n", 194 | "print(\"Weights for key: \\n\", w_key)\n", 195 | "print(\"Weights for query: \\n\", w_query)\n", 196 | "print(\"Weights for value: \\n\", w_value)" 197 | ], 198 | "execution_count": null, 199 | "outputs": [ 200 | { 201 | "output_type": "stream", 202 | "text": [ 203 | "Weights for key: \n", 204 | " tensor([[0., 0., 1.],\n", 205 | " [1., 1., 0.],\n", 206 | " [0., 1., 0.],\n", 207 | " [1., 1., 0.]])\n", 208 | "Weights for query: \n", 209 | " tensor([[1., 0., 1.],\n", 210 | " [1., 0., 0.],\n", 211 | " [0., 0., 1.],\n", 212 | " [0., 1., 1.]])\n", 213 | "Weights for value: \n", 214 | " tensor([[0., 2., 0.],\n", 215 | " [0., 3., 0.],\n", 216 | " [1., 0., 3.],\n", 217 | " [1., 1., 0.]])\n" 218 | ], 219 | "name": "stdout" 220 | } 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": { 226 | "id": "8pr9XZF9X_Ed" 227 | }, 228 | "source": [ 229 | "Note: *In a neural network setting, these weights are usually small numbers, initialised randomly using an appropriate random distribution like Gaussian, Xavier and Kaiming distributions.*" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "id": "UxGT5awVB1Xw" 236 | }, 237 | "source": [ 238 | "### Step 3: Derive key, query and value\n", 239 | "\n", 240 | "Now that we have the three sets of weights, let’s actually obtain the **key**, **query** and **value** representations for every input." 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": { 246 | "id": "VQwhDIi7aGXp" 247 | }, 248 | "source": [ 249 | "Obtaining the keys:\n", 250 | "```\n", 251 | " [0, 0, 1]\n", 252 | "[1, 0, 1, 0] [1, 1, 0] [0, 1, 1]\n", 253 | "[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]\n", 254 | "[1, 1, 1, 1] [1, 1, 0] [2, 3, 1]\n", 255 | "```\n", 256 | "![texto alternativo](https://miro.medium.com/max/1975/1*dr6NIaTfTxEWzxB2rc0JWg.gif)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": { 262 | "id": "Qi0EblXTamFz" 263 | }, 264 | "source": [ 265 | "Obtaining the values:\n", 266 | "```\n", 267 | " [0, 2, 0]\n", 268 | "[1, 0, 1, 0] [0, 3, 0] [1, 2, 3]\n", 269 | "[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]\n", 270 | "[1, 1, 1, 1] [1, 1, 0] [2, 6, 3]\n", 271 | "```\n", 272 | "![texto alternativo](https://miro.medium.com/max/1975/1*5kqW7yEwvcC0tjDOW3Ia-A.gif)\n" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": { 278 | "id": "GTp2izu1bLNq" 279 | }, 280 | "source": [ 281 | "Obtaining the querys:\n", 282 | "```\n", 283 | " [1, 0, 1]\n", 284 | "[1, 0, 1, 0] [1, 0, 0] [1, 0, 2]\n", 285 | "[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]\n", 286 | "[1, 1, 1, 1] [0, 1, 1] [2, 1, 3]\n", 287 | "```\n", 288 | "![texto alternativo](https://miro.medium.com/max/1975/1*wO_UqfkWkv3WmGQVHvrMJw.gif)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": { 294 | "id": "qegb9M0KbnRK" 295 | }, 296 | "source": [ 297 | "Notes: *Notes\n", 298 | "In practice, a bias vector may be added to the product of matrix multiplication.*" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "metadata": { 304 | "id": "rv2NXynOB7oG", 305 | "outputId": "a2656b52-4b1d-4726-9d42-522f941b3126", 306 | "colab": { 307 | "base_uri": "https://localhost:8080/", 308 | "height": 230 309 | } 310 | }, 311 | "source": [ 312 | "keys = x @ w_key\n", 313 | "querys = x @ w_query\n", 314 | "values = x @ w_value\n", 315 | "\n", 316 | "print(\"Keys: \\n\", keys)\n", 317 | "# tensor([[0., 1., 1.],\n", 318 | "# [4., 4., 0.],\n", 319 | "# [2., 3., 1.]])\n", 320 | "\n", 321 | "print(\"Querys: \\n\", querys)\n", 322 | "# tensor([[1., 0., 2.],\n", 323 | "# [2., 2., 2.],\n", 324 | "# [2., 1., 3.]])\n", 325 | "print(\"Values: \\n\", values)\n", 326 | "# tensor([[1., 2., 3.],\n", 327 | "# [2., 8., 0.],\n", 328 | "# [2., 6., 3.]])" 329 | ], 330 | "execution_count": null, 331 | "outputs": [ 332 | { 333 | "output_type": "stream", 334 | "text": [ 335 | "Keys: \n", 336 | " tensor([[0., 1., 1.],\n", 337 | " [4., 4., 0.],\n", 338 | " [2., 3., 1.]])\n", 339 | "Querys: \n", 340 | " tensor([[1., 0., 2.],\n", 341 | " [2., 2., 2.],\n", 342 | " [2., 1., 3.]])\n", 343 | "Values: \n", 344 | " tensor([[1., 2., 3.],\n", 345 | " [2., 8., 0.],\n", 346 | " [2., 6., 3.]])\n" 347 | ], 348 | "name": "stdout" 349 | } 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "id": "3pmf0OQhCnD8" 356 | }, 357 | "source": [ 358 | "### Step 4: Calculate attention scores\n", 359 | "![texto alternativo](https://miro.medium.com/max/1973/1*u27nhUppoWYIGkRDmYFN2A.gif)\n", 360 | "\n", 361 | "To obtain **attention scores**, we start off with taking a dot product between Input 1’s **query** (red) with **all keys** (orange), including itself. Since there are 3 key representations (because we have 3 inputs), we obtain 3 attention scores (blue).\n", 362 | "\n", 363 | "```\n", 364 | " [0, 4, 2]\n", 365 | "[1, 0, 2] x [1, 4, 3] = [2, 4, 4]\n", 366 | " [1, 0, 1]\n", 367 | "```\n", 368 | "Notice that we only use the query from Input 1. Later we’ll work on repeating this same step for the other querys.\n", 369 | "\n", 370 | "Note: *The above operation is known as dot product attention, one of the several score functions. Other score functions include scaled dot product and additive/concat.* " 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "metadata": { 376 | "id": "6GDhKEl0Cokw", 377 | "outputId": "c91356df-202c-4816-e98d-eefd1e1031d3", 378 | "colab": { 379 | "base_uri": "https://localhost:8080/", 380 | "height": 70 381 | } 382 | }, 383 | "source": [ 384 | "attn_scores = querys @ keys.T\n", 385 | "print(attn_scores)\n", 386 | "\n", 387 | "# tensor([[ 2., 4., 4.], # attention scores from Query 1\n", 388 | "# [ 4., 16., 12.], # attention scores from Query 2\n", 389 | "# [ 4., 12., 10.]]) # attention scores from Query 3" 390 | ], 391 | "execution_count": null, 392 | "outputs": [ 393 | { 394 | "output_type": "stream", 395 | "text": [ 396 | "tensor([[ 2., 4., 4.],\n", 397 | " [ 4., 16., 12.],\n", 398 | " [ 4., 12., 10.]])\n" 399 | ], 400 | "name": "stdout" 401 | } 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": { 407 | "id": "bO3NmnbvCxpX" 408 | }, 409 | "source": [ 410 | "### Step 5: Calculate softmax\n", 411 | "![texto alternativo](https://miro.medium.com/max/1973/1*jf__2D8RNCzefwS0TP1Kyg.gif)\n", 412 | "\n", 413 | "Take the **softmax** across these **attention scores** (blue).\n", 414 | "```\n", 415 | "softmax([2, 4, 4]) = [0.0, 0.5, 0.5]\n", 416 | "```" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "metadata": { 422 | "id": "PDNzdZHVC1ys", 423 | "outputId": "c528a7be-5c26-46a9-8fdb-1f2b029b6b93", 424 | "colab": { 425 | "base_uri": "https://localhost:8080/", 426 | "height": 124 427 | } 428 | }, 429 | "source": [ 430 | "from torch.nn.functional import softmax\n", 431 | "\n", 432 | "attn_scores_softmax = softmax(attn_scores, dim=-1)\n", 433 | "print(attn_scores_softmax)\n", 434 | "# tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],\n", 435 | "# [6.0337e-06, 9.8201e-01, 1.7986e-02],\n", 436 | "# [2.9539e-04, 8.8054e-01, 1.1917e-01]])\n", 437 | "\n", 438 | "# For readability, approximate the above as follows\n", 439 | "attn_scores_softmax = [\n", 440 | " [0.0, 0.5, 0.5],\n", 441 | " [0.0, 1.0, 0.0],\n", 442 | " [0.0, 0.9, 0.1]\n", 443 | "]\n", 444 | "attn_scores_softmax = torch.tensor(attn_scores_softmax)\n", 445 | "print(attn_scores_softmax)" 446 | ], 447 | "execution_count": null, 448 | "outputs": [ 449 | { 450 | "output_type": "stream", 451 | "text": [ 452 | "tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],\n", 453 | " [6.0337e-06, 9.8201e-01, 1.7986e-02],\n", 454 | " [2.9539e-04, 8.8054e-01, 1.1917e-01]])\n", 455 | "tensor([[0.0000, 0.5000, 0.5000],\n", 456 | " [0.0000, 1.0000, 0.0000],\n", 457 | " [0.0000, 0.9000, 0.1000]])\n" 458 | ], 459 | "name": "stdout" 460 | } 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": { 466 | "id": "iBe71nseDBhb" 467 | }, 468 | "source": [ 469 | "### Step 6: Multiply scores with values\n", 470 | "![texto alternativo](https://miro.medium.com/max/1973/1*9cTaJGgXPbiJ4AOCc6QHyA.gif)\n", 471 | "\n", 472 | "The softmaxed attention scores for each input (blue) is multiplied with its corresponding **value** (purple). This results in 3 alignment vectors (yellow). In this tutorial, we’ll refer to them as **weighted values**.\n", 473 | "```\n", 474 | "1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]\n", 475 | "2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]\n", 476 | "3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]\n", 477 | "```" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "metadata": { 483 | "id": "tNnx-Fx5DFDi", 484 | "outputId": "abc7a8ec-f964-483a-9bfb-2848f0e8e592", 485 | "colab": { 486 | "base_uri": "https://localhost:8080/", 487 | "height": 212 488 | } 489 | }, 490 | "source": [ 491 | "weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]\n", 492 | "print(weighted_values)" 493 | ], 494 | "execution_count": null, 495 | "outputs": [ 496 | { 497 | "output_type": "stream", 498 | "text": [ 499 | "tensor([[[0.0000, 0.0000, 0.0000],\n", 500 | " [0.0000, 0.0000, 0.0000],\n", 501 | " [0.0000, 0.0000, 0.0000]],\n", 502 | "\n", 503 | " [[1.0000, 4.0000, 0.0000],\n", 504 | " [2.0000, 8.0000, 0.0000],\n", 505 | " [1.8000, 7.2000, 0.0000]],\n", 506 | "\n", 507 | " [[1.0000, 3.0000, 1.5000],\n", 508 | " [0.0000, 0.0000, 0.0000],\n", 509 | " [0.2000, 0.6000, 0.3000]]])\n" 510 | ], 511 | "name": "stdout" 512 | } 513 | ] 514 | }, 515 | { 516 | "cell_type": "markdown", 517 | "metadata": { 518 | "id": "gU6w0U9ADQIc" 519 | }, 520 | "source": [ 521 | "### Step 7: Sum weighted values\n", 522 | "![texto alternativo](https://miro.medium.com/max/1973/1*1je5TwhVAwwnIeDFvww3ew.gif)\n", 523 | "\n", 524 | "Take all the **weighted values** (yellow) and sum them element-wise:\n", 525 | "\n", 526 | "```\n", 527 | " [0.0, 0.0, 0.0]\n", 528 | "+ [1.0, 4.0, 0.0]\n", 529 | "+ [1.0, 3.0, 1.5]\n", 530 | "-----------------\n", 531 | "= [2.0, 7.0, 1.5]\n", 532 | "```\n", 533 | "\n", 534 | "The resulting vector ```[2.0, 7.0, 1.5]``` (dark green) **is Output 1**, which is based on the **query representation from Input 1** interacting with all other keys, including itself.\n" 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": { 540 | "id": "P3yNYDUEgAos" 541 | }, 542 | "source": [ 543 | "### Step 8: Repeat for Input 2 & Input 3\n", 544 | "![texto alternativo](https://miro.medium.com/max/1973/1*G8thyDVqeD8WHim_QzjvFg.gif)\n", 545 | "\n", 546 | "Note: *The dimension of **query** and **key** must always be the same because of the dot product score function. However, the dimension of **value** may be different from **query** and **key**. The resulting output will consequently follow the dimension of **value**.*" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "metadata": { 552 | "id": "R6excNSUDRRj", 553 | "outputId": "e5161fbe-05a5-41d2-da1e-5951ce8b1674", 554 | "colab": { 555 | "base_uri": "https://localhost:8080/", 556 | "height": 70 557 | } 558 | }, 559 | "source": [ 560 | "outputs = weighted_values.sum(dim=0)\n", 561 | "print(outputs)\n", 562 | "\n", 563 | "# tensor([[2.0000, 7.0000, 1.5000], # Output 1\n", 564 | "# [2.0000, 8.0000, 0.0000], # Output 2\n", 565 | "# [2.0000, 7.8000, 0.3000]]) # Output 3" 566 | ], 567 | "execution_count": null, 568 | "outputs": [ 569 | { 570 | "output_type": "stream", 571 | "text": [ 572 | "tensor([[2.0000, 7.0000, 1.5000],\n", 573 | " [2.0000, 8.0000, 0.0000],\n", 574 | " [2.0000, 7.8000, 0.3000]])\n" 575 | ], 576 | "name": "stdout" 577 | } 578 | ] 579 | }, 580 | { 581 | "cell_type": "markdown", 582 | "metadata": { 583 | "id": "oavQirdbhAK7" 584 | }, 585 | "source": [ 586 | "### Bonus: Tensorflow 2 implementation" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "metadata": { 592 | "id": "575q0u_ahP-6", 593 | "colab": { 594 | "base_uri": "https://localhost:8080/", 595 | "height": 35 596 | }, 597 | "outputId": "867a4e88-2223-41e4-ccd5-dbc47f580c83" 598 | }, 599 | "source": [ 600 | "%tensorflow_version 2.x\n", 601 | "import tensorflow as tf" 602 | ], 603 | "execution_count": null, 604 | "outputs": [ 605 | { 606 | "output_type": "stream", 607 | "text": [ 608 | "TensorFlow 2.x selected.\n" 609 | ], 610 | "name": "stdout" 611 | } 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "metadata": { 617 | "id": "0vjwwEKMhqmZ", 618 | "colab": { 619 | "base_uri": "https://localhost:8080/", 620 | "height": 88 621 | }, 622 | "outputId": "56e5ed58-e100-434d-a8b2-00325bfc0d40" 623 | }, 624 | "source": [ 625 | "x = [\n", 626 | " [1, 0, 1, 0], # Input 1\n", 627 | " [0, 2, 0, 2], # Input 2\n", 628 | " [1, 1, 1, 1] # Input 3\n", 629 | " ]\n", 630 | "\n", 631 | "x = tf.convert_to_tensor(x, dtype=tf.float32)\n", 632 | "print(x)" 633 | ], 634 | "execution_count": null, 635 | "outputs": [ 636 | { 637 | "output_type": "stream", 638 | "text": [ 639 | "tf.Tensor(\n", 640 | "[[1. 0. 1. 0.]\n", 641 | " [0. 2. 0. 2.]\n", 642 | " [1. 1. 1. 1.]], shape=(3, 4), dtype=float32)\n" 643 | ], 644 | "name": "stdout" 645 | } 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "metadata": { 651 | "id": "TN-pri7rhwJ-", 652 | "colab": { 653 | "base_uri": "https://localhost:8080/", 654 | "height": 337 655 | }, 656 | "outputId": "aa8b1395-80a3-41e1-b544-beb06ce65a96" 657 | }, 658 | "source": [ 659 | "w_key = [\n", 660 | " [0, 0, 1],\n", 661 | " [1, 1, 0],\n", 662 | " [0, 1, 0],\n", 663 | " [1, 1, 0]\n", 664 | "]\n", 665 | "w_query = [\n", 666 | " [1, 0, 1],\n", 667 | " [1, 0, 0],\n", 668 | " [0, 0, 1],\n", 669 | " [0, 1, 1]\n", 670 | "]\n", 671 | "w_value = [\n", 672 | " [0, 2, 0],\n", 673 | " [0, 3, 0],\n", 674 | " [1, 0, 3],\n", 675 | " [1, 1, 0]\n", 676 | "]\n", 677 | "w_key = tf.convert_to_tensor(w_key, dtype=tf.float32)\n", 678 | "w_query = tf.convert_to_tensor(w_query, dtype=tf.float32)\n", 679 | "w_value = tf.convert_to_tensor(w_value, dtype=tf.float32)\n", 680 | "print(\"Weights for key: \\n\", w_key)\n", 681 | "print(\"Weights for query: \\n\", w_query)\n", 682 | "print(\"Weights for value: \\n\", w_value)\n" 683 | ], 684 | "execution_count": null, 685 | "outputs": [ 686 | { 687 | "output_type": "stream", 688 | "text": [ 689 | "Weights for key: \n", 690 | " tf.Tensor(\n", 691 | "[[0. 0. 1.]\n", 692 | " [1. 1. 0.]\n", 693 | " [0. 1. 0.]\n", 694 | " [1. 1. 0.]], shape=(4, 3), dtype=float32)\n", 695 | "Weights for query: \n", 696 | " tf.Tensor(\n", 697 | "[[1. 0. 1.]\n", 698 | " [1. 0. 0.]\n", 699 | " [0. 0. 1.]\n", 700 | " [0. 1. 1.]], shape=(4, 3), dtype=float32)\n", 701 | "Weights for value: \n", 702 | " tf.Tensor(\n", 703 | "[[0. 2. 0.]\n", 704 | " [0. 3. 0.]\n", 705 | " [1. 0. 3.]\n", 706 | " [1. 1. 0.]], shape=(4, 3), dtype=float32)\n" 707 | ], 708 | "name": "stdout" 709 | } 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "metadata": { 715 | "id": "Jp2DP46Sh19r", 716 | "colab": { 717 | "base_uri": "https://localhost:8080/", 718 | "height": 230 719 | }, 720 | "outputId": "5c1befaf-e096-454c-8402-885f049752e0" 721 | }, 722 | "source": [ 723 | "keys = tf.matmul(x, w_key)\n", 724 | "querys = tf.matmul(x, w_query)\n", 725 | "values = tf.matmul(x, w_value)\n", 726 | "print(keys)\n", 727 | "print(querys)\n", 728 | "print(values)" 729 | ], 730 | "execution_count": null, 731 | "outputs": [ 732 | { 733 | "output_type": "stream", 734 | "text": [ 735 | "tf.Tensor(\n", 736 | "[[0. 1. 1.]\n", 737 | " [4. 4. 0.]\n", 738 | " [2. 3. 1.]], shape=(3, 3), dtype=float32)\n", 739 | "tf.Tensor(\n", 740 | "[[1. 0. 2.]\n", 741 | " [2. 2. 2.]\n", 742 | " [2. 1. 3.]], shape=(3, 3), dtype=float32)\n", 743 | "tf.Tensor(\n", 744 | "[[1. 2. 3.]\n", 745 | " [2. 8. 0.]\n", 746 | " [2. 6. 3.]], shape=(3, 3), dtype=float32)\n" 747 | ], 748 | "name": "stdout" 749 | } 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "metadata": { 755 | "id": "tLJDo_bFigkm", 756 | "colab": { 757 | "base_uri": "https://localhost:8080/", 758 | "height": 88 759 | }, 760 | "outputId": "b5d8e02d-9531-49c8-a587-7a6e0b6f884d" 761 | }, 762 | "source": [ 763 | "attn_scores = tf.matmul(querys, keys, transpose_b=True)\n", 764 | "print(attn_scores)" 765 | ], 766 | "execution_count": null, 767 | "outputs": [ 768 | { 769 | "output_type": "stream", 770 | "text": [ 771 | "tf.Tensor(\n", 772 | "[[ 2. 4. 4.]\n", 773 | " [ 4. 16. 12.]\n", 774 | " [ 4. 12. 10.]], shape=(3, 3), dtype=float32)\n" 775 | ], 776 | "name": "stdout" 777 | } 778 | ] 779 | }, 780 | { 781 | "cell_type": "code", 782 | "metadata": { 783 | "id": "8QY858MEiibV", 784 | "colab": { 785 | "base_uri": "https://localhost:8080/", 786 | "height": 159 787 | }, 788 | "outputId": "2e84f48b-a4ed-4116-8655-21cbb9de8358" 789 | }, 790 | "source": [ 791 | "attn_scores_softmax = tf.nn.softmax(attn_scores, axis=-1)\n", 792 | "print(attn_scores_softmax)\n", 793 | "\n", 794 | "# For readability, approximate the above as follows\n", 795 | "attn_scores_softmax = [\n", 796 | " [0.0, 0.5, 0.5],\n", 797 | " [0.0, 1.0, 0.0],\n", 798 | " [0.0, 0.9, 0.1]\n", 799 | "]\n", 800 | "attn_scores_softmax = tf.convert_to_tensor(attn_scores_softmax)\n", 801 | "print(attn_scores_softmax)" 802 | ], 803 | "execution_count": null, 804 | "outputs": [ 805 | { 806 | "output_type": "stream", 807 | "text": [ 808 | "tf.Tensor(\n", 809 | "[[6.3378938e-02 4.6831051e-01 4.6831051e-01]\n", 810 | " [6.0336647e-06 9.8200780e-01 1.7986100e-02]\n", 811 | " [2.9538720e-04 8.8053685e-01 1.1916770e-01]], shape=(3, 3), dtype=float32)\n", 812 | "tf.Tensor(\n", 813 | "[[0. 0.5 0.5]\n", 814 | " [0. 1. 0. ]\n", 815 | " [0. 0.9 0.1]], shape=(3, 3), dtype=float32)\n" 816 | ], 817 | "name": "stdout" 818 | } 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "metadata": { 824 | "id": "TOJMfkFpi0KQ", 825 | "colab": { 826 | "base_uri": "https://localhost:8080/", 827 | "height": 230 828 | }, 829 | "outputId": "8de18989-50d7-4534-cf5c-2711c66d17ce" 830 | }, 831 | "source": [ 832 | "weighted_values = values[:,None] * tf.transpose(attn_scores_softmax)[:,:,None]\n", 833 | "print(weighted_values)" 834 | ], 835 | "execution_count": null, 836 | "outputs": [ 837 | { 838 | "output_type": "stream", 839 | "text": [ 840 | "tf.Tensor(\n", 841 | "[[[0. 0. 0. ]\n", 842 | " [0. 0. 0. ]\n", 843 | " [0. 0. 0. ]]\n", 844 | "\n", 845 | " [[1. 4. 0. ]\n", 846 | " [2. 8. 0. ]\n", 847 | " [1.8 7.2 0. ]]\n", 848 | "\n", 849 | " [[1. 3. 1.5]\n", 850 | " [0. 0. 0. ]\n", 851 | " [0.2 0.6 0.3]]], shape=(3, 3, 3), dtype=float32)\n" 852 | ], 853 | "name": "stdout" 854 | } 855 | ] 856 | }, 857 | { 858 | "cell_type": "code", 859 | "metadata": { 860 | "id": "jan_cyy7i-s7", 861 | "colab": { 862 | "base_uri": "https://localhost:8080/", 863 | "height": 88 864 | }, 865 | "outputId": "09b1406f-3a08-47e2-8dee-d4d6334ef1de" 866 | }, 867 | "source": [ 868 | "outputs = tf.reduce_sum(weighted_values, axis=0) # 6\n", 869 | "print(outputs)" 870 | ], 871 | "execution_count": null, 872 | "outputs": [ 873 | { 874 | "output_type": "stream", 875 | "text": [ 876 | "tf.Tensor(\n", 877 | "[[2. 7. 1.5 ]\n", 878 | " [2. 8. 0. ]\n", 879 | " [2. 7.7999997 0.3 ]], shape=(3, 3), dtype=float32)\n" 880 | ], 881 | "name": "stdout" 882 | } 883 | ] 884 | } 885 | ] 886 | } -------------------------------------------------------------------------------- /note/data_set.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "1253ac2f-6750-43c7-827c-ebe0ed0499b0", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [ 11 | { 12 | "name": "stdout", 13 | "output_type": "stream", 14 | "text": [ 15 | "2.2.0.dev20230916+cu121\n" 16 | ] 17 | }, 18 | { 19 | "data": { 20 | "text/plain": [ 21 | "True" 22 | ] 23 | }, 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "output_type": "execute_result" 27 | } 28 | ], 29 | "source": [ 30 | "import torch\n", 31 | "print(torch.__version__)\n", 32 | "torch.cuda.is_available()" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "54a3ad37-f151-43d9-b555-fcd8deddc0e9", 39 | "metadata": { 40 | "tags": [] 41 | }, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "0.1.99\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "import sentencepiece as spm\n", 53 | "print(spm.__version__)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "id": "f8ab86d2-8d7e-4af3-afbf-3fb6fa663988", 60 | "metadata": { 61 | "tags": [] 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "def train_model(fname, prefix):\n", 66 | " spm.SentencePieceTrainer.train(input=fname, model_prefix=prefix, vocab_size=16000)\n", 67 | " \n", 68 | "corpus = \"bird_shooter.txt\"\n", 69 | "prefix = \"bird_shooter\"\n", 70 | "train_model(corpus, prefix)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 5, 76 | "id": "abf3510e-cb35-4854-9a76-a139d9f553f4", 77 | "metadata": { 78 | "tags": [] 79 | }, 80 | "outputs": [ 81 | { 82 | "name": "stdout", 83 | "output_type": "stream", 84 | "text": [ 85 | "data split of train has 505009 tokens\n", 86 | "data split of test has 58877 tokens\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "def load_tokenizer(model_file):\n", 92 | " sp = spm.SentencePieceProcessor()\n", 93 | " if not sp.load(model_file=model_file):\n", 94 | " return False, None\n", 95 | " else:\n", 96 | " return True, sp\n", 97 | "\n", 98 | "def load_file_into_splits(text_file, split_ratio):\n", 99 | " with open(text_file, 'r') as file:\n", 100 | " data = file.read()\n", 101 | " split_idx = int(len(data) * split_ratio)\n", 102 | " return data[:split_idx], data[split_idx:]\n", 103 | "\n", 104 | "import numpy as np\n", 105 | "def encode_and_save(sp, content, prefix):\n", 106 | " token_ids = sp.encode(content, out_type=int)\n", 107 | " print(f\"data split of {prefix} has {len(token_ids)} tokens\")\n", 108 | " token_ids = np.array(token_ids, dtype=np.int32)\n", 109 | " token_ids.tofile(\"{}.dat\".format(prefix))\n", 110 | " \n", 111 | "import sys\n", 112 | "def gen_dataset(text_file, model_file):\n", 113 | " flag, sp = load_tokenizer(model_file)\n", 114 | " if not flag:\n", 115 | " print(f\"load tokenizer model from: {model_file} failed\")\n", 116 | " sys.exit(1)\n", 117 | " split_ratio = 0.9\n", 118 | " train_text, test_text = load_file_into_splits(text_file, split_ratio)\n", 119 | " encode_and_save(sp, train_text, \"train\")\n", 120 | " encode_and_save(sp, test_text, \"test\")\n", 121 | " \n", 122 | "gen_dataset(corpus, prefix+\".model\")" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 7, 128 | "id": "d1f4331e-9e16-42ca-8f57-77785d2095d8", 129 | "metadata": { 130 | "tags": [] 131 | }, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "features: 他不会不给六王爷的面子。”完颜洪烈道\n", 138 | "targets: 不会不给六王爷的面子。”完颜洪烈道:“\n", 139 | "features: 女儿家的痴情呆想,这人哪里\n", 140 | "targets: 家的痴情呆想,这人哪里是甚么\n", 141 | "features: 得更低了。完颜康心中一荡,伸出左臂\n", 142 | "targets: 更低了。完颜康心中一荡,伸出左臂去\n", 143 | "features: 呢”郭靖道,“我接她到桃花岛上\n", 144 | "targets: ”郭靖道,“我接她到桃花岛上住\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "def get_batch(data, batch_size=4):\n", 150 | " win_len = 10\n", 151 | " ix = torch.randint(len(data)-win_len, (batch_size,))\n", 152 | " x = np.stack([data[i:i+win_len] for i in ix])\n", 153 | " y = np.stack([data[i+1:i+1+win_len] for i in ix])\n", 154 | " return x, y\n", 155 | "\n", 156 | "model_file = prefix + \".model\"\n", 157 | "\n", 158 | "def gen_samples(fname):\n", 159 | " train_data = np.memmap(fname, dtype=np.int32, mode='r')\n", 160 | " x, y = get_batch(train_data)\n", 161 | " \n", 162 | " flag, sp = load_tokenizer(model_file)\n", 163 | " if not flag:\n", 164 | " print(f\"load tokenizer model from: {model_file} failed\")\n", 165 | " sys.exit(1)\n", 166 | " \n", 167 | " for features, targets in zip(x, y):\n", 168 | " print(\"features: \", sp.decode(features.tolist()))\n", 169 | " print(\"targets: \", sp.decode(targets.tolist()))\n", 170 | "\n", 171 | "gen_samples(\"train.dat\")" 172 | ] 173 | } 174 | ], 175 | "metadata": { 176 | "kernelspec": { 177 | "display_name": "Python 3 (ipykernel)", 178 | "language": "python", 179 | "name": "python3" 180 | }, 181 | "language_info": { 182 | "codemirror_mode": { 183 | "name": "ipython", 184 | "version": 3 185 | }, 186 | "file_extension": ".py", 187 | "mimetype": "text/x-python", 188 | "name": "python", 189 | "nbconvert_exporter": "python", 190 | "pygments_lexer": "ipython3", 191 | "version": "3.10.12" 192 | } 193 | }, 194 | "nbformat": 4, 195 | "nbformat_minor": 5 196 | } 197 | -------------------------------------------------------------------------------- /note/gpt_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "9e089215-da34-4708-89f7-0f2e15a75e10", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import torch\n", 13 | "from torch import nn" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "f946bf15-e596-4dd9-861d-edbb6bb984f4", 20 | "metadata": { 21 | "tags": [] 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "vocab_size = 16000\n", 26 | "seq_len = 128\n", 27 | "d_model = 128\n", 28 | "n_layer = 4\n", 29 | "n_head = 4" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 12, 35 | "id": "21432ea8-6b23-4549-980d-517e7748c8b2", 36 | "metadata": { 37 | "tags": [] 38 | }, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "gpt => 郭靖一掌挥出恶毒果是一技谄的话不加理会众姬犒原本赏去甚是畅快油腻一灯大师的两个嗤的一声点点头延倚奈\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "import math\n", 50 | "from torchinfo import summary\n", 51 | "from torch.nn import functional as F\n", 52 | "\n", 53 | "class SinusoidPE(nn.Module):\n", 54 | " \"\"\" sin/cos position encoding \"\"\"\n", 55 | " def __init__(self):\n", 56 | " super().__init__()\n", 57 | " \n", 58 | " pe = torch.zeros(seq_len, d_model)\n", 59 | " pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)\n", 60 | " emb = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n", 61 | " pe[:,0::2] = torch.sin(pos * emb)\n", 62 | " pe[:,1::2] = torch.cos(pos * emb)\n", 63 | " \n", 64 | " # token embedding: B * C * E\n", 65 | " # pos embedding: 1 * C * E\n", 66 | " pe = pe.unsqueeze(0)\n", 67 | " self.register_buffer('sinusoid_pe', pe)\n", 68 | " \n", 69 | " def forward(self, x):\n", 70 | " return self.sinusoid_pe[:, :x.shape[1],:]\n", 71 | "\n", 72 | "class FeedForward(nn.Module):\n", 73 | " def __init__(self, n_embd, dropout=0.0):\n", 74 | " super().__init__()\n", 75 | " self.net = nn.Sequential(\n", 76 | " nn.Linear(n_embd, 4 * n_embd),\n", 77 | " nn.ReLU(),\n", 78 | " nn.Linear(4 * n_embd, n_embd),\n", 79 | " nn.Dropout(dropout),\n", 80 | " )\n", 81 | " \n", 82 | " def forward(self, x):\n", 83 | " return self.net(x)\n", 84 | " \n", 85 | " \n", 86 | "class Head(nn.Module):\n", 87 | " \"\"\" one head of self-attention \"\"\"\n", 88 | "\n", 89 | " def __init__(self, head_size, dropout=0.0):\n", 90 | " super().__init__()\n", 91 | " self.key = nn.Linear(d_model, head_size, bias=False)\n", 92 | " self.query = nn.Linear(d_model, head_size, bias=False)\n", 93 | " self.value = nn.Linear(d_model, head_size, bias=False)\n", 94 | " self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len)))\n", 95 | "\n", 96 | " self.dropout = nn.Dropout(dropout)\n", 97 | "\n", 98 | " def forward(self, x):\n", 99 | " B, C, E = x.shape\n", 100 | " k = self.key(x) # (B, C, E)\n", 101 | " q = self.query(x) # (B, C, E)\n", 102 | " # compute attention scores (\"affinities\")\n", 103 | " wei = q @ k.transpose(-2, -1) * E**-0.5 # (B, C, E) @ (B, E, C) -> (B, C, C)\n", 104 | " wei = wei.masked_fill(self.mask[:C,:C] == 0, float('-inf')) # (B, C, C)\n", 105 | " wei = F.softmax(wei, dim=-1) # (B, C, C)\n", 106 | " wei = self.dropout(wei)\n", 107 | " # perform the weighted aggregation of the values\n", 108 | " v = self.value(x) # (B, C, E)\n", 109 | " out = wei @ v # (B, C, C) @ (B, C, E) -> (B, C, E)\n", 110 | " return out\n", 111 | " \n", 112 | "class SelfAttention(nn.Module):\n", 113 | " def __init__(self, num_heads, head_size, dropout=0.0):\n", 114 | " super().__init__()\n", 115 | " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n", 116 | " self.proj = nn.Linear(d_model, d_model)\n", 117 | " self.dropout = nn.Dropout(dropout)\n", 118 | " \n", 119 | " def forward(self, x):\n", 120 | " out = torch.cat([h(x) for h in self.heads], dim=-1)\n", 121 | " out = self.dropout(self.proj(out))\n", 122 | " return out\n", 123 | " \n", 124 | "class MultiAttention(nn.Module):\n", 125 | " def __init__(self, dropout=0.0):\n", 126 | " super().__init__()\n", 127 | " # self.w_q = nn.Linear(d_model, d_model)\n", 128 | " # self.w_k = nn.Linear(d_model, d_model)\n", 129 | " # self.w_v = nn.Linear(d_model, d_model)\n", 130 | " self.attn = nn.Linear(d_model, 3 * d_model)\n", 131 | " self.proj = nn.Linear(d_model, d_model)\n", 132 | " self.dropout = nn.Dropout(dropout) \n", 133 | " self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len))\n", 134 | " .view(1,1, seq_len, seq_len))\n", 135 | " \n", 136 | " def forward(self, x):\n", 137 | " B, C, E = x.shape\n", 138 | " q, k, v = self.attn(x).split(d_model, dim=2)\n", 139 | " q = q.view(B, C, n_head, E // n_head).transpose(1,2) # (B, C, nh, hs) -> (B, nh, C, hs)\n", 140 | " k = k.view(B, C, n_head, E // n_head).transpose(1,2) # (B, C, nh, hs)\n", 141 | " v = v.view(B, C, n_head, E // n_head).transpose(1,2) # (B, C, nh, hs)\n", 142 | " \n", 143 | " # (B, nh, C, hs) * (B, nh, hs, C) -> (B, nh, C, C)\n", 144 | " wei = q @ k.transpose(-2, -1) * (k.size(-1))**-0.5 \n", 145 | " wei = wei.masked_fill(self.mask[:,:,:C,:C] == 0, float('-inf'))\n", 146 | " wei = F.softmax(wei, dim=-1)\n", 147 | " wei = self.dropout(wei)\n", 148 | " att = wei @ v # (B, nh, C, C) * (B, nh, C, hs) -> (B, nh, C, hs)\n", 149 | " att = att.transpose(1,2).contiguous().view(B,C,E) # (B, nh, C, hs) -> (B, C, nh, hs) -> (B, C, E)\n", 150 | " \n", 151 | " out = self.proj(att)\n", 152 | " return out\n", 153 | " \n", 154 | "class Block(nn.Module):\n", 155 | " \n", 156 | " def __init__(self):\n", 157 | " super().__init__()\n", 158 | " head_size = d_model // n_head\n", 159 | " self.ln1 = nn.LayerNorm(d_model)\n", 160 | " # self.attn = SelfAttention(n_head, head_size)\n", 161 | " self.attn = MultiAttention()\n", 162 | " self.ln2 = nn.LayerNorm(d_model)\n", 163 | " self.ffn = FeedForward(d_model)\n", 164 | " \n", 165 | " def forward(self, x):\n", 166 | " x = x + self.attn(self.ln1(x))\n", 167 | " x = x + self.ffn(self.ln2(x))\n", 168 | " return x\n", 169 | "\n", 170 | "class GPTModel(nn.Module):\n", 171 | " \n", 172 | " def __init__(self):\n", 173 | " super().__init__()\n", 174 | " self.tok_embed_table = nn.Embedding(vocab_size, d_model)\n", 175 | " self.pos_embed_table = SinusoidPE()\n", 176 | " self.decoder_blocks = nn.Sequential(*[Block() for _ in range(n_layer)])\n", 177 | " self.ln = nn.LayerNorm(d_model)\n", 178 | " self.final_linear = nn.Linear(d_model, vocab_size)\n", 179 | " \n", 180 | " def forward(self, features, targets=None):\n", 181 | " tok_emb = self.tok_embed_table(features)\n", 182 | " pos_emb = self.pos_embed_table(tok_emb)\n", 183 | " x = tok_emb + pos_emb\n", 184 | " x = self.decoder_blocks(x)\n", 185 | " out = self.final_linear(self.ln(x))\n", 186 | " \n", 187 | " if targets is not None:\n", 188 | " B, C, V = out.shape\n", 189 | " out = out.view(B * C, V)\n", 190 | " targets = targets.view(B * C)\n", 191 | " loss = F.cross_entropy(out, targets)\n", 192 | " return out, loss\n", 193 | " else:\n", 194 | " return out, None\n", 195 | " \n", 196 | " @torch.no_grad()\n", 197 | " def generate(self, seq, max_new_tokens):\n", 198 | " for _ in range(max_new_tokens):\n", 199 | " seq = seq[:,-seq_len:] # B, L, E\n", 200 | " pred, _ = self(seq)\n", 201 | " pred = pred[:,-1,:] # B, C, V -> B, 1, V\n", 202 | " probs = F.softmax(pred, dim=-1)\n", 203 | " next_token = torch.multinomial(probs, num_samples=1) # [0.1, 0.7, 0.2]\n", 204 | " seq = torch.cat((seq, next_token), dim=1)\n", 205 | " return seq\n", 206 | " \n", 207 | "model = GPTModel()\n", 208 | "# summary(model)\n", 209 | "\n", 210 | "import sentencepiece as spm\n", 211 | "import sys\n", 212 | "\n", 213 | "model_file = \"bird_shooter.model\"\n", 214 | "sp = spm.SentencePieceProcessor()\n", 215 | "if not sp.load(model_file=model_file):\n", 216 | " print(\"load tokenizer model failed\")\n", 217 | " sys.exit(1)\n", 218 | "\n", 219 | "user_input = \"郭靖一掌挥出\" \n", 220 | "context = torch.tensor([sp.encode(user_input)], dtype=torch.int32)\n", 221 | "gpt_output = model.generate(context, max_new_tokens=20)[0].tolist()\n", 222 | "print(f\"gpt => {sp.decode(gpt_output)}\")" 223 | ] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "Python 3 (ipykernel)", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.10.12" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 5 247 | } 248 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | sentencepiece 4 | torch 5 | torchinfo 6 | --------------------------------------------------------------------------------