├── .DS_Store ├── LICENSE ├── README.md ├── hearthstone ├── .vscode │ └── settings.json ├── Attention.py ├── CombinationLayer.py ├── ConvolutionForward.py ├── Dataset.py ├── DatasetSum.py ├── DenseLayer.py ├── Embedding.py ├── LayerNorm.py ├── Model.py ├── Multihead_Attention.py ├── Multihead_Combination.py ├── README.md ├── Radam.py ├── ScheduledOptim.py ├── SubLayerConnection.py ├── TokenEmbedding.py ├── Transfomer.py ├── TreeConv.py ├── TreeConvGen.py ├── __pycache__ │ ├── Attention.cpython-35.pyc │ ├── Attention.cpython-37.pyc │ ├── Attention.cpython-38.pyc │ ├── CombinationLayer.cpython-35.pyc │ ├── CombinationLayer.cpython-37.pyc │ ├── CombinationLayer.cpython-38.pyc │ ├── ConvolutionForward.cpython-35.pyc │ ├── ConvolutionForward.cpython-37.pyc │ ├── ConvolutionForward.cpython-38.pyc │ ├── Dataset.cpython-35.pyc │ ├── Dataset.cpython-37.pyc │ ├── Dataset.cpython-38.pyc │ ├── DenseLayer.cpython-35.pyc │ ├── DenseLayer.cpython-37.pyc │ ├── DenseLayer.cpython-38.pyc │ ├── Embedding.cpython-35.pyc │ ├── Embedding.cpython-37.pyc │ ├── Embedding.cpython-38.pyc │ ├── LayerNorm.cpython-35.pyc │ ├── LayerNorm.cpython-37.pyc │ ├── LayerNorm.cpython-38.pyc │ ├── Model.cpython-35.pyc │ ├── Model.cpython-37.pyc │ ├── Model.cpython-38.pyc │ ├── Multihead_Attention.cpython-35.pyc │ ├── Multihead_Attention.cpython-37.pyc │ ├── Multihead_Attention.cpython-38.pyc │ ├── Multihead_Combination.cpython-35.pyc │ ├── Multihead_Combination.cpython-37.pyc │ ├── Multihead_Combination.cpython-38.pyc │ ├── Radam.cpython-35.pyc │ ├── Radam.cpython-37.pyc │ ├── Radam.cpython-38.pyc │ ├── ScheduledOptim.cpython-35.pyc │ ├── ScheduledOptim.cpython-37.pyc │ ├── ScheduledOptim.cpython-38.pyc │ ├── SubLayerConnection.cpython-35.pyc │ ├── SubLayerConnection.cpython-37.pyc │ ├── SubLayerConnection.cpython-38.pyc │ ├── TokenEmbedding.cpython-35.pyc │ ├── TokenEmbedding.cpython-37.pyc │ ├── TokenEmbedding.cpython-38.pyc │ ├── Transfomer.cpython-35.pyc │ ├── Transfomer.cpython-37.pyc │ ├── Transfomer.cpython-38.pyc │ ├── TreeConv.cpython-35.pyc │ ├── TreeConv.cpython-37.pyc │ ├── TreeConv.cpython-38.pyc │ ├── TreeConvGen.cpython-35.pyc │ ├── TreeConvGen.cpython-37.pyc │ ├── TreeConvGen.cpython-38.pyc │ ├── decodeTrans.cpython-35.pyc │ ├── decodeTrans.cpython-37.pyc │ ├── decodeTrans.cpython-38.pyc │ ├── gcnn.cpython-35.pyc │ ├── gcnn.cpython-37.pyc │ ├── gcnn.cpython-38.pyc │ ├── gcnnnormal.cpython-35.pyc │ ├── gcnnnormal.cpython-37.pyc │ ├── gcnnnormal.cpython-38.pyc │ ├── gelu.cpython-35.pyc │ ├── gelu.cpython-37.pyc │ ├── gelu.cpython-38.pyc │ ├── postionEmbedding.cpython-35.pyc │ ├── postionEmbedding.cpython-37.pyc │ ├── postionEmbedding.cpython-38.pyc │ ├── rightnTransfomer.cpython-35.pyc │ ├── rightnTransfomer.cpython-37.pyc │ ├── rightnTransfomer.cpython-38.pyc │ ├── vocab.cpython-35.pyc │ ├── vocab.cpython-37.pyc │ └── vocab.cpython-38.pyc ├── cal.py ├── char_voc.pkl ├── code_voc.pkl ├── decodeTrans.py ├── dev.txt ├── dev_process.txt ├── dev_trans.txt ├── gcnn.py ├── gcnnnormal.py ├── gelu.py ├── nl.pkl ├── nl_voc.pkl ├── outval.txt ├── postionEmbedding.py ├── process.py ├── rightnTransfomer.py ├── rule.pkl ├── rulead.pkl ├── run.py ├── solvetree.py ├── test_process.txt ├── train.txt ├── train_process.txt └── vocab.py └── hearthstone_preprocess ├── .DS_Store ├── Code_Voc.pkl ├── README.md ├── codead.pkl ├── cp.sh ├── dev.txt ├── dev_hs.in ├── dev_hs.out ├── dev_process.txt ├── py3_asdl.simplified.txt ├── rule.pkl ├── rulead.pkl ├── runComplex.py ├── runall.sh ├── solvetree.py ├── test.txt ├── test_hs.in ├── test_hs.out ├── test_process.txt ├── train.txt ├── train_hs.in ├── train_hs.out └── train_process.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zeyu Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TreeGen 2 | 3 | This is a Pytorch version of TreeGen (A Tree-Based Transformer Architecture for Code Generation). 4 | 5 | Our paper is available at https://arxiv.org/abs/1911.09983. (Accepted by AAAI'20) 6 | 7 | The model is in the ```hearthstone``` folder. 8 | 9 | The preprocess is in the ```hearthstone_preprocess``` folder. 10 | 11 | 12 | ## Dependenices 13 | * NLTK 3.2.1 14 | * Pytorch 1.10.0 15 | * Python 3.7 16 | * Ubuntu 16.04 17 | -------------------------------------------------------------------------------- /hearthstone/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/usr/bin/python3" 3 | } -------------------------------------------------------------------------------- /hearthstone/Attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | import math 6 | 7 | 8 | class Attention(nn.Module): 9 | """ 10 | Compute 'Scaled Dot Product Attention 11 | """ 12 | 13 | def forward(self, query, key, value, mask=None, dropout=None): 14 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 15 | / math.sqrt(query.size(-1)) 16 | 17 | if mask is not None: 18 | if len(list(mask.size())) != 4: 19 | mask = mask.unsqueeze(1).repeat(1, query.size(2), 1).unsqueeze(1) 20 | #print(mask.shape) 21 | scores = scores.masked_fill(mask == 0, -1e9) 22 | 23 | p_attn = F.softmax(scores, dim=-1) 24 | 25 | if dropout is not None: 26 | p_attn = dropout(p_attn) 27 | 28 | return torch.matmul(p_attn, value), p_attn -------------------------------------------------------------------------------- /hearthstone/CombinationLayer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | import torch.nn.functional as F 5 | class CombinationLayer(nn.Module): 6 | def forward(self, query, key, value, dropout=None): 7 | query_key = query * key / math.sqrt(query.size(-1)) 8 | query_value = query * value / math.sqrt(query.size(-1)) 9 | tmpW = torch.stack([query_key, query_value], -1) 10 | tmpsum = torch.softmax(tmpW, dim=-1) 11 | tmpV = torch.stack([key, value], dim=-1) 12 | #print(tmpV.shape) 13 | #print(tmpsum.shape) 14 | tmpsum = tmpsum * tmpV 15 | tmpsum = torch.squeeze(torch.sum(tmpsum, dim=-1), -1) 16 | if dropout: 17 | tmpsum = dropout(tmpsum) 18 | return tmpsum 19 | 20 | 21 | -------------------------------------------------------------------------------- /hearthstone/ConvolutionForward.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from gelu import GELU 3 | class ConvolutionLayer(nn.Module): 4 | def __init__(self, dmodel, layernum, kernelsize=3, dropout=0.1): 5 | super(ConvolutionLayer, self).__init__() 6 | self.conv1 = nn.Conv1d(dmodel, layernum, kernelsize, padding=(kernelsize-1)//2) 7 | self.conv2 = nn.Conv1d(dmodel, layernum, kernelsize, padding=(kernelsize-1)//2) 8 | self.activation = GELU() 9 | self.dropout = nn.Dropout(dropout) 10 | def forward(self, x, mask): 11 | convx = self.conv1(x.permute(0, 2, 1)) 12 | convx = self.conv2(convx) 13 | out = self.dropout(self.activation(convx.permute(0, 2, 1))) 14 | return out#self.dropout(self.activation(self.conv1(self.conv2(x)))) 15 | -------------------------------------------------------------------------------- /hearthstone/Dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.utils.data as data 4 | import random 5 | import pickle 6 | import os 7 | from nltk import word_tokenize 8 | from vocab import VocabEntry 9 | import numpy as np 10 | import re 11 | import h5py 12 | from tqdm import tqdm 13 | import json 14 | sys.setrecursionlimit(500000000) 15 | class SumDataset(data.Dataset): 16 | def __init__(self, config, dataName="train"): 17 | self.train_path = "train_process.txt" 18 | self.val_path = "dev_process.txt" # "validD.txt" 19 | self.test_path = "test_process.txt" 20 | self.Nl_Voc = {"pad": 0, "Unknown": 1} 21 | self.Code_Voc = {"pad": 0, "Unknown": 1} 22 | self.Char_Voc = {"pad": 0, "Unknown": 1} 23 | self.Nl_Len = config.NlLen 24 | self.Code_Len = config.CodeLen 25 | self.Char_Len = config.WoLen 26 | self.batch_size = config.batch_size 27 | self.PAD_token = 0 28 | self.data = None 29 | self.dataName = dataName 30 | self.Codes = [] 31 | self.Nls = [] 32 | self.num_step = 50 33 | self.ruledict = pickle.load(open("rule.pkl", "rb")) 34 | self.ruledict["start -> Module"] = len(self.ruledict) 35 | self.ruledict["start -> copyword"] = len(self.ruledict) 36 | self.rrdict = {} 37 | for x in self.ruledict: 38 | self.rrdict[self.ruledict[x]] = x 39 | if not os.path.exists("nl_voc.pkl"): 40 | self.init_dic() 41 | self.Load_Voc() 42 | #print(self.Nl_Voc) 43 | if dataName == "train": 44 | if os.path.exists("data.pkl"): 45 | self.data = pickle.load(open("data.pkl", "rb")) 46 | return 47 | self.data = self.preProcessData(open(self.train_path, "r", encoding='utf-8')) 48 | elif dataName == "val": 49 | if os.path.exists("valdata.pkl"): 50 | self.data = pickle.load(open("valdata.pkl", "rb")) 51 | self.nl = pickle.load(open("valnl.pkl", "rb")) 52 | return 53 | self.data = self.preProcessData(open(self.val_path, "r", encoding='utf-8')) 54 | else: 55 | if os.path.exists("testdata.pkl"): 56 | self.data = pickle.load(open("testdata.pkl", "rb")) 57 | #self.code = pickle.load(open("testcode.pkl", "rb")) 58 | self.nl = pickle.load(open("testnl.pkl", "rb")) 59 | return 60 | self.data = self.preProcessData(open(self.test_path, "r", encoding='utf-8')) 61 | 62 | def Load_Voc(self): 63 | if os.path.exists("nl_voc.pkl"): 64 | self.Nl_Voc = pickle.load(open("nl_voc.pkl", "rb")) 65 | if os.path.exists("code_voc.pkl"): 66 | self.Code_Voc = pickle.load(open("code_voc.pkl", "rb")) 67 | if os.path.exists("char_voc.pkl"): 68 | self.Char_Voc = pickle.load(open("char_voc.pkl", "rb")) 69 | self.Nl_Voc[""] = len(self.Nl_Voc) 70 | self.Code_Voc[""] = len(self.Code_Voc) 71 | 72 | def init_dic(self): 73 | print("initVoc") 74 | f = open(self.train_path, "r", encoding='utf-8') 75 | lines = f.readlines() 76 | maxNlLen = 0 77 | maxCodeLen = 0 78 | maxCharLen = 0 79 | nls = [] 80 | rules = [] 81 | for i in tqdm(range(int(len(lines) / 5))): 82 | data = lines[5 * i].strip().lower().split() 83 | nls.append(data) 84 | rulelist = lines[5 * i + 1].strip().split() 85 | tmp = [] 86 | for x in rulelist: 87 | if int(x) >= 10000: 88 | tmp.append(data[int(x) - 10000]) 89 | rules.append(tmp) 90 | f.close() 91 | nl_voc = VocabEntry.from_corpus(nls, size=50000, freq_cutoff=0) 92 | code_voc = VocabEntry.from_corpus(rules, size=50000, freq_cutoff=10) 93 | self.Nl_Voc = nl_voc.word2id 94 | self.Code_Voc = code_voc.word2id 95 | for x in self.ruledict: 96 | lst = x.strip().lower().split() 97 | tmp = [lst[0]] + lst[2:] 98 | for y in tmp: 99 | if y not in self.Code_Voc: 100 | self.Code_Voc[y] = len(self.Code_Voc) 101 | #rules.append([lst[0]] + lst[2:]) 102 | #print(self.Code_Voc) 103 | assert("module" in self.Code_Voc) 104 | for x in self.Nl_Voc: 105 | maxCharLen = max(maxCharLen, len(x)) 106 | for c in x: 107 | if c not in self.Char_Voc: 108 | self.Char_Voc[c] = len(self.Char_Voc) 109 | for x in self.Code_Voc: 110 | maxCharLen = max(maxCharLen, len(x)) 111 | for c in x: 112 | if c not in self.Char_Voc: 113 | self.Char_Voc[c] = len(self.Char_Voc) 114 | open("nl_voc.pkl", "wb").write(pickle.dumps(self.Nl_Voc)) 115 | open("code_voc.pkl", "wb").write(pickle.dumps(self.Code_Voc)) 116 | open("char_voc.pkl", "wb").write(pickle.dumps(self.Char_Voc)) 117 | print(maxNlLen, maxCodeLen, maxCharLen) 118 | def Get_Em(self, WordList, voc): 119 | ans = [] 120 | for x in WordList: 121 | x = x.lower() 122 | if x not in voc: 123 | ans.append(1) 124 | else: 125 | ans.append(voc[x]) 126 | return ans 127 | def Get_Char_Em(self, WordList): 128 | ans = [] 129 | for x in WordList: 130 | x = x.lower() 131 | tmp = [] 132 | for c in x: 133 | c_id = self.Char_Voc[c] if c in self.Char_Voc else 1 134 | tmp.append(c_id) 135 | ans.append(tmp) 136 | return ans 137 | def pad_seq(self, seq, maxlen): 138 | act_len = len(seq) 139 | if len(seq) < maxlen: 140 | seq = seq + [self.PAD_token] * maxlen 141 | seq = seq[:maxlen] 142 | else: 143 | seq = seq[:maxlen] 144 | act_len = maxlen 145 | return seq 146 | def pad_str_seq(self, seq, maxlen): 147 | act_len = len(seq) 148 | if len(seq) < maxlen: 149 | seq = seq + [""] * maxlen 150 | seq = seq[:maxlen] 151 | else: 152 | seq = seq[:maxlen] 153 | act_len = maxlen 154 | return seq 155 | def pad_list(self,seq, maxlen1, maxlen2): 156 | if len(seq) < maxlen1: 157 | seq = seq + [[self.PAD_token] * maxlen2] * maxlen1 158 | seq = seq[:maxlen1] 159 | else: 160 | seq = seq[:maxlen1] 161 | return seq 162 | def pad_multilist(self, seq, maxlen1, maxlen2, maxlen3): 163 | if len(seq) < maxlen1: 164 | seq = seq + [[[self.PAD_token] * maxlen3] * maxlen2] * maxlen1 165 | seq = seq[:maxlen1] 166 | else: 167 | seq = seq[:maxlen1] 168 | return seq 169 | def preProcessData(self, dataFile): 170 | lines = dataFile.readlines() 171 | inputNl = [] 172 | inputNlChar = [] 173 | inputRuleParent = [] 174 | inputRuleChild = [] 175 | inputParent = [] 176 | inputParentPath = [] 177 | inputRes = [] 178 | inputRule = [] 179 | inputDepth = [] 180 | nls = [] 181 | for i in tqdm(range(int(len(lines) / 5))): 182 | child = {} 183 | nl = lines[5 * i].lower().strip().split() 184 | nls.append(nl) 185 | inputparent = lines[5 * i + 2].strip().split() 186 | inputres = lines[5 * i + 1].strip().split() 187 | depth = lines[5 * i + 3].strip().split() 188 | parentname = lines[5 * i + 4].strip().lower().split() 189 | inputad = np.zeros([self.Nl_Len + self.Code_Len, self.Nl_Len + self.Code_Len]) 190 | for i in range(min(self.Nl_Len, len(nl))): 191 | for j in range(min(self.Nl_Len, len(nl))): 192 | inputad[i, j] = 1 193 | inputrule = [self.ruledict["start -> Module"]] 194 | for j in range(len(inputres)): 195 | inputres[j] = int(inputres[j]) 196 | #depth[j] = int(depth[j]) 197 | inputparent[j] = int(inputparent[j]) + 1 198 | child.setdefault(inputparent[j], []).append(j + 1) 199 | if inputres[j] >= 10000: 200 | inputres[j] = len(self.ruledict) + inputres[j] - 10000 201 | if j + 1 < self.Code_Len: 202 | inputad[self.Nl_Len + j + 1, inputres[j] - len(self.ruledict)] = 1 203 | inputrule.append(self.ruledict['start -> copyword']) 204 | else: 205 | inputrule.append(inputres[j]) 206 | if inputres[j] - len(self.ruledict) >= self.Nl_Len: 207 | print(inputres[j] - len(self.ruledict)) 208 | if j + 1 < self.Code_Len: 209 | inputad[self.Nl_Len + j + 1, self.Nl_Len + inputparent[j]] = 1 210 | depth = [self.pad_seq([1], 40)] 211 | for j in range(len(inputres)): 212 | tmp = [] 213 | ids = child[inputparent[j]].index(j + 1) + 1 214 | tmp.append(ids) 215 | tmp.extend(depth[inputparent[j]]) 216 | tmp = self.pad_seq(tmp, 40) 217 | depth.append(tmp) 218 | depth = self.pad_list(depth, self.Code_Len, 40) 219 | #inputrule = [self.ruledict["start -> Module"]] + inputres 220 | #depth = self.pad_seq([1] + depth, self.Code_Len) 221 | inputnls = self.Get_Em(nl, self.Nl_Voc) 222 | inputNl.append(self.pad_seq(inputnls, self.Nl_Len)) 223 | inputnlchar = self.Get_Char_Em(nl) 224 | for j in range(len(inputnlchar)): 225 | inputnlchar[j] = self.pad_seq(inputnlchar[j], self.Char_Len) 226 | inputnlchar = self.pad_list(inputnlchar, self.Nl_Len, self.Char_Len) 227 | inputNlChar.append(inputnlchar) 228 | inputruleparent = self.pad_seq(self.Get_Em(["start"] + parentname, self.Code_Voc), self.Code_Len) 229 | inputrulechild = [] 230 | for x in inputrule: 231 | if x >= len(self.rrdict): 232 | inputrulechild.append(self.pad_seq(self.Get_Em(["copyword"], self.Code_Voc), self.Char_Len)) 233 | else: 234 | rule = self.rrdict[x].strip().lower().split() 235 | inputrulechild.append(self.pad_seq(self.Get_Em(rule[2:], self.Code_Voc), self.Char_Len)) 236 | 237 | inputparentpath = [] 238 | for j in range(len(inputres)): 239 | if inputres[j] in self.rrdict: 240 | tmppath = [self.rrdict[inputres[j]].strip().lower().split()[0]] 241 | assert(tmppath[0] == parentname[j].lower()) 242 | else: 243 | tmppath = [parentname[j].lower()] 244 | '''siblings = child[inputparent[j]] 245 | for x in siblings: 246 | if x == j + 1: 247 | break 248 | tmppath.append(parentname[x - 1])''' 249 | curr = inputparent[j] 250 | while curr != 0: 251 | rule = self.rrdict[inputres[curr - 1]].strip().lower().split()[0] 252 | tmppath.append(rule) 253 | curr = inputparent[curr - 1] 254 | inputparentpath.append(self.pad_seq(self.Get_Em(tmppath, self.Code_Voc), 10)) 255 | inputrule = self.pad_seq(inputrule, self.Code_Len) 256 | inputres = self.pad_seq(inputres, self.Code_Len) 257 | tmp = [self.pad_seq(self.Get_Em(['start'], self.Code_Voc), 10)] + inputparentpath 258 | inputrulechild = self.pad_list(tmp, self.Code_Len, 10) 259 | inputRuleParent.append(inputruleparent) 260 | inputRuleChild.append(inputrulechild) 261 | inputRes.append(inputres) 262 | inputRule.append(inputrule) 263 | inputparent = [0] + inputparent 264 | inputParent.append(inputad) 265 | inputParentPath.append(self.pad_list(inputparentpath, self.Code_Len, 10)) 266 | inputDepth.append(depth) 267 | batchs = [inputNl, inputNlChar, inputRule, inputRuleParent, inputRuleChild, inputRes, inputParent, inputParentPath, inputDepth] 268 | self.data = batchs 269 | self.nls = nls 270 | #self.code = codes 271 | if self.dataName == "train": 272 | open("data.pkl", "wb").write(pickle.dumps(batchs, protocol=4)) 273 | open("nl.pkl", "wb").write(pickle.dumps(nls)) 274 | if self.dataName == "val": 275 | open("valdata.pkl", "wb").write(pickle.dumps(batchs, protocol=4)) 276 | open("valnl.pkl", "wb").write(pickle.dumps(nls)) 277 | if self.dataName == "test": 278 | open("testdata.pkl", "wb").write(pickle.dumps(batchs)) 279 | #open("testcode.pkl", "wb").write(pickle.dumps(self.code)) 280 | open("testnl.pkl", "wb").write(pickle.dumps(self.nls)) 281 | return batchs 282 | 283 | def __getitem__(self, offset): 284 | ans = [] 285 | '''if self.dataName == "train": 286 | h5f = h5py.File("data.h5", 'r') 287 | if self.dataName == "val": 288 | h5f = h5py.File("valdata.h5", 'r') 289 | if self.dataName == "test": 290 | h5f = h5py.File("testdata.h5", 'r')''' 291 | for i in range(len(self.data)): 292 | d = self.data[i][offset] 293 | '''if i == 6: 294 | #print(self.data[i][offset]) 295 | tmp = np.eye(self.Code_Len)[d] 296 | #print(tmp.shape) 297 | tmp = np.concatenate([tmp, np.zeros([self.Code_Len, self.Code_Len])], axis=0)[:self.Code_Len,:]#self.pad_list(tmp, self.Code_Len, self.Code_Len) 298 | ans.append(np.array(tmp)) 299 | else:''' 300 | ans.append(np.array(d)) 301 | return ans 302 | def __len__(self): 303 | return len(self.data[0]) 304 | class Node: 305 | def __init__(self, name, s): 306 | self.name = name 307 | self.id = s 308 | self.father = None 309 | self.child = [] 310 | self.sibiling = None 311 | 312 | #dset = SumDataset(args) 313 | -------------------------------------------------------------------------------- /hearthstone/DatasetSum.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.utils.data as data 4 | import random 5 | import pickle 6 | import os 7 | from nltk import word_tokenize 8 | from vocab import VocabEntry 9 | import numpy as np 10 | class SumDataset(data.Dataset): 11 | def __init__(self, config, dataName="train"): 12 | self.train_path = "DataProcess/train.txt" 13 | self.val_path = "DataProcess/test.txt" # "validD.txt" 14 | self.test_path = "DataProcess/test.txt" 15 | self.Nl_Voc = {"pad": 0, "Unknown": 1} 16 | self.Code_Voc = {"pad": 0, "Unknown": 1} 17 | self.Char_Voc = {"pad": 0, "Unknown": 1} 18 | self.Nl_Len = config.NlLen 19 | self.Code_Len = config.CodeLen 20 | self.Char_Len = config.WoLen 21 | self.batch_size = config.batch_size 22 | self.PAD_token = 0 23 | self.data = None 24 | self.dataName = dataName 25 | self.Codes = [] 26 | self.Nls = [] 27 | if not os.path.exists("nl_sum_voc.pkl"): 28 | self.init_dic() 29 | self.Load_Voc() 30 | if dataName == "train": 31 | if os.path.exists("data_sum.pkl"): 32 | self.data = pickle.load(open("data_sum.pkl", "rb")) 33 | return 34 | self.data = self.preProcessData(open(self.train_path, "r", encoding='iso-8859-1')) 35 | elif dataName == "val": 36 | if os.path.exists("valdata_sum.pkl"): 37 | self.data = pickle.load(open("valdata_sum.pkl", "rb")) 38 | return 39 | self.data = self.preProcessData(open(self.val_path, "r", encoding='iso-8859-1')) 40 | else: 41 | if os.path.exists("testdata_sum.pkl"): 42 | self.data = pickle.load(open("testdata_sum.pkl", "rb")) 43 | return 44 | self.data = self.preProcessData(open(self.test_path, "r", encoding='iso-8859-1')) 45 | 46 | def Load_Voc(self): 47 | if os.path.exists("nl_sum_voc.pkl"): 48 | self.Nl_Voc = pickle.load(open("nl_sum_voc.pkl", "rb")) 49 | if os.path.exists("code_sum_voc.pkl"): 50 | self.Code_Voc = pickle.load(open("code_sum_voc.pkl", "rb")) 51 | if os.path.exists("char_sum_voc.pkl"): 52 | self.Char_Voc = pickle.load(open("char_sum_voc.pkl", "rb")) 53 | 54 | def init_dic(self): 55 | print("initVoc") 56 | f = open(self.train_path, "r", encoding='iso-8859-1') 57 | lines = f.readlines() 58 | maxNlLen = 0 59 | maxCodeLen = 0 60 | maxCharLen = 0 61 | Nls = [] 62 | Codes = [] 63 | for i in range(int(len(lines) / 2)): 64 | Code = lines[2 * i + 1].strip() 65 | Nl = lines[2 * i].strip() 66 | #if "^" in Nl 67 | #print(Nl) 68 | Nl_tokens = [""] + word_tokenize(Nl.lower()) + [""] 69 | Code_Tokens = Code.lower().split() 70 | Nls.append(Nl_tokens) 71 | # Nls.append(Code_Tokens) 72 | Codes.append(Code_Tokens) 73 | maxNlLen = max(maxNlLen, len(Nl_tokens)) 74 | maxCodeLen = max(maxCodeLen, len(Code_Tokens)) 75 | # print(Nls) 76 | # print("------------------") 77 | nl_voc = VocabEntry.from_corpus(Nls, size=50000, freq_cutoff=3) 78 | code_voc = VocabEntry.from_corpus(Codes, size=50000, freq_cutoff=3) 79 | self.Nl_Voc = nl_voc.word2id 80 | self.Code_Voc = code_voc.word2id 81 | 82 | for x in self.Nl_Voc: 83 | maxCharLen = max(maxCharLen, len(x)) 84 | for c in x: 85 | if c not in self.Char_Voc: 86 | self.Char_Voc[c] = len(self.Char_Voc) 87 | for x in self.Code_Voc: 88 | maxCharLen = max(maxCharLen, len(x)) 89 | for c in x: 90 | if c not in self.Char_Voc: 91 | self.Char_Voc[c] = len(self.Char_Voc) 92 | if "" in self.Nl_Voc: 93 | print("right") 94 | print(len(self.Nl_Voc), len(self.Code_Voc)) 95 | open("nl_sum_voc.pkl", "wb").write(pickle.dumps(self.Nl_Voc)) 96 | open("code_sum_voc.pkl", "wb").write(pickle.dumps(self.Code_Voc)) 97 | open("char_sum_voc.pkl", "wb").write(pickle.dumps(self.Char_Voc)) 98 | #print(self.Nl_Voc) 99 | #print(self.Code_Voc) 100 | print(maxNlLen, maxCodeLen, maxCharLen) 101 | def Get_Em(self, WordList, NlFlag=True): 102 | ans = [] 103 | for x in WordList: 104 | if NlFlag: 105 | if x not in self.Nl_Voc: 106 | ans.append(1) 107 | else: 108 | ans.append(self.Nl_Voc[x]) 109 | else: 110 | if x not in self.Code_Voc: 111 | ans.append(1) 112 | else: 113 | ans.append(self.Code_Voc[x]) 114 | return ans 115 | def Get_Char_Em(self, WordList): 116 | ans = [] 117 | for x in WordList: 118 | tmp = [] 119 | for c in x: 120 | c_id = self.Char_Voc[c] if c in self.Char_Voc else 1 121 | tmp.append(c_id) 122 | ans.append(tmp) 123 | return ans 124 | def pad_seq(self, seq, maxlen): 125 | act_len = len(seq) 126 | if len(seq) < maxlen: 127 | seq = seq + [self.PAD_token] * maxlen 128 | seq = seq[:maxlen] 129 | else: 130 | seq = seq[:maxlen] 131 | act_len = maxlen 132 | return seq, act_len 133 | def pad_list(self,seq, maxlen1, maxlen2): 134 | if len(seq) < maxlen1: 135 | seq = seq + [[self.PAD_token] * maxlen2] * maxlen1 136 | seq = seq[:maxlen1] 137 | else: 138 | seq = seq[:maxlen1] 139 | return seq 140 | def getAdMatrix(self, codetokens): 141 | lst = codetokens#codetokens.split() 142 | #print(codetokens) 143 | currNode = node(lst[0]) 144 | currNode.id = 0 145 | nodedist = {} 146 | for i, x in enumerate(lst): 147 | if i == 0: 148 | nodedist[i] = currNode 149 | continue 150 | if not x[-1] == "^" or ("^" in x and "_" not in x): 151 | newNode = node(x) 152 | newNode.father = currNode 153 | currNode.child.append(newNode) 154 | newNode.id = i 155 | currNode = newNode 156 | nodedist[i] = newNode 157 | else: 158 | newNode = node(x) 159 | newNode.child.append(currNode) 160 | if currNode.father: 161 | newNode.child.append(currNode.father) 162 | newNode.id = i 163 | nodedist[i] = newNode 164 | currNode.child.append(newNode) 165 | if currNode.father: 166 | currNode.father.child.append(newNode) 167 | #print(x, currNode.name) 168 | currNode = currNode.father 169 | admatrix = [] 170 | upbound = min(self.Code_Len, len(lst)) 171 | for i in range(upbound): 172 | ids = [] 173 | for x in nodedist[i].child: 174 | if x.id < self.Code_Len: 175 | ids.append(x.id) 176 | ids.append(nodedist[i].id) 177 | if nodedist[i].father: 178 | if nodedist[i].father.id < self.Code_Len: 179 | ids.append(nodedist[i].father.id) 180 | #tmp = np.sum(np.eye(len(lst))[ids]) 181 | admatrix.append(ids) 182 | return admatrix 183 | 184 | 185 | 186 | def preProcessData(self, dataFile): 187 | lines = dataFile.readlines() 188 | Nl_Sentences = [] 189 | Code_Sentences = [] 190 | Nl_Chars = [] 191 | Code_Chars = [] 192 | admatrix = [] 193 | res = [] 194 | from tqdm import tqdm 195 | for i in tqdm(range(int(len(lines) / 2))): 196 | code = lines[2 * i + 1].strip() 197 | nl = lines[2 * i].strip() 198 | code_tokens = code.lower().split() 199 | try: 200 | admatrix.append(self.getAdMatrix(code_tokens)) 201 | except: 202 | continue 203 | nl_tokens = [""] + word_tokenize(nl.lower()) + [""] 204 | Code_Sentences.append(self.Get_Em(code_tokens, False)) 205 | Nl_Sentences.append(self.Get_Em(nl_tokens)) 206 | Nl_Chars.append(self.Get_Char_Em(nl_tokens)) 207 | Code_Chars.append(self.Get_Char_Em(code_tokens)) 208 | #admatrix.append(self.getAdMatrix(code_tokens)) 209 | res.append(Nl_Sentences[-1][1:]) 210 | for i in range(len(Nl_Sentences)): 211 | Nl_Sentences[i], _ = self.pad_seq(Nl_Sentences[i], self.Nl_Len) 212 | Code_Sentences[i], _ = self.pad_seq(Code_Sentences[i], self.Code_Len) 213 | res[i], _ = self.pad_seq(res[i], self.Nl_Len) 214 | for j in range(len(Nl_Chars[i])): 215 | Nl_Chars[i][j], _ = self.pad_seq(Nl_Chars[i][j], self.Char_Len) 216 | for j in range(len(Code_Chars[i])): 217 | Code_Chars[i][j], _ = self.pad_seq(Code_Chars[i][j], self.Char_Len) 218 | Nl_Chars[i] = self.pad_list(Nl_Chars[i], self.Nl_Len, self.Char_Len) 219 | Code_Chars[i] = self.pad_list(Code_Chars[i], self.Code_Len, self.Char_Len) 220 | batchs = [Nl_Sentences, Nl_Chars, Code_Sentences, Code_Chars, admatrix, res] 221 | batchs = np.array(batchs) 222 | self.data = batchs 223 | if self.dataName == "train": 224 | open("data_sum.pkl", "wb").write(pickle.dumps(batchs, protocol=4)) 225 | if self.dataName == "val": 226 | open("valdata_sum.pkl", "wb").write(pickle.dumps(batchs, protocol=4)) 227 | if self.dataName == "test": 228 | open("testdata_sum.pkl", "wb").write(pickle.dumps(batchs)) 229 | return batchs 230 | 231 | 232 | 233 | def __getitem__(self, offset): 234 | ans = [] 235 | for i in range(len(self.data)): 236 | if i == 4: 237 | tmp = [] 238 | for j in range(len(self.data[i][offset])): 239 | #print(np.sum(np.eye(self.Code_Len)[self.data[i][offset][j]], axis=0)) 240 | tmp.append(np.sum(np.eye(self.Code_Len)[self.data[i][offset][j]], axis=0)) 241 | tmp = self.pad_list(tmp, self.Code_Len, self.Code_Len) 242 | tmp = np.array(tmp) 243 | tmp = tmp.reshape(1, self.Code_Len, self.Code_Len) 244 | ans.append(tmp) 245 | else: 246 | ans.append(np.array(self.data[i][offset])) 247 | return ans 248 | def __len__(self): 249 | return len(self.data[0]) 250 | class node: 251 | def __init__(self, name): 252 | self.name = name 253 | self.father = None 254 | self.child = [] 255 | self.id = -1 256 | -------------------------------------------------------------------------------- /hearthstone/DenseLayer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from gelu import GELU 3 | 4 | 5 | class DenseLayer(nn.Module): 6 | "Implements FFN equation." 7 | 8 | def __init__(self, d_model, d_ff, dropout=0.1): 9 | super(DenseLayer, self).__init__() 10 | self.w_1 = nn.Linear(d_model, d_ff) 11 | self.w_2 = nn.Linear(d_ff, d_model) 12 | self.dropout = nn.Dropout(dropout) 13 | self.activation = GELU() 14 | 15 | def forward(self, x): 16 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) -------------------------------------------------------------------------------- /hearthstone/Embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from TokenEmbedding import TokenEmbedding 3 | from postionEmbedding import PositionalEmbedding 4 | 5 | 6 | class Embedding(nn.Module): 7 | """ 8 | BERT Embedding which is consisted with under features 9 | 1. TokenEmbedding : normal embedding matrix 10 | 2. PositionalEmbedding : adding positional information using sin, cos 11 | 2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2) 12 | sum of all these features are output of BERTEmbedding 13 | """ 14 | 15 | def __init__(self, vocab_size, embed_size, dropout=0.1): 16 | """ 17 | :param vocab_size: total vocab size 18 | :param embed_size: embedding size of token embedding 19 | :param dropout: dropout rate 20 | """ 21 | super().__init__() 22 | self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size) 23 | self.position = PositionalEmbedding(d_model=self.token.embedding_dim) 24 | self.depth_embedding = nn.Embedding(20, embed_size, padding_idx=0) 25 | self.dropout = nn.Dropout(p=dropout) 26 | self.embed_size = embed_size 27 | 28 | def forward(self, sequence, inputdept=None, usedepth=False): 29 | x = self.token(sequence) + self.position(sequence) 30 | if usedepth: 31 | x = x + self.depth_embedding(inputdept) 32 | return self.dropout(x) -------------------------------------------------------------------------------- /hearthstone/LayerNorm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class LayerNorm(nn.Module): 6 | "Construct a layernorm module (See citation for details)." 7 | 8 | def __init__(self, features, eps=1e-6): 9 | super(LayerNorm, self).__init__() 10 | self.a_2 = nn.Parameter(torch.ones(features)) 11 | self.b_2 = nn.Parameter(torch.zeros(features)) 12 | self.eps = eps 13 | 14 | def forward(self, x): 15 | mean = x.mean(-1, keepdim=True) 16 | std = x.std(-1, keepdim=True) 17 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 -------------------------------------------------------------------------------- /hearthstone/Model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from Transfomer import TransformerBlock 5 | from rightnTransfomer import rightTransformerBlock 6 | from Multihead_Combination import MultiHeadedCombination 7 | from Embedding import Embedding 8 | from TreeConvGen import TreeConvGen 9 | from Multihead_Attention import MultiHeadedAttention 10 | from gelu import GELU 11 | from LayerNorm import LayerNorm 12 | from decodeTrans import decodeTransformerBlock 13 | from gcnnnormal import GCNNM 14 | from postionEmbedding import PositionalEmbedding 15 | class TreeAttEncoder(nn.Module): 16 | def __init__(self, args): 17 | super(TreeAttEncoder, self).__init__() 18 | self.embedding_size = args.embedding_size 19 | self.nl_len = args.NlLen 20 | self.word_len = args.WoLen 21 | self.char_embedding = nn.Embedding(args.Vocsize, self.embedding_size) 22 | self.token_embedding = Embedding(args.Code_Vocsize, self.embedding_size) 23 | self.feed_forward_hidden = 4 * self.embedding_size 24 | self.conv = nn.Conv2d(self.embedding_size, self.embedding_size, (1, self.word_len)) 25 | self.transformerBlocks = nn.ModuleList( 26 | [TransformerBlock(self.embedding_size, 8, self.feed_forward_hidden, 0.1) for _ in range(3)]) 27 | self.transformerBlocksTree = nn.ModuleList( 28 | [TransformerBlock(self.embedding_size, 8, self.feed_forward_hidden, 0.1) for _ in range(3)]) 29 | 30 | 31 | def forward(self, input_code, input_codechar, inputAd): 32 | codemask = torch.gt(input_code, 0) 33 | charEm = self.char_embedding(input_codechar) 34 | charEm = self.conv(charEm.permute(0, 3, 1, 2)) 35 | charEm = charEm.permute(0, 2, 3, 1).squeeze(dim=-2) 36 | #print(charEm.shape) 37 | x = self.token_embedding(input_code.long()) 38 | for trans in self.transformerBlocksTree: 39 | x = trans.forward(x, codemask, charEm, inputAd, True) 40 | for trans in self.transformerBlocks: 41 | x = trans.forward(x, codemask, charEm) 42 | return x 43 | 44 | class NlEncoder(nn.Module): 45 | def __init__(self, args): 46 | super(NlEncoder, self).__init__() 47 | self.embedding_size = args.embedding_size 48 | self.nl_len = args.NlLen 49 | self.word_len = args.WoLen 50 | self.char_embedding = nn.Embedding(args.Vocsize, self.embedding_size) 51 | self.feed_forward_hidden = 4 * self.embedding_size 52 | self.conv = nn.Conv2d(self.embedding_size, self.embedding_size, (1, self.word_len)) 53 | self.transformerBlocks = nn.ModuleList( 54 | [TransformerBlock(self.embedding_size, 8, self.feed_forward_hidden, 0.1) for _ in range(5)]) 55 | self.token_embedding = Embedding(args.Nl_Vocsize, self.embedding_size) 56 | '''self.transformerBlocksTree = nn.ModuleList( 57 | [TransformerBlock(self.embedding_size, 8, self.feed_forward_hidden, 0.1) for _ in range(5)])''' 58 | 59 | 60 | def forward(self, input_nl, input_nlchar): 61 | nlmask = torch.gt(input_nl, 0) 62 | charEm = self.char_embedding(input_nlchar.long()) 63 | charEm = self.conv(charEm.permute(0, 3, 1, 2)) 64 | charEm = charEm.permute(0, 2, 3, 1).squeeze(dim=-2) 65 | x = self.token_embedding(input_nl.long()) 66 | for trans in self.transformerBlocks: 67 | x = trans.forward(x, nlmask, charEm) 68 | return x, nlmask 69 | class CopyNet(nn.Module): 70 | def __init__(self, args): 71 | super(CopyNet, self).__init__() 72 | self.embedding_size = args.embedding_size 73 | self.LinearSource = nn.Linear(self.embedding_size, self.embedding_size, bias=False) 74 | self.LinearTarget = nn.Linear(self.embedding_size, self.embedding_size, bias=False) 75 | self.LinearRes = nn.Linear(self.embedding_size, 1) 76 | self.LinearProb = nn.Linear(self.embedding_size, 2) 77 | def forward(self, source, traget): 78 | sourceLinear = self.LinearSource(source) 79 | targetLinear = self.LinearTarget(traget) 80 | genP = self.LinearRes(F.tanh(sourceLinear.unsqueeze(1) + targetLinear.unsqueeze(2))).squeeze(-1) 81 | prob = F.softmax(self.LinearProb(traget), dim=-1)#.squeeze(-1)) 82 | return genP, prob 83 | class Decoder(nn.Module): 84 | def __init__(self, args): 85 | super(Decoder, self).__init__() 86 | self.embedding_size = args.embedding_size 87 | self.word_len = args.WoLen 88 | self.nl_len = args.NlLen 89 | self.code_len = args.CodeLen 90 | self.feed_forward_hidden = 4 * self.embedding_size 91 | self.conv = nn.Conv2d(self.embedding_size, self.embedding_size, (1, args.WoLen)) 92 | self.path_conv = nn.Conv2d(self.embedding_size, self.embedding_size, (1, 10)) 93 | self.rule_conv = nn.Conv2d(self.embedding_size, self.embedding_size, (1, 2)) 94 | self.depth_conv = nn.Conv2d(self.embedding_size, self.embedding_size, (1, 40)) 95 | self.resLen = args.rulenum - args.NlLen 96 | self.encodeTransformerBlock = nn.ModuleList( 97 | [rightTransformerBlock(self.embedding_size, 8, self.feed_forward_hidden, 0.1) for _ in range(9)]) 98 | self.decodeTransformerBlocksP = nn.ModuleList( 99 | [decodeTransformerBlock(self.embedding_size, 8, self.feed_forward_hidden, 0.1) for _ in range(2)]) 100 | self.finalLinear = nn.Linear(self.embedding_size, 2048) 101 | self.resLinear = nn.Linear(2048, self.resLen) 102 | self.rule_token_embedding = Embedding(args.Code_Vocsize, self.embedding_size) 103 | self.rule_embedding = nn.Embedding(args.rulenum, self.embedding_size) 104 | self.encoder = NlEncoder(args) 105 | self.layernorm = LayerNorm(self.embedding_size) 106 | self.activate = GELU() 107 | self.copy = CopyNet(args) 108 | self.copy2 = CopyNet(args) 109 | self.dropout = nn.Dropout(p=0.1) 110 | self.depthembedding = nn.Embedding(40, self.embedding_size, padding_idx=0) 111 | self.gcnnm = GCNNM(self.embedding_size) 112 | self.position = PositionalEmbedding(self.embedding_size) 113 | def getBleu(self, losses, ngram): 114 | bleuloss = F.max_pool1d(losses.unsqueeze(1), ngram, 1).squeeze(1) 115 | bleuloss = torch.sum(bleuloss, dim=-1) 116 | return bleuloss 117 | def forward(self, inputnl, inputnlchar, inputrule, inputruleparent, inputrulechild, inputParent, inputParentPath, inputdepth, tmpf, tmpc, tmpindex, rulead, antimask, inputRes=None, mode="train"): 118 | selfmask = antimask 119 | #selfmask = antimask.unsqueeze(0).repeat(inputtype.size(0), 1, 1).unsqueeze(1) 120 | #admask = admask.unsqueeze(0).repeat(inputtype.size(0), 1, 1).float() 121 | rulemask = torch.gt(inputrule, 0) 122 | inputParent = inputParent.float() 123 | #encode_nl 124 | nlencode, nlmask = self.encoder(inputnl, inputnlchar) 125 | #encode_rule 126 | childEm = self.rule_token_embedding(tmpc) 127 | childEm = self.conv(childEm.permute(0, 3, 1, 2)) 128 | childEm = childEm.permute(0, 2, 3, 1).squeeze(dim=-2) 129 | childEm = self.layernorm(childEm) 130 | fatherEm = self.rule_token_embedding(tmpf) 131 | ruleEmCom = self.rule_conv(torch.stack([fatherEm, childEm], dim=-2).permute(0, 3, 1, 2)) 132 | ruleEmCom = self.layernorm(ruleEmCom.permute(0, 2, 3, 1).squeeze(dim=-2)) 133 | x = self.rule_embedding(tmpindex[0]) 134 | # for i in range(9): 135 | # x = self.gcnnm(x, rulead[0], ruleEmCom[0]).view(self.resLen, self.embedding_size) 136 | ruleEm = self.rule_embedding(inputrule) 137 | ruleselect = x 138 | #print(inputdepth.shape) 139 | #depthEm = self.depthembedding(inputdepth.long()) 140 | #depthEm = self.depth_conv(depthEm.permute(0, 3, 1, 2)) 141 | #depthEm = depthEm.permute(0, 2, 3, 1).squeeze(dim=-2) 142 | #depthEm = self.layernorm(depthEm) 143 | Ppath = self.rule_token_embedding(inputrulechild) 144 | ppathEm = self.path_conv(Ppath.permute(0, 3, 1, 2)) 145 | ppathEm = ppathEm.permute(0, 2, 3, 1).squeeze(dim=-2) 146 | ppathEm = self.layernorm(ppathEm) 147 | x = self.dropout(ruleEm + self.position(inputrule)) 148 | for trans in self.encodeTransformerBlock: 149 | x = trans(x, selfmask, nlencode, nlmask, ppathEm, inputParent) 150 | decode = x 151 | #ppath 152 | Ppath = self.rule_token_embedding(inputParentPath) 153 | ppathEm = self.path_conv(Ppath.permute(0, 3, 1, 2)) 154 | ppathEm = ppathEm.permute(0, 2, 3, 1).squeeze(dim=-2) 155 | ppathEm = self.layernorm(ppathEm) 156 | x = self.dropout(ppathEm + self.position(inputrule)) 157 | for trans in self.decodeTransformerBlocksP: 158 | x = trans(x, rulemask, decode, antimask, nlencode, nlmask) 159 | decode = x 160 | #genP1, _ = self.copy2(ruleselect.unsqueeze(0), decode) 161 | #resSoftmax = F.softmax(genP, dim=-1) 162 | genP, prob = self.copy(nlencode, decode) 163 | copymask = nlmask.unsqueeze(1).repeat(1, inputrule.size(1), 1) 164 | genP = genP.masked_fill(copymask==0, -1e9) 165 | #genP = torch.cat([genP1, genP], dim=2) 166 | #genP = F.softmax(genP, dim=-1) 167 | 168 | x = self.finalLinear(decode) 169 | x = self.activate(x) 170 | x = self.resLinear(x) 171 | resSoftmax = F.softmax(x, dim=-1) 172 | 173 | resSoftmax = resSoftmax * prob[:,:,0].unsqueeze(-1) 174 | genP = genP * prob[:,:,1].unsqueeze(-1) 175 | resSoftmax = torch.cat([resSoftmax, genP], -1) 176 | if mode != "train": 177 | return resSoftmax 178 | resmask = torch.gt(inputRes, 0) 179 | loss = -torch.log(torch.gather(resSoftmax, -1, inputRes.unsqueeze(-1)).squeeze(-1)) 180 | loss = loss.masked_fill(resmask == 0, 0.0) 181 | resTruelen = torch.sum(resmask, dim=-1).float() 182 | totalloss = torch.mean(loss, dim=-1) * self.code_len / resTruelen 183 | totalloss = totalloss# + (self.getBleu(loss, 2) + self.getBleu(loss, 3) + self.getBleu(loss, 4)) / resTruelen 184 | #totalloss = torch.mean(totalloss) 185 | return totalloss, resSoftmax 186 | 187 | 188 | 189 | class JointEmbber(nn.Module): 190 | def __init__(self, args): 191 | super(JointEmbber, self).__init__() 192 | self.embedding_size = args.embedding_size 193 | self.codeEncoder = TreeAttEncoder(args) 194 | self.margin = args.margin 195 | self.nlEncoder = NlEncoder(args) 196 | self.poolConvnl = nn.Conv1d(self.embedding_size, self.embedding_size, 3) 197 | self.poolConvcode = nn.Conv1d(self.embedding_size, self.embedding_size, 3) 198 | self.maxPoolnl = nn.MaxPool1d(args.NlLen) 199 | self.maxPoolcode = nn.MaxPool1d(args.CodeLen) 200 | def scoring(self, qt_repr, cand_repr): 201 | sim = F.cosine_similarity(qt_repr, cand_repr) 202 | return sim 203 | def nlencoding(self, inputnl, inputnlchar): 204 | nl = self.nlEncoder(inputnl, inputnlchar) 205 | nl = self.maxPoolnl(self.poolConvnl(nl.permute(0, 2, 1))).squeeze(-1) 206 | return nl 207 | def codeencoding(self, inputcode, inputcodechar, ad): 208 | code = self.codeEncoder(inputcode, inputcodechar, ad) 209 | code = self.maxPoolcode(self.poolConvcode(code.permute(0, 2, 1))).squeeze(-1) 210 | return code 211 | def forward(self, inputnl, inputnlchar, inputcode, inputcodechar, ad, inputcodeneg, inputcodenegchar, adneg): 212 | nl = self.nlEncoder(inputnl, inputnlchar) 213 | code = self.codeEncoder(inputcode, inputcodechar, ad) 214 | codeneg = self.codeEncoder(inputcodeneg, inputcodenegchar, adneg) 215 | nl = self.maxPoolnl(self.poolConvnl(nl.permute(0, 2, 1))).squeeze(-1) 216 | code = self.maxPoolcode(self.poolConvcode(code.permute(0, 2, 1))).squeeze(-1) 217 | codeneg = self.maxPoolcode(self.poolConvcode(codeneg.permute(0, 2, 1))).squeeze(-1) 218 | good_score = self.scoring(nl, code) 219 | bad_score = self.scoring(nl, codeneg) 220 | loss = (self.margin - good_score + bad_score).clamp(min=1e-6).mean() 221 | return loss, good_score, bad_score 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /hearthstone/Multihead_Attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from Attention import Attention 3 | 4 | 5 | class MultiHeadedAttention(nn.Module): 6 | """ 7 | Take in model size and number of heads. 8 | """ 9 | 10 | def __init__(self, h, d_model, dropout=0.1): 11 | super().__init__() 12 | assert d_model % h == 0 13 | 14 | # We assume d_v always equals d_k 15 | self.d_k = d_model // h 16 | self.h = h 17 | 18 | self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) 19 | self.output_linear = nn.Linear(d_model, d_model) 20 | self.attention = Attention() 21 | 22 | self.dropout = nn.Dropout(p=dropout) 23 | 24 | def forward(self, query, key, value, mask=None): 25 | batch_size = query.size(0) 26 | 27 | # 1) Do all the linear projections in batch from d_model => h x d_k 28 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 29 | for l, x in zip(self.linear_layers, (query, key, value))] 30 | 31 | # 2) Apply attention on all the projected vectors in batch. 32 | x, attn = self.attention(query, key, value, mask=mask, dropout=None) 33 | 34 | # 3) "Concat" using a view and apply a final linear. 35 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) 36 | return self.output_linear(x) -------------------------------------------------------------------------------- /hearthstone/Multihead_Combination.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from CombinationLayer import CombinationLayer 3 | 4 | 5 | class MultiHeadedCombination(nn.Module): 6 | """ 7 | Take in model size and number of heads. 8 | """ 9 | 10 | def __init__(self, h, d_model, dropout=0.1): 11 | super().__init__() 12 | assert d_model % h == 0 13 | 14 | # We assume d_v always equals d_k 15 | self.d_k = d_model // h 16 | self.h = h 17 | 18 | self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) 19 | self.output_linear = nn.Linear(d_model, d_model) 20 | self.combination = CombinationLayer() 21 | 22 | self.dropout = nn.Dropout(p=dropout) 23 | 24 | def forward(self, query, key, value, mask=None, batch_size=-1): 25 | if batch_size == -1: 26 | batch_size = query.size(0) 27 | 28 | # 1) Do all the linear projections in batch from d_model => h x d_k 29 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 30 | for l, x in zip(self.linear_layers, (query, key, value))] 31 | 32 | # 2) Apply attention on all the projected vectors in batch. 33 | x = self.combination(query, key, value, dropout=self.dropout) 34 | 35 | # 3) "Concat" using a view and apply a final linear. 36 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) 37 | return self.output_linear(x) -------------------------------------------------------------------------------- /hearthstone/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Usage 3 | ### To train a new model 4 | 5 | ``` 6 | python3 run.py train 7 | ``` 8 | 9 | ### To predict the code 10 | ``` 11 | python3 run.py test 12 | ``` 13 | 14 | where the output is available at ```outval.txt```. 15 | -------------------------------------------------------------------------------- /hearthstone/Radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 8 | if not 0.0 <= lr: 9 | raise ValueError("Invalid learning rate: {}".format(lr)) 10 | if not 0.0 <= eps: 11 | raise ValueError("Invalid epsilon value: {}".format(eps)) 12 | if not 0.0 <= betas[0] < 1.0: 13 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 14 | if not 0.0 <= betas[1] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 16 | 17 | self.degenerated_to_sgd = degenerated_to_sgd 18 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 19 | for param in params: 20 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 21 | param['buffer'] = [[None, None, None] for _ in range(10)] 22 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 23 | super(RAdam, self).__init__(params, defaults) 24 | 25 | def __setstate__(self, state): 26 | super(RAdam, self).__setstate__(state) 27 | 28 | def step(self, closure=None): 29 | 30 | loss = None 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | grad = p.grad.data.float() 40 | if grad.is_sparse: 41 | raise RuntimeError('RAdam does not support sparse gradients') 42 | 43 | p_data_fp32 = p.data.float() 44 | 45 | state = self.state[p] 46 | 47 | if len(state) == 0: 48 | state['step'] = 0 49 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 51 | else: 52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 54 | 55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 56 | beta1, beta2 = group['betas'] 57 | 58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 59 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 60 | 61 | state['step'] += 1 62 | buffered = group['buffer'][int(state['step'] % 10)] 63 | if state['step'] == buffered[0]: 64 | N_sma, step_size = buffered[1], buffered[2] 65 | else: 66 | buffered[0] = state['step'] 67 | beta2_t = beta2 ** state['step'] 68 | N_sma_max = 2 / (1 - beta2) - 1 69 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 70 | buffered[1] = N_sma 71 | 72 | # more conservative since it's an approximated value 73 | if N_sma >= 5: 74 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 75 | elif self.degenerated_to_sgd: 76 | step_size = 1.0 / (1 - beta1 ** state['step']) 77 | else: 78 | step_size = -1 79 | buffered[2] = step_size 80 | 81 | # more conservative since it's an approximated value 82 | if N_sma >= 5: 83 | if group['weight_decay'] != 0: 84 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 85 | denom = exp_avg_sq.sqrt().add_(group['eps']) 86 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 87 | p.data.copy_(p_data_fp32) 88 | elif step_size > 0: 89 | if group['weight_decay'] != 0: 90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 91 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 92 | p.data.copy_(p_data_fp32) 93 | 94 | return loss 95 | 96 | class PlainRAdam(Optimizer): 97 | 98 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 99 | if not 0.0 <= lr: 100 | raise ValueError("Invalid learning rate: {}".format(lr)) 101 | if not 0.0 <= eps: 102 | raise ValueError("Invalid epsilon value: {}".format(eps)) 103 | if not 0.0 <= betas[0] < 1.0: 104 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 105 | if not 0.0 <= betas[1] < 1.0: 106 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 107 | 108 | self.degenerated_to_sgd = degenerated_to_sgd 109 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 110 | 111 | super(PlainRAdam, self).__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super(PlainRAdam, self).__setstate__(state) 115 | 116 | def step(self, closure=None): 117 | 118 | loss = None 119 | if closure is not None: 120 | loss = closure() 121 | 122 | for group in self.param_groups: 123 | 124 | for p in group['params']: 125 | if p.grad is None: 126 | continue 127 | grad = p.grad.data.float() 128 | if grad.is_sparse: 129 | raise RuntimeError('RAdam does not support sparse gradients') 130 | 131 | p_data_fp32 = p.data.float() 132 | 133 | state = self.state[p] 134 | 135 | if len(state) == 0: 136 | state['step'] = 0 137 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 138 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 139 | else: 140 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 141 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 142 | 143 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 144 | beta1, beta2 = group['betas'] 145 | 146 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 147 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 148 | 149 | state['step'] += 1 150 | beta2_t = beta2 ** state['step'] 151 | N_sma_max = 2 / (1 - beta2) - 1 152 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 153 | 154 | 155 | # more conservative since it's an approximated value 156 | if N_sma >= 5: 157 | if group['weight_decay'] != 0: 158 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 159 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 160 | denom = exp_avg_sq.sqrt().add_(group['eps']) 161 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 162 | p.data.copy_(p_data_fp32) 163 | elif self.degenerated_to_sgd: 164 | if group['weight_decay'] != 0: 165 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 166 | step_size = group['lr'] / (1 - beta1 ** state['step']) 167 | p_data_fp32.add_(-step_size, exp_avg) 168 | p.data.copy_(p_data_fp32) 169 | 170 | return loss 171 | 172 | 173 | class AdamW(Optimizer): 174 | 175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 176 | if not 0.0 <= lr: 177 | raise ValueError("Invalid learning rate: {}".format(lr)) 178 | if not 0.0 <= eps: 179 | raise ValueError("Invalid epsilon value: {}".format(eps)) 180 | if not 0.0 <= betas[0] < 1.0: 181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 182 | if not 0.0 <= betas[1] < 1.0: 183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 184 | 185 | defaults = dict(lr=lr, betas=betas, eps=eps, 186 | weight_decay=weight_decay, warmup = warmup) 187 | super(AdamW, self).__init__(params, defaults) 188 | 189 | def __setstate__(self, state): 190 | super(AdamW, self).__setstate__(state) 191 | 192 | def step(self, closure=None): 193 | loss = None 194 | if closure is not None: 195 | loss = closure() 196 | 197 | for group in self.param_groups: 198 | 199 | for p in group['params']: 200 | if p.grad is None: 201 | continue 202 | grad = p.grad.data.float() 203 | if grad.is_sparse: 204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 205 | 206 | p_data_fp32 = p.data.float() 207 | 208 | state = self.state[p] 209 | 210 | if len(state) == 0: 211 | state['step'] = 0 212 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 214 | else: 215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 217 | 218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 219 | beta1, beta2 = group['betas'] 220 | 221 | state['step'] += 1 222 | 223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 224 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 225 | 226 | denom = exp_avg_sq.sqrt().add_(group['eps']) 227 | bias_correction1 = 1 - beta1 ** state['step'] 228 | bias_correction2 = 1 - beta2 ** state['step'] 229 | 230 | if group['warmup'] > state['step']: 231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 232 | else: 233 | scheduled_lr = group['lr'] 234 | 235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 236 | 237 | if group['weight_decay'] != 0: 238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 239 | 240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 241 | 242 | p.data.copy_(p_data_fp32) 243 | 244 | return loss -------------------------------------------------------------------------------- /hearthstone/ScheduledOptim.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | import numpy as np 3 | 4 | 5 | class ScheduledOptim(): 6 | '''A simple wrapper class for learning rate scheduling''' 7 | 8 | def __init__(self, optimizer, d_model, n_warmup_steps): 9 | self._optimizer = optimizer 10 | self.n_warmup_steps = n_warmup_steps 11 | self.n_current_steps = 0 12 | self.init_lr = np.power(d_model, -0.5) 13 | 14 | def step_and_update_lr(self): 15 | "Step with the inner optimizer" 16 | #self._update_learning_rate() 17 | self._optimizer.step() 18 | 19 | def zero_grad(self): 20 | "Zero out the gradients by the inner optimizer" 21 | self._optimizer.zero_grad() 22 | 23 | def _get_lr_scale(self): 24 | return np.min([ 25 | np.power(self.n_current_steps, -0.5), 26 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 27 | 28 | def _update_learning_rate(self): 29 | ''' Learning rate scheduling per step ''' 30 | 31 | self.n_current_steps += 1 32 | lr = self.init_lr * self._get_lr_scale() 33 | 34 | for param_group in self._optimizer.param_groups: 35 | param_group['lr'] = lr 36 | -------------------------------------------------------------------------------- /hearthstone/SubLayerConnection.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from LayerNorm import LayerNorm 3 | 4 | 5 | class SublayerConnection(nn.Module): 6 | """ 7 | A residual connection followed by a layer norm. 8 | Note for code simplicity the norm is first as opposed to last. 9 | """ 10 | 11 | def __init__(self, size, dropout): 12 | super(SublayerConnection, self).__init__() 13 | self.norm = LayerNorm(size) 14 | self.dropout = nn.Dropout(dropout) 15 | 16 | def forward(self, x, sublayer): 17 | "Apply residual connection to any sublayer with the same size." 18 | return x + self.dropout(sublayer(self.norm(x))) -------------------------------------------------------------------------------- /hearthstone/TokenEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class TokenEmbedding(nn.Embedding): 5 | def __init__(self, vocab_size, embed_size=512): 6 | super().__init__(vocab_size, embed_size, padding_idx=0) -------------------------------------------------------------------------------- /hearthstone/Transfomer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from Multihead_Attention import MultiHeadedAttention 4 | from SubLayerConnection import SublayerConnection 5 | from DenseLayer import DenseLayer 6 | from ConvolutionForward import ConvolutionLayer 7 | from Multihead_Combination import MultiHeadedCombination 8 | 9 | 10 | class TransformerBlock(nn.Module): 11 | """ 12 | Bidirectional Encoder = Transformer (self-attention) 13 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection 14 | """ 15 | 16 | def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout): 17 | """ 18 | :param hidden: hidden size of transformer 19 | :param attn_heads: head sizes of multi-head attention 20 | :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size 21 | :param dropout: dropout rate 22 | """ 23 | 24 | super().__init__() 25 | self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden) 26 | self.combination = MultiHeadedCombination(h=attn_heads, d_model=hidden) 27 | self.feed_forward = DenseLayer(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout) 28 | self.conv_forward = ConvolutionLayer(dmodel=hidden, layernum=hidden) 29 | self.sublayer1 = SublayerConnection(size=hidden, dropout=dropout) 30 | self.sublayer2 = SublayerConnection(size=hidden, dropout=dropout) 31 | self.sublayer3 = SublayerConnection(size=hidden, dropout=dropout) 32 | self.sublayer4 = SublayerConnection(size=hidden, dropout=dropout) 33 | self.dropout = nn.Dropout(p=dropout) 34 | 35 | def forward(self, x, mask, charEm, treemask=None, isTree=False): 36 | x = self.sublayer1(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) 37 | x = self.sublayer2(x, lambda _x: self.combination.forward(_x, _x, charEm)) 38 | if isTree: 39 | x = self.sublayer3(x, lambda _x: self.attention.forward(_x, _x, _x, mask=treemask)) 40 | #x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) 41 | x = self.sublayer4(x, self.feed_forward) 42 | else: 43 | x = self.sublayer3(x, lambda _x:self.conv_forward.forward(_x, mask)) 44 | return self.dropout(x) -------------------------------------------------------------------------------- /hearthstone/TreeConv.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from gelu import GELU 4 | class TreeConv(nn.Module): 5 | def __init__(self, kernel, dmodel): 6 | super(TreeConv ,self).__init__() 7 | self.kernel = kernel 8 | self.conv = nn.Conv2d(dmodel, dmodel, (1, kernel)) 9 | self.activate = GELU() 10 | def forward(self, state, inputad): 11 | tmp = [state] 12 | tmpState = state 13 | for i in range(self.kernel - 1): 14 | tmpState = torch.matmul(inputad, tmpState) 15 | tmp.append(tmpState) 16 | states = torch.stack(tmp, 2) 17 | convstates = self.activate(self.conv(states.permute(0, 3, 1, 2))) 18 | convstates = convstates.squeeze(3).permute(0, 2, 1) 19 | return convstates -------------------------------------------------------------------------------- /hearthstone/TreeConvGen.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from gelu import GELU 4 | class TreeConvGen(nn.Module): 5 | def __init__(self, kernel, dmodel): 6 | super(TreeConvGen ,self).__init__() 7 | self.kernel = kernel 8 | self.conv = nn.Conv2d(dmodel, dmodel, (1, kernel)) 9 | self.activate = GELU() 10 | def forward(self, state, inputad, inputgen): 11 | tmp = [] 12 | tmpState = state 13 | for i in range(self.kernel): 14 | tmpState = torch.matmul(inputad, tmpState) 15 | tmp.append(torch.matmul(inputgen, tmpState)) 16 | states = torch.stack(tmp, 2) 17 | convstates = self.activate(self.conv(states.permute(0, 3, 1, 2))) 18 | convstates = convstates.squeeze(3).permute(0, 2, 1) 19 | return convstates -------------------------------------------------------------------------------- /hearthstone/__pycache__/Attention.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Attention.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Attention.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Attention.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/CombinationLayer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/CombinationLayer.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/CombinationLayer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/CombinationLayer.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/CombinationLayer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/CombinationLayer.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/ConvolutionForward.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/ConvolutionForward.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/ConvolutionForward.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/ConvolutionForward.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/ConvolutionForward.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/ConvolutionForward.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Dataset.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Dataset.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Dataset.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/DenseLayer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/DenseLayer.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/DenseLayer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/DenseLayer.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/DenseLayer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/DenseLayer.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Embedding.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Embedding.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Embedding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Embedding.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Embedding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Embedding.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/LayerNorm.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/LayerNorm.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/LayerNorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/LayerNorm.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/LayerNorm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/LayerNorm.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Model.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Model.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Model.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Multihead_Attention.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Multihead_Attention.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Multihead_Attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Multihead_Attention.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Multihead_Attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Multihead_Attention.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Multihead_Combination.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Multihead_Combination.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Multihead_Combination.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Multihead_Combination.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Multihead_Combination.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Multihead_Combination.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Radam.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Radam.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Radam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Radam.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Radam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Radam.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/ScheduledOptim.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/ScheduledOptim.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/ScheduledOptim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/ScheduledOptim.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/ScheduledOptim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/ScheduledOptim.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/SubLayerConnection.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/SubLayerConnection.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/SubLayerConnection.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/SubLayerConnection.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/SubLayerConnection.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/SubLayerConnection.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/TokenEmbedding.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/TokenEmbedding.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/TokenEmbedding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/TokenEmbedding.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/TokenEmbedding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/TokenEmbedding.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Transfomer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Transfomer.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Transfomer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Transfomer.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/Transfomer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/Transfomer.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/TreeConv.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/TreeConv.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/TreeConv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/TreeConv.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/TreeConv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/TreeConv.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/TreeConvGen.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/TreeConvGen.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/TreeConvGen.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/TreeConvGen.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/TreeConvGen.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/TreeConvGen.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/decodeTrans.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/decodeTrans.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/decodeTrans.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/decodeTrans.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/decodeTrans.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/decodeTrans.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/gcnn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/gcnn.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/gcnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/gcnn.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/gcnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/gcnn.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/gcnnnormal.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/gcnnnormal.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/gcnnnormal.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/gcnnnormal.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/gcnnnormal.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/gcnnnormal.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/gelu.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/gelu.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/gelu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/gelu.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/gelu.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/gelu.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/postionEmbedding.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/postionEmbedding.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/postionEmbedding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/postionEmbedding.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/postionEmbedding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/postionEmbedding.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/rightnTransfomer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/rightnTransfomer.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/rightnTransfomer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/rightnTransfomer.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/rightnTransfomer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/rightnTransfomer.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/vocab.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/vocab.cpython-35.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/vocab.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/vocab.cpython-37.pyc -------------------------------------------------------------------------------- /hearthstone/__pycache__/vocab.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/__pycache__/vocab.cpython-38.pyc -------------------------------------------------------------------------------- /hearthstone/cal.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | rule = pickle.load(open('rulead.pkl', 'rb')) 4 | print(len(rule)) 5 | res = [] 6 | rest = {} 7 | for i in range(len(rule)): 8 | res.append(np.sum(rule[i])) 9 | if np.sum(rule[i]) not in rest: 10 | rest[np.sum(rule[i])] = 0 11 | rest[np.sum(rule[i])] += 1 12 | print(rest) -------------------------------------------------------------------------------- /hearthstone/char_voc.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/char_voc.pkl -------------------------------------------------------------------------------- /hearthstone/code_voc.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/code_voc.pkl -------------------------------------------------------------------------------- /hearthstone/decodeTrans.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from Multihead_Attention import MultiHeadedAttention 4 | from SubLayerConnection import SublayerConnection 5 | from DenseLayer import DenseLayer 6 | from ConvolutionForward import ConvolutionLayer 7 | from Multihead_Combination import MultiHeadedCombination 8 | 9 | 10 | class decodeTransformerBlock(nn.Module): 11 | """ 12 | Bidirectional Encoder = Transformer (self-attention) 13 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection 14 | """ 15 | 16 | def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout): 17 | """ 18 | :param hidden: hidden size of transformer 19 | :param attn_heads: head sizes of multi-head attention 20 | :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size 21 | :param dropout: dropout rate 22 | """ 23 | 24 | super().__init__() 25 | self.attention1 = MultiHeadedAttention(h=attn_heads, d_model=hidden) 26 | self.attention2 = MultiHeadedAttention(h=attn_heads, d_model=hidden) 27 | self.feed_forward = DenseLayer(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout) 28 | self.sublayer1 = SublayerConnection(size=hidden, dropout=dropout) 29 | self.sublayer2 = SublayerConnection(size=hidden, dropout=dropout) 30 | self.sublayer3 = SublayerConnection(size=hidden, dropout=dropout) 31 | self.sublayer4 = SublayerConnection(size=hidden, dropout=dropout) 32 | self.dropout = nn.Dropout(p=dropout) 33 | 34 | def forward(self, x, mask, inputleft, leftmask, inputleft2, leftmask2): 35 | x = self.sublayer1(x, lambda _x: self.attention1.forward(_x, inputleft, inputleft, mask=leftmask)) 36 | x = self.sublayer3(x, lambda _x: self.attention2.forward(_x, inputleft2, inputleft2, mask=leftmask2)) 37 | x = self.sublayer4(x, self.feed_forward) 38 | return self.dropout(x) -------------------------------------------------------------------------------- /hearthstone/gcnn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from gelu import GELU 4 | from SubLayerConnection import SublayerConnection 5 | from Multihead_Combination import MultiHeadedCombination 6 | class GCNN(nn.Module): 7 | def __init__(self, dmodel): 8 | super(GCNN ,self).__init__() 9 | self.hiddensize = dmodel 10 | self.linear = nn.Linear(dmodel, dmodel) 11 | self.linearSecond = nn.Linear(dmodel, dmodel) 12 | self.activate = GELU() 13 | self.dropout = nn.Dropout(p=0.1) 14 | self.subconnect = SublayerConnection(dmodel, 0.1) 15 | self.com = MultiHeadedCombination(8, dmodel) 16 | def forward(self, state, left, inputad): 17 | #print(state.size(), left.size()) 18 | state = torch.cat([left, state], dim=1) 19 | state = self.linear(state) 20 | degree = torch.sum(inputad, dim=-1, keepdim=True).clamp(min=1e-6) 21 | degree2 = torch.sum(inputad, dim=-2, keepdim=True).clamp(min=1e-6) 22 | 23 | degree = 1.0 / torch.sqrt(degree) 24 | degree2 = 1.0 / torch.sqrt(degree2) 25 | #print(degree2.size(), state.size()) 26 | degree2 = degree2 * inputad * degree 27 | #tmp = torch.matmul(degree2, state) 28 | state = self.subconnect(state, lambda _x: self.com(_x, _x, torch.matmul(degree2, state))) #state + torch.matmul(degree2, state) 29 | state = self.linearSecond(state) 30 | return state[:,50:,:]#self.dropout(state)[:,50:,:] 31 | 32 | -------------------------------------------------------------------------------- /hearthstone/gcnnnormal.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from gelu import GELU 4 | from SubLayerConnection import SublayerConnection 5 | from Multihead_Combination import MultiHeadedCombination 6 | class GCNNM(nn.Module): 7 | def __init__(self, dmodel): 8 | super(GCNNM ,self).__init__() 9 | self.hiddensize = dmodel 10 | self.linear = nn.Linear(dmodel, dmodel) 11 | self.linearSecond = nn.Linear(dmodel, dmodel) 12 | self.activate = GELU() 13 | self.dropout = nn.Dropout(p=0.1) 14 | self.subconnect = SublayerConnection(dmodel, 0.1) 15 | self.com = MultiHeadedCombination(8, dmodel) 16 | self.comb = MultiHeadedCombination(8, dmodel) 17 | self.subconnect1 = SublayerConnection(dmodel, 0.1) 18 | def forward(self, state, inputad, rule): 19 | #print(rule.size()) 20 | state = self.subconnect1(state, lambda _x:self.comb(_x, _x, rule, batch_size=1))# 21 | state = self.linear(state) 22 | #print(state.size()) 23 | degree = torch.sum(inputad, dim=-1, keepdim=True).clamp(min=1e-6) 24 | degree2 = torch.sum(inputad, dim=-2, keepdim=True).clamp(min=1e-6) 25 | 26 | degree = 1.0 / torch.sqrt(degree) 27 | degree2 = 1.0 / torch.sqrt(degree2) 28 | degree2 = degree2 * inputad * degree 29 | state2 = torch.matmul(degree2, state) 30 | #state = self.linearSecond(state) 31 | state = self.subconnect(state, lambda _x: self.com(_x, _x, state2, batch_size=1)) #state + torch.matmul(degree2, state) 32 | return state#self.dropout(state)[:,50:,:] 33 | 34 | -------------------------------------------------------------------------------- /hearthstone/gelu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | class GELU(nn.Module): 7 | """ 8 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU 9 | """ 10 | 11 | def forward(self, x): 12 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) -------------------------------------------------------------------------------- /hearthstone/nl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/nl.pkl -------------------------------------------------------------------------------- /hearthstone/nl_voc.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/nl_voc.pkl -------------------------------------------------------------------------------- /hearthstone/outval.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/outval.txt -------------------------------------------------------------------------------- /hearthstone/postionEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | class PositionalEmbedding(nn.Module): 7 | 8 | def __init__(self, d_model, max_len=1024): 9 | super().__init__() 10 | 11 | # Compute the positional encodings once in log space. 12 | pe = torch.zeros(max_len, d_model).float() 13 | pe.require_grad = False 14 | 15 | position = torch.arange(0, max_len).float().unsqueeze(1) 16 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 17 | 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | 21 | pe = pe.unsqueeze(0) 22 | self.register_buffer('pe', pe) 23 | 24 | def forward(self, x): 25 | return self.pe[:, :x.size(-1)] 26 | -------------------------------------------------------------------------------- /hearthstone/process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from tqdm import tqdm 5 | from nltk import word_tokenize 6 | class Node: 7 | def __init__(self, name): 8 | self.name = name 9 | self.father = None 10 | self.child = [] 11 | def visitTree(node, ast): 12 | if isinstance(ast, dict): 13 | for x in ast: 14 | nnode = Node(x) 15 | nnode.father = node 16 | node.child.append(nnode) 17 | visitTree(nnode, ast[x]) 18 | elif isinstance(ast, list): 19 | nnode = Node("list") 20 | nnode.father = node 21 | node.child.append(nnode) 22 | for x in ast: 23 | nn = Node("stmt") 24 | nnode.child.append(nn) 25 | nn.father = nnode 26 | nod = visitTree(nn, x) 27 | #nnode.child.append(nod) 28 | elif isinstance(ast, str): 29 | if ast == "\'\'" or ast == "\"\"": 30 | ast = "empty_str" 31 | if "\'" in ast or "\"" in ast or ast == "empty_str": 32 | ast = ast.replace("\"", "").replace("\'", "").replace(" ", "_").replace("%", "") 33 | nnode = Node("str_") 34 | nnode.father = node 35 | node.child.append(nnode) 36 | nnnode = Node(ast) 37 | nnnode.father = nnode 38 | nnode.child.append(nnnode) 39 | else: 40 | nnode = Node(ast) 41 | nnode.father = node 42 | node.child.append(nnode) 43 | elif isinstance(ast, bool): 44 | nnode = Node(str(ast)) 45 | nnode.father = node 46 | node.child.append(nnode) 47 | elif not ast: 48 | nnode = Node("none") 49 | nnode.father = node 50 | node.child.append(nnode) 51 | else: 52 | print(type(ast)) 53 | exit(0) 54 | def printTree(r): 55 | s = r.name + " "#print(r.name) 56 | if len(r.child) == 0: 57 | s += "^ " 58 | return s 59 | r.child = sorted(r.child, key=lambda x:x.name) 60 | for c in r.child: 61 | s += printTree(c) 62 | s += "^ "#print(r.name + "^") 63 | return s 64 | tables = json.load(open("tables.json", "r")) 65 | tablename = {} 66 | for t in tables: 67 | tablename[t['db_id']] = t 68 | def sovleNl(dbid, nl): 69 | if nl[-1] == "?": 70 | nl = nl[:-1] + " ?" 71 | elif nl[-1] == ".": 72 | nl = nl[:-1] + " ." 73 | else: 74 | print(nl) 75 | nls = nl.lower().strip().split() 76 | tmp = [] 77 | for i in range(len(nls)): 78 | if nls[i] == "share": 79 | tmp.append('share_') 80 | if nls[i] == "females": 81 | tmp.append('female') 82 | elif "\"" in nls[i] or "\'" in nls[i] or "“" in nls[i]: 83 | nls[i] = nls[i].replace("\"", " | ").replace("\'", " | ").replace("“", " | ").replace("”", " | ") 84 | lst = nls[i].split() 85 | for x in lst: 86 | if x == "|": 87 | tmp.append("\"") 88 | elif x[-1] in ",?.!;?": 89 | tmp += [x[:-1].replace(",", ""), x[-1]] 90 | else: 91 | tmp.append(x) 92 | elif nls[i][-1] in ",?.!;?": 93 | tmp += [nls[i][:-1].replace(",", ""), nls[i][-1]] 94 | else: 95 | tmp.append(nls[i].replace(",", "")) 96 | nls = tmp 97 | #print(nls) 98 | ans = "" 99 | for i, x in enumerate(tablename[dbid]['table_names_original']): 100 | ans += x.lower() + " table_end " 101 | for j, y in enumerate(tablename[dbid]['column_names_original']): 102 | if y[0] == i: 103 | if y[1].lower() == "share": 104 | y[1] = "share_" 105 | ans += y[1].lower() + " " + tablename[dbid]['column_types'][j].lower() + "_end " 106 | ans += "col_end " 107 | ans += " ".join(nls) 108 | ans += " query_end" 109 | return ans 110 | lst = json.loads(open("train_spider.json", "r").read()) 111 | for i, x in tqdm(enumerate(lst)): 112 | q = x['query'] 113 | q = q.lower() 114 | if i == 12: 115 | print(q) 116 | open("data.txt", "w").write(q) 117 | status = subprocess.call(["node", "process.js"])#commands.getstatusoutput('node process.js') 118 | oq = "" 119 | if status != 0: 120 | oq = q 121 | #q = q.replace("INTERSECT", "union").replace("EXCEPT", "union") 122 | open("data.txt", "w").write(q) 123 | status = subprocess.call(["node", "process.js"]) 124 | if status != 0: 125 | print(q) 126 | exit(1) 127 | #exit(1) 128 | s = json.load(open("data.json", "r")) 129 | f = open("train_output/" + str(i + 1) + ".txt", "w") 130 | r = Node("root") 131 | visitTree(r, s) 132 | s = printTree(r) 133 | f.write(s) 134 | f.close() 135 | f = open("train_input/" + str(i + 1) + ".txt", "w") 136 | f.write(sovleNl(x['db_id'], x['question'])) 137 | f.close() 138 | -------------------------------------------------------------------------------- /hearthstone/rightnTransfomer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from Multihead_Attention import MultiHeadedAttention 4 | from SubLayerConnection import SublayerConnection 5 | from DenseLayer import DenseLayer 6 | from ConvolutionForward import ConvolutionLayer 7 | from Multihead_Combination import MultiHeadedCombination 8 | from TreeConv import TreeConv 9 | from gcnn import GCNN 10 | 11 | class rightTransformerBlock(nn.Module): 12 | """ 13 | Bidirectional Encoder = Transformer (self-attention) 14 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection 15 | """ 16 | 17 | def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout): 18 | """ 19 | :param hidden: hidden size of transformer 20 | :param attn_heads: head sizes of multi-head attention 21 | :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size 22 | :param dropout: dropout rate 23 | """ 24 | 25 | super().__init__() 26 | self.attention1 = MultiHeadedAttention(h=attn_heads, d_model=hidden) 27 | self.attention2 = MultiHeadedAttention(h=attn_heads, d_model=hidden) 28 | self.combination = MultiHeadedCombination(h=attn_heads, d_model=hidden) 29 | self.feed_forward = DenseLayer(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout) 30 | self.conv_forward = ConvolutionLayer(dmodel=hidden, layernum=hidden) 31 | self.Tconv_forward = GCNN(dmodel=hidden) 32 | self.sublayer1 = SublayerConnection(size=hidden, dropout=dropout) 33 | self.sublayer2 = SublayerConnection(size=hidden, dropout=dropout) 34 | self.sublayer3 = SublayerConnection(size=hidden, dropout=dropout) 35 | self.sublayer4 = SublayerConnection(size=hidden, dropout=dropout) 36 | self.dropout = nn.Dropout(p=dropout) 37 | 38 | def forward(self, x, mask, inputleft, leftmask, charEm, inputP): 39 | x = self.sublayer1(x, lambda _x: self.attention1.forward(_x, _x, _x, mask=mask)) 40 | x = self.sublayer2(x, lambda _x: self.combination.forward(_x, _x, charEm)) 41 | x = self.sublayer3(x, lambda _x: self.attention2.forward(_x, inputleft, inputleft, mask=leftmask)) 42 | x = self.sublayer4(x, lambda _x: self.Tconv_forward.forward(_x, inputleft, inputP)) 43 | #x = self.sublayer4(x, self.feed_forward) 44 | return self.dropout(x) -------------------------------------------------------------------------------- /hearthstone/rule.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/rule.pkl -------------------------------------------------------------------------------- /hearthstone/rulead.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone/rulead.pkl -------------------------------------------------------------------------------- /hearthstone/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from Dataset import SumDataset 4 | import os 5 | from tqdm import tqdm 6 | from Model import * 7 | import numpy as np 8 | import wandb 9 | from copy import deepcopy 10 | import pickle 11 | from ScheduledOptim import * 12 | import sys 13 | from Radam import RAdam 14 | import torch.nn.functional as F 15 | #from pythonBottom.run import finetune 16 | #from pythonBottom.run import pre 17 | #wandb.init("sql") 18 | class dotdict(dict): 19 | def __getattr__(self, name): 20 | return self[name] 21 | args = dotdict({ 22 | 'NlLen':50, 23 | 'CodeLen':200, 24 | 'batch_size':32, 25 | 'embedding_size':312, 26 | 'WoLen':15, 27 | 'Vocsize':100, 28 | 'Nl_Vocsize':100, 29 | 'max_step':3, 30 | 'margin':0.5, 31 | 'poolsize':50, 32 | 'Code_Vocsize':100, 33 | 'num_steps':50, 34 | 'rulenum':10, 35 | 'seed':0 36 | }) 37 | #os.environ["CUDA_VISIBLE_DEVICES"]="2, 3" 38 | #os.environ['CUDA_LAUNCH_BLOCKING']="1" 39 | def save_model(model, dirs='checkpointSearch/'): 40 | if not os.path.exists(dirs): 41 | os.makedirs(dirs) 42 | torch.save(model.state_dict(), dirs + 'best_model.ckpt') 43 | 44 | 45 | def load_model(model, dirs = 'checkpointSearch/'): 46 | assert os.path.exists(dirs + 'best_model.ckpt'), 'Weights for saved model not found' 47 | #cprint(dirs) 48 | model.load_state_dict(torch.load(dirs + 'best_model.ckpt')) 49 | use_cuda = torch.cuda.is_available() 50 | def gVar(data): 51 | tensor = data 52 | if isinstance(data, np.ndarray): 53 | tensor = torch.from_numpy(data) 54 | else: 55 | assert isinstance(tensor, torch.Tensor) 56 | if use_cuda: 57 | tensor = tensor.cuda() 58 | return tensor 59 | def getAntiMask(size): 60 | ans = np.zeros([size, size]) 61 | for i in range(size): 62 | for j in range(0, i + 1): 63 | ans[i, j] = 1.0 64 | return ans 65 | def getAdMask(size): 66 | ans = np.zeros([size, size]) 67 | for i in range(size - 1): 68 | ans[i, i + 1] = 1.0 69 | return ans 70 | def getRulePkl(vds): 71 | inputruleparent = [] 72 | inputrulechild = [] 73 | for i in range(len(vds.ruledict)): 74 | rule = vds.rrdict[i].strip().lower().split() 75 | inputrulechild.append(vds.pad_seq(vds.Get_Em(rule[2:], vds.Code_Voc), vds.Char_Len)) 76 | inputruleparent.append(vds.Code_Voc[rule[0].lower()]) 77 | return np.array(inputruleparent), np.array(inputrulechild) 78 | def evalacc(model, dev_set): 79 | antimask = gVar(getAntiMask(args.CodeLen)) 80 | a, b = getRulePkl(dev_set) 81 | tmpf = gVar(a).unsqueeze(0).repeat(2, 1).long() 82 | tmpc = gVar(b).unsqueeze(0).repeat(2, 1, 1).long() 83 | devloader = torch.utils.data.DataLoader(dataset=dev_set, batch_size=22, 84 | shuffle=False, drop_last=True, num_workers=1) 85 | model = model.eval() 86 | accs = [] 87 | tcard = [] 88 | antimask2 = antimask.unsqueeze(0).repeat(22, 1, 1).unsqueeze(1) 89 | rulead = gVar(pickle.load(open("rulead.pkl", "rb"))).float().unsqueeze(0).repeat(2, 1, 1) 90 | tmpindex = gVar(np.arange(len(dev_set.ruledict))).unsqueeze(0).repeat(2, 1).long() 91 | for devBatch in tqdm(devloader): 92 | for i in range(len(devBatch)): 93 | devBatch[i] = gVar(devBatch[i]) 94 | with torch.no_grad(): 95 | _, pre = model(devBatch[0], devBatch[1], devBatch[2], devBatch[3], devBatch[4], devBatch[6], devBatch[7], devBatch[8], tmpf, tmpc, tmpindex, rulead, antimask2, devBatch[5]) 96 | pred = pre.argmax(dim=-1) 97 | resmask = torch.gt(devBatch[5], 0) 98 | acc = (torch.eq(pred, devBatch[5]) * resmask).float()#.mean(dim=-1) 99 | predres = (1 - acc) * pred.float() * resmask.float() 100 | accsum = torch.sum(acc, dim=-1) 101 | resTruelen = torch.sum(resmask, dim=-1).float() 102 | print(torch.eq(accsum, resTruelen)) 103 | cnum = (torch.eq(accsum, resTruelen)).sum().float() 104 | acc = acc.sum(dim=-1) / resTruelen 105 | accs.append(acc.mean().item()) 106 | tcard.append(cnum.item()) 107 | #print(devBatch[5]) 108 | #print(predres) 109 | tnum = np.sum(tcard) 110 | acc = np.mean(accs) 111 | #wandb.log({"accuracy":acc}) 112 | return acc, tnum 113 | def train(): 114 | torch.manual_seed(args.seed) 115 | np.random.seed(args.seed) 116 | train_set = SumDataset(args, "train") 117 | print(len(train_set.rrdict)) 118 | a, b = getRulePkl(train_set) 119 | tmpf = gVar(a).unsqueeze(0).repeat(2, 1).long() 120 | tmpc = gVar(b).unsqueeze(0).repeat(2, 1, 1).long() 121 | rulead = gVar(pickle.load(open("rulead.pkl", "rb"))).float().unsqueeze(0).repeat(2, 1, 1) 122 | tmpindex = gVar(np.arange(len(train_set.ruledict))).unsqueeze(0).repeat(2, 1).long() 123 | args.Code_Vocsize = len(train_set.Code_Voc) 124 | args.Nl_Vocsize = len(train_set.Nl_Voc) 125 | args.Vocsize = len(train_set.Char_Voc) 126 | args.rulenum = len(train_set.ruledict) + args.NlLen 127 | dev_set = SumDataset(args, "val") 128 | test_set = SumDataset(args, "test") 129 | data_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=args.batch_size, 130 | shuffle=True, drop_last=True, num_workers=1) 131 | model = Decoder(args) 132 | #load_model(model) 133 | optimizer = optim.Adam(model.parameters(), lr=1e-4) 134 | optimizer = ScheduledOptim(optimizer, d_model=args.embedding_size, n_warmup_steps=4000) 135 | maxAcc = 0 136 | maxC = 0 137 | maxAcc2 = 0 138 | maxC2 = 0 139 | if torch.cuda.is_available(): 140 | print('using GPU') 141 | #os.environ["CUDA_VISIBLE_DEVICES"] = "3" 142 | model = model.cuda() 143 | model = nn.DataParallel(model, device_ids=[0, 1]) 144 | antimask = gVar(getAntiMask(args.CodeLen)) 145 | #model.to() 146 | for epoch in range(100000): 147 | j = 0 148 | for dBatch in tqdm(data_loader): 149 | if j % 3000 == 0: 150 | acc, tnum = evalacc(model, dev_set) 151 | acc2, tnum2 = evalacc(model, test_set) 152 | print("for dev " + str(acc) + " " + str(tnum) + " max is " + str(maxC)) 153 | print("for test " + str(acc2) + " " + str(tnum2) + " max is " + str(maxC2)) 154 | if maxC < tnum or maxC == tnum and maxAcc < acc: 155 | maxC = tnum 156 | maxAcc = acc 157 | print("find better acc " + str(maxAcc)) 158 | save_model(model.module, 'checkpointSearch/') 159 | if maxC2 < tnum2 or maxC2 == tnum2 and maxAcc2 < acc2: 160 | maxC2 = tnum2 161 | maxAcc2 = acc2 162 | print("find better acc " + str(maxAcc2)) 163 | save_model(model.module, "test%s/"%args.seed) 164 | #exit(0) 165 | antimask2 = antimask.unsqueeze(0).repeat(args.batch_size, 1, 1).unsqueeze(1) 166 | model = model.train() 167 | for i in range(len(dBatch)): 168 | dBatch[i] = gVar(dBatch[i]) 169 | loss, _ = model(dBatch[0], dBatch[1], dBatch[2], dBatch[3], dBatch[4], dBatch[6], dBatch[7], dBatch[8], tmpf, tmpc, tmpindex, rulead, antimask2, dBatch[5]) 170 | loss = torch.mean(loss) + F.max_pool1d(loss.unsqueeze(0).unsqueeze(0), 2, 1).squeeze(0).squeeze(0).mean() + F.max_pool1d(loss.unsqueeze(0).unsqueeze(0), 3, 1).squeeze(0).squeeze(0).mean() + F.max_pool1d(loss.unsqueeze(0).unsqueeze(0), 4, 1).squeeze(0).squeeze(0).mean() 171 | optimizer.zero_grad() 172 | loss.backward() 173 | optimizer.step_and_update_lr() 174 | j += 1 175 | class Node: 176 | def __init__(self, name, d): 177 | self.name = name 178 | self.depth = d 179 | self.father = None 180 | self.child = [] 181 | self.sibiling = None 182 | self.expanded = False 183 | self.fatherlistID = 0 184 | class SearchNode: 185 | def __init__(self, ds): 186 | self.state = [ds.ruledict["start -> Module"]] 187 | self.prob = 0 188 | self.aprob = 0 189 | self.bprob = 0 190 | self.root = Node("Module", 2) 191 | self.inputparent = ["start"] 192 | self.parent = np.zeros([args.NlLen + args.CodeLen, args.NlLen + args.CodeLen]) 193 | #self.parent[args.NlLen] 194 | self.expanded = None 195 | self.ruledict = ds.rrdict 196 | self.expandedname = [] 197 | self.depth = [1] 198 | for x in ds.ruledict: 199 | self.expandedname.append(x.strip().split()[0]) 200 | self.everTreepath = [] 201 | def selcetNode(self, root): 202 | if not root.expanded and root.name in self.expandedname and root.name != "body" and self.state[root.fatherlistID] < len(self.ruledict): 203 | return root 204 | else: 205 | for x in root.child: 206 | ans = self.selcetNode(x) 207 | if ans: 208 | return ans 209 | if root.name == "body" and root.expanded == False: 210 | return root 211 | return None 212 | def selectExpandedNode(self): 213 | self.expanded = self.selcetNode(self.root) 214 | def getRuleEmbedding(self, ds, nl): 215 | inputruleparent = [] 216 | inputrulechild = [] 217 | for x in self.state: 218 | if x >= len(ds.rrdict): 219 | inputruleparent.append(ds.Get_Em(["value"], ds.Code_Voc)[0]) 220 | inputrulechild.append(ds.pad_seq(ds.Get_Em(["copyword"], ds.Code_Voc), ds.Char_Len)) 221 | else: 222 | rule = ds.rrdict[x].strip().lower().split() 223 | inputruleparent.append(ds.Get_Em([rule[0]], ds.Code_Voc)[0]) 224 | inputrulechild.append(ds.pad_seq(ds.Get_Em(rule[2:], ds.Code_Voc), ds.Char_Len)) 225 | tmp = [ds.pad_seq(ds.Get_Em(['start'], ds.Code_Voc), 10)] + self.everTreepath 226 | inputrule = ds.pad_seq(self.state, ds.Code_Len) 227 | inputrulechild = ds.pad_list(tmp, ds.Code_Len, 10) 228 | #inputrulechild = ds.pad_list(inputrulechild, ds.Code_Len, ds.Char_Len) 229 | inputruleparent = ds.pad_seq(ds.Get_Em(self.inputparent, ds.Code_Voc), ds.Code_Len) 230 | inputdepth = ds.pad_list(self.depth, ds.Code_Len, 40) 231 | #inputdepth = ds.pad_seq(self.depth, ds.Code_Len) 232 | return inputrule, inputrulechild, inputruleparent, inputdepth 233 | def getTreePath(self, ds): 234 | tmppath = [self.expanded.name.lower()] 235 | node = self.expanded.father 236 | while node: 237 | tmppath.append(node.name.lower()) 238 | node = node.father 239 | tmp = ds.pad_seq(ds.Get_Em(tmppath, ds.Code_Voc), 10) 240 | self.everTreepath.append(tmp) 241 | return ds.pad_list(self.everTreepath, ds.Code_Len, 10) 242 | def applyrule(self, rule, nl): 243 | if rule >= len(self.ruledict): 244 | if rule - len(self.ruledict) >= len(nl): 245 | return False 246 | if self.expanded.depth + 1 >= 40: 247 | nnode = Node(nl[rule - len(self.ruledict)], 39) 248 | else: 249 | nnode = Node(nl[rule - len(self.ruledict)], self.expanded.depth + 1) 250 | self.expanded.child.append(nnode) 251 | nnode.father = self.expanded 252 | nnode.fatherlistID = len(self.state) 253 | else: 254 | rules = self.ruledict[rule] 255 | if rules.strip().split()[0] != self.expanded.name: 256 | #print(self.expanded.name) 257 | return False 258 | #assert(rules.strip().split()[0] == self.expanded.name) 259 | if rules == self.expanded.name + " -> End ": 260 | self.expanded.expanded = True 261 | else: 262 | for x in rules.strip().split()[2:]: 263 | if self.expanded.depth + 1 >= 40: 264 | nnode = Node(x, 39) 265 | else: 266 | nnode = Node(x, self.expanded.depth + 1) 267 | #nnode = Node(x, self.expanded.depth + 1) 268 | self.expanded.child.append(nnode) 269 | nnode.father = self.expanded 270 | nnode.fatherlistID = len(self.state) 271 | #self.parent.append(self.expanded.fatherlistID) 272 | self.parent[args.NlLen + len(self.depth), args.NlLen + self.expanded.fatherlistID] = 1 273 | if rule >= len(self.ruledict): 274 | self.parent[args.NlLen + len(self.depth), rule - len(self.ruledict)] = 1 275 | if rule >= len(self.ruledict): 276 | self.state.append(len(self.ruledict) - 1) 277 | else: 278 | self.state.append(rule) 279 | self.inputparent.append(self.expanded.name.lower()) 280 | self.depth.append(self.expanded.depth) 281 | if self.expanded.name != "body": 282 | self.expanded.expanded = True 283 | return True 284 | def printTree(self, r): 285 | s = r.name + " "#print(r.name) 286 | if len(r.child) == 0: 287 | s += "^ " 288 | return s 289 | #r.child = sorted(r.child, key=lambda x:x.name) 290 | for c in r.child: 291 | s += self.printTree(c) 292 | s += "^ "#print(r.name + "^") 293 | return s 294 | def getTreestr(self): 295 | return self.printTree(self.root) 296 | 297 | 298 | beamss = [] 299 | def BeamSearch(inputnl, vds, model, beamsize, batch_size, k): 300 | args.batch_size = len(inputnl[0]) 301 | rulead = gVar(pickle.load(open("rulead.pkl", "rb"))).float().unsqueeze(0).repeat(2, 1, 1) 302 | a, b = getRulePkl(vds) 303 | tmpf = gVar(a).unsqueeze(0).repeat(2, 1).long() 304 | tmpc = gVar(b).unsqueeze(0).repeat(2, 1, 1).long() 305 | tmpindex = gVar(np.arange(len(vds.ruledict))).unsqueeze(0).repeat(2, 1).long() 306 | with torch.no_grad(): 307 | beams = {} 308 | for i in range(batch_size): 309 | beams[i] = [SearchNode(vds)] 310 | index = 0 311 | antimask = gVar(getAntiMask(args.CodeLen)) 312 | endnum = {} 313 | continueSet = {} 314 | while True: 315 | print(index) 316 | tmpbeam = {} 317 | ansV = {} 318 | if len(endnum) == args.batch_size: 319 | break 320 | if index >= args.CodeLen: 321 | break 322 | for p in range(beamsize): 323 | tmprule = [] 324 | tmprulechild = [] 325 | tmpruleparent = [] 326 | tmptreepath = [] 327 | tmpAd = [] 328 | validnum = [] 329 | tmpdepth = [] 330 | for i in range(args.batch_size): 331 | if p >= len(beams[i]): 332 | continue 333 | x = beams[i][p] 334 | #print(x.getTreestr()) 335 | x.selectExpandedNode() 336 | if x.expanded == None or len(x.state) >= args.CodeLen: 337 | ansV.setdefault(i, []).append(x) 338 | else: 339 | #print(x.expanded.name) 340 | validnum.append(i) 341 | a, b, c, d = x.getRuleEmbedding(vds, vds.nl[args.batch_size * k + i]) 342 | tmprule.append(a) 343 | tmprulechild.append(b) 344 | tmpruleparent.append(c) 345 | tmptreepath.append(x.getTreePath(vds)) 346 | #tmp = np.eye(vds.Code_Len)[x.parent] 347 | #tmp = np.concatenate([tmp, np.zeros([vds.Code_Len, vds.Code_Len])], axis=0)[:vds.Code_Len,:]#self.pad_list(tmp, self.Code_Len, self.Code_Len) 348 | tmpAd.append(x.parent) 349 | tmpdepth.append(d) 350 | #print("--------------------------") 351 | if len(tmprule) == 0: 352 | continue 353 | batch_size = len(tmprule) 354 | antimasks = antimask.unsqueeze(0).repeat(batch_size, 1, 1).unsqueeze(1) 355 | tmprule = np.array(tmprule) 356 | tmprulechild = np.array(tmprulechild) 357 | tmpruleparent = np.array(tmpruleparent) 358 | tmptreepath = np.array(tmptreepath) 359 | tmpAd = np.array(tmpAd) 360 | tmpdepth = np.array(tmpdepth) 361 | '''print(inputnl[3][:index + 1], tmprule[:index + 1]) 362 | assert(np.array_equal(inputnl[3][0][:index + 1], tmprule[0][:index + 1])) 363 | assert(np.array_equal(inputnl[4][0][:index + 1], tmpruleparent[0][:index + 1])) 364 | assert(np.array_equal(inputnl[5][0][:index + 1], tmprulechild[0][:index + 1])) 365 | assert(np.array_equal(inputnl[6][0][:index + 1], tmpAd[0][:index + 1])) 366 | assert(np.array_equal(inputnl[7][0][:index + 1], tmptreepath[0][:index + 1])) 367 | assert(np.array_equal(inputnl[8][0][:index + 1], tmpdepth[0][:index + 1]))''' 368 | #result = model(gVar(inputnl[0][validnum]), gVar(inputnl[1][validnum]), gVar(tmprule), gVar(tmpruleparent), gVar(tmprulechild), gVar(tmpAd), gVar(tmptreepath), gVar(tmpdepth), antimasks, None, "test") 369 | result = model(gVar(inputnl[0][validnum]), gVar(inputnl[1][validnum]), gVar(tmprule), gVar(tmpruleparent), gVar(tmprulechild), gVar(tmpAd), gVar(tmptreepath), None, tmpf, tmpc, tmpindex, rulead, antimasks, None, "test") 370 | results = result.data.cpu().numpy() 371 | #print(result, inputCode) 372 | currIndex = 0 373 | for j in range(args.batch_size): 374 | if j not in validnum: 375 | continue 376 | x = beams[j][p] 377 | tmpbeamsize = beamsize 378 | result = np.negative(results[currIndex, index]) 379 | currIndex += 1 380 | cresult = np.negative(result) 381 | indexs = np.argsort(result) 382 | for i in range(tmpbeamsize): 383 | if tmpbeamsize >= 20: 384 | break 385 | copynode = deepcopy(x) 386 | #if indexs[i] >= len(vds.rrdict): 387 | #print(cresult[indexs[i]]) 388 | c = copynode.applyrule(indexs[i], vds.nl[args.batch_size * k + j]) 389 | if not c: 390 | tmpbeamsize += 1 391 | continue 392 | copynode.prob = copynode.prob + np.log(cresult[indexs[i]]) 393 | tmpbeam.setdefault(j, []).append(copynode) 394 | #print(tmpbeam[0].prob) 395 | for i in range(args.batch_size): 396 | if i in ansV: 397 | if len(ansV[i]) == beamsize: 398 | endnum[i] = 1 399 | for j in range(args.batch_size): 400 | if j in tmpbeam: 401 | if j in ansV: 402 | for x in ansV[j]: 403 | tmpbeam[j].append(x) 404 | beams[j] = sorted(tmpbeam[j], key=lambda x: x.prob, reverse=True)[:beamsize] 405 | index += 1 406 | '''for p in range(beamsize): 407 | beam = [] 408 | nls = [] 409 | for i in range(len(beams)): 410 | if p >= len(beams[i]): 411 | beam.append(beams[i][len(beams[i]) - 1]) 412 | else: 413 | beam.append(beams[i][p]) 414 | nls.append(vds.nl[args.batch_size * k + i])''' 415 | #finetune(beam, k, nls, args.batch_size) 416 | #for i in range(len(beams)): 417 | # beamss.append(deepcopy(beams[i])) 418 | 419 | 420 | for i in range(len(beams)): 421 | mans = -1000000 422 | lst = beams[i] 423 | tmpans = 0 424 | for y in lst: 425 | #print(y.getTreestr()) 426 | if y.prob > mans: 427 | mans = y.prob 428 | tmpans = y 429 | beams[i] = tmpans 430 | #open("beams.pkl", "wb").write(pickle.dumps(beamss)) 431 | return beams 432 | #return beams 433 | def test(): 434 | #pre() 435 | dev_set = SumDataset(args, "test") 436 | print(len(dev_set)) 437 | args.Nl_Vocsize = len(dev_set.Nl_Voc) 438 | args.Code_Vocsize = len(dev_set.Code_Voc) 439 | args.Vocsize = len(dev_set.Char_Voc) 440 | args.rulenum = len(dev_set.ruledict) + args.NlLen 441 | args.batch_size = 22 442 | rdic = {} 443 | for x in dev_set.Nl_Voc: 444 | rdic[dev_set.Nl_Voc[x]] = x 445 | #print(dev_set.Nl_Voc) 446 | model = Decoder(args) 447 | if use_cuda: 448 | print('using GPU') 449 | #os.environ["CUDA_VISIBLE_DEVICES"] = "3" 450 | model = model.cuda() 451 | devloader = torch.utils.data.DataLoader(dataset=dev_set, batch_size=args.batch_size, 452 | shuffle=False, drop_last=False, num_workers=0) 453 | model = model.eval() 454 | load_model(model) 455 | f = open("outval.txt", "w") 456 | index = 0 457 | for x in tqdm(devloader): 458 | '''if index == 0: 459 | index += 1 460 | continue''' 461 | ans = BeamSearch((x[0], x[1], x[5], x[2], x[3], x[4], x[6], x[7], x[8]), dev_set, model, 15, args.batch_size, index) 462 | index += 1 463 | for i in range(args.batch_size): 464 | beam = ans[i] 465 | #print(beam[0].parent, beam[0].everTreepath, beam[0].state) 466 | f.write(beam.getTreestr()) 467 | f.write("\n") 468 | f.flush() 469 | #exit(0) 470 | #f.write(" ".join(ans.ans[1:-1])) 471 | #f.write("\n") 472 | #f.flush()#print(ans) 473 | if __name__ == "__main__": 474 | if sys.argv[1] == "train": 475 | train() 476 | args.seed = sys.argv[2] 477 | else: 478 | test() 479 | #test() 480 | 481 | 482 | 483 | 484 | -------------------------------------------------------------------------------- /hearthstone/solvetree.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import pickle 4 | lst = ["train", "dev"] 5 | rules = {"pad":0} 6 | onelist = ["list"] 7 | rulelist = [] 8 | fatherlist = [] 9 | fathername = [] 10 | depthlist = [] 11 | copynode = [] 12 | class Node: 13 | def __init__(self, name, s): 14 | self.name = name 15 | self.id = s 16 | self.father = None 17 | self.child = [] 18 | def parseTree(treestr): 19 | tokens = treestr.split() 20 | root = Node("root", 0) 21 | currnode = root 22 | for i, x in enumerate(tokens[1:]): 23 | if x != "^": 24 | nnode = Node(x, i + 1) 25 | nnode.father = currnode 26 | currnode.child.append(nnode) 27 | currnode = nnode 28 | else: 29 | currnode = currnode.father 30 | return root 31 | def getRule(node, nls, currId, d): 32 | global rules 33 | global onelist 34 | global rulelist 35 | global fatherlist 36 | global depthlist 37 | global copynode 38 | if len(node.child) == 0: 39 | return [], [] 40 | if " -> End " not in rules: 41 | rules[" -> End "] = len(rules) 42 | return [rules[" -> End "]] 43 | child = sorted(node.child, key=lambda x:x.name) 44 | if len(node.child) == 1 and node.child[0].name in nls: 45 | copynode.append(node.name) 46 | rulelist.append(10000 + nls.index(node.child[0].name)) 47 | fatherlist.append(currId) 48 | fathername.append(node.name) 49 | depthlist.append(d) 50 | currid = len(rulelist) - 1 51 | for x in child: 52 | getRule(x, nls, currId, d + 1) 53 | #rulelist.extend(a) 54 | #fatherlist.extend(b) 55 | else: 56 | if node.name not in onelist: 57 | rule = node.name + " -> " 58 | for x in child: 59 | rule += x.name + " " 60 | if rule in rules: 61 | rulelist.append(rules[rule]) 62 | else: 63 | rules[rule] = len(rules) 64 | rulelist.append(rules[rule]) 65 | fatherlist.append(currId) 66 | fathername.append(node.name) 67 | depthlist.append(d) 68 | currid = len(rulelist) - 1 69 | for x in child: 70 | getRule(x, nls, currid, d + 1) 71 | else: 72 | for x in (child): 73 | rule = node.name + " -> " + x.name 74 | if rule in rules: 75 | rulelist.append(rules[rule]) 76 | else: 77 | rules[rule] = len(rules) 78 | rulelist.append(rules[rule]) 79 | fatherlist.append(currId) 80 | fathername.append(node.name) 81 | depthlist.append(d) 82 | getRule(x, nls, len(rulelist) - 1, d + 1) 83 | rule = node.name + " -> End " 84 | if rule in rules: 85 | rulelist.append(rules[rule]) 86 | else: 87 | rules[rule] = len(rules) 88 | rulelist.append(rules[rule]) 89 | fatherlist.append(currId) 90 | fathername.append(node.name) 91 | depthlist.append(d) 92 | '''if node.name == "root": 93 | print('rr') 94 | print('rr') 95 | print(rulelist)''' 96 | '''rule = " -> End " 97 | if rule in rules: 98 | rulelist.append(rules[rule]) 99 | else: 100 | rules[rule] = len(rules) 101 | rulelist.append(rules[rule])''' 102 | #return rulelist, fatherlist 103 | for x in lst: 104 | inputdir = x + "_input/" 105 | outputdir = x + "_output/" 106 | wf = open(x + ".txt", "w") 107 | for i in tqdm(range(len(os.listdir(inputdir)))): 108 | fname = inputdir + str(i + 1) + ".txt" 109 | ofname = outputdir + str(i + 1) + ".txt" 110 | f = open(fname, "r") 111 | nls = f.read() 112 | f.close() 113 | f = open(ofname, "r") 114 | asts = f.read() 115 | f.close() 116 | wf.write(nls + "\n") 117 | #wf.write(asts + "\n") 118 | assert(len(asts.split()) == 2 * asts.split().count('^')) 119 | root = parseTree(asts) 120 | rulelist = [] 121 | fatherlist = [] 122 | fathername = [] 123 | depthlist = [] 124 | getRule(root, nls.split(), -1, 2) 125 | s = "" 126 | for x in rulelist: 127 | s += str(x) + " " 128 | wf.write(s + "\n") 129 | s = "" 130 | for x in fatherlist: 131 | s += str(x) + " " 132 | wf.write(s + "\n") 133 | s = "" 134 | for x in depthlist: 135 | s += str(x) + " " 136 | wf.write(s + "\n") 137 | wf.write(" ".join(fathername) + "\n") 138 | #print(rules) 139 | #print(asts) 140 | wf.close() 141 | wf = open("rule.pkl", "wb") 142 | wf.write(pickle.dumps(rules)) 143 | wf.close() 144 | print(copynode) -------------------------------------------------------------------------------- /hearthstone/vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import print_function 4 | import argparse 5 | from collections import Counter 6 | from itertools import chain 7 | 8 | 9 | class VocabEntry(object): 10 | def __init__(self): 11 | self.word2id = dict() 12 | self.unk_id = 3 13 | '''self.word2id[''] = 0 14 | self.word2id[''] = 1 15 | self.word2id[''] = 2 16 | self.word2id[''] = 3''' 17 | self.word2id[""] = 0 18 | #self.word2id["NothingHere"] = 0 19 | self.word2id["Unknown"] = 1 20 | '''self.word2id["Unknown"] = 1 21 | self.word2id["NothingHere"] = 0 22 | self.word2id["NoneCopy"] = 2 23 | self.word2id["CopyNode"] = 3 24 | self.word2id[""] = 4''' 25 | 26 | self.id2word = {v: k for k, v in self.word2id.items()} 27 | 28 | def __getitem__(self, word): 29 | return self.word2id.get(word, self.unk_id) 30 | 31 | def __contains__(self, word): 32 | return word in self.word2id 33 | 34 | def __setitem__(self, key, value): 35 | raise ValueError('vocabulary is readonly') 36 | 37 | def __len__(self): 38 | return len(self.word2id) 39 | 40 | def __repr__(self): 41 | return 'Vocabulary[size=%d]' % len(self) 42 | 43 | def id2word(self, wid): 44 | return self.id2word[wid] 45 | 46 | def add(self, word): 47 | if word not in self: 48 | wid = self.word2id[word] = len(self) 49 | self.id2word[wid] = word 50 | return wid 51 | else: 52 | return self[word] 53 | 54 | def is_unk(self, word): 55 | return word not in self 56 | 57 | @staticmethod 58 | def from_corpus(corpus, size, freq_cutoff=0): 59 | vocab_entry = VocabEntry() 60 | #print(list(chain(*corpus))) 61 | word_freq = Counter(chain(*corpus)) 62 | #print(word_freq) 63 | non_singletons = [w for w in word_freq if word_freq[w] > 1] 64 | singletons = [w for w in word_freq if word_freq[w] == 1] 65 | print('number of word types: %d, number of word types w/ frequency > 1: %d' % (len(word_freq), 66 | len(non_singletons))) 67 | print('singletons: %s' % singletons) 68 | 69 | top_k_words = sorted(word_freq.keys(), reverse=True, key=word_freq.get)[:size] 70 | words_not_included = [] 71 | for word in top_k_words: 72 | if len(vocab_entry) < size: 73 | if word_freq[word] >= freq_cutoff: 74 | vocab_entry.add(word) 75 | else: 76 | words_not_included.append(word) 77 | 78 | print('word types not included: %s' % words_not_included) 79 | 80 | return vocab_entry 81 | 82 | 83 | class Vocab(object): 84 | def __init__(self, **kwargs): 85 | self.entries = [] 86 | for key, item in kwargs.items(): 87 | assert isinstance(item, VocabEntry) 88 | self.__setattr__(key, item) 89 | 90 | self.entries.append(key) 91 | 92 | def __repr__(self): 93 | return 'Vocab(%s)' % (', '.join('%s %swords' % (entry, getattr(self, entry)) for entry in self.entries)) 94 | 95 | 96 | if __name__ == '__main__': 97 | raise NotImplementedError 98 | -------------------------------------------------------------------------------- /hearthstone_preprocess/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone_preprocess/.DS_Store -------------------------------------------------------------------------------- /hearthstone_preprocess/Code_Voc.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone_preprocess/Code_Voc.pkl -------------------------------------------------------------------------------- /hearthstone_preprocess/README.md: -------------------------------------------------------------------------------- 1 | To run this preprocess: 2 | 3 | ```bash runall.sh``` -------------------------------------------------------------------------------- /hearthstone_preprocess/codead.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone_preprocess/codead.pkl -------------------------------------------------------------------------------- /hearthstone_preprocess/cp.sh: -------------------------------------------------------------------------------- 1 | cp *.pkl ../hearthstone 2 | cp *.txt ../hearthstone -------------------------------------------------------------------------------- /hearthstone_preprocess/dev_hs.in: -------------------------------------------------------------------------------- 1 | Assassin's Blade NAME_END 3 ATK_END -1 DEF_END 5 COST_END 4 DUR_END Weapon TYPE_END Rogue PLAYER_CLS_END NIL RACE_END Common RARITY_END NIL 2 | Boulderfist Ogre NAME_END 6 ATK_END 7 DEF_END 6 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Free RARITY_END NIL 3 | Deadly Poison NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Rogue PLAYER_CLS_END NIL RACE_END Free RARITY_END Give your weapon +2 Attack. 4 | Fire Elemental NAME_END 6 ATK_END 5 DEF_END 6 COST_END -1 DUR_END Minion TYPE_END Shaman PLAYER_CLS_END NIL RACE_END Common RARITY_END Battlecry: Deal 3 damage. 5 | Gnomish Inventor NAME_END 2 ATK_END 4 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Battlecry: Draw a card. 6 | Heroic Strike NAME_END -1 ATK_END -1 DEF_END 2 COST_END -1 DUR_END Spell TYPE_END Warrior PLAYER_CLS_END NIL RACE_END Free RARITY_END Give your hero +4 Attack this turn. 7 | Ironbark Protector NAME_END 8 ATK_END 8 DEF_END 8 COST_END -1 DUR_END Minion TYPE_END Druid PLAYER_CLS_END NIL RACE_END Common RARITY_END Taunt 8 | Mark of the Wild NAME_END -1 ATK_END -1 DEF_END 2 COST_END -1 DUR_END Spell TYPE_END Druid PLAYER_CLS_END NIL RACE_END Free RARITY_END Give a minion Taunt and +2/+2. (+2 Attack/+2 Health) 9 | Multi-Shot NAME_END -1 ATK_END -1 DEF_END 4 COST_END -1 DUR_END Spell TYPE_END Hunter PLAYER_CLS_END NIL RACE_END Free RARITY_END Deal $3 damage to two random enemy minions. 10 | Power Word: Shield NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Priest PLAYER_CLS_END NIL RACE_END Free RARITY_END Give a minion +2 Health. NL Draw a card. 11 | Sen'jin Shieldmasta NAME_END 3 ATK_END 5 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Free RARITY_END Taunt 12 | Sinister Strike NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Rogue PLAYER_CLS_END NIL RACE_END Free RARITY_END Deal $3 damage to the enemy hero. 13 | Succubus NAME_END 4 ATK_END 3 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Warlock PLAYER_CLS_END Demon RACE_END Free RARITY_END Battlecry: Discard a random card. 14 | War Golem NAME_END 7 ATK_END 7 DEF_END 7 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END NIL 15 | Acidmaw NAME_END 4 ATK_END 2 DEF_END 7 COST_END -1 DUR_END Minion TYPE_END Hunter PLAYER_CLS_END Beast RACE_END Legendary RARITY_END Whenever another minion takes damage, destroy it. 16 | Boar NAME_END 4 ATK_END 2 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Beast RACE_END NIL RARITY_END Charge 17 | Anodized Robo Cub NAME_END 2 ATK_END 2 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Druid PLAYER_CLS_END Mech RACE_END Common RARITY_END Taunt. Choose One - NL +1 Attack; or +1 Health. 18 | Burrowing Mine NAME_END -1 ATK_END -1 DEF_END 0 COST_END -1 DUR_END Spell TYPE_END Warrior PLAYER_CLS_END NIL RACE_END NIL RARITY_END When you draw this, it explodes. You take 10 damage and draw a card. 19 | Crackle NAME_END -1 ATK_END -1 DEF_END 2 COST_END -1 DUR_END Spell TYPE_END Shaman PLAYER_CLS_END NIL RACE_END Common RARITY_END Deal $3-$6 damage. Overload: (1) 20 | Emergency Coolant NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Neutral PLAYER_CLS_END NIL RACE_END NIL RARITY_END Freeze a minion. 21 | Flying Machine NAME_END 1 ATK_END 4 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Mech RACE_END Common RARITY_END Windfury 22 | Goblin Auto-Barber NAME_END 3 ATK_END 2 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Rogue PLAYER_CLS_END Mech RACE_END Common RARITY_END Battlecry: Give your weapon +1 Attack. 23 | Iron Sensei NAME_END 2 ATK_END 2 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Rogue PLAYER_CLS_END Mech RACE_END Rare RARITY_END At the end of your turn, give another friendly Mech +2/+2. 24 | Mal'Ganis NAME_END 9 ATK_END 7 DEF_END 9 COST_END -1 DUR_END Minion TYPE_END Warlock PLAYER_CLS_END Demon RACE_END Legendary RARITY_END Your other Demons have +2/+2. NL Your hero is Immune. 25 | Mistress of Pain NAME_END 1 ATK_END 4 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Warlock PLAYER_CLS_END Demon RACE_END Rare RARITY_END Whenever this minion deals damage, restore that much Health to your hero. 26 | Powermace NAME_END 3 ATK_END -1 DEF_END 3 COST_END 2 DUR_END Weapon TYPE_END Shaman PLAYER_CLS_END NIL RACE_END Rare RARITY_END Deathrattle: Give a random friendly Mech +2/+2. 27 | Screwjank Clunker NAME_END 2 ATK_END 5 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Warrior PLAYER_CLS_END Mech RACE_END Rare RARITY_END Battlecry: Give a friendly Mech +2/+2. 28 | Sneed's Old Shredder NAME_END 5 ATK_END 7 DEF_END 8 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Mech RACE_END Legendary RARITY_END Deathrattle: Summon a random legendary minion. 29 | Toshley NAME_END 5 ATK_END 7 DEF_END 6 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Legendary RARITY_END Battlecry and Deathrattle: Add a Spare Part card to your hand. 30 | Warbot NAME_END 1 ATK_END 3 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Warrior PLAYER_CLS_END Mech RACE_END Common RARITY_END Enrage: +1 Attack. 31 | Deathlord NAME_END 2 ATK_END 8 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END Taunt. Deathrattle: Your opponent puts a minion from their deck into the battlefield. 32 | Nerub'ar Weblord NAME_END 1 ATK_END 4 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Minions with Battlecry cost (2) more. 33 | Spectral Knight NAME_END 4 ATK_END 6 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Can't be targeted by spells or Hero Powers. 34 | Wailing Soul NAME_END 3 ATK_END 5 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END Battlecry: Silence your other minions. 35 | Amani Berserker NAME_END 2 ATK_END 3 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Enrage: +3 Attack 36 | Archmage Antonidas NAME_END 5 ATK_END 7 DEF_END 7 COST_END -1 DUR_END Minion TYPE_END Mage PLAYER_CLS_END NIL RACE_END Legendary RARITY_END Whenever you cast a spell, add a 'Fireball' spell to your hand. 37 | Bananas NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Neutral PLAYER_CLS_END NIL RACE_END NIL RARITY_END Give a minion +1/+1. 38 | Blessed Champion NAME_END -1 ATK_END -1 DEF_END 5 COST_END -1 DUR_END Spell TYPE_END Paladin PLAYER_CLS_END NIL RACE_END Rare RARITY_END Double a minion's Attack. 39 | Cabal Shadow Priest NAME_END 4 ATK_END 5 DEF_END 6 COST_END -1 DUR_END Minion TYPE_END Priest PLAYER_CLS_END NIL RACE_END Epic RARITY_END Battlecry: Take control of an enemy minion that has 2 or less Attack. 40 | Cone of Cold NAME_END -1 ATK_END -1 DEF_END 4 COST_END -1 DUR_END Spell TYPE_END Mage PLAYER_CLS_END NIL RACE_END Common RARITY_END Freeze a minion and the minions next to it, and deal $1 damage to them. 41 | Defender of Argus NAME_END 2 ATK_END 3 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END Battlecry: Give adjacent minions +1/+1 and Taunt. 42 | Doomhammer NAME_END 2 ATK_END -1 DEF_END 5 COST_END 8 DUR_END Weapon TYPE_END Shaman PLAYER_CLS_END NIL RACE_END Epic RARITY_END Windfury, Overload: (2) 43 | Earth Shock NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Shaman PLAYER_CLS_END NIL RACE_END Common RARITY_END Silence a minion, then deal $1 damage to it. 44 | Eye for an Eye NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Paladin PLAYER_CLS_END NIL RACE_END Common RARITY_END Secret: When your hero takes damage, deal that much damage to the enemy hero. 45 | Flare NAME_END -1 ATK_END -1 DEF_END 2 COST_END -1 DUR_END Spell TYPE_END Hunter PLAYER_CLS_END NIL RACE_END Rare RARITY_END All minions lose Stealth. Destroy all enemy Secrets. Draw a card. 46 | Gorehowl NAME_END 7 ATK_END -1 DEF_END 7 COST_END 1 DUR_END Weapon TYPE_END Warrior PLAYER_CLS_END NIL RACE_END Epic RARITY_END Attacking a minion costs 1 Attack instead of 1 Durability. 47 | Hound NAME_END 1 ATK_END 1 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Hunter PLAYER_CLS_END Beast RACE_END NIL RARITY_END Charge 48 | Injured Blademaster NAME_END 4 ATK_END 7 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END Battlecry: Deal 4 damage to HIMSELF. 49 | Knife Juggler NAME_END 3 ATK_END 2 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END After you summon a minion, deal 1 damage to a random enemy. 50 | Lightwell NAME_END 0 ATK_END 5 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Priest PLAYER_CLS_END NIL RACE_END Rare RARITY_END At the start of your turn, restore 3 Health to a damaged friendly character. 51 | Mana Wyrm NAME_END 1 ATK_END 3 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Mage PLAYER_CLS_END NIL RACE_END Common RARITY_END Whenever you cast a spell, gain +1 Attack. 52 | Mogu'shan Warden NAME_END 1 ATK_END 7 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Taunt 53 | Nourish NAME_END -1 ATK_END -1 DEF_END 5 COST_END -1 DUR_END Spell TYPE_END Druid PLAYER_CLS_END NIL RACE_END Rare RARITY_END Choose One - Gain 2 Mana Crystals; or Draw 3 cards. 54 | Preparation NAME_END -1 ATK_END -1 DEF_END 0 COST_END -1 DUR_END Spell TYPE_END Rogue PLAYER_CLS_END NIL RACE_END Epic RARITY_END The next spell you cast this turn costs (3) less. 55 | Repentance NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Paladin PLAYER_CLS_END NIL RACE_END Common RARITY_END Secret: When your opponent plays a minion, reduce its Health to 1. 56 | Shadow of Nothing NAME_END 0 ATK_END 1 DEF_END 0 COST_END -1 DUR_END Minion TYPE_END Priest PLAYER_CLS_END NIL RACE_END Epic RARITY_END Mindgames whiffed! Your opponent had no minions! 57 | Slam NAME_END -1 ATK_END -1 DEF_END 2 COST_END -1 DUR_END Spell TYPE_END Warrior PLAYER_CLS_END NIL RACE_END Common RARITY_END Deal $2 damage to a minion. If it survives, draw a card. 58 | Spellbreaker NAME_END 4 ATK_END 3 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Battlecry: Silence a minion. 59 | Sunfury Protector NAME_END 2 ATK_END 3 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END Battlecry: Give adjacent minions Taunt. 60 | Tinkmaster Overspark NAME_END 3 ATK_END 3 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Legendary RARITY_END Battlecry: Transform another random minion into a 5/5 Devilsaur or a 1/1 Squirrel. 61 | Vaporize NAME_END -1 ATK_END -1 DEF_END 3 COST_END -1 DUR_END Spell TYPE_END Mage PLAYER_CLS_END NIL RACE_END Rare RARITY_END Secret: When a minion attacks your hero, destroy it. 62 | Worgen Infiltrator NAME_END 2 ATK_END 1 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Stealth 63 | Blackwing Corruptor NAME_END 5 ATK_END 4 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Battlecry: If you're holding a Dragon, deal 3 damage. 64 | Drakonid Crusher NAME_END 6 ATK_END 6 DEF_END 6 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Dragon RACE_END Common RARITY_END Battlecry: If your opponent has 15 or less Health, gain +3/+3. 65 | Imp NAME_END 1 ATK_END 1 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Warlock PLAYER_CLS_END Demon RACE_END NIL RARITY_END NIL 66 | Twilight Whelp NAME_END 2 ATK_END 1 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Priest PLAYER_CLS_END Dragon RACE_END Common RARITY_END Battlecry: If you're holding a Dragon, gain +2 Health. 67 | -------------------------------------------------------------------------------- /hearthstone_preprocess/dev_hs.out: -------------------------------------------------------------------------------- 1 | class AssassinsBlade(WeaponCard):§ def __init__(self):§ super().__init__("Assassin's Blade", 5, CHARACTER_CLASS.ROGUE, CARD_RARITY.COMMON)§§ def create_weapon(self, player):§ return Weapon(3, 4)§ 2 | class BoulderfistOgre(MinionCard):§ def __init__(self):§ super().__init__("Boulderfist Ogre", 6, CHARACTER_CLASS.ALL, CARD_RARITY.FREE)§§ def create_minion(self, player):§ return Minion(6, 7)§ 3 | class DeadlyPoison(SpellCard):§ def __init__(self):§ super().__init__("Deadly Poison", 1, CHARACTER_CLASS.ROGUE, CARD_RARITY.FREE)§§ def use(self, player, game):§ super().use(player, game)§§ player.weapon.base_attack += 2§ player.hero.change_temp_attack(2)§§ def can_use(self, player, game):§ return super().can_use(player, game) and player.weapon is not None§ 4 | class FireElemental(MinionCard):§ def __init__(self):§ super().__init__("Fire Elemental", 6, CHARACTER_CLASS.SHAMAN, CARD_RARITY.COMMON, battlecry=Battlecry(Damage(3), CharacterSelector(players=BothPlayer(), picker=UserPicker())))§§ def create_minion(self, player):§ return Minion(6, 5)§ 5 | class GnomishInventor(MinionCard):§ def __init__(self):§ super().__init__("Gnomish Inventor", 4, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, battlecry=Battlecry(Draw(), PlayerSelector()))§§ def create_minion(self, player):§ return Minion(2, 4)§ 6 | class HeroicStrike(SpellCard):§ def __init__(self):§ super().__init__("Heroic Strike", 2, CHARACTER_CLASS.WARRIOR, CARD_RARITY.FREE)§§ def use(self, player, game):§ super().use(player, game)§ player.hero.change_temp_attack(4)§ 7 | class IronbarkProtector(MinionCard):§ def __init__(self):§ super().__init__("Ironbark Protector", 8, CHARACTER_CLASS.DRUID, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(8, 8, taunt=True)§ 8 | class MarkOfTheWild(SpellCard):§ def __init__(self):§ super().__init__("Mark of the Wild", 2, CHARACTER_CLASS.DRUID, CARD_RARITY.FREE, target_func=hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§ self.target.change_attack(2)§ self.target.increase_health(2)§ self.target.taunt = True§ 9 | class MultiShot(SpellCard):§ def __init__(self):§ super().__init__("Multi-Shot", 4, CHARACTER_CLASS.HUNTER, CARD_RARITY.FREE)§§ def use(self, player, game):§ super().use(player, game)§§ targets = copy.copy(game.other_player.minions)§ for i in range(0, 2):§ target = game.random_choice(targets)§ targets.remove(target)§ target.damage(player.effective_spell_damage(3), self)§§ def can_use(self, player, game):§ return super().can_use(player, game) and len(game.other_player.minions) >= 2§ 10 | class PowerWordShield(SpellCard):§ def __init__(self):§ super().__init__("Power Word: Shield", 1, CHARACTER_CLASS.PRIEST, CARD_RARITY.FREE, target_func=hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§§ self.target.increase_health(2)§ player.draw()§ 11 | class SenjinShieldmasta(MinionCard):§ def __init__(self):§ super().__init__("Sen'jin Shieldmasta", 4, CHARACTER_CLASS.ALL, CARD_RARITY.FREE)§§ def create_minion(self, player):§ return Minion(3, 5, taunt=True)§ 12 | class SinisterStrike(SpellCard):§ def __init__(self):§ super().__init__("Sinister Strike", 1, CHARACTER_CLASS.ROGUE, CARD_RARITY.FREE)§§ def use(self, player, game):§ super().use(player, game)§§ game.other_player.hero.damage(player.effective_spell_damage(3), self)§ 13 | class Succubus(MinionCard):§ def __init__(self):§ super().__init__("Succubus", 2, CHARACTER_CLASS.WARLOCK, CARD_RARITY.FREE, minion_type=MINION_TYPE.DEMON, battlecry=Battlecry(Discard(), PlayerSelector()))§§ def create_minion(self, player):§ return Minion(4, 3)§ 14 | class WarGolem(MinionCard):§ def __init__(self):§ super().__init__("War Golem", 7, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(7, 7)§ 15 | class Acidmaw(MinionCard):§ def __init__(self):§ super().__init__("Acidmaw", 7, CHARACTER_CLASS.HUNTER, CARD_RARITY.LEGENDARY, minion_type=MINION_TYPE.BEAST)§§ def create_minion(self, player):§ return Minion(4, 2, effects=[Effect(CharacterDamaged(MinionIsNotTarget(), BothPlayer()), [ActionTag(Kill(), TargetSelector())])])§ 16 | class Boar(MinionCard):§ def __init__(self):§ super().__init__("Boar", 1, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, False, minion_type=MINION_TYPE.BEAST)§§ def create_minion(self, player):§ return Minion(1, 1)§ 17 | class AnodizedRoboCub(MinionCard):§ def __init__(self):§ super().__init__("Anodized Robo Cub", 2, CHARACTER_CLASS.DRUID, CARD_RARITY.COMMON, minion_type=MINION_TYPE.MECH, choices=[Choice(AttackMode(), Give([Buff(ChangeAttack(1))]), SelfSelector()), Choice(TankMode(), Give([Buff(ChangeHealth(1))]), SelfSelector())])§§ def create_minion(self, player):§ return Minion(2, 2, taunt=True)§ 18 | class BurrowingMine(SpellCard):§ def __init__(self):§ super().__init__("Burrowing Mine", 0, CHARACTER_CLASS.WARRIOR, CARD_RARITY.COMMON, False, effects=[Effect(Drawn(), ActionTag(Damage(10), HeroSelector())), Effect(Drawn(), ActionTag(Discard(query=CardQuery(source=CARD_SOURCE.LAST_DRAWN)), PlayerSelector())), Effect(Drawn(), ActionTag(Draw(), PlayerSelector()))])§§ def use(self, player, game):§ super().use(player, game)§ 19 | class Crackle(SpellCard):§ def __init__(self):§ super().__init__("Crackle", 2, CHARACTER_CLASS.SHAMAN, CARD_RARITY.COMMON, target_func=hearthbreaker.targeting.find_spell_target, overload=1)§§ def use(self, player, game):§ super().use(player, game)§§ self.target.damage(player.effective_spell_damage(game.random_amount(3, 6)), self)§ 20 | class EmergencyCoolant(SpellCard):§ def __init__(self):§ super().__init__("Emergency Coolant", 1, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, False, target_func=hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§ self.target.add_buff(Buff(Frozen()))§ 21 | class FlyingMachine(MinionCard):§ def __init__(self):§ super().__init__("Flying Machine", 3, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, minion_type=MINION_TYPE.MECH)§§ def create_minion(self, player):§ return Minion(1, 4, windfury=True)§ 22 | class GoblinAutoBarber(MinionCard):§ def __init__(self):§ super().__init__("Goblin Auto-Barber", 2, CHARACTER_CLASS.ROGUE, CARD_RARITY.COMMON, minion_type=MINION_TYPE.MECH, battlecry=Battlecry(IncreaseWeaponAttack(1), WeaponSelector()))§§ def create_minion(self, player):§ return Minion(3, 2)§ 23 | class IronSensei(MinionCard):§ def __init__(self):§ super().__init__("Iron Sensei", 3, CHARACTER_CLASS.ROGUE, CARD_RARITY.RARE, minion_type=MINION_TYPE.MECH)§§ def create_minion(self, player):§ return Minion(2, 2, effects=[Effect(TurnEnded(), ActionTag(Give([Buff(ChangeAttack(2)), Buff(ChangeHealth(2))]), MinionSelector(IsType(MINION_TYPE.MECH), picker=RandomPicker())))])§ 24 | class MalGanis(MinionCard):§ def __init__(self):§ super().__init__("Mal'Ganis", 9, CHARACTER_CLASS.WARLOCK, CARD_RARITY.LEGENDARY, minion_type=MINION_TYPE.DEMON)§§ def create_minion(self, player):§ return Minion(9, 7, auras=[Aura(ChangeHealth(2), MinionSelector(IsType(MINION_TYPE.DEMON))), Aura(ChangeAttack(2), MinionSelector(IsType(MINION_TYPE.DEMON))), Aura(Immune(), HeroSelector())])§ 25 | class MistressOfPain(MinionCard):§ def __init__(self):§ super().__init__("Mistress of Pain", 2, CHARACTER_CLASS.WARLOCK, CARD_RARITY.RARE, minion_type=MINION_TYPE.DEMON)§§ def create_minion(self, player):§ return Minion(1, 4, effects=[Effect(DidDamage(), ActionTag(Heal(EventValue()), HeroSelector()))])§ 26 | class Powermace(WeaponCard):§ def __init__(self):§ super().__init__("Powermace", 3, CHARACTER_CLASS.SHAMAN, CARD_RARITY.RARE)§§ def create_weapon(self, player):§ return Weapon(3, 2, deathrattle=Deathrattle(Give([Buff(ChangeHealth(2)), Buff(ChangeAttack(2))]), MinionSelector(IsType(MINION_TYPE.MECH), picker=RandomPicker())))§ 27 | class ScrewjankClunker(MinionCard):§ def __init__(self):§ super().__init__("Screwjank Clunker", 4, CHARACTER_CLASS.WARRIOR, CARD_RARITY.RARE, minion_type=MINION_TYPE.MECH, battlecry=Battlecry(Give([Buff(ChangeHealth(2)), Buff(ChangeAttack(2))]), MinionSelector(IsType(MINION_TYPE.MECH), picker=UserPicker())))§§ def create_minion(self, player):§ return Minion(2, 5)§ 28 | class SneedsOldShredder(MinionCard):§ def __init__(self):§ super().__init__("Sneed's Old Shredder", 8, CHARACTER_CLASS.ALL, CARD_RARITY.LEGENDARY, minion_type=MINION_TYPE.MECH)§§ def create_minion(self, player):§ return Minion(5, 7, deathrattle=Deathrattle(Summon(CardQuery(conditions=[IsRarity(CARD_RARITY.LEGENDARY), IsMinion()])), PlayerSelector()))§ 29 | class Toshley(MinionCard):§ def __init__(self):§ from hearthbreaker.cards.spells.neutral import spare_part_list§ super().__init__("Toshley", 6, CHARACTER_CLASS.ALL, CARD_RARITY.LEGENDARY, battlecry=Battlecry(AddCard(CardQuery(source=CARD_SOURCE.LIST, source_list=spare_part_list)), PlayerSelector()))§§ def create_minion(self, player):§ from hearthbreaker.cards.spells.neutral import spare_part_list§ return Minion(5, 7, deathrattle=Deathrattle(AddCard(CardQuery(source=CARD_SOURCE.LIST, source_list=spare_part_list)), PlayerSelector()))§ 30 | class Warbot(MinionCard):§ def __init__(self):§ super().__init__("Warbot", 1, CHARACTER_CLASS.WARRIOR, CARD_RARITY.COMMON, minion_type=MINION_TYPE.MECH)§§ def create_minion(self, player):§ return Minion(1, 3, enrage=[Aura(ChangeAttack(1), SelfSelector())])§ 31 | class Deathlord(MinionCard):§ def __init__(self):§ super().__init__("Deathlord", 3, CHARACTER_CLASS.ALL, CARD_RARITY.RARE)§§ def create_minion(self, player):§ return Minion(2, 8, taunt=True, deathrattle=Deathrattle(Summon(CardQuery(conditions=[IsMinion()], source=CARD_SOURCE.MY_DECK)), PlayerSelector(EnemyPlayer())))§ 32 | class NerubarWeblord(MinionCard):§ def __init__(self):§ super().__init__("Nerub'ar Weblord", 2, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(1, 4, auras=[Aura(ManaChange(2), CardSelector(BothPlayer(), HasBattlecry()))])§ 33 | class SpectralKnight(MinionCard):§ def __init__(self):§ super().__init__("Spectral Knight", 5, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(4, 6, spell_targetable=False)§ 34 | class WailingSoul(MinionCard):§ def __init__(self):§ super().__init__("Wailing Soul", 4, CHARACTER_CLASS.ALL, CARD_RARITY.RARE, battlecry=Battlecry(Silence(), MinionSelector()))§§ def create_minion(self, player):§ return Minion(3, 5)§ 35 | class AmaniBerserker(MinionCard):§ def __init__(self):§ super().__init__("Amani Berserker", 2, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(2, 3, enrage=[Aura(ChangeAttack(3), SelfSelector())])§ 36 | class ArchmageAntonidas(MinionCard):§ def __init__(self):§ super().__init__("Archmage Antonidas", 7, CHARACTER_CLASS.MAGE, CARD_RARITY.LEGENDARY)§§ def create_minion(self, player):§ return Minion(5, 7, effects=[Effect(SpellCast(), ActionTag(AddCard(hearthbreaker.cards.Fireball()), PlayerSelector()))])§ 37 | class Bananas(SpellCard):§ def __init__(self):§ super().__init__("Bananas", 1, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, False, hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§ self.target.change_attack(1)§ self.target.increase_health(1)§ 38 | class BlessedChampion(SpellCard):§ def __init__(self):§ super().__init__("Blessed Champion", 5, CHARACTER_CLASS.PALADIN, CARD_RARITY.RARE, target_func=hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§ self.target.add_buff(Buff(DoubleAttack()))§ 39 | class CabalShadowPriest(MinionCard):§ def __init__(self):§ super().__init__("Cabal Shadow Priest", 6, CHARACTER_CLASS.PRIEST, CARD_RARITY.EPIC, battlecry=Battlecry(Steal(), MinionSelector(AttackLessThanOrEqualTo(2), players=EnemyPlayer(), picker=UserPicker())))§§ def create_minion(self, player):§ return Minion(4, 5)§ 40 | class ConeOfCold(SpellCard):§ def __init__(self):§ super().__init__("Cone of Cold", 4, CHARACTER_CLASS.MAGE, CARD_RARITY.COMMON, target_func=hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§§ self.target.add_buff(Buff(Frozen()))§ index = self.target.index§§ if self.target.index < len(self.target.player.minions) - 1:§ minion = self.target.player.minions[index + 1]§ minion.damage(player.effective_spell_damage(1), self)§ minion.add_buff(Buff(Frozen()))§§ self.target.damage(player.effective_spell_damage(1), self)§§ if self.target.index > 0:§ minion = self.target.player.minions[index - 1]§ minion.damage(player.effective_spell_damage(1), self)§ minion.add_buff(Buff(Frozen()))§ 41 | class DefenderOfArgus(MinionCard):§ def __init__(self):§ super().__init__("Defender of Argus", 4, CHARACTER_CLASS.ALL, CARD_RARITY.RARE, battlecry=Battlecry(Give([ Buff(Taunt()), Buff(ChangeAttack(1)), Buff(ChangeHealth(1)) ]), MinionSelector(Adjacent())))§§ def create_minion(self, player):§ return Minion(2, 3)§ 42 | class Doomhammer(WeaponCard):§ def __init__(self):§ super().__init__("Doomhammer", 5, CHARACTER_CLASS.SHAMAN, CARD_RARITY.EPIC, overload=2)§§ def create_weapon(self, player):§ return Weapon(2, 8, buffs=[Buff(Windfury())])§ 43 | class EarthShock(SpellCard):§ def __init__(self):§ super().__init__("Earth Shock", 1, CHARACTER_CLASS.SHAMAN, CARD_RARITY.COMMON, target_func=hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§§ self.target.silence()§ self.target.damage(player.effective_spell_damage(1), self)§ 44 | class EyeForAnEye(SecretCard):§ def __init__(self):§ super().__init__("Eye for an Eye", 1, CHARACTER_CLASS.PALADIN, CARD_RARITY.COMMON)§§ def _reveal(self, character, attacker, amount):§ if character.is_hero():§ character.player.opponent.hero.damage(amount, self)§ super().reveal()§§ def activate(self, player):§ player.bind("character_damaged", self._reveal)§§ def deactivate(self, player):§ player.unbind("character_damaged", self._reveal)§ 45 | class Flare(SpellCard):§ def __init__(self):§ super().__init__("Flare", 2, CHARACTER_CLASS.HUNTER, CARD_RARITY.RARE)§§ def use(self, player, game):§ super().use(player, game)§ for minion in hearthbreaker.targeting.find_minion_spell_target(game, lambda m: m.stealth):§ minion.stealth = False§§ for secret in game.other_player.secrets:§ secret.deactivate(game.other_player)§§ game.other_player.secrets = []§ player.draw()§ 46 | class Gorehowl(WeaponCard):§ def __init__(self):§ super().__init__("Gorehowl", 7, CHARACTER_CLASS.WARRIOR, CARD_RARITY.EPIC)§§ def create_weapon(self, player):§ return Weapon(7, 1, effects=[Effect(CharacterAttack(And(IsHero(), TargetIsMinion())), [ActionTag(IncreaseDurability(), WeaponSelector()), ActionTag(IncreaseWeaponAttack(-1), WeaponSelector()), ActionTag(Give(BuffUntil(ChangeAttack(1), AttackCompleted())), HeroSelector())])])§ 47 | class Hound(MinionCard):§ def __init__(self):§ super().__init__("Hound", 1, CHARACTER_CLASS.HUNTER, CARD_RARITY.COMMON, False, minion_type=MINION_TYPE.BEAST)§§ def create_minion(self, player):§ return Minion(1, 1, charge=True)§ 48 | class InjuredBlademaster(MinionCard):§ def __init__(self):§ super().__init__("Injured Blademaster", 3, CHARACTER_CLASS.ALL, CARD_RARITY.RARE, battlecry=Battlecry(Damage(4), SelfSelector()))§§ def create_minion(self, player):§ return Minion(4, 7)§ 49 | class KnifeJuggler(MinionCard):§ def __init__(self):§ super().__init__("Knife Juggler", 2, CHARACTER_CLASS.ALL, CARD_RARITY.RARE)§§ def create_minion(self, player):§ return Minion(3, 2, effects=[Effect(AfterAdded(), ActionTag(Damage(1), CharacterSelector(players=EnemyPlayer(), picker=RandomPicker(), condition=None)))])§ 50 | class Lightwell(MinionCard):§ def __init__(self):§ super().__init__("Lightwell", 2, CHARACTER_CLASS.PRIEST, CARD_RARITY.RARE)§§ def create_minion(self, player):§ return Minion(0, 5, effects=[Effect(TurnStarted(), ActionTag(Heal(3), CharacterSelector(condition=IsDamaged(), picker=RandomPicker())))])§ 51 | class ManaWyrm(MinionCard):§ def __init__(self):§ super().__init__("Mana Wyrm", 1, CHARACTER_CLASS.MAGE, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(1, 3, effects=[Effect(SpellCast(), ActionTag(Give(ChangeAttack(1)), SelfSelector()))])§ 52 | class MogushanWarden(MinionCard):§ def __init__(self):§ super().__init__("Mogu'shan Warden", 4, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(1, 7, taunt=True)§ 53 | class Nourish(SpellCard):§ def __init__(self):§ super().__init__("Nourish", 5, CHARACTER_CLASS.DRUID, CARD_RARITY.RARE)§§ def use(self, player, game):§ super().use(player, game)§§ class Gain2(ChoiceCard):§§ def __init__(self):§ super().__init__("Gain 2 mana crystals", 0, CHARACTER_CLASS.DRUID, CARD_RARITY.COMMON, False)§§ def use(self, player, game):§ if player.max_mana < 8:§ player.max_mana += 2§ player.mana += 2§ else:§ player.max_mana = 10§ player.mana += 2§§ class Draw3(ChoiceCard):§§ def __init__(self):§ super().__init__("Draw three cards", 0, CHARACTER_CLASS.DRUID, CARD_RARITY.COMMON, False)§§ def use(self, player, game):§ player.draw()§ player.draw()§ player.draw()§§ option = player.agent.choose_option([Gain2(), Draw3()], player)§ option.use(player, game)§ 54 | class Preparation(SpellCard):§ def __init__(self):§ super().__init__("Preparation", 0, CHARACTER_CLASS.ROGUE, CARD_RARITY.EPIC)§§ def use(self, player, game):§ super().use(player, game)§ player.add_aura(AuraUntil(ManaChange(-3), CardSelector(condition=IsSpell()), SpellCast()))§ 55 | class Repentance(SecretCard):§ def __init__(self):§ super().__init__("Repentance", 1, CHARACTER_CLASS.PALADIN, CARD_RARITY.COMMON)§§ def _reveal(self, minion):§§ minion.set_health_to(1)§ super().reveal()§§ def activate(self, player):§ player.game.current_player.bind("minion_played", self._reveal)§§ def deactivate(self, player):§ player.game.current_player.unbind("minion_played", self._reveal)§ 56 | class ShadowOfNothing(MinionCard):§ def __init__(self):§ super().__init__("Shadow of Nothing", 0, CHARACTER_CLASS.PRIEST, CARD_RARITY.EPIC, False)§§ def create_minion(self, p):§ return Minion(0, 1)§ 57 | class Slam(SpellCard):§ def __init__(self):§ super().__init__("Slam", 2, CHARACTER_CLASS.WARRIOR, CARD_RARITY.COMMON, target_func=hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§ if self.target.health > player.effective_spell_damage(2) or self.target.divine_shield:§ self.target.damage(player.effective_spell_damage(2), self)§ player.draw()§ else:§ self.target.damage(player.effective_spell_damage(2), self)§ 58 | class Spellbreaker(MinionCard):§ def __init__(self):§ super().__init__("Spellbreaker", 4, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, minion_type=MINION_TYPE.NONE, battlecry=Battlecry(Silence(), MinionSelector(players=BothPlayer(), picker=UserPicker())))§§ def create_minion(self, player):§ return Minion(4, 3)§ 59 | class SunfuryProtector(MinionCard):§ def __init__(self):§ super().__init__("Sunfury Protector", 2, CHARACTER_CLASS.ALL, CARD_RARITY.RARE, battlecry=Battlecry(Give(Buff(Taunt())), MinionSelector(Adjacent())))§§ def create_minion(self, player):§ return Minion(2, 3)§ 60 | class TinkmasterOverspark(MinionCard):§ def __init__(self):§ super().__init__("Tinkmaster Overspark", 3, CHARACTER_CLASS.ALL, CARD_RARITY.LEGENDARY, battlecry=Battlecry(Transform(CardQuery(source=CARD_SOURCE.LIST, source_list=[Devilsaur(), Squirrel()])), MinionSelector(players=BothPlayer(), picker=RandomPicker())))§§ def create_minion(self, player):§ return Minion(3, 3)§ 61 | class Vaporize(SecretCard):§ def __init__(self):§ super().__init__("Vaporize", 3, CHARACTER_CLASS.MAGE, CARD_RARITY.RARE)§§ def _reveal(self, attacker, target):§ if target is self.player.hero and attacker.is_minion() and not attacker.removed:§ attacker.die(self)§ attacker.game.check_delayed()§ super().reveal()§§ def activate(self, player):§ player.opponent.bind("character_attack", self._reveal)§§ def deactivate(self, player):§ player.opponent.unbind("character_attack", self._reveal)§ 62 | class WorgenInfiltrator(MinionCard):§ def __init__(self):§ super().__init__("Worgen Infiltrator", 1, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(2, 1, stealth=True)§ 63 | class BlackwingCorruptor(MinionCard):§ def __init__(self):§ super().__init__("Blackwing Corruptor", 5, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, battlecry=Battlecry(Damage(3), CharacterSelector(players=BothPlayer(), picker=UserPicker()), GreaterThan(Count(CardSelector(condition=IsType(MINION_TYPE.DRAGON))), value=0)))§§ def create_minion(self, player):§ return Minion(5, 4)§ 64 | class DrakonidCrusher(MinionCard):§ def __init__(self):§ super().__init__("Drakonid Crusher", 6, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, minion_type=MINION_TYPE.DRAGON, battlecry=(Battlecry(Give([Buff(ChangeAttack(3)), Buff(ChangeHealth(3))]), SelfSelector(), Not(GreaterThan(Attribute('health', HeroSelector(EnemyPlayer())), value=15)))))§§ def create_minion(self, player):§ return Minion(6, 6)§ 65 | class Imp(MinionCard):§ def __init__(self):§ super().__init__("Imp", 1, CHARACTER_CLASS.ALL, CARD_RARITY.RARE, False, minion_type=MINION_TYPE.DEMON)§§ def create_minion(self, player):§ return Minion(1, 1)§ 66 | class TwilightWhelp(MinionCard):§ def __init__(self):§ super().__init__("Twilight Whelp", 1, CHARACTER_CLASS.PRIEST, CARD_RARITY.COMMON, minion_type=MINION_TYPE.DRAGON, battlecry=(Battlecry(Give(Buff(ChangeHealth(2))), SelfSelector(), GreaterThan(Count(CardSelector(condition=IsType(MINION_TYPE.DRAGON))), value=0))))§§ def create_minion(self, player):§ return Minion(2, 1)§ 67 | -------------------------------------------------------------------------------- /hearthstone_preprocess/py3_asdl.simplified.txt: -------------------------------------------------------------------------------- 1 | ## ASDL's six builtin types are 2 | identifier, int, string, bytes, object, singleton 3 | 4 | 5 | mod = Module(stmt* body) 6 | | Interactive(stmt* body) 7 | | Expression(expr body) 8 | 9 | stmt = FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list, expr? returns) 10 | | ClassDef(identifier name, expr* bases, keyword* keywords, stmt* body, expr* decorator_list) 11 | | Return(expr? value) 12 | 13 | | Delete(expr* targets) 14 | | Assign(expr* targets, expr value) 15 | | AugAssign(expr target, operator op, expr value) 16 | 17 | | For(expr target, expr iter, stmt* body, stmt* orelse) 18 | | While(expr test, stmt* body, stmt* orelse) 19 | | If(expr test, stmt* body, stmt* orelse) 20 | | With(withitem* items, stmt* body) 21 | 22 | | Raise(expr? exc, expr? cause) 23 | | Try(stmt* body, excepthandler* handlers, stmt* orelse, stmt* finalbody) 24 | | Assert(expr test, expr? msg) 25 | 26 | | Import(alias* names) 27 | | ImportFrom(identifier? module, alias* names, int? level) 28 | 29 | | Global(identifier* names) 30 | | Nonlocal(identifier* names) 31 | | Expr(expr value) 32 | | Pass 33 | | Break 34 | | Continue 35 | 36 | 37 | expr = BoolOp(boolop op, expr* values) 38 | | BinOp(expr left, operator op, expr right) 39 | | UnaryOp(unaryop op, expr operand) 40 | | Lambda(arguments args, expr body) 41 | | IfExp(expr test, expr body, expr orelse) 42 | | Dict(expr* keys, expr* values) 43 | | Set(expr* elts) 44 | | ListComp(expr elt, comprehension* generators) 45 | | SetComp(expr elt, comprehension* generators) 46 | | DictComp(expr key, expr value, comprehension* generators) 47 | | GeneratorExp(expr elt, comprehension* generators) 48 | | Await(expr value) 49 | | Yield(expr? value) 50 | | YieldFrom(expr value) 51 | | Compare(expr left, cmpop* ops, expr* comparators) 52 | | Call(expr func, expr* args, keyword* keywords) 53 | ## a number as a PyObject. 54 | | Num(object n) 55 | ## need to specify raw, unicode, etc? 56 | | Str(string s) 57 | | Bytes(bytes s) 58 | | NameConstant(singleton value) 59 | | Ellipsis 60 | 61 | ## the following expression can appear in assignment context 62 | | Attribute(expr value, identifier attr) 63 | | Subscript(expr value, slice slice) 64 | | Starred(expr value) 65 | | Name(identifier id) 66 | | List(expr* elts) 67 | | Tuple(expr* elts) 68 | 69 | expr_context = Load | Store | Del | AugLoad | AugStore | Param 70 | 71 | slice = Slice(expr? lower, expr? upper, expr? step) 72 | | ExtSlice(slice* dims) 73 | | Index(expr value) 74 | 75 | boolop = And | Or 76 | 77 | operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift 78 | | RShift | BitOr | BitXor | BitAnd | FloorDiv 79 | 80 | unaryop = Invert | Not | UAdd | USub 81 | 82 | cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn 83 | 84 | comprehension = comprehension(expr target, expr iter, expr* ifs) 85 | 86 | excepthandler = ExceptHandler(expr? type, identifier? name, stmt* body) 87 | 88 | arguments = arguments(arg* args, arg? vararg, arg* kwonlyargs, expr* kw_defaults, arg? kwarg, expr* defaults) 89 | 90 | arg = arg(identifier arg, expr? annotation) 91 | 92 | keyword = keyword(identifier? arg, expr value) 93 | 94 | alias = alias(identifier name, identifier? asname) 95 | 96 | withitem = withitem(expr context_expr, expr? optional_vars) -------------------------------------------------------------------------------- /hearthstone_preprocess/rule.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone_preprocess/rule.pkl -------------------------------------------------------------------------------- /hearthstone_preprocess/rulead.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zysszy/TreeGen-Pytorch/d2f8d89f9e2570951913ccf0ac1547501d7d741e/hearthstone_preprocess/rulead.pkl -------------------------------------------------------------------------------- /hearthstone_preprocess/runComplex.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | sets = ["train", "test", "dev"] 3 | import re 4 | 5 | import ast 6 | import sys 7 | #reload(sys) 8 | #sys.setdefaultencoding('utf-8') 9 | class CodeVisitor: 10 | def __init__(self, name, edge=""): 11 | self.child = [] 12 | self.father = None 13 | self.edge = edge 14 | self.name = name 15 | def visitNode(node): 16 | #print(node[0], type(node).__name__) 17 | rootNode = CodeVisitor(type(node).__name__) 18 | tmpNode = rootNode 19 | for x in ast.iter_fields(node): 20 | if str(x[0]) == "ctx": 21 | continue 22 | #print(str(x[1])) 23 | if str(x[0]) == "vararg" or str(x[0]) == "kwarg" or str(x[0]) == 'defaults' or str(x[0]) == "decorator_list" or str(x[0]) == "starargs" or str(x[0]) == "kwargs": 24 | continue 25 | rootNode = tmpNode 26 | currnode = CodeVisitor(x[0]) 27 | rootNode.child.append(currnode) 28 | currnode.father = rootNode 29 | rootNode = currnode 30 | if isinstance(x[1], list): 31 | if len(x[1]) == 0: 32 | tmpnode = CodeVisitor("empty") 33 | rootNode.child.append(tmpnode) 34 | tmpnode.father = rootNode 35 | tmpnode.edge = x[0] 36 | for obj in x[1]: 37 | if isinstance(obj, ast.AST): 38 | tmpnode = visitNode(obj) 39 | rootNode.child.append(tmpnode) 40 | tmpnode.father = rootNode 41 | tmpnode.edge = x[0] 42 | elif isinstance(x[1], (int, complex)) or type(x[1]).__name__ == "float" or type(x[1]).__name__ == "long": 43 | #print(x[1], isinstance(x[1], bytes)) 44 | tmpStr = str(x[1]).replace("\n", "").replace("\r", "") 45 | if len(tmpStr.split()) == 0: 46 | tmpStr = "" 47 | if tmpStr[-1] == "^": 48 | tmpStr += "<>" 49 | tmpnode = CodeVisitor(tmpStr) 50 | tmpnode.father = rootNode 51 | rootNode.child.append(tmpnode) 52 | tmpnode.edge = x[0] 53 | elif isinstance(x[1], str) or type(x[1]).__name__ == "unicode": 54 | tmpStr = x[1] 55 | tmpStr = tmpStr.replace("\'", "").replace(" ", "").replace("-", "").replace(":", "") 56 | #tmpStr = "" if " " in x[1] else x[1].replace("\n", "").replace("\r", "") 57 | if "\t" in tmpStr: 58 | tmpStr = "" 59 | if len(tmpStr.split()) == 0: 60 | tmpStr = "" 61 | if tmpStr[-1] == "^": 62 | tmpStr += "<>" 63 | '''if x[0] == 'name': 64 | s = "namestr" 65 | else: 66 | s = x[0] 67 | tmpnodef = CodeVisitor(s) 68 | tmpnodef.father = rootNode 69 | rootNode.child.append(tmpnodef)''' 70 | tmpnodef = rootNode 71 | tmpnode = CodeVisitor(tmpStr) 72 | tmpnode.father = tmpnodef 73 | tmpnodef.child.append(tmpnode) 74 | tmpnode.edge = x[0] 75 | elif isinstance(x[1], ast.AST): 76 | tmpnode = visitNode(x[1]) 77 | rootNode.child.append(tmpnode) 78 | tmpnode.father = rootNode 79 | tmpnode.edge = x[0] 80 | elif not x[1]: 81 | continue 82 | else: 83 | print(type(x[1]), x[0]) 84 | sys.exit(1) 85 | return tmpNode 86 | def parseAst(codeStr): 87 | root_node = ast.parse(codeStr) 88 | #print(ast.dump(root_node)) 89 | return visitNode(root_node) 90 | def printTree(node): 91 | ans = "" 92 | ans += node.name + "\t" 93 | for x in node.child: 94 | ans += printTree(x) 95 | ans += "^" + "\t" 96 | return ans 97 | def tokenize_for_bleu_eval(code): 98 | code = re.sub(r'([^A-Za-z0-9_])', r' \1 ', code) 99 | #code = re.sub(r'([a-z])([A-Z])', r'\1 \2', code) 100 | code = re.sub(r'\s+', ' ', code) 101 | code = code.replace('"', '`') 102 | code = code.replace('\'', '`') 103 | tokens = [t for t in code.split(' ') if t] 104 | return tokens 105 | grammar_file = 'py3_asdl.simplified.txt' 106 | #asdl_text = open(grammar_file).read() 107 | #grammar = ASDLGrammar.from_text(asdl_text) 108 | #transition_system = Python3TransitionSystem(grammar) 109 | maxlen = [] 110 | for x in sets: 111 | fout = open(x + "_hs.out", "r") 112 | fin = open(x + "_hs.in", "r") 113 | wf = open(x + ".txt", "w") 114 | linesin = fin.readlines() 115 | for i, y in enumerate(fout): 116 | code = y.strip().replace("§", "\n") 117 | try: 118 | parsed_ast = ast.parse(code)#parse(code, error_recovery=True, version="2.7") 119 | root = visitNode(parsed_ast) 120 | #print(ast.dump(parsed_ast)) 121 | except Exception as e: 122 | print(e) 123 | print(code) 124 | assert(0) 125 | nl = linesin[i] 126 | i = nl.find("NAME_END") 127 | name = nl[:i].strip() 128 | nl = nl[i:] 129 | nl = tokenize_for_bleu_eval(code)[1] + " " + nl 130 | nl = nl.replace("", " ").replace("", " ").replace("", " 在 ").replace("", " 见 ").replace("+", " + ").replace("/", " / ").replace(":", "").replace("在", "").replace(".", " . ").replace("(", " ( ").replace(")", " ) ").replace("见", "").replace(";", " ; ").replace(",", " , ").replace("#", " # ").replace("$", " $ ") 131 | lst = [] 132 | for x in nl.split(): 133 | if "-" in x and x[0] != "-": 134 | lst += x.replace("-", " - ").split() 135 | else: 136 | lst.append(x) 137 | nl = "\t".join(lst) 138 | maxlen.append(len(nl.split("\t"))) 139 | wf.write(nl + "\n") 140 | wf.write(printTree(root) + "\n") 141 | print(maxlen) 142 | -------------------------------------------------------------------------------- /hearthstone_preprocess/runall.sh: -------------------------------------------------------------------------------- 1 | python3 runComplex.py 2 | python3 solvetree.py 3 | bash cp.sh -------------------------------------------------------------------------------- /hearthstone_preprocess/solvetree.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import pickle 4 | import json 5 | import numpy as np 6 | lst = ["train", "dev", "test"] 7 | rules = {"pad":0} 8 | onelist =['body'] 9 | rulelist = [] 10 | fatherlist = [] 11 | fathername = [] 12 | depthlist = [] 13 | copynode = {} 14 | class Node: 15 | def __init__(self, name, s): 16 | self.name = name 17 | self.id = s 18 | self.father = None 19 | self.child = [] 20 | def parseTree(treestr): 21 | tokens = treestr.split() 22 | root = Node("Module", 0) 23 | currnode = root 24 | for i, x in enumerate(tokens[1:]): 25 | if x != "^": 26 | nnode = Node(x, i + 1) 27 | nnode.father = currnode 28 | currnode.child.append(nnode) 29 | currnode = nnode 30 | else: 31 | currnode = currnode.father 32 | return root 33 | maxnlnum = 40 34 | hascopy = {} 35 | def getcopyid(nls, name): 36 | global maxnlnum 37 | global hascopy 38 | lastcopyid = -1 39 | for i, x in enumerate(nls): 40 | if name.lower() == x.lower(): 41 | lastcopyid = i 42 | if i not in hascopy: 43 | hascopy[i] = 1 44 | return i + 10000 45 | if lastcopyid != -1: 46 | return lastcopyid + 10000 47 | return -1 48 | rulead = np.zeros([1772, 1772]) 49 | astnode = {"pad": 0, "Unknown": 1} 50 | def getRule(node, nls, currId, d): 51 | global rules 52 | global onelist 53 | global rulelist 54 | global fatherlist 55 | global depthlist 56 | global copynode 57 | global rulead 58 | if node.name == "str_": 59 | assert(len(node.child) == 1) 60 | if len(node.child) == 0: 61 | return [], [] 62 | if " -> End " not in rules: 63 | rules[" -> End "] = len(rules) 64 | return [rules[" -> End "]] 65 | child = node.child#sorted(node.child, key=lambda x:x.name) 66 | if len(node.child) == 1 and len(node.child[0].child) == 0: 67 | node.child[0].name = node.child[0].name.replace("!", "") 68 | copyid = getcopyid(nls, node.child[0].name) 69 | if len(node.child) == 1 and len(node.child[0].child) == 0 and copyid != -1: 70 | if len(node.child[0].child) != 0: 71 | print(node.child[0].name) 72 | copynode[node.name] = 1 73 | rulelist.append(copyid) 74 | fatherlist.append(currId) 75 | # rulead[rulelist[currId], 1771] = 1 76 | # rulead[1771, rulelist[currId]] = 1 77 | fathername.append(node.name) 78 | depthlist.append(d) 79 | currid = len(rulelist) - 1 80 | for x in child: 81 | getRule(x, nls, currId, d + 1) 82 | #rulelist.extend(a) 83 | #fatherlist.extend(b) 84 | else: 85 | if node.name not in onelist: 86 | rule = node.name + " -> " 87 | for x in child: 88 | rule += x.name + " " 89 | if rule in rules: 90 | rulelist.append(rules[rule]) 91 | else: 92 | rules[rule] = len(rules) 93 | rulelist.append(rules[rule]) 94 | fatherlist.append(currId) 95 | fathername.append(node.name) 96 | depthlist.append(d) 97 | # if currId != -1: 98 | # rulead[rulelist[currId], rulelist[-1]] = 1 99 | # rulead[rulelist[-1], rulelist[currId]] = 1 100 | # else: 101 | # rulead[770, rulelist[-1]] = 1 102 | # rulead[rulelist[-1], 770] = 1 103 | currid = len(rulelist) - 1 104 | for x in child: 105 | getRule(x, nls, currid, d + 1) 106 | else: 107 | #assert(0) 108 | for x in (child): 109 | rule = node.name + " -> " + x.name 110 | if rule in rules: 111 | rulelist.append(rules[rule]) 112 | else: 113 | rules[rule] = len(rules) 114 | rulelist.append(rules[rule]) 115 | # rulead[rulelist[currId], rulelist[-1]] = 1 116 | # rulead[rulelist[-1], rulelist[currId]] = 1 117 | fatherlist.append(currId) 118 | fathername.append(node.name) 119 | depthlist.append(d) 120 | getRule(x, nls, len(rulelist) - 1, d + 1) 121 | rule = node.name + " -> End " 122 | if rule in rules: 123 | rulelist.append(rules[rule]) 124 | else: 125 | rules[rule] = len(rules) 126 | rulelist.append(rules[rule]) 127 | rulead[rulelist[currId], rulelist[-1]] = 1 128 | rulead[rulelist[-1], rulelist[currId]] = 1 129 | fatherlist.append(currId) 130 | fathername.append(node.name) 131 | depthlist.append(d) 132 | '''if node.name == "root": 133 | print('rr') 134 | print('rr') 135 | print(rulelist)''' 136 | '''rule = " -> End " 137 | if rule in rules: 138 | rulelist.append(rules[rule]) 139 | else: 140 | rules[rule] = len(rules) 141 | rulelist.append(rules[rule])''' 142 | #return rulelist, fatherlistd 143 | def getTableName(f): 144 | global tablename 145 | lines = f.readlines() 146 | tabname = [] 147 | dbid = "" 148 | tabname = [] 149 | colnames = [] 150 | for i in range(len(lines)): 151 | if i == 0: 152 | nl = lines[i].strip().split() 153 | if i == 1: 154 | originnl = lines[i].strip().split() 155 | if i == 2: 156 | dbid = lines[i].strip() 157 | for i, x in enumerate(tablename[dbid]['table_names_original']): 158 | tabname.append(x.lower()) 159 | for j, y in enumerate(tablename[dbid]['column_names_original']): 160 | if y[0] == i: 161 | if y[1].lower() == "share": 162 | y[1] = "share_" 163 | colnames.append(y[1].lower()) 164 | return nl, originnl, tabname, dbid, colnames 165 | 166 | for x in lst: 167 | #inputdir = x + "_input/" 168 | #outputdir = x + "_output/" 169 | wf = open(x + "_process.txt", "w") 170 | f = open(x + ".txt", "r") 171 | lines = f.readlines() 172 | f.close() 173 | for i in tqdm(range(int(len(lines) / 2))): 174 | #fname = inputdir + str(i + 1) + ".txt" 175 | #ofname = outputdir + str(i + 1) + ".txt" 176 | nls = lines[2 * i].split("\t")#getTableName(f) 177 | asts = lines[2 * i + 1].strip() 178 | #wf.write(asts + "\n") 179 | hascopy = {} 180 | print(asts.split().count("^")) 181 | assert(len(asts.split()) == 2 * asts.split().count('^')) 182 | root = parseTree(asts) 183 | rulelist = [] 184 | fatherlist = [] 185 | fathername = [] 186 | depthlist = [] 187 | getRule(root, nls, -1, 2) 188 | wf.write(" ".join(nls)) 189 | s = "" 190 | for x in rulelist: 191 | s += str(x) + " " 192 | wf.write(s + "\n") 193 | s = "" 194 | for x in fatherlist: 195 | s += str(x) + " " 196 | wf.write(s + "\n") 197 | s = "" 198 | for x in depthlist: 199 | s += str(x) + " " 200 | wf.write(s + "\n") 201 | wf.write(" ".join(fathername) + "\n") 202 | 203 | #print(rules) 204 | #print(asts) 205 | wf.close() 206 | wf = open("rule.pkl", "wb") 207 | open("rulead.pkl", "wb").write(pickle.dumps(rulead)) 208 | #rules["start -> Module"] = len(rules) 209 | #rules["start -> copyword"] = len(rules) 210 | codead = np.zeros([565, 565]) 211 | for x in rules: 212 | lst = x.strip().lower().split() 213 | tmp = [lst[0]] + lst[2:] 214 | for y in tmp: 215 | if y not in astnode: 216 | astnode[y] = len(astnode) 217 | pid = astnode[lst[0]] 218 | for s in lst[2:]: 219 | tid = astnode[s] 220 | # codead[pid, tid] = 1 221 | # codead[tid, pid] = 1 222 | open("Code_Voc.pkl", "wb").write(pickle.dumps(astnode)) 223 | open("codead.pkl", "wb").write(pickle.dumps(codead)) 224 | wf.write(pickle.dumps(rules)) 225 | wf.close() 226 | print(rules) 227 | print(astnode) 228 | -------------------------------------------------------------------------------- /hearthstone_preprocess/test_hs.in: -------------------------------------------------------------------------------- 1 | Archmage NAME_END 4 ATK_END 7 DEF_END 6 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Spell Damage +1 2 | Booty Bay Bodyguard NAME_END 5 ATK_END 4 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Taunt 3 | Darkscale Healer NAME_END 4 ATK_END 5 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Battlecry: Restore 2 Health to all friendly characters. 4 | Fiery War Axe NAME_END 3 ATK_END -1 DEF_END 2 COST_END 2 DUR_END Weapon TYPE_END Warrior PLAYER_CLS_END NIL RACE_END Free RARITY_END NIL 5 | Frostwolf Warlord NAME_END 4 ATK_END 4 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Battlecry: Gain +1/+1 for each other friendly minion on the battlefield. 6 | Hellfire NAME_END -1 ATK_END -1 DEF_END 4 COST_END -1 DUR_END Spell TYPE_END Warlock PLAYER_CLS_END NIL RACE_END Free RARITY_END Deal $3 damage to ALL characters. 7 | Innervate NAME_END -1 ATK_END -1 DEF_END 0 COST_END -1 DUR_END Spell TYPE_END Druid PLAYER_CLS_END NIL RACE_END Free RARITY_END Gain 2 Mana Crystals this turn only. 8 | Magma Rager NAME_END 5 ATK_END 1 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Free RARITY_END NIL 9 | Mortal Coil NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Warlock PLAYER_CLS_END NIL RACE_END Common RARITY_END Deal $1 damage to a minion. If that kills it, draw a card. 10 | Polymorph NAME_END -1 ATK_END -1 DEF_END 4 COST_END -1 DUR_END Spell TYPE_END Mage PLAYER_CLS_END NIL RACE_END Free RARITY_END Transform a minion into a 1/1 Sheep. 11 | Searing Totem NAME_END 1 ATK_END 1 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Shaman PLAYER_CLS_END Totem RACE_END Free RARITY_END NIL 12 | Silverback Patriarch NAME_END 1 ATK_END 4 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Beast RACE_END Common RARITY_END Taunt 13 | Stormwind Knight NAME_END 2 ATK_END 5 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Charge 14 | Voodoo Doctor NAME_END 2 ATK_END 1 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Free RARITY_END Battlecry: Restore 2 Health. 15 | Wrath of Air Totem NAME_END 0 ATK_END 2 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Shaman PLAYER_CLS_END Totem RACE_END Free RARITY_END Spell Damage +1 16 | Astral Communion NAME_END -1 ATK_END -1 DEF_END 4 COST_END -1 DUR_END Spell TYPE_END Druid PLAYER_CLS_END NIL RACE_END Epic RARITY_END Gain 10 Mana Crystals. Discard your hand. 17 | Annoy-o-Tron NAME_END 1 ATK_END 2 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Mech RACE_END Common RARITY_END Taunt NL Divine Shield 18 | Burly Rockjaw Trogg NAME_END 3 ATK_END 5 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Whenever your opponent casts a spell, gain +2 Attack. 19 | Cogmaster's Wrench NAME_END 1 ATK_END -1 DEF_END 3 COST_END 3 DUR_END Weapon TYPE_END Rogue PLAYER_CLS_END NIL RACE_END Epic RARITY_END Has +2 Attack while you have a Mech. 20 | Echo of Medivh NAME_END -1 ATK_END -1 DEF_END 4 COST_END -1 DUR_END Spell TYPE_END Mage PLAYER_CLS_END NIL RACE_END Epic RARITY_END Put a copy of each friendly minion into your hand. 21 | Floating Watcher NAME_END 4 ATK_END 4 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Warlock PLAYER_CLS_END Demon RACE_END Common RARITY_END Whenever your hero takes damage on your turn, gain +2/+2. 22 | Gnomish Experimenter NAME_END 3 ATK_END 2 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END Battlecry: Draw a card. If it's a minion, transform it into a Chicken. 23 | Iron Juggernaut NAME_END 6 ATK_END 5 DEF_END 6 COST_END -1 DUR_END Minion TYPE_END Warrior PLAYER_CLS_END Mech RACE_END Legendary RARITY_END Battlecry: Shuffle a Mine into your opponent's deck. When drawn, it explodes for 10 damage. 24 | Madder Bomber NAME_END 5 ATK_END 4 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END Battlecry: Deal 6 damage randomly split between all other characters. 25 | Mini-Mage NAME_END 4 ATK_END 1 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Epic RARITY_END Stealth NL Spell Damage +1 26 | Piloted Sky Golem NAME_END 6 ATK_END 4 DEF_END 6 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Mech RACE_END Epic RARITY_END Deathrattle: Summon a random 4-Cost minion. 27 | Scarlet Purifier NAME_END 4 ATK_END 3 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Paladin PLAYER_CLS_END NIL RACE_END Rare RARITY_END Battlecry: Deal 2 damage to all minions with Deathrattle. 28 | Siltfin Spiritwalker NAME_END 2 ATK_END 5 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Shaman PLAYER_CLS_END Murloc RACE_END Epic RARITY_END Whenever another friendly Murloc dies, draw a card. Overload: (1) 29 | Tinkertown Technician NAME_END 3 ATK_END 3 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Battlecry: If you have a Mech, gain +1/+1 and add a Spare Part to your hand. 30 | Vol'jin NAME_END 6 ATK_END 2 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Priest PLAYER_CLS_END NIL RACE_END Legendary RARITY_END Battlecry: Swap Health with another minion. 31 | Death's Bite NAME_END 4 ATK_END -1 DEF_END 4 COST_END 2 DUR_END Weapon TYPE_END Warrior PLAYER_CLS_END NIL RACE_END Common RARITY_END Deathrattle: Deal 1 damage to all minions. 32 | Maexxna NAME_END 2 ATK_END 8 DEF_END 6 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Beast RACE_END Legendary RARITY_END Destroy any minion damaged by this minion. 33 | Sludge Belcher NAME_END 3 ATK_END 5 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END Taunt. NL Deathrattle: Summon a 1/2 Slime with Taunt. 34 | Voidcaller NAME_END 3 ATK_END 4 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Warlock PLAYER_CLS_END Demon RACE_END Common RARITY_END Deathrattle: Put a random Demon from your hand into the battlefield. 35 | Alexstrasza NAME_END 8 ATK_END 8 DEF_END 9 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Dragon RACE_END Legendary RARITY_END Battlecry: Set a hero's remaining Health to 15. 36 | Arcane Golem NAME_END 4 ATK_END 2 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END Charge. Battlecry: Give your opponent a Mana Crystal. 37 | Baine Bloodhoof NAME_END 4 ATK_END 5 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Legendary RARITY_END NIL 38 | Blade Flurry NAME_END -1 ATK_END -1 DEF_END 2 COST_END -1 DUR_END Spell TYPE_END Rogue PLAYER_CLS_END NIL RACE_END Rare RARITY_END Destroy your weapon and deal its damage to all enemies. 39 | Brawl NAME_END -1 ATK_END -1 DEF_END 5 COST_END -1 DUR_END Spell TYPE_END Warrior PLAYER_CLS_END NIL RACE_END Epic RARITY_END Destroy all minions except one. (chosen randomly) 40 | Conceal NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Rogue PLAYER_CLS_END NIL RACE_END Common RARITY_END Give your minions Stealth until your next turn. 41 | Defender NAME_END 2 ATK_END 1 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Paladin PLAYER_CLS_END NIL RACE_END Common RARITY_END NIL 42 | Doomguard NAME_END 5 ATK_END 7 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Warlock PLAYER_CLS_END Demon RACE_END Rare RARITY_END Charge. Battlecry: Discard two random cards. 43 | Earth Elemental NAME_END 7 ATK_END 8 DEF_END 5 COST_END -1 DUR_END Minion TYPE_END Shaman PLAYER_CLS_END NIL RACE_END Epic RARITY_END Taunt. Overload: (3) 44 | Explosive Trap NAME_END -1 ATK_END -1 DEF_END 2 COST_END -1 DUR_END Spell TYPE_END Hunter PLAYER_CLS_END NIL RACE_END Common RARITY_END Secret: When your hero is attacked, deal $2 damage to all enemies. 45 | Flame of Azzinoth NAME_END 2 ATK_END 1 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END NIL RARITY_END NIL 46 | Gnoll NAME_END 2 ATK_END 2 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END NIL RARITY_END Taunt 47 | Holy Wrath NAME_END -1 ATK_END -1 DEF_END 5 COST_END -1 DUR_END Spell TYPE_END Paladin PLAYER_CLS_END NIL RACE_END Rare RARITY_END Draw a card and deal damage equal to its cost. 48 | Infernal NAME_END 6 ATK_END 6 DEF_END 6 COST_END -1 DUR_END Minion TYPE_END Warlock PLAYER_CLS_END Demon RACE_END Common RARITY_END NIL 49 | Kirin Tor Mage NAME_END 4 ATK_END 3 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Mage PLAYER_CLS_END NIL RACE_END Rare RARITY_END Battlecry: The next Secret you play this turn costs (0). 50 | Lightwarden NAME_END 1 ATK_END 2 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END Whenever a character is healed, gain +2 Attack. 51 | Mana Wraith NAME_END 2 ATK_END 2 DEF_END 2 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Rare RARITY_END ALL minions cost (1) more. 52 | Misdirection NAME_END -1 ATK_END -1 DEF_END 2 COST_END -1 DUR_END Spell TYPE_END Hunter PLAYER_CLS_END NIL RACE_END Rare RARITY_END Secret: When a character attacks your hero, instead he attacks another random character. 53 | Noble Sacrifice NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Paladin PLAYER_CLS_END NIL RACE_END Common RARITY_END Secret: When an enemy attacks, summon a 2/1 Defender as the new target. 54 | Power of the Wild NAME_END -1 ATK_END -1 DEF_END 2 COST_END -1 DUR_END Spell TYPE_END Druid PLAYER_CLS_END NIL RACE_END Common RARITY_END Choose One - Give your minions +1/+1; or Summon a 3/2 Panther. 55 | Redemption NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Paladin PLAYER_CLS_END NIL RACE_END Common RARITY_END Secret: When one of your minions dies, return it to life with 1 Health. 56 | Shadow Madness NAME_END -1 ATK_END -1 DEF_END 4 COST_END -1 DUR_END Spell TYPE_END Priest PLAYER_CLS_END NIL RACE_END Rare RARITY_END Gain control of an enemy minion with 3 or less Attack until end of turn. 57 | Siphon Soul NAME_END -1 ATK_END -1 DEF_END 6 COST_END -1 DUR_END Spell TYPE_END Warlock PLAYER_CLS_END NIL RACE_END Rare RARITY_END Destroy a minion. Restore #3 Health to your hero. 58 | Spellbender NAME_END 1 ATK_END 3 DEF_END 0 COST_END -1 DUR_END Minion TYPE_END Mage PLAYER_CLS_END NIL RACE_END Epic RARITY_END NIL 59 | Summoning Portal NAME_END 0 ATK_END 4 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Warlock PLAYER_CLS_END NIL RACE_END Common RARITY_END Your minions cost (2) less, but not less than (1). 60 | Thrallmar Farseer NAME_END 2 ATK_END 3 DEF_END 3 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END Windfury 61 | Upgrade! NAME_END -1 ATK_END -1 DEF_END 1 COST_END -1 DUR_END Spell TYPE_END Warrior PLAYER_CLS_END NIL RACE_END Rare RARITY_END If you have a weapon, give it +1/+1. Otherwise equip a 1/3 weapon. 62 | Wisp NAME_END 1 ATK_END 1 DEF_END 0 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END NIL RACE_END Common RARITY_END NIL 63 | Black Whelp NAME_END 2 ATK_END 1 DEF_END 1 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Dragon RACE_END Common RARITY_END NIL 64 | Dragonkin Sorcerer NAME_END 3 ATK_END 5 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Dragon RACE_END Common RARITY_END Whenever you target this minion with a spell, gain +1/+1. 65 | Hungry Dragon NAME_END 5 ATK_END 6 DEF_END 4 COST_END -1 DUR_END Minion TYPE_END Neutral PLAYER_CLS_END Dragon RACE_END Common RARITY_END Battlecry: Summon a random 1-Cost minion for your opponent. 66 | Solemn Vigil NAME_END -1 ATK_END -1 DEF_END 5 COST_END -1 DUR_END Spell TYPE_END Paladin PLAYER_CLS_END NIL RACE_END Common RARITY_END Draw 2 cards. Costs (1) less for each minion that died this turn. 67 | -------------------------------------------------------------------------------- /hearthstone_preprocess/test_hs.out: -------------------------------------------------------------------------------- 1 | class Archmage(MinionCard):§ def __init__(self):§ super().__init__("Archmage", 6, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(4, 7, spell_damage=1)§ 2 | class BootyBayBodyguard(MinionCard):§ def __init__(self):§ super().__init__("Booty Bay Bodyguard", 5, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(5, 4, taunt=True)§ 3 | class DarkscaleHealer(MinionCard):§ def __init__(self):§ super().__init__("Darkscale Healer", 5, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, battlecry=Battlecry(Heal(2), CharacterSelector()))§§ def create_minion(self, player):§ return Minion(4, 5)§ 4 | class FieryWarAxe(WeaponCard):§ def __init__(self):§ super().__init__("Fiery War Axe", 2, CHARACTER_CLASS.WARRIOR, CARD_RARITY.FREE)§§ def create_weapon(self, player):§ return Weapon(3, 2)§ 5 | class FrostwolfWarlord(MinionCard):§ def __init__(self):§ super().__init__("Frostwolf Warlord", 5, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, battlecry=Battlecry(Give([Buff(ChangeAttack(Count(MinionSelector()))), Buff(ChangeHealth(Count(MinionSelector())))]), SelfSelector()))§§ def create_minion(self, player):§ return Minion(4, 4)§ 6 | class Hellfire(SpellCard):§ def __init__(self):§ super().__init__("Hellfire", 4, CHARACTER_CLASS.WARLOCK, CARD_RARITY.FREE)§§ def use(self, player, game):§ super().use(player, game)§ targets = copy.copy(game.other_player.minions)§ targets.extend(game.current_player.minions)§ targets.append(game.other_player.hero)§ targets.append(game.current_player.hero)§ for minion in targets:§ minion.damage(player.effective_spell_damage(3), self)§ 7 | class Innervate(SpellCard):§ def __init__(self):§ super().__init__("Innervate", 0, CHARACTER_CLASS.DRUID, CARD_RARITY.FREE)§§ def use(self, player, game):§ super().use(player, game)§ if player.mana < 8:§ player.mana += 2§ else:§ player.mana = 10§ 8 | class MagmaRager(MinionCard):§ def __init__(self):§ super().__init__("Magma Rager", 3, CHARACTER_CLASS.ALL, CARD_RARITY.FREE)§§ def create_minion(self, player):§ return Minion(5, 1)§ 9 | class MortalCoil(SpellCard):§ def __init__(self):§ super().__init__("Mortal Coil", 1, CHARACTER_CLASS.WARLOCK, CARD_RARITY.COMMON, target_func=hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§ if self.target.health <= player.effective_spell_damage(1) and not self.target.divine_shield:§ self.target.damage(player.effective_spell_damage(1), self)§ player.draw()§ else:§ self.target.damage(player.effective_spell_damage(1), self)§ 10 | class Polymorph(SpellCard):§ def __init__(self):§ super().__init__("Polymorph", 4, CHARACTER_CLASS.MAGE, CARD_RARITY.FREE, target_func=hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§ from hearthbreaker.cards.minions.mage import Sheep§ sheep = Sheep()§ minion = sheep.create_minion(None)§ minion.card = sheep§ self.target.replace(minion)§ 11 | class SearingTotem(MinionCard):§ def __init__(self):§ super().__init__("Searing Totem", 1, CHARACTER_CLASS.SHAMAN, CARD_RARITY.FREE, False, MINION_TYPE.TOTEM)§§ def create_minion(self, player):§ return Minion(1, 1)§ 12 | class SilverbackPatriarch(MinionCard):§ def __init__(self):§ super().__init__("Silverback Patriarch", 3, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, minion_type=MINION_TYPE.BEAST)§§ def create_minion(self, player):§ return Minion(1, 4, taunt=True)§ 13 | class StormwindKnight(MinionCard):§ def __init__(self):§ super().__init__("Stormwind Knight", 4, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(2, 5, charge=True)§ 14 | class VoodooDoctor(MinionCard):§ def __init__(self):§ super().__init__("Voodoo Doctor", 1, CHARACTER_CLASS.ALL, CARD_RARITY.FREE, battlecry=Battlecry(Heal(2), CharacterSelector(players=BothPlayer(), picker=UserPicker())))§§ def create_minion(self, player):§ return Minion(2, 1)§ 15 | class WrathOfAirTotem(MinionCard):§ def __init__(self):§ super().__init__("Wrath of Air Totem", 1, CHARACTER_CLASS.SHAMAN, CARD_RARITY.FREE, False, MINION_TYPE.TOTEM)§§ def create_minion(self, player):§ return Minion(0, 2, spell_damage=1)§ 16 | class AstralCommunion(SpellCard):§ def __init__(self):§ super().__init__("Astral Communion", 4, CHARACTER_CLASS.DRUID, CARD_RARITY.EPIC)§§ def use(self, player, game):§ super().use(player, game)§ for card in player.hand:§ card.unattach()§ player.trigger("card_discarded", card)§ player.hand = []§ player.max_mana = 10§ player.mana = 10§ 17 | class AnnoyoTron(MinionCard):§ def __init__(self):§ super().__init__("Annoy-o-Tron", 2, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, minion_type=MINION_TYPE.MECH)§§ def create_minion(self, player):§ return Minion(1, 2, divine_shield=True, taunt=True)§ 18 | class BurlyRockjawTrogg(MinionCard):§ def __init__(self):§ super().__init__("Burly Rockjaw Trogg", 4, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(3, 5, effects=[Effect(SpellCast(player=EnemyPlayer()), ActionTag(Give(ChangeAttack(2)), SelfSelector()))])§ 19 | class CogmastersWrench(WeaponCard):§ def __init__(self):§ super().__init__("Cogmaster's Wrench", 3, CHARACTER_CLASS.ROGUE, CARD_RARITY.EPIC)§§ def create_weapon(self, player):§ return Weapon(1, 3, buffs=[Buff(ChangeAttack(2), GreaterThan(Count(MinionSelector(IsType(MINION_TYPE.MECH))), value=0))])§ 20 | class EchoOfMedivh(SpellCard):§ def __init__(self):§ super().__init__("Echo of Medivh", 4, CHARACTER_CLASS.MAGE, CARD_RARITY.EPIC)§§ def use(self, player, game):§ super().use(player, game)§ for minion in sorted(copy.copy(player.minions), key=lambda minion: minion.born):§ if len(player.hand) < 10:§ player.hand.append(minion.card)§ 21 | class FloatingWatcher(MinionCard):§ def __init__(self):§ super().__init__("Floating Watcher", 5, CHARACTER_CLASS.WARLOCK, CARD_RARITY.COMMON, minion_type=MINION_TYPE.DEMON)§§ def create_minion(self, player):§ return Minion(4, 4, effects=[Effect(CharacterDamaged(And(IsHero(), OwnersTurn())), ActionTag(Give([Buff(ChangeAttack(2)), Buff(ChangeHealth(2))]), SelfSelector()))])§ 22 | class GnomishExperimenter(MinionCard):§ def __init__(self):§ super().__init__("Gnomish Experimenter", 3, CHARACTER_CLASS.ALL, CARD_RARITY.RARE, battlecry=(Battlecry(Draw(), PlayerSelector()), Battlecry(Transform(GnomishChicken()), LastDrawnSelector(), Matches(LastDrawnSelector(), IsMinion()))))§§ def create_minion(self, player):§ return Minion(3, 2)§ 23 | class IronJuggernaut(MinionCard):§ def __init__(self):§ super().__init__("Iron Juggernaut", 6, CHARACTER_CLASS.WARRIOR, CARD_RARITY.LEGENDARY, minion_type=MINION_TYPE.MECH, battlecry=Battlecry(AddCard(BurrowingMine(), add_to_deck=True), PlayerSelector(EnemyPlayer())))§§ def create_minion(self, player):§ return Minion(6, 5)§ 24 | class MadderBomber(MinionCard):§ def __init__(self):§ super().__init__("Madder Bomber", 5, CHARACTER_CLASS.ALL, CARD_RARITY.RARE, battlecry=Battlecry(Damage(1), CharacterSelector(players=BothPlayer(), picker=RandomPicker(6))))§§ def create_minion(self, player):§ return Minion(5, 4)§ 25 | class MiniMage(MinionCard):§ def __init__(self):§ super().__init__("Mini-Mage", 4, CHARACTER_CLASS.ALL, CARD_RARITY.EPIC)§§ def create_minion(self, player):§ return Minion(4, 1, stealth=True, spell_damage=1)§ 26 | class PilotedSkyGolem(MinionCard):§ def __init__(self):§ super().__init__("Piloted Sky Golem", 6, CHARACTER_CLASS.ALL, CARD_RARITY.EPIC, minion_type=MINION_TYPE.MECH)§§ def create_minion(self, player):§ return Minion(6, 4, deathrattle=Deathrattle(Summon(CardQuery(conditions=[ManaCost(4), IsMinion()])), PlayerSelector()))§ 27 | class ScarletPurifier(MinionCard):§ def __init__(self):§ super().__init__("Scarlet Purifier", 3, CHARACTER_CLASS.PALADIN, CARD_RARITY.RARE, battlecry=Battlecry(Damage(2), MinionSelector(MinionHasDeathrattle(), BothPlayer())))§§ def create_minion(self, player):§ return Minion(4, 3)§ 28 | class SiltfinSpiritwalker(MinionCard):§ def __init__(self):§ super().__init__("Siltfin Spiritwalker", 4, CHARACTER_CLASS.SHAMAN, CARD_RARITY.EPIC, minion_type=MINION_TYPE.MURLOC, overload=1)§§ def create_minion(self, player):§ return Minion(2, 5, effects=[Effect(MinionDied(IsType(MINION_TYPE.MURLOC)), ActionTag(Draw(), PlayerSelector()))])§ 29 | class TinkertownTechnician(MinionCard):§ def __init__(self):§ from hearthbreaker.cards.spells.neutral import spare_part_list§ super().__init__("Tinkertown Technician", 3, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, battlecry=(Battlecry(Give([Buff(ChangeAttack(1)), Buff(ChangeHealth(1))]), SelfSelector(), GreaterThan(Count(MinionSelector(IsType(MINION_TYPE.MECH))), value=0)), Battlecry(AddCard(CardQuery(source=CARD_SOURCE.LIST, source_list=spare_part_list)), PlayerSelector(), GreaterThan(Count(MinionSelector(IsType(MINION_TYPE.MECH))), value=0))))§§ def create_minion(self, player):§ return Minion(3, 3)§ 30 | class Voljin(MinionCard):§ def __init__(self):§ super().__init__("Vol'jin", 5, CHARACTER_CLASS.PRIEST, CARD_RARITY.LEGENDARY, battlecry=Battlecry(SwapStats("health", "health", True), MinionSelector(players=BothPlayer(), picker=UserPicker())))§§ def create_minion(self, player):§ return Minion(6, 2)§ 31 | class DeathsBite(WeaponCard):§ def __init__(self):§ super().__init__("Death's Bite", 4, CHARACTER_CLASS.WARRIOR, CARD_RARITY.COMMON)§§ def create_weapon(self, player):§ return Weapon(4, 2, deathrattle=Deathrattle(Damage(1), MinionSelector(players=BothPlayer())))§ 32 | class Maexxna(MinionCard):§ def __init__(self):§ super().__init__("Maexxna", 6, CHARACTER_CLASS.ALL, CARD_RARITY.LEGENDARY, minion_type=MINION_TYPE.BEAST)§§ def create_minion(self, player):§ return Minion(2, 8, effects=[Effect(DidDamage(), ActionTag(Kill(), TargetSelector(IsMinion())))])§ 33 | class SludgeBelcher(MinionCard):§ def __init__(self):§ super().__init__("Sludge Belcher", 5, CHARACTER_CLASS.ALL, CARD_RARITY.RARE)§§ def create_minion(self, player):§ return Minion(3, 5, taunt=True, deathrattle=Deathrattle(Summon(Slime()), PlayerSelector()))§ 34 | class Voidcaller(MinionCard):§ def __init__(self):§ super().__init__("Voidcaller", 4, CHARACTER_CLASS.WARLOCK, CARD_RARITY.COMMON, minion_type=MINION_TYPE.DEMON)§§ def create_minion(self, player):§ return Minion(3, 4, deathrattle=Deathrattle(Summon(CardQuery(conditions=[IsType(MINION_TYPE.DEMON)], source=CARD_SOURCE.MY_HAND)), PlayerSelector()))§ 35 | class Alexstrasza(MinionCard):§ def __init__(self):§ super().__init__("Alexstrasza", 9, CHARACTER_CLASS.ALL, CARD_RARITY.LEGENDARY, minion_type=MINION_TYPE.DRAGON, battlecry=Battlecry(SetHealth(15), HeroSelector(players=BothPlayer(), picker=UserPicker())))§§ def create_minion(self, player):§ return Minion(8, 8)§ 36 | class ArcaneGolem(MinionCard):§ def __init__(self):§ super().__init__("Arcane Golem", 3, CHARACTER_CLASS.ALL, CARD_RARITY.RARE, battlecry=Battlecry(GiveManaCrystal(), PlayerSelector(players=EnemyPlayer())))§§ def create_minion(self, player):§ return Minion(4, 2, charge=True)§ 37 | class BaineBloodhoof(MinionCard):§ def __init__(self):§ super().__init__("Baine Bloodhoof", 4, CHARACTER_CLASS.ALL, CARD_RARITY.LEGENDARY, False)§§ def create_minion(self, player):§ return Minion(4, 5)§ 38 | class BladeFlurry(SpellCard):§ def __init__(self):§ super().__init__("Blade Flurry", 2, CHARACTER_CLASS.ROGUE, CARD_RARITY.RARE)§§ def use(self, player, game):§ super().use(player, game)§§ if player.weapon is not None:§ attack_power = player.effective_spell_damage(player.hero.calculate_attack())§ player.weapon.destroy()§§ for minion in copy.copy(game.other_player.minions):§ minion.damage(attack_power, self)§§ game.other_player.hero.damage(attack_power, self)§ 39 | class Brawl(SpellCard):§ def __init__(self):§ super().__init__("Brawl", 5, CHARACTER_CLASS.WARRIOR, CARD_RARITY.EPIC)§§ def can_use(self, player, game):§ return super().can_use(player, game) and len(player.minions) + len(player.opponent.minions) >= 2§§ def use(self, player, game):§ super().use(player, game)§§ minions = copy.copy(player.minions)§ minions.extend(game.other_player.minions)§§ if len(minions) > 1:§ survivor = game.random_choice(minions)§ for minion in minions:§ if minion is not survivor:§ minion.die(self)§ 40 | class Conceal(SpellCard):§ def __init__(self):§ super().__init__("Conceal", 1, CHARACTER_CLASS.ROGUE, CARD_RARITY.COMMON)§§ def use(self, player, game):§ super().use(player, game)§ for minion in player.minions:§ if not minion.stealth:§ minion.add_buff(BuffUntil(Stealth(), TurnStarted()))§ 41 | class DefenderMinion(MinionCard):§ def __init__(self):§ super().__init__("Defender", 1, CHARACTER_CLASS.PALADIN, CARD_RARITY.COMMON)§§ def create_minion(self, p):§ return Minion(2, 1)§ 42 | class Doomguard(MinionCard):§ def __init__(self):§ super().__init__("Doomguard", 5, CHARACTER_CLASS.WARLOCK, CARD_RARITY.RARE, minion_type=MINION_TYPE.DEMON, battlecry=Battlecry(Discard(amount=2), PlayerSelector()))§§ def create_minion(self, player):§ return Minion(5, 7, charge=True)§ 43 | class EarthElemental(MinionCard):§ def __init__(self):§ super().__init__("Earth Elemental", 5, CHARACTER_CLASS.SHAMAN, CARD_RARITY.EPIC, overload=3)§§ def create_minion(self, player):§ return Minion(7, 8, taunt=True)§ 44 | class ExplosiveTrap(SecretCard):§ def __init__(self):§ super().__init__("Explosive Trap", 2, CHARACTER_CLASS.HUNTER, CARD_RARITY.COMMON)§§ def activate(self, player):§ player.opponent.bind("character_attack", self._reveal)§§ def deactivate(self, player):§ player.opponent.unbind("character_attack", self._reveal)§§ def _reveal(self, attacker, target):§ if isinstance(target, Hero):§ game = attacker.player.game§ enemies = copy.copy(game.current_player.minions)§ enemies.append(game.current_player.hero)§ for enemy in enemies:§ enemy.damage(2, None)§ game.check_delayed()§ super().reveal()§ 45 | class FlameOfAzzinoth(MinionCard):§ def __init__(self):§ super().__init__("Flame of Azzinoth", 1, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, False)§§ def create_minion(self, player):§ return Minion(2, 1)§ 46 | class Gnoll(MinionCard):§ def __init__(self):§ super().__init__("Gnoll", 2, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, False)§§ def create_minion(self, player):§ return Minion(2, 2, taunt=True)§ 47 | class HolyWrath(SpellCard):§ def __init__(self):§ super().__init__("Holy Wrath", 5, CHARACTER_CLASS.PALADIN, CARD_RARITY.RARE, target_func=hearthbreaker.targeting.find_spell_target)§§ def use(self, player, game):§ super().use(player, game)§§ fatigue = False§ if player.deck.left == 0:§ fatigue = True§§ player.draw()§ if not fatigue:§ cost = player.hand[-1].mana§ self.target.damage(player.effective_spell_damage(cost), self)§ 48 | class Infernal(MinionCard):§ def __init__(self):§ super().__init__("Infernal", 6, CHARACTER_CLASS.WARLOCK, CARD_RARITY.COMMON, False, minion_type=MINION_TYPE.DEMON)§§ def create_minion(self, player):§ return Minion(6, 6)§ 49 | class KirinTorMage(MinionCard):§ def __init__(self):§ super().__init__("Kirin Tor Mage", 3, CHARACTER_CLASS.MAGE, CARD_RARITY.RARE, battlecry=Battlecry(GiveAura([AuraUntil(ManaChange(-100), CardSelector(condition=IsSecret()), CardPlayed(IsSecret()))]), PlayerSelector()))§§ def create_minion(self, player):§ return Minion(4, 3)§ 50 | class Lightwarden(MinionCard):§ def __init__(self):§ super().__init__("Lightwarden", 1, CHARACTER_CLASS.ALL, CARD_RARITY.RARE)§§ def create_minion(self, player):§ return Minion(1, 2, effects=[Effect(CharacterHealed(player=BothPlayer()), ActionTag(Give(ChangeAttack(2)), SelfSelector()))])§ 51 | class ManaWraith(MinionCard):§ def __init__(self):§ super().__init__("Mana Wraith", 2, CHARACTER_CLASS.ALL, CARD_RARITY.RARE)§§ def create_minion(self, player):§ return Minion(2, 2, auras=[Aura(ManaChange(1), CardSelector(BothPlayer(), IsMinion()))])§ 52 | class Misdirection(SecretCard):§ def __init__(self):§ super().__init__("Misdirection", 2, CHARACTER_CLASS.HUNTER, CARD_RARITY.RARE)§§ def activate(self, player):§ player.opponent.bind("character_attack", self._reveal)§§ def deactivate(self, player):§ player.opponent.unbind("character_attack", self._reveal)§§ def _reveal(self, character, target):§ if isinstance(target, Hero) and not character.removed:§ game = character.player.game§ possibilities = copy.copy(game.current_player.minions)§ possibilities.extend(game.other_player.minions)§ possibilities.append(game.current_player.hero)§ possibilities.append(game.other_player.hero)§ possibilities.remove(character.current_target)§ character.current_target = game.random_choice(possibilities)§§ super().reveal()§ 53 | class NobleSacrifice(SecretCard):§ def __init__(self):§ super().__init__("Noble Sacrifice", 1, CHARACTER_CLASS.PALADIN, CARD_RARITY.COMMON)§§ def _reveal(self, attacker, target):§ player = attacker.player.game.other_player§ if len(player.minions) < 7 and not attacker.removed:§ from hearthbreaker.cards.minions.paladin import DefenderMinion§ defender = DefenderMinion()§ defender.summon(player, player.game, len(player.minions))§ attacker.current_target = player.minions[-1]§ super().reveal()§§ def activate(self, player):§ player.opponent.bind("character_attack", self._reveal)§§ def deactivate(self, player):§ player.opponent.unbind("character_attack", self._reveal)§ 54 | class PowerOfTheWild(SpellCard):§ def __init__(self):§ super().__init__("Power of the Wild", 2, CHARACTER_CLASS.DRUID, CARD_RARITY.COMMON)§§ def use(self, player, game):§ super().use(player, game)§ option = player.agent.choose_option([LeaderOfThePack(), SummonPanther()], player)§ option.use(player, game)§ 55 | class Redemption(SecretCard):§ def __init__(self):§ super().__init__("Redemption", 1, CHARACTER_CLASS.PALADIN, CARD_RARITY.COMMON)§§ def _reveal(self, minion, by):§ resurrection = minion.card.summon(minion.player, minion.game, min(minion.index, len(minion.player.minions)))§ if resurrection:§ resurrection.health = 1§ super().reveal()§§ def activate(self, player):§ player.bind("minion_died", self._reveal)§§ def deactivate(self, player):§ player.unbind("minion_died", self._reveal)§ 56 | class ShadowMadness(SpellCard):§ def __init__(self):§ super().__init__("Shadow Madness", 4, CHARACTER_CLASS.PRIEST, CARD_RARITY.RARE, target_func=hearthbreaker.targeting.find_enemy_minion_spell_target, filter_func=lambda target: target.calculate_attack() <= 3 and target.spell_targetable())§§ def use(self, player, game):§§ super().use(player, game)§§ minion = self.target.copy(player)§ minion.active = True§ minion.exhausted = False§§ self.target.remove_from_board()§ minion.add_to_board(len(player.minions))§§ minion.add_buff(BuffUntil(Stolen(), TurnEnded()))§§ def can_use(self, player, game):§ return super().can_use(player, game) and len(player.minions) < 7§ 57 | class SiphonSoul(SpellCard):§ def __init__(self):§ super().__init__("Siphon Soul", 6, CHARACTER_CLASS.WARLOCK, CARD_RARITY.RARE, target_func=hearthbreaker.targeting.find_minion_spell_target)§§ def use(self, player, game):§ super().use(player, game)§ self.target.die(self)§ player.hero.heal(player.effective_heal_power(3), self)§ 58 | class Spellbender(SecretCard):§ def __init__(self):§ super().__init__("Spellbender", 3, CHARACTER_CLASS.MAGE, CARD_RARITY.EPIC)§ self.player = None§§ def _reveal(self, card, index):§ if card.is_spell() and len(self.player.minions) < 7 and card.target and card.target.is_minion():§ SpellbenderMinion().summon(self.player, self.player.game, len(self.player.minions))§ card.target = self.player.minions[-1]§ super().reveal()§§ def activate(self, player):§ player.game.current_player.bind("card_played", self._reveal)§ self.player = player§§ def deactivate(self, player):§ player.game.current_player.unbind("card_played", self._reveal)§ self.player = None§ 59 | class SummoningPortal(MinionCard):§ def __init__(self):§ super().__init__("Summoning Portal", 4, CHARACTER_CLASS.WARLOCK, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(0, 4, auras=[Aura(ManaChange(-2, 1, minimum=1), CardSelector(condition=IsMinion()))])§ 60 | class ThrallmarFarseer(MinionCard):§ def __init__(self):§ super().__init__("Thrallmar Farseer", 3, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(2, 3, windfury=True)§ 61 | class Upgrade(SpellCard):§ def __init__(self):§ super().__init__("Upgrade!", 1, CHARACTER_CLASS.WARRIOR, CARD_RARITY.RARE)§§ def use(self, player, game):§ super().use(player, game)§ from hearthbreaker.cards.weapons.warrior import HeavyAxe§ if player.weapon:§ player.weapon.durability += 1§ player.weapon.base_attack += 1§ else:§ heavy_axe = HeavyAxe().create_weapon(player)§ heavy_axe.equip(player)§ 62 | class Wisp(MinionCard):§ def __init__(self):§ super().__init__("Wisp", 0, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON)§§ def create_minion(self, player):§ return Minion(1, 1)§ 63 | class BlackWhelp(MinionCard):§ def __init__(self):§ super().__init__("Black Whelp", 1, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, False, MINION_TYPE.DRAGON)§§ def create_minion(self, player):§ return Minion(2, 1)§ 64 | class DragonkinSorcerer(MinionCard):§ def __init__(self):§ super().__init__("Dragonkin Sorcerer", 4, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, minion_type=MINION_TYPE.DRAGON)§§ def create_minion(self, player):§ return Minion(3, 5, effects=[Effect(SpellTargeted(), [ActionTag(Give([Buff(ChangeAttack(1)), Buff(ChangeHealth(1))]), SelfSelector())])])§ 65 | class HungryDragon(MinionCard):§ def __init__(self):§ super().__init__("Hungry Dragon", 4, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, minion_type=MINION_TYPE.DRAGON, battlecry=(Battlecry(Summon(CardQuery(conditions=[ManaCost(1), IsMinion()])), PlayerSelector(EnemyPlayer()))))§§ def create_minion(self, player):§ return Minion(5, 6)§ 66 | class SolemnVigil(SpellCard):§ def __init__(self):§ super().__init__("Solemn Vigil", 5, CHARACTER_CLASS.PALADIN, CARD_RARITY.COMMON, buffs=[Buff(ManaChange(Count(DeadMinionSelector(players=BothPlayer())), -1))])§§ def use(self, player, game):§ super().use(player, game)§ for n in range(0, 2):§ player.draw()§ 67 | --------------------------------------------------------------------------------