├── README.md ├── codeTrans_bcb.py ├── createclone_bcb.py └── models.py /README.md: -------------------------------------------------------------------------------- 1 | ## Efficient Transformer with Code Token Learner for Code Clone Detection 2 | 3 | The official implement of "Efficient Transformer with Code Token Learner for Code Clone Detection". 4 | 5 | Requires: 6 | pytorch 7 | javalang 8 | torch_geometric 9 | 10 | ## Data 11 | BigCloneBench snippets and clone pairs in BCB.zip (refer to [FA-AST](https://github.com/jacobwwh/graphmatch_clone)) 12 | 13 | ## Running 14 | python codeTrans_bcb.py 15 | 16 | This operation include training, validation, testing and writing test results to files. 17 | 18 | -------------------------------------------------------------------------------- /codeTrans_bcb.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import torch.optim as optim 8 | import numpy as np 9 | import time 10 | import sys 11 | import argparse 12 | from tqdm import tqdm, trange 13 | import pycparser 14 | from createclone_bcb import createast, creategmndata, createseparategraph 15 | import models 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--cuda", default=True) 19 | parser.add_argument("--dataset", default='gcj') 20 | parser.add_argument("--data_setting", default='0') 21 | parser.add_argument("--batch_size", default=10) 22 | parser.add_argument("--dropout", default=0.5) 23 | parser.add_argument("--num_layers", default=3) 24 | parser.add_argument("--inp_dim_num", default=128) 25 | parser.add_argument("--num_heads", default=4) 26 | parser.add_argument("--head_dim_num", default=16) 27 | parser.add_argument("--mlp_dim_num", default=128) 28 | parser.add_argument("--num_epochs", default=5) 29 | parser.add_argument("--lr", default=0.001) 30 | parser.add_argument("--threshold", default=0.1) 31 | args = parser.parse_args() 32 | 33 | device=torch.device('cuda:0') 34 | astdict, vocablen, vocabdict = createast() 35 | treedict=createseparategraph(astdict, vocablen, vocabdict, device) 36 | traindata,validdata,testdata=creategmndata(args.data_setting, treedict, vocablen, vocabdict, device) 37 | print(len(traindata)) 38 | 39 | inp_dim_num=args.inp_dim_num 40 | num_heads=args.num_heads 41 | head_dim_num=args.head_dim_num 42 | dropout=args.dropout 43 | num_layers=args.num_layers 44 | model = models.CloneTrans(vocablen, inp_dim_num, num_heads, head_dim_num, dropout, num_layers, device=device) 45 | model = model.to(device) 46 | 47 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 48 | criterion=nn.CosineEmbeddingLoss() 49 | criterion2=nn.MSELoss() 50 | 51 | def create_batches(data): 52 | batches = [data[graph:graph+args.batch_size] for graph in range(0, len(data), args.batch_size)] 53 | return batches 54 | 55 | def test(dataset): 56 | #model.eval() 57 | count=0 58 | correct=0 59 | tp = 0 60 | tn = 0 61 | fp = 0 62 | fn = 0 63 | results=[] 64 | for data, label in dataset: 65 | label = torch.tensor(label, dtype=torch.float, device=device) 66 | x1, x2, edge_index1, edge_index2, df1, df2, bf1, bf2 = data 67 | 68 | x1=torch.tensor(x1, dtype=torch.long, device=device) 69 | x2=torch.tensor(x2, dtype=torch.long, device=device) 70 | 71 | edge_index1=torch.tensor(edge_index1, dtype=torch.long, device=device) 72 | edge_index2=torch.tensor(edge_index2, dtype=torch.long, device=device) 73 | 74 | df1 = torch.tensor(df1, dtype=torch.long, device=device) 75 | df2 = torch.tensor(df2, dtype=torch.long, device=device) 76 | bf1 = torch.tensor(bf1, dtype=torch.long, device=device) 77 | bf2 = torch.tensor(bf2, dtype=torch.long, device=device) 78 | 79 | data1=[x1, edge_index1, df1, bf1] 80 | data2=[x2, edge_index2, df2, bf2] 81 | prediction1, prediction2, _, _ = model(data1, data2) 82 | output = F.cosine_similarity(prediction1,prediction2) 83 | results.append(output.item()) 84 | prediction = torch.sign(output).item() 85 | 86 | if prediction>args.threshold and label.item()==1: 87 | tp+=1 88 | #print('tp') 89 | if prediction<=args.threshold and label.item()==-1: 90 | tn+=1 91 | #print('tn') 92 | if prediction>args.threshold and label.item()==-1: 93 | fp+=1 94 | #print('fp') 95 | if prediction<=args.threshold and label.item()==1: 96 | fn+=1 97 | #print('fn') 98 | print(tp,tn,fp,fn) 99 | p=0.0 100 | r=0.0 101 | f1=0.0 102 | if tp+fp==0: 103 | print('precision is none') 104 | return 105 | p=tp/(tp+fp) 106 | if tp+fn==0: 107 | print('recall is none') 108 | return 109 | r=tp/(tp+fn) 110 | f1=2*p*r/(p+r) 111 | print('precision') 112 | print(p) 113 | print('recall') 114 | print(r) 115 | print('F1') 116 | print(f1) 117 | return results 118 | 119 | epochs = trange(args.num_epochs, leave=True, desc = "Epoch") 120 | for epoch in epochs: 121 | print(epoch) 122 | batches=create_batches(traindata) 123 | totalloss=0.0 124 | main_index=0.0 125 | for index, batch in tqdm(enumerate(batches), total=len(batches), desc = "Batches"): 126 | optimizer.zero_grad() 127 | batchloss= 0 128 | for data, label in batch: 129 | label_t = torch.tensor(label, dtype=torch.float, device=device) 130 | x1, x2, edge_index1, edge_index2, df1, df2, bf1, bf2 = data 131 | x1=torch.tensor(x1, dtype=torch.long, device=device) 132 | x2=torch.tensor(x2, dtype=torch.long, device=device) 133 | edge_index1=torch.tensor(edge_index1, dtype=torch.long, device=device) 134 | edge_index2=torch.tensor(edge_index2, dtype=torch.long, device=device) 135 | 136 | df1 = torch.tensor(df1, dtype=torch.long, device=device) 137 | df2 = torch.tensor(df2, dtype=torch.long, device=device) 138 | bf1 = torch.tensor(bf1, dtype=torch.long, device=device) 139 | bf2 = torch.tensor(bf2, dtype=torch.long, device=device) 140 | 141 | data1=[x1, edge_index1, df1, bf1] 142 | data2=[x2, edge_index2, df2, bf2] 143 | 144 | prediction1, prediction2, _, _ = model(data1, data2) 145 | cossim = F.cosine_similarity(prediction1,prediction2) 146 | batchloss=batchloss + criterion2(cossim, label_t) 147 | batchloss.backward(retain_graph=True) 148 | optimizer.step() 149 | loss = batchloss.item() 150 | totalloss+=loss 151 | main_index = main_index + len(batch) 152 | loss=totalloss / main_index 153 | epochs.set_description("Epoch (Loss=%g)" % round(loss,5)) 154 | 155 | devresults=test(validdata) 156 | devfile=open('bcbresult/' + 'dev_epoch_'+str(epoch+1),mode='w') 157 | for res in devresults: 158 | devfile.write(str(res)+'\n') 159 | devfile.close() 160 | 161 | testresults = test(testdata) 162 | resfile=open('bcbresult/' + 'test_epoch_'+str(epoch+1),mode='w') 163 | for res in testresults: 164 | resfile.write(str(res)+'\n') 165 | resfile.close() 166 | -------------------------------------------------------------------------------- /createclone_bcb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import javalang 4 | import javalang.tree 5 | import javalang.ast 6 | import javalang.util 7 | from javalang.ast import Node 8 | import torch 9 | from anytree import AnyNode, RenderTree 10 | from anytree import find 11 | 12 | def get_token(node): 13 | token = '' 14 | 15 | if isinstance(node, str): 16 | token = node 17 | elif isinstance(node, set): 18 | token = 'Modifier' 19 | elif isinstance(node, Node): 20 | token = node.__class__.__name__ 21 | 22 | return token 23 | 24 | def get_child(root): 25 | #print(root) 26 | if isinstance(root, Node): 27 | children = root.children 28 | elif isinstance(root, set): 29 | children = list(root) 30 | else: 31 | children = [] 32 | 33 | def expand(nested_list): 34 | for item in nested_list: 35 | if isinstance(item, list): 36 | for sub_item in expand(item): 37 | #print(sub_item) 38 | yield sub_item 39 | elif item: 40 | #print(item) 41 | yield item 42 | return list(expand(children)) 43 | 44 | def get_sequence(node, sequence): 45 | token, children = get_token(node), get_child(node) 46 | sequence.append(token) 47 | #print(len(sequence), token) 48 | for child in children: 49 | get_sequence(child, sequence) 50 | 51 | def getnodes(node,nodelist): 52 | nodelist.append(node) 53 | children = get_child(node) 54 | for child in children: 55 | getnodes(child,nodelist) 56 | 57 | def createtree(root,node,nodelist,parent=None): 58 | id = len(nodelist) 59 | #print(id) 60 | token, children = get_token(node), get_child(node) 61 | if id==0: 62 | root.token=token 63 | root.data=node 64 | else: 65 | newnode=AnyNode(id=id,token=token,data=node,parent=parent) 66 | nodelist.append(node) 67 | for child in children: 68 | if id==0: 69 | createtree(root,child, nodelist, parent=root) 70 | else: 71 | createtree(root,child, nodelist, parent=newnode) 72 | 73 | def getnodeid_depthfirst(node, nodeidindex): 74 | nodeidindex.append(node.id) 75 | for child in node.children: 76 | getnodeid_depthfirst(child, nodeidindex) 77 | 78 | def getnodeid_breadthfirst(node, nodeidindex, is_root=True): 79 | if is_root: 80 | nodeidindex.append(node.id) 81 | for child in node.children: 82 | nodeidindex.append(child.id) 83 | for child in node.children: 84 | getnodeid_breadthfirst(child, nodeidindex, is_root=False) 85 | 86 | def getnodeandedge_astonly(node,nodeindexlist,vocabdict,src,tgt): 87 | token=node.token 88 | nodeindexlist.append([vocabdict[token]]) 89 | for child in node.children: 90 | src.append(node.id) 91 | tgt.append(child.id) 92 | src.append(child.id) 93 | tgt.append(node.id) 94 | getnodeandedge_astonly(child,nodeindexlist,vocabdict,src,tgt) 95 | 96 | def getnodeandedge(node,nodeindexlist,vocabdict,src,tgt,edgetype): 97 | token=node.token 98 | nodeindexlist.append([vocabdict[token]]) 99 | for child in node.children: 100 | src.append(node.id) 101 | tgt.append(child.id) 102 | edgetype.append([0]) 103 | src.append(child.id) 104 | tgt.append(node.id) 105 | edgetype.append([0]) 106 | getnodeandedge(child,nodeindexlist,vocabdict,src,tgt,edgetype) 107 | 108 | def countnodes(node,ifcount,whilecount,forcount,blockcount): 109 | token=node.token 110 | if token=='IfStatement': 111 | ifcount+=1 112 | if token=='WhileStatement': 113 | whilecount+=1 114 | if token=='ForStatement': 115 | forcount+=1 116 | if token=='BlockStatement': 117 | blockcount+=1 118 | print(ifcount,whilecount,forcount,blockcount) 119 | for child in node.children: 120 | countnodes(child,ifcount,whilecount,forcount,blockcount) 121 | 122 | def createast(): 123 | asts=[] 124 | paths=[] 125 | alltokens=[] 126 | dirname = 'BCB/bigclonebenchdata/' 127 | for rt, dirs, files in os.walk(dirname): 128 | for file in files: 129 | programfile=open(os.path.join(rt,file),encoding='utf-8') 130 | #print(os.path.join(rt,file)) 131 | programtext=programfile.read() 132 | #programtext=programtext.replace('\r','') 133 | programtokens=javalang.tokenizer.tokenize(programtext) 134 | #print(list(programtokens)) 135 | parser=javalang.parse.Parser(programtokens) 136 | programast=parser.parse_member_declaration() 137 | paths.append(os.path.join(rt,file)) 138 | asts.append(programast) 139 | get_sequence(programast,alltokens) 140 | programfile.close() 141 | #print(programast) 142 | #print(alltokens) 143 | astdict=dict(zip(paths,asts)) 144 | ifcount=0 145 | whilecount=0 146 | forcount=0 147 | blockcount=0 148 | docount = 0 149 | switchcount = 0 150 | for token in alltokens: 151 | if token=='IfStatement': 152 | ifcount+=1 153 | if token=='WhileStatement': 154 | whilecount+=1 155 | if token=='ForStatement': 156 | forcount+=1 157 | if token=='BlockStatement': 158 | blockcount+=1 159 | if token=='DoStatement': 160 | docount+=1 161 | if token=='SwitchStatement': 162 | switchcount+=1 163 | print(ifcount,whilecount,forcount,blockcount,docount,switchcount) 164 | print('allnodes ',len(alltokens)) 165 | alltokens=list(set(alltokens)) 166 | vocabsize = len(alltokens) 167 | tokenids = range(vocabsize) 168 | vocabdict = dict(zip(alltokens, tokenids)) 169 | print(vocabsize) 170 | return astdict,vocabsize,vocabdict 171 | 172 | def createseparategraph(astdict, vocablen, vocabdict, device): 173 | pathlist=[] 174 | treelist=[] 175 | 176 | print(len(astdict)) 177 | for path,tree in astdict.items(): 178 | 179 | nodelist = [] 180 | newtree=AnyNode(id=0,token=None,data=None) 181 | createtree(newtree, tree, nodelist) 182 | 183 | x = [] 184 | edgesrc = [] 185 | edgetgt = [] 186 | 187 | getnodeandedge_astonly(newtree, x, vocabdict, edgesrc, edgetgt) 188 | 189 | depth_first_id = [] 190 | breadth_first_id = [] 191 | getnodeid_depthfirst(newtree, depth_first_id) 192 | getnodeid_breadthfirst(newtree, breadth_first_id, True) 193 | 194 | edge_index=[edgesrc, edgetgt] 195 | astlength=len(x) 196 | 197 | pathlist.append(path) 198 | treelist.append([[x, edge_index], astlength]) 199 | astdict[path]=[[x, edge_index], astlength, depth_first_id, breadth_first_id] 200 | 201 | return astdict 202 | 203 | def creategmndata(id,treedict,vocablen,vocabdict,device): 204 | indexdir='BCB/' 205 | if id=='0': 206 | trainfile = open(indexdir+'traindata.txt') 207 | validfile = open(indexdir+'devdata.txt') 208 | testfile = open(indexdir+'testdata.txt') 209 | elif id=='11': 210 | trainfile = open(indexdir+'traindata11.txt') 211 | validfile = open(indexdir+'devdata.txt') 212 | testfile = open(indexdir+'testdata.txt') 213 | else: 214 | print('file not exist') 215 | quit() 216 | trainlist=trainfile.readlines() 217 | validlist=validfile.readlines() 218 | testlist=testfile.readlines() 219 | traindata=[] 220 | validdata=[] 221 | testdata=[] 222 | print('train data') 223 | traindata=createpairdata(treedict,trainlist,device=device) 224 | print('valid data') 225 | validdata=createpairdata(treedict,validlist,device=device) 226 | print('test data') 227 | testdata=createpairdata(treedict,testlist,device=device) 228 | return traindata, validdata, testdata 229 | 230 | def createpairdata(treedict,pathlist,device): 231 | datalist=[] 232 | countlines=1 233 | for line in pathlist: 234 | #print(countlines) 235 | countlines += 1 236 | pairinfo = line.split() 237 | code1path='BCB'+pairinfo[0].strip('.') 238 | code2path='BCB'+pairinfo[1].strip('.') 239 | label=int(pairinfo[2]) 240 | data1 = treedict[code1path] 241 | data2 = treedict[code2path] 242 | x1, edge_index1, ast1length, df1, bf1 = data1[0][0],data1[0][1], data1[1], data1[2], data1[3] 243 | x2, edge_index2, ast2length, df2, bf2 = data2[0][0],data2[0][1], data2[1], data2[2], data2[3] 244 | 245 | data = [[x1, x2, edge_index1, edge_index2, df1, df2, bf1, bf2], label] 246 | datalist.append(data) 247 | return datalist 248 | 249 | if __name__ == '__main__': 250 | astdict, vocabsize, vocabdict=createast() 251 | treedict=createseparategraph(astdict, vocabsize, vocabdict, device='cpu') 252 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import Parameter 6 | import torch.nn.functional as F 7 | import torch.nn.init as torch_init 8 | from torch_geometric.nn import GCNConv 9 | from einops import rearrange, repeat 10 | from einops.layers.torch import Rearrange 11 | 12 | 13 | is_python2 = sys.version_info[0] < 3 14 | getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec 15 | special_args = ['edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j'] 16 | __size_error_msg__ = ('All tensors which should get mapped to the same source ' 17 | 'or target nodes must be of same size in dimension 0.') 18 | 19 | 20 | def weights_init_random(m): 21 | classname = m.__class__.__name__ 22 | if classname.find('Conv2d') != -1 or classname.find('Linear') != -1: 23 | torch_init.xavier_uniform_(m.weight) 24 | if m.bias is not None: 25 | m.bias.data.fill_(0) 26 | 27 | 28 | def get_emb(sin_inp): 29 | """ 30 | Gets a base embedding for one dimension with sin and cos intertwined 31 | """ 32 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) 33 | return torch.flatten(emb, -2, -1) 34 | 35 | 36 | class PreNorm(nn.Module): 37 | def __init__(self, dim, fn): 38 | super().__init__() 39 | self.norm = nn.LayerNorm(dim) 40 | self.fn = fn 41 | def forward(self, x, **kwargs): 42 | return self.fn(self.norm(x), **kwargs) 43 | 44 | 45 | class FeedForward(nn.Module): 46 | def __init__(self, dim_in, dim_out): 47 | super().__init__() 48 | self.net = nn.Sequential( 49 | nn.Linear(dim_in, dim_out), 50 | nn.GELU() 51 | ) 52 | def forward(self, x): 53 | return self.net(x) 54 | 55 | class SelfAttention(nn.Module): 56 | def __init__(self, dim, heads = 8, dim_head = 64): 57 | super().__init__() 58 | inner_dim = dim_head * heads 59 | project_out = not (heads == 1 and dim_head == dim) 60 | 61 | self.heads = heads 62 | self.scale = dim_head ** -0.5 63 | 64 | self.attend = nn.Softmax(dim = -1) 65 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 66 | 67 | self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity() 68 | 69 | def forward(self, x): 70 | b, n, _, h = *x.shape, self.heads 71 | qkv = self.to_qkv(x).chunk(3, dim = -1) 72 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 73 | 74 | dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 75 | 76 | tmp_mask = torch.zeros(b, self.heads, n, n, device=x.device, requires_grad=False) 77 | index = torch.topk(dots, k=int(max(int(n//3), 1)), dim=-1, largest=True)[1] 78 | tmp_mask.scatter_(-1, index, 1.) 79 | attn = torch.where(tmp_mask>0, dots, torch.full_like(dots, float('-inf'))) 80 | 81 | attn = self.attend(attn) 82 | 83 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) 84 | out = rearrange(out, 'b h n d -> b n (h d)') 85 | return self.to_out(out) 86 | 87 | 88 | class InterAttention(nn.Module): 89 | def __init__(self, dim, heads = 8, dim_head = 64): 90 | super().__init__() 91 | inner_dim = dim_head * heads 92 | project_out = not (heads == 1 and dim_head == dim) 93 | 94 | self.heads = heads 95 | self.scale = dim_head ** -0.5 96 | 97 | self.attend = nn.Softmax(dim = -1) 98 | 99 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 100 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 101 | 102 | self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity() 103 | 104 | def forward(self, x, y): 105 | b, nx, _ = x.size() 106 | b, ny, _ = y.size() 107 | h = self.heads 108 | 109 | # q:y kv:x 110 | q_y = self.to_q(y) 111 | kv_x = self.to_kv(x).chunk(2, dim = -1) 112 | qkv_y = (q_y,) + kv_x 113 | q_y, k_x, v_x = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv_y) 114 | 115 | dots_yx = torch.einsum('b h i d, b h j d -> b h i j', q_y, k_x) * self.scale 116 | 117 | tmp_mask_yx = torch.zeros(b, self.heads, ny, nx, device=x.device, requires_grad=False) 118 | index = torch.topk(dots_yx, k=int(max(int(nx//4), 1)), dim=-1, largest=True)[1] 119 | tmp_mask_yx.scatter_(-1, index, 1.) 120 | attn_yx = torch.where(tmp_mask_yx>0, dots_yx, torch.full_like(dots_yx, float('-inf'))) 121 | 122 | attn_yx = self.attend(attn_yx) 123 | 124 | out_y = torch.einsum('b h i j, b h j d -> b h i d', attn_yx, v_x) 125 | out_y = rearrange(out_y, 'b h n d -> b n (h d)') 126 | 127 | # q:x kv:y 128 | q_x = self.to_q(x) 129 | kv_y = self.to_kv(y).chunk(2, dim = -1) 130 | qkv_x = (q_x,) + kv_y 131 | q_x, k_y, v_y = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv_x) 132 | 133 | dots_xy = torch.einsum('b h i d, b h j d -> b h i j', q_x, k_y) * self.scale 134 | 135 | tmp_mask_xy = torch.zeros(b, self.heads, nx, ny, device=y.device, requires_grad=False) 136 | index = torch.topk(dots_xy, k=int(max(int(ny//4), 1)), dim=-1, largest=True)[1] 137 | tmp_mask_xy.scatter_(-1, index, 1.) 138 | attn_xy = torch.where(tmp_mask_xy>0, dots_xy, torch.full_like(dots_xy, float('-inf'))) 139 | 140 | attn_xy = self.attend(attn_xy) 141 | 142 | out_x = torch.einsum('b h i j, b h j d -> b h i d', attn_xy, v_y) 143 | out_x = rearrange(out_x, 'b h n d -> b n (h d)') 144 | 145 | return self.to_out(out_x), self.to_out(out_y) 146 | 147 | 148 | class InterTransformer(torch.nn.Module): 149 | def __init__(self, num_inp, depth, heads, dim_head, num_out): 150 | super().__init__() 151 | self.layers = nn.ModuleList([]) 152 | for _ in range(depth): 153 | self.layers.append(nn.ModuleList([ 154 | nn.LayerNorm(num_inp), 155 | InterAttention(num_inp, heads = heads, dim_head = dim_head), 156 | PreNorm(num_inp, FeedForward(num_inp, num_out)), 157 | ])) 158 | 159 | def forward(self, x, y): 160 | for norm, attn, ff in self.layers: 161 | x = norm(x) 162 | y = norm(y) 163 | # attention 164 | out_x, out_y = attn(x, y) 165 | # feed forward 166 | x = ff(out_x) + x 167 | y = ff(out_y) + y 168 | return x, y 169 | 170 | 171 | class SelfTransformer(torch.nn.Module): 172 | def __init__(self, num_inp, depth, heads, dim_head, num_out): 173 | super().__init__() 174 | self.layers = nn.ModuleList([]) 175 | for _ in range(depth): 176 | self.layers.append(nn.ModuleList([ 177 | PreNorm(num_inp, SelfAttention(num_inp, heads = heads, dim_head = dim_head)), 178 | PreNorm(num_inp, FeedForward(num_inp, num_out)), 179 | ])) 180 | 181 | def forward(self, x): 182 | for attn, ff in self.layers: 183 | # attention 184 | x = attn(x) + x 185 | # feed forward 186 | x = ff(x) + x 187 | return x 188 | 189 | 190 | class FeatureEmbedding(torch.nn.Module): 191 | def __init__(self, vocablen, num_inp): 192 | super().__init__() 193 | self.embed = nn.Embedding(vocablen, num_inp) 194 | self.code_token = nn.Parameter(torch.randn(1, num_inp)) 195 | torch_init.xavier_uniform_(self.code_token) 196 | # self.pos_embedding = nn.Parameter(torch.randn(1, 500, num_inp)) 197 | # torch_init.xavier_uniform_(self.pos_embedding) 198 | 199 | self.embedding = GCNConv(num_inp, 8) 200 | 201 | def forward(self, x_input, PE): 202 | 203 | x, edge_index, _, _ = x_input 204 | encodes = self.embed(x).permute(1,0,2) 205 | n, t, d = encodes.size() 206 | att = self.embedding(encodes, edge_index).permute(0,2,1) 207 | att = F.softmax(att, -1) 208 | 209 | encodes = encodes + PE 210 | learned_tokens = torch.einsum('nkt,ntd->nkd', [att, encodes]) 211 | 212 | code_tokens = repeat(self.code_token[None, ...], '() t d -> n t d', n = n) 213 | x = torch.cat((code_tokens, learned_tokens), dim=1) 214 | return x, att 215 | 216 | class CloneTrans(torch.nn.Module): 217 | def __init__(self, vocablen, num_inp, heads, dim_head, dropout, num_layers, device): 218 | super().__init__() 219 | self.device = device 220 | 221 | self.embedding = FeatureEmbedding(vocablen, num_inp) 222 | 223 | self.intra_transformer = SelfTransformer(num_inp, 3, heads,\ 224 | dim_head, num_inp) 225 | 226 | self.inter_transformer = InterTransformer(num_inp, 1, heads,\ 227 | dim_head, num_inp) 228 | 229 | self.mlp = nn.Linear(num_inp, num_inp) 230 | 231 | channels = int(num_inp / 2) 232 | self.channels = channels 233 | df_pe = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 234 | self.register_buffer("df_pe", df_pe) 235 | bf_pe = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 236 | self.register_buffer("bf_pe", bf_pe) 237 | 238 | def get_depth_first_PE(self, df): 239 | sin_inp_x = torch.einsum("i,j->ij", df, self.df_pe) 240 | emb = get_emb(sin_inp_x) 241 | emb_df = torch.zeros((df.shape[0], self.channels), device=df.device).type(df.type()) 242 | emb_df[:, :self.channels] = emb 243 | return emb_df 244 | 245 | def get_breadth_first_PE(self, bf): 246 | sin_inp_x = torch.einsum("i,j->ij", bf, self.bf_pe) 247 | emb = get_emb(sin_inp_x) 248 | emb_bf = torch.zeros((bf.shape[0], self.channels), device=bf.device).type(bf.type()) 249 | emb_bf[:, :self.channels] = emb 250 | return emb_bf 251 | 252 | def forward(self, x_input, y_input): 253 | # import pdb; pdb.set_trace() 254 | 255 | df_x, bf_x = x_input[2], x_input[3] 256 | df_y, bf_y = y_input[2], y_input[3] 257 | 258 | PE_df_x = self.get_depth_first_PE(df_x) 259 | PE_df_y = self.get_depth_first_PE(df_y) 260 | PE_bf_x = self.get_breadth_first_PE(bf_x) 261 | PE_bf_y = self.get_breadth_first_PE(bf_y) 262 | 263 | PE_x = torch.cat([PE_df_x, PE_bf_x], -1) 264 | PE_y = torch.cat([PE_df_y, PE_bf_y], -1) 265 | 266 | x_embed, x_att = self.embedding(x_input, PE_x) 267 | y_embed, y_att = self.embedding(y_input, PE_y) 268 | 269 | # import pdb; pdb.set_trace() 270 | 271 | x_trans_out = self.intra_transformer(x_embed) 272 | y_trans_out = self.intra_transformer(y_embed) 273 | x_out, y_out = self.inter_transformer(x_trans_out, y_trans_out) 274 | 275 | x_out = self.mlp(x_out[:, 0]) 276 | y_out = self.mlp(y_out[:, 0]) 277 | 278 | return x_out, y_out, x_att, y_att 279 | 280 | 281 | --------------------------------------------------------------------------------