├── AutoTencent.py ├── ParseCpp.py ├── README.md ├── SelfAttention.ipynb ├── TestTencent.py ├── tools ├── README.md └── clang-16.0.1.1-py3-none-any.whl └── transformer.py /AutoTencent.py: -------------------------------------------------------------------------------- 1 | # utf-8 2 | from selenium import webdriver 3 | from selenium.webdriver.edge.options import Options 4 | from selenium.webdriver.common.by import By 5 | from selenium.webdriver.support.ui import WebDriverWait 6 | from selenium.webdriver.support import expected_conditions as EC 7 | import time,datetime 8 | 9 | # start=time.time() 10 | options =webdriver.ChromeOptions() 11 | # 设置EdgeOptions 12 | driver = webdriver.Chrome(options=options) 13 | # driver = webdriver.Edge() 14 | # 使用EdgeOptions初始化EdgeDriver 15 | 16 | url = "https//docs.qq.com/form/page/..." #在线收集表的url地址 17 | driver.get(url) 18 | 19 | 20 | # 打开网址 21 | # time.sleep(1) 22 | content='' 23 | elmet = driver.find_element(By.ID,"header-login-btn") 24 | elmet.click() 25 | driver.implicitly_wait(2) 26 | # driver.execute_script("window.location.reload()") 27 | elmet=driver.find_element(By.CSS_SELECTOR,'span.qq') 28 | elmet.click() 29 | while True: # 等待通过手机扫码或者其他方式登录,之后输入y即可开始等待开始抢填 30 | ch = input("Are you logined ok?(y/n)") 31 | if(ch == 'y'): 32 | break 33 | # elmet = driver.find_element(By.ID,"img_out_qqnum") 34 | # elmet.click() 35 | # print("快捷登录成功") 36 | # time.sleep(1) 37 | 38 | # 设定执行时间2023年9月20日14点 39 | execute_time = datetime.datetime(2023, 10, 20, 16, 38, 0) 40 | 41 | # 计算需要等待的时间 42 | wait_time = execute_time - datetime.datetime.now() 43 | 44 | # 等待到指定时间再执行 45 | if wait_time.total_seconds() > 0: 46 | print("Waiting for the beginning...") 47 | time.sleep(wait_time.total_seconds()) 48 | print("Beginning!") 49 | start = time.time() 50 | # 刷新网页 51 | driver.execute_script("window.location.reload()") 52 | 53 | # 等待元素出现 54 | timeout = 10 # 设置超时时间,单位为秒 55 | locator = (By.XPATH, "//textarea[@placeholder='请输入']") 56 | elements = [] 57 | while not elements: 58 | try: 59 | elements = WebDriverWait(driver, timeout).until(EC.presence_of_all_elements_located(locator)) 60 | except: 61 | pass 62 | elements[0].send_keys("11") # 输入需要填写的信息:姓名、联系方式... 63 | elements[1].send_keys("22") 64 | elements[2].send_keys("33") 65 | elements[3].send_keys("44") 66 | 67 | # 由于不同的网页收集表提交和确认按钮似乎会变化,不推荐使用css_selector选择 68 | # button = driver.find_element(By.CSS_SELECTOR,'#root > div.form-root.fill-form-root > div > div > div.form-fill-container > div.form-with-history-record.fill-area > div.form-body.form-fill-body > div.question-commit > button') 69 | # driver.execute_script("arguments[0].click();", button) 70 | # elmet.click() 71 | # button = driver.find_element(By.CSS_SELECTOR,'body > div.dui-modal-mask.dui-modal-mask-visible.dui-modal-mask-display > div > div.dui-modal-footer > button.dui-button.dui-modal-footer-ok.dui-button-type-primary.dui-button-size-default > div') 72 | # button.click() 73 | # elmet = driver.find_element(By.CSS_SELECTOR,'textarea[placeholder="请输入"]') 74 | # elmet.send_keys("1562320xxxx") 75 | # textareas = driver.find_elements_by_css_selector('textarea[placeholder="请输入"]') 76 | 77 | button = driver.find_element(By.XPATH,"//button[text()='提交']") 78 | driver.execute_script("arguments[0].click();", button) 79 | locator = (By.XPATH, "//button[contains(.,'确认')]") 80 | button = WebDriverWait(driver, timeout).until(EC.presence_of_element_located(locator)) 81 | button.click() 82 | 83 | print(time.time()-start) 84 | print("The current date and time is", datetime.datetime.now()) 85 | -------------------------------------------------------------------------------- /ParseCpp.py: -------------------------------------------------------------------------------- 1 | import clang.cindex 2 | import re, chardet, sys, json 3 | 4 | def extract_class_function_names_cpp(cpp_file_path): 5 | ''' 6 | 提取类名 函数名 7 | ''' 8 | # with open(cpp_file_path, 'rb') as file: 9 | # encoding = chardet.detect(file.read())['encoding'] 10 | with open(cpp_file_path, 'r', encoding='utf-8') as file: 11 | content = file.read() 12 | function_pattern = '\w+\**[\x20\n]+\w+\:\:\w+\([\s\S]*\)' 13 | function_names = re.findall(function_pattern, content) 14 | class_names = set() 15 | for function in function_names: 16 | class_names.add(function.split('::')[0].split(' ')[-1]) 17 | return class_names, function_names 18 | 19 | def get_functions(node): 20 | ''' 21 | 使用clang捕获函数节点 22 | ''' 23 | if node.kind == clang.cindex.CursorKind.FUNCTION_DECL or node.kind == clang.cindex.CursorKind.CXX_METHOD: 24 | yield node 25 | for child in node.get_children(): 26 | yield from get_functions(child) 27 | 28 | def get_function_code(filename, start, end): 29 | ''' 30 | 根据起始行和末尾行提取函数体 31 | ''' 32 | with open(filename, 'r', errors='ignore', encoding='utf-8') as f: 33 | function_code = f.readlines()[start-1:end] 34 | return ''.join(function_code) 35 | 36 | def extract_functions(filename): 37 | ''' 38 | 使用clang解析出函数的起始行和结束行,并提取函数体 39 | ''' 40 | index = clang.cindex.Index.create() 41 | tu = index.parse(filename) 42 | start_flag = 0 43 | functions = list() 44 | for function in get_functions(tu.cursor): 45 | clear = 1 46 | if function.is_definition(): 47 | start, end = function.extent.start.line, function.extent.end.line 48 | if start <= start_flag: 49 | clear = 1 50 | start_flag = start 51 | functions.append(get_function_code(filename, start, end)) 52 | return functions 53 | 54 | def add_class_define(file_name, class_names): 55 | ''' 56 | 在源代码文件中添加其中使用的自定义类,的简陋定义... 57 | ''' 58 | # with open(file_name, 'rb') as file: 59 | # encoding = chardet.detect(file.read())['encoding'] 60 | with open(file_name, 'r+', encoding='utf-8') as f: 61 | content = f.read() 62 | for class_name in class_names: 63 | f.seek(0, 0) 64 | # print('class {}'.format(class_name)+'{}\n') 65 | f.write('class {}'.format(class_name)+'{}\n') 66 | 67 | if __name__ == '__main__': 68 | if len(sys.argv) < 2: 69 | print("请提供文件路径作为参数") 70 | sys.exit(1) 71 | file_path = sys.argv[1] 72 | # file_path = '../../class/messages.cpp' 73 | 74 | clang.cindex.Config.set_library_file('D:/software/LLVM/bin/libclang.dll') 75 | 76 | class_names, _ = extract_class_function_names_cpp(file_path) 77 | # add_class_define(file_path, class_names) 78 | start = time.time() 79 | functions = extract_functions(file_path) 80 | print(f"Time cost in extract_functions: {time.time()-start}s") 81 | print(class_names, functions) 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Useful Scripts 2 | | File | Description | Remark | 3 | | ---- | ----------- | ------ | 4 | | AutoTencent.py | 腾讯文档在线收集表定时自动抢填脚本 | selenium环境依赖 | 5 | | TestTencent.py | 腾讯文档在线收集表选择题、填空题测试脚本 | selenium环境依赖 | 6 | | ParseCpp.py | cpp源代码文件简易解析脚本,可提取函数 | llvm, clang环境依赖, 见[tools](https://github.com/LiKe-rm/Useful-Scripts/tree/main/tools) | 7 | | ImageCompress.py | png/jpg图片定制尺寸压缩(极简)| - | 8 | | SelfAttention.ipynb | 细致拆解,pytorch实现自注意力、多头、交叉、因果(掩码)注意力 | pytorch | 9 | -------------------------------------------------------------------------------- /SelfAttention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "687d4428-758c-4ef9-9493-1601095f9ec0", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "tensor([0, 4, 5, 2, 1, 3])" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "# 参考资料: https://www.jiqizhixin.com/articles/2024-02-16\n", 22 | "# 一、文本输入转为向量嵌入\n", 23 | "# 分词、转为token下标\n", 24 | "import torch\n", 25 | "sentence = 'Life is short, eat dessert first'\n", 26 | "dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split(' ')))}\n", 27 | "s_index = [dc[s] for s in sentence.replace(',', '').split(' ')]\n", 28 | "s_index = torch.tensor(s_index)\n", 29 | "s_index" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "ba38b19a-db1b-41c7-84e4-46be06cee15b", 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "data": { 40 | "text/plain": [ 41 | "tensor([[ 0.3374, -0.1778, -0.3035],\n", 42 | " [ 0.1794, 1.8951, 0.4954],\n", 43 | " [ 0.2692, -0.0770, -1.0205],\n", 44 | " [-0.2196, -0.3792, 0.7671],\n", 45 | " [-0.5880, 0.3486, 0.6603],\n", 46 | " [-1.1925, 0.6984, -1.4097]])" 47 | ] 48 | }, 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "output_type": "execute_result" 52 | } 53 | ], 54 | "source": [ 55 | "# 对下标进行embed\n", 56 | "vocab_size = 50_000\n", 57 | "torch.manual_seed(123)\n", 58 | "embed = torch.nn.Embedding(vocab_size, 3)\n", 59 | "embeded_sentence = embed(s_index).detach()\n", 60 | "embeded_sentence" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 5, 66 | "id": "9fe94023-1935-4274-bcc5-61f323dda914", 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "tensor([[ 0.0327, -0.2112],\n", 74 | " [ 0.5667, 1.8269],\n", 75 | " [-0.0152, -0.7982],\n", 76 | " [-0.1037, 0.2902],\n", 77 | " [-0.0375, 0.5085],\n", 78 | " [-0.2816, -1.3567]], grad_fn=)\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "# 二、自注意力机制\n", 84 | "# 初始化q、k、v的权重矩阵\n", 85 | "torch.manual_seed(123)\n", 86 | "d = embeded_sentence.shape[1]\n", 87 | "d_q, d_k, d_v = 2, 2, 4\n", 88 | "\n", 89 | "w_q = torch.nn.Parameter(torch.rand(d, d_q))\n", 90 | "w_k = torch.nn.Parameter(torch.rand(d, d_k))\n", 91 | "w_v = torch.nn.Parameter(torch.rand(d, d_v))\n", 92 | "print(embeded_sentence @ w_q)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 16, 98 | "id": "b55c9e99-94cc-4332-b68a-4ffba8e24b37", 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "tensor([-0.0152, -0.7982], grad_fn=)\n", 106 | "tensor([[ 0.0327, -0.2112],\n", 107 | " [ 0.5667, 1.8269],\n", 108 | " [-0.0152, -0.7982],\n", 109 | " [-0.1037, 0.2902],\n", 110 | " [-0.0375, 0.5085],\n", 111 | " [-0.2816, -1.3567]], grad_fn=)\n", 112 | "tensor([[-0.0823, -0.3031],\n", 113 | " [ 0.5295, 1.7355],\n", 114 | " [-0.2991, -0.7295],\n", 115 | " [ 0.1420, 0.2291],\n", 116 | " [ 0.1920, 0.6467],\n", 117 | " [-0.4788, -0.5835]], grad_fn=)\n", 118 | "tensor([[-0.2546, -0.2608, -0.1544, -0.2801],\n", 119 | " [ 0.6612, 1.8972, 1.0963, 1.8106],\n", 120 | " [-0.8598, -0.6161, -0.5940, -0.9455],\n", 121 | " [ 0.5932, 0.0981, 0.2741, 0.4151],\n", 122 | " [ 0.5605, 0.5645, 0.3676, 0.6429],\n", 123 | " [-1.2107, -0.4929, -1.0081, -1.4031]], grad_fn=)\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "# 计算q、k、v\n", 129 | "x_2 = embeded_sentence[2]\n", 130 | "q_2 = x_2 @ w_q\n", 131 | "k_2 = x_2 @ w_k\n", 132 | "v_2 = x_2 @ w_v\n", 133 | "print(q_2)\n", 134 | "querys = embeded_sentence @ w_q\n", 135 | "keys = embeded_sentence @ w_k\n", 136 | "values = embeded_sentence @ w_v\n", 137 | "print(querys)\n", 138 | "print(keys)\n", 139 | "print(values)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 10, 145 | "id": "9f4a2814-4469-4100-80f0-79d82acd626a", 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "tensor(-0.5191, grad_fn=)\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "# 注意力权重ω(i,j) 是查询和键序列之间的点积 ω(i,j) = q⁽ⁱ⁾ k⁽ʲ⁾\n", 158 | "omega_24 = q_2.dot(keys[4])\n", 159 | "print(omega_24)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 19, 165 | "id": "f15509b6-2e68-4146-9d85-df2a82bc8e7b", 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "data": { 170 | "text/plain": [ 171 | "tensor([ 0.2432, -1.3934, 0.5869, -0.1851, -0.5191, 0.4730],\n", 172 | " grad_fn=)" 173 | ] 174 | }, 175 | "execution_count": 19, 176 | "metadata": {}, 177 | "output_type": "execute_result" 178 | } 179 | ], 180 | "source": [ 181 | "# 例:计算第三个词对整个序列的注意力权重 w, omega\n", 182 | "omega_2 = q_2 @ keys.T\n", 183 | "omega_2" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 20, 189 | "id": "67a1ff6f-909c-4f7c-b5d7-f0e938e7bb63", 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "tensor([0.1965, 0.0618, 0.2506, 0.1452, 0.1146, 0.2312],\n", 197 | " grad_fn=)\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "# 归一化\n", 203 | "import torch.nn.functional as F\n", 204 | "attention_w_2 = F.softmax(omega_2/d_k ** 0.5, dim=0)\n", 205 | "print(attention_w_2)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 21, 211 | "id": "47d906a2-8ce4-45a2-bd50-23c520429719", 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "tensor([-0.3542, -0.1234, -0.2627, -0.3706], grad_fn=)" 218 | ] 219 | }, 220 | "execution_count": 21, 221 | "metadata": {}, 222 | "output_type": "execute_result" 223 | } 224 | ], 225 | "source": [ 226 | "# 使用归一化后的注意力,计算上下文向量嵌入\n", 227 | "context_vec_2 = attention_w_2 @ values\n", 228 | "context_vec_2" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 25, 234 | "id": "0ae9d3a5-9991-449b-b944-546062af6324", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "# 将自注意力融合为一个类\n", 239 | "import torch.nn as nn\n", 240 | "class SelfAttention(nn.Module):\n", 241 | " def __init__(self, d_in, d_out_kq, d_out_v):\n", 242 | " super().__init__()\n", 243 | " self.d_out_kq = d_out_kq\n", 244 | " self.w_query = nn.Parameter(torch.rand(d_in, d_out_kq))\n", 245 | " self.w_key = nn.Parameter(torch.rand(d_in, d_out_kq))\n", 246 | " self.w_value = nn.Parameter(torch.rand(d_in, d_out_v))\n", 247 | " \n", 248 | " def forward(self, x):\n", 249 | " keys = x @ self.w_key\n", 250 | " queries = x @ self.w_query\n", 251 | " values = x @ self.w_value\n", 252 | " attn_scores = queries @ keys.T\n", 253 | " # 得到归一化的,每个token彼此之间的注意力值,seq_length * seq_length\n", 254 | " attn_weights = torch.softmax(attn_scores/self.d_out_kq ** 0.5, dim=-1)\n", 255 | " # 得到在每一个value维度上,每个token使用自己与其他token的注意力 @ 该维度的value , seq_length * d_v\n", 256 | " context_vec = attn_weights @ values\n", 257 | " return context_vec" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 26, 263 | "id": "692b397c-eb2f-4cd3-a91e-301593663a70", 264 | "metadata": {}, 265 | "outputs": [ 266 | { 267 | "name": "stdout", 268 | "output_type": "stream", 269 | "text": [ 270 | "tensor([[-0.1564, 0.1028, -0.0763, -0.0764],\n", 271 | " [ 0.5313, 1.3607, 0.7891, 1.3110],\n", 272 | " [-0.3542, -0.1234, -0.2627, -0.3706],\n", 273 | " [ 0.0071, 0.3345, 0.0969, 0.1998],\n", 274 | " [ 0.1008, 0.4780, 0.2021, 0.3674],\n", 275 | " [-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=)\n" 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "# 测试,结果中的第三行与上文计算的上下文嵌入一致\n", 281 | "torch.manual_seed(123)\n", 282 | "\n", 283 | "d_in, d_out_kq, d_out_v = 3, 2, 4\n", 284 | "\n", 285 | "sa = SelfAttention(d_in, d_out_kq, d_out_v)\n", 286 | "print(sa(embeded_sentence))" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 31, 292 | "id": "4d9d8261-4b19-4c61-a7c1-5864797e3978", 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "# 三、多头注意力\n", 297 | "class MultiHeadAttentionWrapper(nn.Module):\n", 298 | " def __init__(self, d_in, d_out_kq, d_out_v, num_heads):\n", 299 | " super().__init__()\n", 300 | " self.heads = nn.ModuleList([SelfAttention(d_in, d_out_kq, d_out_v) for _ in range(num_heads)])\n", 301 | " def forward(self, x):\n", 302 | " return torch.cat([head(x) for head in self.heads], dim=-1)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 33, 308 | "id": "d73c9284-76a0-4519-b6a8-1fb6af8ee839", 309 | "metadata": {}, 310 | "outputs": [ 311 | { 312 | "name": "stdout", 313 | "output_type": "stream", 314 | "text": [ 315 | "tensor([[-0.0185],\n", 316 | " [ 0.4003],\n", 317 | " [-0.1103],\n", 318 | " [ 0.0668],\n", 319 | " [ 0.1180],\n", 320 | " [-0.1827]], grad_fn=)\n", 321 | "tensor([[-0.0185, 0.0170, 0.1999, -0.0860],\n", 322 | " [ 0.4003, 1.7137, 1.3981, 1.0497],\n", 323 | " [-0.1103, -0.1609, 0.0079, -0.2416],\n", 324 | " [ 0.0668, 0.3534, 0.2322, 0.1008],\n", 325 | " [ 0.1180, 0.6949, 0.3157, 0.2807],\n", 326 | " [-0.1827, -0.2060, -0.2393, -0.3167]], grad_fn=)\n" 327 | ] 328 | } 329 | ], 330 | "source": [ 331 | "# 举例测试多头注意力机制\n", 332 | "torch.manual_seed(123)\n", 333 | "\n", 334 | "# 单头注意力,输出维度为seq_len * d_v\n", 335 | "d_in, d_out_kq, d_out_v = 3, 2, 1\n", 336 | "sa = SelfAttention(d_in, d_out_kq, d_out_v)\n", 337 | "print(sa(embeded_sentence))\n", 338 | "\n", 339 | "# 多头注意力,在最后一个维度拼接,输出维度为 seq_len * (d_v * num_heads)\n", 340 | "torch.manual_seed(123)\n", 341 | "mha = MultiHeadAttentionWrapper(d_in, d_out_kq, d_out_v, num_heads=4)\n", 342 | "context_vecs = mha(embeded_sentence)\n", 343 | "print(context_vecs)" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 35, 349 | "id": "0c6df9cb-f33b-4e32-b792-77704a4b9c70", 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "# 四、交叉注意力(从selfattention的基础上改)\n", 354 | "import torch.nn as nn\n", 355 | "class CrossAttention(nn.Module):\n", 356 | " def __init__(self, d_in, d_out_kq, d_out_v):\n", 357 | " super().__init__()\n", 358 | " self.d_out_kq = d_out_kq\n", 359 | " self.w_query = nn.Parameter(torch.rand(d_in, d_out_kq))\n", 360 | " self.w_key = nn.Parameter(torch.rand(d_in, d_out_kq))\n", 361 | " self.w_value = nn.Parameter(torch.rand(d_in, d_out_v))\n", 362 | " \n", 363 | " def forward(self, x1, x2):\n", 364 | " queries = x2 @ self.w_query\n", 365 | " \n", 366 | " keys = x1 @ self.w_key\n", 367 | " values = x1 @ self.w_value\n", 368 | " attn_scores = queries @ keys.T\n", 369 | " # 得到归一化的,每个x2的token对每个x1的token之间的注意力值,输出维度x2_seq_length * x1_seq_length\n", 370 | " attn_weights = torch.softmax(attn_scores/self.d_out_kq ** 0.5, dim=-1)\n", 371 | " # 得到在每一个value维度上,每个x2的token对每个x1的token之间的注意力值 @ x1的该token在该维度的value \n", 372 | " # 输出维度为x2_seq_length * d_v,代表从value的不同维度上(角度/语境)对应不同注意力值,计算上下文的嵌入\n", 373 | " context_vec = attn_weights @ values\n", 374 | " return context_vec\n", 375 | " " 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 36, 381 | "id": "2f25a946-381a-4125-9fbc-5aad1d2825f4", 382 | "metadata": {}, 383 | "outputs": [ 384 | { 385 | "name": "stdout", 386 | "output_type": "stream", 387 | "text": [ 388 | "tensor([[ 0.3374, -0.1778, -0.3035],\n", 389 | " [ 0.1794, 1.8951, 0.4954],\n", 390 | " [ 0.2692, -0.0770, -1.0205],\n", 391 | " [-0.2196, -0.3792, 0.7671],\n", 392 | " [-0.5880, 0.3486, 0.6603],\n", 393 | " [-1.1925, 0.6984, -1.4097]])\n", 394 | "tensor([[0.2745, 0.6584, 0.2775],\n", 395 | " [0.8573, 0.8993, 0.0390],\n", 396 | " [0.9268, 0.7388, 0.7179],\n", 397 | " [0.7058, 0.9156, 0.4340],\n", 398 | " [0.0772, 0.3565, 0.1479],\n", 399 | " [0.5331, 0.4066, 0.2318],\n", 400 | " [0.4545, 0.9737, 0.4606],\n", 401 | " [0.5159, 0.4220, 0.5786]])\n" 402 | ] 403 | } 404 | ], 405 | "source": [ 406 | "torch.manual_seed(123)\n", 407 | "d_in, d_out_kq, d_out_v = 3, 2, 4\n", 408 | "cat = CrossAttention(d_in, d_out_kq, d_out_v)\n", 409 | "\n", 410 | "x1 = embeded_sentence\n", 411 | "x2 = torch.rand(8, d_in)\n", 412 | "print(x1)\n", 413 | "print(x2)" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 37, 419 | "id": "8aad73b9-9b8a-4200-b930-3a21c59f0c15", 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "name": "stdout", 424 | "output_type": "stream", 425 | "text": [ 426 | "tensor([[0.2628, 0.7515, 0.3963, 0.6775],\n", 427 | " [0.3689, 0.9600, 0.5367, 0.9030],\n", 428 | " [0.4914, 1.2517, 0.7219, 1.2023],\n", 429 | " [0.4381, 1.1187, 0.6384, 1.0672],\n", 430 | " [0.0906, 0.4545, 0.1880, 0.3441],\n", 431 | " [0.2374, 0.7029, 0.3635, 0.6248],\n", 432 | " [0.4167, 1.0701, 0.6070, 1.0166],\n", 433 | " [0.3376, 0.8998, 0.4955, 0.8371]], grad_fn=)\n", 434 | "torch.Size([8, 4])\n" 435 | ] 436 | } 437 | ], 438 | "source": [ 439 | "context_vecs = cat(x1, x2)\n", 440 | "print(context_vecs)\n", 441 | "print(context_vecs.shape)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 38, 447 | "id": "d68e27b1-d2e7-44d1-9772-c8052db13361", 448 | "metadata": {}, 449 | "outputs": [ 450 | { 451 | "name": "stdout", 452 | "output_type": "stream", 453 | "text": [ 454 | "tensor([[ 0.0613, -0.3491, 0.1443, -0.0437, -0.1303, 0.1076],\n", 455 | " [-0.6004, 3.4707, -1.5023, 0.4991, 1.2903, -1.3374],\n", 456 | " [ 0.2432, -1.3934, 0.5869, -0.1851, -0.5191, 0.4730],\n", 457 | " [-0.0794, 0.4487, -0.1807, 0.0518, 0.1677, -0.1197],\n", 458 | " [-0.1510, 0.8626, -0.3597, 0.1112, 0.3216, -0.2787],\n", 459 | " [ 0.4344, -2.5037, 1.0740, -0.3509, -0.9315, 0.9265]],\n", 460 | " grad_fn=)\n" 461 | ] 462 | } 463 | ], 464 | "source": [ 465 | "# 五、掩码自注意力(因果自注意力,Causal self-attention)\n", 466 | "# 回顾自注意力\n", 467 | "torch.manual_seed(123)\n", 468 | "\n", 469 | "d_in, d_out_kq, d_out_v = 3, 2, 4\n", 470 | "\n", 471 | "w_q = torch.nn.Parameter(torch.rand(d_in, d_out_kq))\n", 472 | "w_k = torch.nn.Parameter(torch.rand(d_in, d_out_kq))\n", 473 | "w_v = torch.nn.Parameter(torch.rand(d_in, d_out_v))\n", 474 | "\n", 475 | "x = embeded_sentence\n", 476 | "\n", 477 | "q = x @ w_q\n", 478 | "k = x @ w_k\n", 479 | "atten_scores = q @ k.T\n", 480 | "print(atten_scores)" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 40, 486 | "id": "bacc11bb-a215-4aeb-acfd-935c46e7e23a", 487 | "metadata": {}, 488 | "outputs": [ 489 | { 490 | "data": { 491 | "text/plain": [ 492 | "tensor([[0.1772, 0.1326, 0.1879, 0.1645, 0.1547, 0.1831],\n", 493 | " [0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229],\n", 494 | " [0.1965, 0.0618, 0.2506, 0.1452, 0.1146, 0.2312],\n", 495 | " [0.1505, 0.2187, 0.1401, 0.1651, 0.1793, 0.1463],\n", 496 | " [0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.1231],\n", 497 | " [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],\n", 498 | " grad_fn=)" 499 | ] 500 | }, 501 | "execution_count": 40, 502 | "metadata": {}, 503 | "output_type": "execute_result" 504 | } 505 | ], 506 | "source": [ 507 | "atten_weights = torch.softmax(atten_scores/d_out_kq**0.5, dim=1)\n", 508 | "atten_weights" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": 41, 514 | "id": "c5c969d2-4458-4777-b785-642b47ae6e6f", 515 | "metadata": {}, 516 | "outputs": [ 517 | { 518 | "data": { 519 | "text/plain": [ 520 | "tensor([[1., 0., 0., 0., 0., 0.],\n", 521 | " [1., 1., 0., 0., 0., 0.],\n", 522 | " [1., 1., 1., 0., 0., 0.],\n", 523 | " [1., 1., 1., 1., 0., 0.],\n", 524 | " [1., 1., 1., 1., 1., 0.],\n", 525 | " [1., 1., 1., 1., 1., 1.]])" 526 | ] 527 | }, 528 | "execution_count": 41, 529 | "metadata": {}, 530 | "output_type": "execute_result" 531 | } 532 | ], 533 | "source": [ 534 | "# 使用torch.tril构建简单掩码矩阵\n", 535 | "block_size = atten_scores.shape[0]\n", 536 | "mask_simple = torch.tril(torch.ones(block_size, block_size))\n", 537 | "mask_simple" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": 42, 543 | "id": "946e38df-0805-458d-ae89-4df2fe5221c6", 544 | "metadata": {}, 545 | "outputs": [ 546 | { 547 | "data": { 548 | "text/plain": [ 549 | "tensor([[0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 550 | " [0.0386, 0.6870, 0.0000, 0.0000, 0.0000, 0.0000],\n", 551 | " [0.1965, 0.0618, 0.2506, 0.0000, 0.0000, 0.0000],\n", 552 | " [0.1505, 0.2187, 0.1401, 0.1651, 0.0000, 0.0000],\n", 553 | " [0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.0000],\n", 554 | " [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],\n", 555 | " grad_fn=)" 556 | ] 557 | }, 558 | "execution_count": 42, 559 | "metadata": {}, 560 | "output_type": "execute_result" 561 | } 562 | ], 563 | "source": [ 564 | "# 使用*构建掩码注意力\n", 565 | "masked_atten = atten_weights * mask_simple\n", 566 | "masked_atten" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 48, 572 | "id": "a4f04b4f-2fa1-4636-91e2-f31f8dc447f9", 573 | "metadata": {}, 574 | "outputs": [ 575 | { 576 | "name": "stdout", 577 | "output_type": "stream", 578 | "text": [ 579 | "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 580 | " [0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],\n", 581 | " [0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],\n", 582 | " [0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],\n", 583 | " [0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],\n", 584 | " [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],\n", 585 | " grad_fn=)\n", 586 | "tensor([[-0.2546, -0.2608, -0.1544, -0.2801],\n", 587 | " [ 0.6612, 1.8972, 1.0963, 1.8106],\n", 588 | " [-0.8598, -0.6161, -0.5940, -0.9455],\n", 589 | " [ 0.5932, 0.0981, 0.2741, 0.4151],\n", 590 | " [ 0.5605, 0.5645, 0.3676, 0.6429],\n", 591 | " [-1.2107, -0.4929, -1.0081, -1.4031]], grad_fn=)\n", 592 | "tensor([[-0.2546, -0.2608, -0.1544, -0.2801],\n", 593 | " [ 0.6124, 1.7823, 1.0298, 1.6994],\n", 594 | " [-0.4415, -0.1738, -0.2191, -0.3539],\n", 595 | " [ 0.1242, 0.4529, 0.2647, 0.4297],\n", 596 | " [ 0.2848, 0.6142, 0.3719, 0.6158],\n", 597 | " [-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=)\n" 598 | ] 599 | } 600 | ], 601 | "source": [ 602 | "# 再度进行归一化,逐行、注意力保持和为1\n", 603 | "row_sums = masked_atten.sum(dim=1, keepdim=True)\n", 604 | "masked_atten_norm = masked_atten / row_sums\n", 605 | "print(masked_atten_norm)\n", 606 | "\n", 607 | "# 使用掩码注意力计算上下文嵌入\n", 608 | "v = x @ w_v\n", 609 | "print(v)\n", 610 | "masked_context_vec = masked_atten_norm @ v\n", 611 | "print(masked_context_vec)" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 55, 617 | "id": "59f5c460-6e4c-4a0d-b299-e5f8a2d9d8a6", 618 | "metadata": {}, 619 | "outputs": [ 620 | { 621 | "name": "stdout", 622 | "output_type": "stream", 623 | "text": [ 624 | "tensor([[False, True, True, True, True, True],\n", 625 | " [False, False, True, True, True, True],\n", 626 | " [False, False, False, True, True, True],\n", 627 | " [False, False, False, False, True, True],\n", 628 | " [False, False, False, False, False, True],\n", 629 | " [False, False, False, False, False, False]])\n", 630 | "tensor([[ 0.0613, -inf, -inf, -inf, -inf, -inf],\n", 631 | " [-0.6004, 3.4707, -inf, -inf, -inf, -inf],\n", 632 | " [ 0.2432, -1.3934, 0.5869, -inf, -inf, -inf],\n", 633 | " [-0.0794, 0.4487, -0.1807, 0.0518, -inf, -inf],\n", 634 | " [-0.1510, 0.8626, -0.3597, 0.1112, 0.3216, -inf],\n", 635 | " [ 0.4344, -2.5037, 1.0740, -0.3509, -0.9315, 0.9265]],\n", 636 | " grad_fn=)\n", 637 | "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 638 | " [0.0168, 0.9832, 0.0000, 0.0000, 0.0000, 0.0000],\n", 639 | " [0.3839, 0.0747, 0.5414, 0.0000, 0.0000, 0.0000],\n", 640 | " [0.2110, 0.3578, 0.1907, 0.2406, 0.0000, 0.0000],\n", 641 | " [0.1338, 0.3688, 0.1086, 0.1740, 0.2147, 0.0000],\n", 642 | " [0.1888, 0.0100, 0.3580, 0.0861, 0.0482, 0.3089]],\n", 643 | " grad_fn=)\n" 644 | ] 645 | } 646 | ], 647 | "source": [ 648 | "# 六、掩码自注意力更高效的实现方法\n", 649 | "# 将上述 【计算注意力分数-》softmax权重-》掩码注意力-》归一化】 的过程使用 【负无穷掩码-》softmax】的方法实现\n", 650 | "mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", 651 | "print(mask.bool())\n", 652 | "masked_atten = atten_scores.masked_fill(mask.bool(), -torch.inf)\n", 653 | "print(masked_atten)\n", 654 | "masked_atten_soft = torch.softmax(masked_atten, dim=1)\n", 655 | "print(masked_atten_soft)" 656 | ] 657 | } 658 | ], 659 | "metadata": { 660 | "kernelspec": { 661 | "display_name": "Python 3 (ipykernel)", 662 | "language": "python", 663 | "name": "python3" 664 | }, 665 | "language_info": { 666 | "codemirror_mode": { 667 | "name": "ipython", 668 | "version": 3 669 | }, 670 | "file_extension": ".py", 671 | "mimetype": "text/x-python", 672 | "name": "python", 673 | "nbconvert_exporter": "python", 674 | "pygments_lexer": "ipython3", 675 | "version": "3.11.4" 676 | } 677 | }, 678 | "nbformat": 4, 679 | "nbformat_minor": 5 680 | } 681 | -------------------------------------------------------------------------------- /TestTencent.py: -------------------------------------------------------------------------------- 1 | # utf-8 2 | from selenium import webdriver 3 | from selenium.webdriver.common.by import By 4 | from selenium.webdriver.support.ui import WebDriverWait 5 | from selenium.webdriver.support import expected_conditions as EC 6 | import time, datetime 7 | 8 | 9 | # options =webdriver.ChromeOptions() 10 | # driver = webdriver.Chrome(options=options) 11 | 12 | # edge 通过添加环境变量的方式指定对应的driver.exe, ,如本机的D:\software\edgedriver_win64 13 | options = webdriver.EdgeOptions() 14 | driver = webdriver.Edge() 15 | 16 | def test0(): 17 | driver.get('https://bing.com') 18 | element = driver.find_element(By.ID, 'sb_form_q') 19 | element.send_keys('WebDriver') 20 | element.submit() 21 | time.sleep(5) 22 | driver.quit() 23 | 24 | 25 | def test(): 26 | url = "https://docs.qq.com/form/page/DSXJMWUVYWnl3UEZs?u=3c7b87fb66f7484e8549216d8fd08aa0#/fill" 27 | driver.get(url) 28 | 29 | elmet = driver.find_element(By.ID,"header-login-btn") 30 | elmet.click() 31 | driver.implicitly_wait(2) 32 | elmet=driver.find_element(By.CSS_SELECTOR,'span.qq') 33 | elmet.click() 34 | while True: # 等待通过手机扫码或者其他方式登录,之后输入y即可开始等待开始抢填 35 | ch = input("Are you logined ok?(y/n)") 36 | if(ch == 'y'): 37 | break 38 | 39 | # 填空题 40 | timeout = 10 # 设置超时时间,单位为秒 41 | locator = (By.XPATH, "//textarea[@placeholder='请输入']") 42 | elements = [] 43 | while not elements: 44 | try: 45 | elements = WebDriverWait(driver, timeout).until(EC.presence_of_all_elements_located(locator)) 46 | except: 47 | pass 48 | elements[0].send_keys("11") 49 | 50 | # 选择题 51 | radio_buttons = WebDriverWait(driver, 10).until( 52 | EC.presence_of_all_elements_located((By.CLASS_NAME, "form-choice-radio-option")) 53 | ) 54 | desired_class = "form-choice-radio-option-text-content" 55 | 56 | time.sleep(5) 57 | # 设定选项对应的key 58 | key = "一班" 59 | 60 | for radio_button in radio_buttons: 61 | text_content = radio_button.find_element(By.CLASS_NAME, desired_class).text 62 | if text_content == key: 63 | radio_button.click() 64 | break 65 | 66 | button = driver.find_element(By.XPATH,"//button[text()='提交']") 67 | driver.execute_script("arguments[0].click();", button) 68 | locator = (By.XPATH, "//button[contains(.,'确认')]") 69 | button = WebDriverWait(driver, timeout).until(EC.presence_of_element_located(locator)) 70 | button.click() 71 | 72 | if __name__ == "__main__": 73 | test() 74 | 75 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | ## cpp源代码文件解析工具 2 | ### clang安装方法: 3 | pip install clang-16.0.1.1-py3-none-any.whl 4 | ### llvm使用方法 5 | llvm[下载地址](https://github.com/llvm/llvm-project/releases/tag/llvmorg-17.0.1) 6 | 7 | 本例中使用的版本是`LLVM-17.0.1-win64.exe` 8 | 9 | 下载后,双击安装, 在python代码中设定对应路径即可,示例: 10 | ```python 11 | clang.cindex.Config.set_library_file('D:/software/LLVM/bin/libclang.dll') 12 | ``` 13 | -------------------------------------------------------------------------------- /tools/clang-16.0.1.1-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiKe-rm/Useful-Scripts/44098708cd31358733d79819e0cc4b4c920e2a13/tools/clang-16.0.1.1-py3-none-any.whl -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | device = torch.device('cuda') 8 | 9 | def scaled_softmax_attention(query, key, value): 10 | """ 11 | Args: 12 | query: torch.Tensor (..., L, D) 13 | key: torch.Tensor (..., L, D) 14 | value: torch.Tensor (..., L, D) 15 | Returns: 16 | res: torch.Tensor (..., L, D), output of the attention layer (\softmax(Q K^T / d) V 17 | attention: torch.Tensor (..., L, L), attention weights (\softmax(Q K^T / d)) 18 | 19 | L is the length of sequence, D is the embedding dimension 20 | """ 21 | QK = torch.matmul(query, torch.transpose(key, -2, -1)) 22 | QK /= torch.sqrt(torch.tensor(query.shape[-1])) 23 | 24 | attention = F.softmax(QK, dim=-1) 25 | 26 | res = torch.matmul(attention, value) 27 | return res, attention 28 | 29 | 30 | class MultiheadAttention(nn.Module): 31 | 32 | def __init__(self, embed_dim, num_heads): 33 | """ 34 | Args: 35 | embed_dim: dimensionality of embedding (total) 36 | num_heads: number of heads (must divide embed_dim) 37 | """ 38 | super().__init__() 39 | assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads." 40 | 41 | self.embed_dim = embed_dim 42 | self.num_heads = num_heads 43 | self.head_dim = embed_dim // num_heads 44 | # W_i_Q = d_model x d_k 45 | # W_i_K = d_model x d_k 46 | # W_i_V = d_model x d_v 47 | # W_O = h * d_v x d_model 48 | 49 | self.q_proj = nn.Linear(embed_dim, embed_dim) 50 | self.k_proj = nn.Linear(embed_dim, embed_dim) 51 | self.v_proj = nn.Linear(embed_dim, embed_dim) 52 | self.o_proj = nn.Linear(embed_dim, embed_dim) 53 | 54 | self._reset_parameters() 55 | 56 | # original implementation uses this initialization 57 | def _reset_parameters(self): 58 | for layer in self.modules(): 59 | if isinstance(layer, nn.Linear): 60 | nn.init.xavier_uniform_(layer.weight) 61 | if layer.bias is not None: 62 | layer.bias.data.fill_(0) 63 | 64 | 65 | def forward(self, x, return_attention=False): 66 | """ 67 | Args: 68 | x: torch.Tensor (B, L, D) 69 | return_attention: If specified, returns attention along with outputs 70 | Returns: 71 | outputs: torch.Tensor (B, L, D) 72 | attention: Optional[torch.Tensor] (B, num_heads, L, L) 73 | 74 | B is batch size, L is the length of sequence, D is the embedding dimension 75 | """ 76 | L = x.shape[1] 77 | batch_len = x.shape[0] 78 | outputs, attention = None, None 79 | 80 | Q = self.q_proj(x).reshape((batch_len, L, self.num_heads, self.head_dim)) 81 | K = self.k_proj(x).reshape((batch_len, L, self.num_heads, self.head_dim)) 82 | V = self.v_proj(x).reshape((batch_len, L, self.num_heads, self.head_dim)) 83 | 84 | Q = Q.transpose(1,2) 85 | K = K.transpose(1,2) 86 | V = V.transpose(1,2) 87 | 88 | outputs, attention = scaled_softmax_attention(Q, K, V) 89 | 90 | outputs = outputs.transpose(1,2).reshape((batch_len, L, self.embed_dim)) 91 | outputs = self.o_proj(outputs) 92 | 93 | if return_attention: 94 | return outputs, attention 95 | else: 96 | return outputs 97 | 98 | 99 | 100 | class EncoderBlock(nn.Module): 101 | 102 | def __init__(self, embed_dim, num_heads, feedforward_dim, activation=nn.ReLU, dropout=0.0): 103 | """ 104 | Inputs: 105 | embed_dim - Dimensionality of the input 106 | num_heads - Number of heads to use in the attention block 107 | feedforward_dim - Dimensionality of the hidden layer in the MLP 108 | activation - activation function in FFN 109 | dropout - Dropout probability to use in the dropout layers 110 | """ 111 | super().__init__() 112 | 113 | self.dropout1 = nn.Dropout(dropout) 114 | self.layernorm1 = nn.LayerNorm(embed_dim) 115 | 116 | 117 | self.multihead = MultiheadAttention(embed_dim, num_heads) 118 | self.activation = activation 119 | self.feedforward = nn.Sequential(*[ 120 | nn.Linear(embed_dim, feedforward_dim), 121 | nn.Dropout(dropout), 122 | self.activation(), 123 | nn.Linear(feedforward_dim, embed_dim) 124 | ]) 125 | 126 | self.dropout2 = nn.Dropout(dropout) 127 | self.layernorm2 = nn.LayerNorm(embed_dim) 128 | 129 | # TODO 130 | 131 | def forward(self, x, return_attention=False): 132 | """ 133 | Args: 134 | x: torch.Tensor (B, L, D) 135 | Returns: 136 | outputs: torch.Tensor (B, L, D) 137 | attention: Optional[torch.Tensor] (B, num_heads, L, L) 138 | """ 139 | residual = x 140 | if return_attention: 141 | outputs, attention = self.multihead(x, return_attention=return_attention) 142 | else: 143 | outputs = self.multihead(x) 144 | outputs = self.dropout1(outputs) 145 | outputs = self.layernorm1(outputs + residual) 146 | 147 | residual2 = outputs 148 | 149 | outputs = self.feedforward(outputs) 150 | outputs = self.dropout2(outputs) 151 | outputs = self.layernorm2(outputs + residual2) 152 | 153 | 154 | 155 | if return_attention: 156 | return outputs, attention 157 | else: 158 | return outputs 159 | 160 | 161 | 162 | class PositionalEncoding(nn.Module): 163 | 164 | def __init__(self, embed_dim, max_len: int = 5000): 165 | """ 166 | Inputs 167 | embed_dim - Hidden dimensionality of the input. 168 | max_len - Maximum length of a sequence to expect. 169 | """ 170 | super().__init__() 171 | # a tensor of size (1, max_len, embed_dim), dummy dimension is needed for proper addition 172 | pe = torch.zeros((1, max_len, embed_dim)).float() 173 | positions = torch.arange(0, max_len).float() 174 | positions = positions.unsqueeze(1) 175 | i_s = torch.arange(0, embed_dim, 2).float() 176 | 177 | pe[:,:, ::2] = torch.sin(positions / torch.pow(10000, i_s / embed_dim) ) 178 | pe[:,:, 1::2] = torch.cos(positions / torch.pow(10000, i_s / embed_dim) ) 179 | 180 | # register_buffer => Tensor which is not a parameter, but should be part of the modules state. 181 | # Used for tensors that need to be on the same device as the module. 182 | # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model) 183 | self.register_buffer('pe', pe, persistent=False) 184 | 185 | def forward(self, x): 186 | x = x + self.pe[:, :x.shape[1]] 187 | return x 188 | 189 | class TransformerForSequenceClassification(nn.Module): 190 | 191 | def __init__( 192 | self, 193 | input_dim: int, 194 | embed_dim: int, 195 | num_classes: int, 196 | num_heads: int, 197 | feedforward_dim: int, 198 | num_layers: int, 199 | activation = nn.GELU, 200 | max_len: int = 5000, 201 | dropout: float = 0.0 202 | ): 203 | super().__init__() 204 | # define layers 205 | self.cls_token = torch.randn(embed_dim) # TODO create vector of size (embed_dim,) from N(0, 1) 206 | self.input_embedding = nn.Linear(input_dim, embed_dim) 207 | self.positional_encoding = PositionalEncoding(embed_dim, max_len) # TODO 208 | 209 | encoder_blocks = nn.ModuleList([EncoderBlock(embed_dim, num_heads, feedforward_dim, activation, dropout) for i in range(num_layers)]) 210 | self.encoder = encoder_blocks 211 | 212 | self.classifier = nn.Linear(embed_dim, num_classes) 213 | 214 | def forward(self, x): 215 | """ 216 | Args: 217 | x: torch.Tensor (B, L, |V|) 218 | Returns: 219 | x: torch.Tensor (B, |C|) 220 | """ 221 | 222 | x = self.input_embedding(x) 223 | x = self.positional_encoding(x) 224 | x = torch.cat((x, self.cls_token.repeat(x.shape[0], 1, 1)), dim=1) 225 | 226 | for i, encoder_layer in enumerate(self.encoder): 227 | x = encoder_layer(x, return_attention=False) 228 | 229 | x = self.classifier(x[:, -1, :]) 230 | 231 | return x 232 | --------------------------------------------------------------------------------