├── .gitignore ├── LICENSE ├── README.md ├── _data ├── dummy └── nyt.json ├── _snapshot ├── dummy ├── nyt_1p.pt └── nyt_2p.pt ├── dataset.py ├── imgs ├── experiment.png ├── overview.png ├── poster.png └── result.png ├── lib.py ├── main.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | __pycache__ 3 | *.pyc 4 | .ipynb_checkpoints 5 | 6 | _snapshot 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tsu-Jui Fu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [**2021/12/20 Update**] Reimplementation by Author 2 | Thanks to everyone's interest in this project and sorry for missing the original preprocessed data.
3 | It got lost in my previous lab, and I finally had time to reimplement it 😂.
4 | I also want to appreciate @LuoXukun for his [nice reply about reproducing](https://github.com/tsujuifu/pytorch_graph-rel/issues/6#issuecomment-615836064). 5 | 6 | | [NYT](https://github.com/tsujuifu/pytorch_graph-rel/blob/master/_data/nyt.json) | Precision | Recall | F1 | 7 | | :-: | :-: | :-: | :-: | 8 | | GraphRel1p ([Paper](https://tsujuifu.github.io/pubs/acl19_graph-rel.pdf)) | 62.9 | 57.3 | 60.0 | 9 | | GraphRel1p ([Reimplementation](https://github.com/tsujuifu/pytorch_graph-rel/blob/master/_snapshot/nyt_1p.pt)) | 60.9 | 59.2 | 60.1 | 10 | | GraphRel2p ([Paper](https://tsujuifu.github.io/pubs/acl19_graph-rel.pdf)) | 63.9 | 60.0 | 61.9 | 11 | | GraphRel2p ([Reimplementation](https://github.com/tsujuifu/pytorch_graph-rel/blob/master/_snapshot/nyt_2p.pt)) | 63.1 | 60.2 | 61.6 | 12 | 13 | # [ACL'19 (Long)] GraphRel: Modeling Text as Relational Graphs for Joint Entity and Relation Extraction 14 | A **PyTorch** implementation of GraphRel 15 | 16 | [Paper](https://tsujuifu.github.io/pubs/acl19_graph-rel.pdf) | [Slide](https://tsujuifu.github.io/slides/acl19_graph-rel.pdf) | [Poster](https://github.com/tsujuifu/pytorch_graph-rel/raw/master/imgs/poster.png) 17 | 18 | 19 | 20 | ## Overview 21 | GraphRel is an implementation of
22 | "[GraphRel: Modeling Text as Relational Graphs for Joint Entity and Relation Extraction](https://tsujuifu.github.io/pubs/acl19_graph-rel.pdf)"
23 | [Tsu-Jui Fu](https://scholar.google.com/citations?user=mwFy9kkAAAAJ), [Peng-Hsuan Li](https://scholar.google.com/citations?user=sqYoxbsAAAAJ), and [Wei-Yun Ma](https://scholar.google.com/citations?user=AHG3DncAAAAJ)
24 | in Annual Meeting of the Association for Computational Linguistics (**ACL**) 2019 (Long) 25 | 26 | 27 | 28 | In the 1st-phase, we **adopt bi-RNN and GCN to extract both sequential and regional dependency** word features. Given the word features, we **predict relations for each word pair** and the entities for all words. Then, in 2nd-phase, based on the predicted 1st-phase relations, we build complete relational graphs for each relation, to which we **apply GCN on each graph to integrate each relation’s information** and further consider the interaction between entities and relations. 29 | 30 | ## Requirements 31 | This code is implemented under **Python3.8** and [PyTorch 1.7](https://pypi.org/project/torch/1.7.0).
32 | + [tqdm](https://pypi.org/project/tqdm), [spaCy](https://spacy.io) 33 | 34 | ## Usage 35 | ``` 36 | python -m spacy download en_core_web_lg 37 | python main.py --arch=2p 38 | ``` 39 | We also provide the [trained checkpoints](https://github.com/tsujuifu/pytorch_graph-rel/tree/master/_snapshot). 40 | 41 | ## Citation 42 | ``` 43 | @inproceedings{fu2019graph-rel, 44 | author = {Tsu-Jui Fu and Peng-Hsuan Li and Wei-Yun Ma}, 45 |   title = {{GraphRel: Modeling Text as Relational Graphs for Joint Entity and Relation Extractionn}}, 46 |   booktitle = {Annual Meeting of the Association for Computational Linguistics (ACL)}, 47 |   year = {2019} 48 | } 49 | ``` 50 | 51 | ## Acknowledgement 52 | + [copy_re](https://github.com/xiangrongzeng/copy_re) 53 | -------------------------------------------------------------------------------- /_data/dummy: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /_snapshot/dummy: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /_snapshot/nyt_1p.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/_snapshot/nyt_1p.pt -------------------------------------------------------------------------------- /_snapshot/nyt_2p.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/_snapshot/nyt_2p.pt -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from lib import * 3 | 4 | class DS(T.utils.data.Dataset): 5 | def __init__(self, NLP, path, typ, max_len): 6 | super().__init__() 7 | 8 | self.NLP = NLP 9 | self.dat = json.load(open('./_data/%s.json'%(path), 'r'))[typ] 10 | self.max_len = max_len 11 | 12 | self.POS = {} 13 | for p in self.NLP.pipe_labels['tagger']: 14 | self.POS[p] = len(self.POS)+1 15 | 16 | def __len__(self): 17 | return len(self.dat) 18 | 19 | def __getitem__(self, idx): 20 | item = self.dat[idx] 21 | sent, label = item['sentence'], item['label'] 22 | 23 | s = ' '.join(sent) 24 | inp_sent, inp_pos = np.zeros((self.max_len, 300), dtype=np.float32), np.zeros((self.max_len, ), dtype=np.int64) 25 | dep_fw, dep_bw = np.zeros((self.max_len, self.max_len), dtype=np.float32), np.zeros((self.max_len, self.max_len), dtype=np.float32) 26 | ans_ne, ans_rel = np.ones((self.max_len, ), dtype=np.int64)*-1, np.ones((self.max_len, self.max_len), dtype=np.int64)*-1 27 | 28 | res = self.NLP(s) 29 | for i in range(len(res)): 30 | ans_ne[i] = 0 31 | for j in range(len(res)): 32 | ans_rel[i][j] = 0 33 | 34 | for i, w in enumerate(res): 35 | inp_sent[i], inp_pos[i] = w.vector, self.POS[w.tag_] 36 | 37 | dep_fw[i][i], dep_bw[i][i] = 1, 1 38 | for c in res[i].children: 39 | for j, t in enumerate(res): 40 | if c==t: 41 | dep_fw[i][j], dep_bw[j][i] = 1, 1 42 | L = len(res) 43 | dep_fw[:L], dep_bw[:L] = [dep_fw[:L]/dep_fw[:L].sum(axis=1, keepdims=True), 44 | dep_bw[:L]/dep_bw[:L].sum(axis=1, keepdims=True)] 45 | 46 | for ne1, ne2, rel in label: 47 | def set_ne(ne): 48 | b, e = ne 49 | if b==e: 50 | ans_ne[b] = 4 # 'S' 51 | else: 52 | ans_ne[b], ans_ne[e] = 1, 3 # 'B', 'E' 53 | ans_ne[b+1:e] = 2 # 'I' 54 | 55 | set_ne(ne1), set_ne(ne2) 56 | ans_rel[ne1[0]:ne1[1]+1, ne2[0]:ne2[1]+1] = rel 57 | 58 | return s, inp_sent, inp_pos, dep_fw, dep_bw, ans_ne, ans_rel 59 | 60 | if __name__=='__main__': 61 | NLP = spacy.load('en_core_web_lg') 62 | ds_tr, ds_vl, ds_ts = [DS(NLP, 'nyt', typ, 120) for typ in ['train', 'val', 'test']] 63 | 64 | dl = T.utils.data.DataLoader(ds_tr, batch_size=64, shuffle=True, num_workers=32) 65 | for s, inp_sent, inp_pos, dep_fw, dep_bw, ans_ne, ans_rel in tqdm(dl, ascii=True): 66 | print(len(s), inp_sent.shape, inp_pos.shape, dep_fw.shape, dep_bw.shape, ans_ne.shape, ans_rel.shape) 67 | 68 | -------------------------------------------------------------------------------- /imgs/experiment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/imgs/experiment.png -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/imgs/overview.png -------------------------------------------------------------------------------- /imgs/poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/imgs/poster.png -------------------------------------------------------------------------------- /imgs/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/imgs/result.png -------------------------------------------------------------------------------- /lib.py: -------------------------------------------------------------------------------- 1 | 2 | import os, argparse, json, math 3 | 4 | from datetime import datetime 5 | from tqdm import tqdm 6 | 7 | import numpy as np 8 | import torch as T 9 | 10 | import spacy 11 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | from lib import * 3 | from dataset import * 4 | from model import * 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument("--path", default='nyt', type=str) 10 | 11 | parser.add_argument("--max_len", default=120, type=int) 12 | parser.add_argument("--num_ne", default=5, type=int) 13 | parser.add_argument("--num_rel", default=25, type=int) 14 | 15 | parser.add_argument("--size_hid", default=256, type=int) 16 | parser.add_argument("--layer_rnn", default=2, type=int) 17 | parser.add_argument("--layer_gcn", default=2, type=int) 18 | parser.add_argument("--dropout", default=0.5, type=float) 19 | parser.add_argument("--arch", default='2p', type=str) 20 | 21 | parser.add_argument("--size_epoch", default=40, type=int) 22 | parser.add_argument("--size_batch", default=64, type=int) 23 | parser.add_argument("--lr", default=8e-4, type=float) 24 | parser.add_argument("--lr_decay", default=0.9, type=float) 25 | parser.add_argument("--weight_loss", default=2.0, type=float) 26 | parser.add_argument("--weight_alpha", default=3.0, type=float) 27 | 28 | args = parser.parse_args() 29 | args.path_output = '_snapshot/_%s_%s_%s'%(args.path, args.arch, datetime.now().strftime('%Y%m%d%H%M%S')) 30 | 31 | return args 32 | 33 | def train_dl(args, model, dl, optzr): 34 | def get_loss(weight_loss, out, ans): 35 | out, ans = out.flatten(0, len(out.shape)-2), ans.flatten(0, len(ans.shape)-1).cuda() 36 | ls = T.nn.functional.cross_entropy(out, ans, ignore_index=-1, reduction='none') 37 | weight = 1.0-(ans==-1).float() 38 | weight.masked_fill_(ans>0, weight_loss) 39 | ls = (ls*weight).sum() / (weight>0).sum() 40 | return ls 41 | 42 | ret = {'ls_ne': [], 'ls_rel': []} 43 | for s, inp_sent, inp_pos, dep_fw, dep_bw, ans_ne, ans_rel in tqdm(dl, ascii=True): 44 | if args.arch=='1p': 45 | out_ne, out_rel = model(inp_sent.cuda(), inp_pos.cuda(), dep_fw.cuda(), dep_bw.cuda()) 46 | ls_ne, ls_rel = get_loss(args.weight_loss, out_ne, ans_ne), get_loss(args.weight_loss, out_rel, ans_rel) 47 | ls = ls_ne + args.weight_alpha*ls_rel 48 | 49 | elif args.arch=='2p': 50 | out_ne1p, out_rel1p, out_ne2p, out_rel2p = model(inp_sent.cuda(), inp_pos.cuda(), dep_fw.cuda(), dep_bw.cuda()) 51 | ls_ne1p, ls_rel1p = get_loss(args.weight_loss, out_ne1p, ans_ne), get_loss(args.weight_loss, out_rel1p, ans_rel) 52 | ls_ne2p, ls_rel2p = get_loss(args.weight_loss, out_ne2p, ans_ne), get_loss(args.weight_loss, out_rel2p, ans_rel) 53 | ls_ne, ls_rel = ls_ne2p, ls_rel2p 54 | ls = (ls_ne1p+ls_ne2p) + args.weight_alpha*(ls_rel1p+ls_rel2p) 55 | 56 | optzr.zero_grad() 57 | ls.backward() 58 | optzr.step() 59 | ret['ls_ne'].append(ls_ne.item()), ret['ls_rel'].append(ls_rel.item()) 60 | ret = {k: float(np.average(l)) for k, l in ret.items()} 61 | 62 | return ret 63 | 64 | def eval_dl(model, dl): 65 | ret = {'precision': [0, 0], 'recall': [0, 0], 'f1': 0} 66 | 67 | I = 0 68 | for s, inp_sent, inp_pos, dep_fw, dep_bw, ans_ne, ans_rel in tqdm(dl, ascii=True): 69 | if args.arch=='1p': 70 | out_ne, out_rel = model(inp_sent.cuda(), inp_pos.cuda(), dep_fw.cuda(), dep_bw.cuda()) 71 | elif args.arch=='2p': 72 | _, _, out_ne, out_rel = model(inp_sent.cuda(), inp_pos.cuda(), dep_fw.cuda(), dep_bw.cuda()) 73 | 74 | out_ne, out_rel = [T.argmax(out, dim=-1).data.cpu().numpy() for out in [out_ne, out_rel]] 75 | for o_ne, o_rel in zip(out_ne, out_rel): 76 | l = len(dl.dataset.dat[I]['sentence'])+1 77 | 78 | ne, pos = {}, -1 79 | for i in range(l): 80 | v = o_ne[i] 81 | if v==4: 82 | ne[i] = [i, i] 83 | pos = -1 84 | elif v==1: 85 | pos = i 86 | elif v==2: 87 | pass 88 | elif v==3: 89 | if pos!=-1: 90 | for p in range(pos, i+1): 91 | ne[p] = [pos, i] 92 | elif v==0: 93 | pos = -1 94 | 95 | pd = set() 96 | for i in range(l): 97 | for j in range(l): 98 | if o_rel[i][j]!=0 and i in ne and j in ne: 99 | pd.add((ne[i][1], ne[j][1], o_rel[i][j])) 100 | 101 | gt = set() 102 | for ne1, ne2, rel in dl.dataset.dat[I]['label']: 103 | gt.add((ne1[1], ne2[1], rel)) 104 | 105 | ret['precision'][0] += len(pd.intersection(gt)) 106 | ret['precision'][1] += len(pd) 107 | ret['recall'][0] += len(pd.intersection(gt)) 108 | ret['recall'][1] += len(gt) 109 | 110 | I += 1 111 | 112 | ret['precision'] = ret['precision'][0]/ret['precision'][1] if ret['precision'][1]>0 else 0 113 | ret['recall'] = ret['recall'][0]/ret['recall'][1] if ret['recall'][1]>0 else 0 114 | ret['f1'] = 2*ret['precision']*ret['recall']/(ret['precision']+ret['recall']) if (ret['precision']+ret['recall'])>0 else 0 115 | 116 | return ret 117 | 118 | if __name__=='__main__': 119 | args = get_args() 120 | os.makedirs(args.path_output, exist_ok=True) 121 | json.dump(vars(args), open('%s/args.json'%(args.path_output), 'w'), indent=2) 122 | print(args) 123 | 124 | NLP = spacy.load('en_core_web_lg') 125 | ds_tr, ds_vl, ds_ts = [DS(NLP, args.path, typ, args.max_len) for typ in ['train', 'val', 'test']] 126 | dl_tr, dl_vl, dl_ts = [T.utils.data.DataLoader(ds, batch_size=args.size_batch, 127 | shuffle=(ds is ds_tr), num_workers=32, pin_memory=True) \ 128 | for ds in [ds_tr, ds_vl, ds_ts]] 129 | 130 | log = {'ls_tr': [], 'f1_vl': [], 'f1_ts': []} 131 | json.dump(log, open('%s/log.json'%(args.path_output), 'w'), indent=2) 132 | 133 | model = GraphRel(len(ds_tr.POS)+1, args.num_ne, args.num_rel, 134 | args.size_hid, args.layer_rnn, args.layer_gcn, args.dropout, 135 | args.arch).cuda() 136 | T.save(model.state_dict(), '%s/model_0.pt'%(args.path_output)) 137 | 138 | optzr = T.optim.AdamW(model.parameters(), lr=args.lr) 139 | for e in tqdm(range(args.size_epoch), ascii=True): 140 | model.train() 141 | ls_tr = train_dl(args, model, dl_tr, optzr) 142 | 143 | model.eval() 144 | f1_vl = eval_dl(model, dl_vl) 145 | f1_ts = eval_dl(model, dl_ts) 146 | 147 | log['ls_tr'].append(ls_tr), log['f1_vl'].append(f1_vl), log['f1_ts'].append(f1_ts) 148 | json.dump(log, open('%s/log.json'%(args.path_output), 'w'), indent=2) 149 | T.save(model.state_dict(), '%s/model_%d.pt'%(args.path_output, e+1)) 150 | print('Ep %d:'%(e+1), ls_tr, f1_vl, f1_ts) 151 | 152 | for pg in optzr.param_groups: 153 | pg['lr'] *= args.lr_decay 154 | 155 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | from lib import * 3 | from dataset import * 4 | 5 | class GCN(T.nn.Module): 6 | def __init__(self, size_hid): 7 | super().__init__() 8 | 9 | self.size_hid = size_hid 10 | 11 | self.W = T.nn.Parameter(T.FloatTensor(self.size_hid, self.size_hid//2)) 12 | self.b = T.nn.Parameter(T.FloatTensor(self.size_hid//2, )) 13 | 14 | stdv = 1.0/math.sqrt(self.size_hid//2) 15 | self.W.data.uniform_(-stdv, stdv) 16 | self.b.data.uniform_(-stdv, stdv) 17 | 18 | def forward(self, inp, adj): 19 | out = T.matmul(inp, self.W) + self.b 20 | out = T.matmul(adj, out) 21 | out = T.nn.functional.relu(out) 22 | 23 | return out 24 | 25 | def __repr__(self): 26 | return self.__class__.__name__+'(size_hid=%d)'%(self.size_hid) 27 | 28 | class GraphRel(T.nn.Module): 29 | def __init__(self, num_pos, num_ne, num_rel, 30 | size_hid, layer_rnn, layer_gcn, dp, 31 | arch='2p'): 32 | super().__init__() 33 | 34 | self.arch = arch 35 | 36 | self.emb_pos = T.nn.Embedding(num_pos, 15) 37 | 38 | self.rnn = T.nn.GRU(300+15, size_hid, num_layers=layer_rnn, dropout=dp, 39 | batch_first=True, bidirectional=True) 40 | self.gcn_fw, self.gcn_bw = [T.nn.ModuleList([GCN(size_hid*2) for _ in range(layer_gcn)]), 41 | T.nn.ModuleList([GCN(size_hid*2) for _ in range(layer_gcn)])] 42 | 43 | self.rnn_ne = T.nn.GRU(size_hid*2, size_hid, batch_first=True) 44 | self.fc_ne = T.nn.Linear(size_hid, num_ne) 45 | 46 | self.fc_rf, self.fc_rb = [T.nn.Sequential(*[T.nn.Linear(size_hid*2, size_hid), T.nn.ReLU()]), 47 | T.nn.Sequential(*[T.nn.Linear(size_hid*2, size_hid), T.nn.ReLU()])] 48 | self.fc_rel = T.nn.Linear(size_hid*2, num_rel) 49 | 50 | if self.arch=='2p': 51 | self.gcn2p_fw, self.gcn2p_bw = [T.nn.ModuleList([GCN(size_hid*2) for _ in range(num_rel)]), 52 | T.nn.ModuleList([GCN(size_hid*2) for _ in range(num_rel)])] 53 | 54 | self.dp = T.nn.Dropout(dp) 55 | 56 | def head(self, feat): 57 | feat_ne, _ = self.rnn_ne(feat) 58 | out_ne = self.fc_ne(feat_ne) 59 | 60 | rf, rb = self.fc_rf(feat), self.fc_rb(feat) 61 | rf, rb = [rf.unsqueeze(2).expand([-1, -1, rf.shape[1], -1]), 62 | rb.unsqueeze(1).expand([-1, rb.shape[1], -1, -1])] 63 | out_rel = self.fc_rel(T.cat([rf, rb], dim=3)) 64 | 65 | return out_ne, out_rel 66 | 67 | def forward(self, inp_sent, inp_pos, dep_fw, dep_bw): 68 | inp = T.cat([inp_sent, self.emb_pos(inp_pos)], dim=2) 69 | inp = self.dp(inp) 70 | 71 | feat, _ = self.rnn(inp) 72 | for gf, gb in zip(self.gcn_fw, self.gcn_bw): 73 | of, ob = gf(feat, dep_fw), gb(feat, dep_bw) 74 | feat = self.dp(T.cat([of, ob], dim=2)) 75 | 76 | out_ne, out_rel = self.head(feat) 77 | 78 | if self.arch=='1p': 79 | return out_ne, out_rel 80 | 81 | # 2p 82 | feat1p, out_ne1p, out_rel1p = feat, out_ne, out_rel 83 | 84 | dep_fw = T.nn.functional.softmax(out_rel1p, dim=3) 85 | dep_bw = dep_fw.transpose(1, 2) 86 | 87 | feat2p = feat1p.clone() 88 | for i, (gf, gb) in enumerate(zip(self.gcn2p_fw, self.gcn2p_bw)): 89 | of, ob = gf(feat1p, dep_fw[:, :, :, i]), gb(feat1p, dep_bw[:, :, :, i]) 90 | feat2p += self.dp(T.cat([of, ob], dim=2)) 91 | 92 | out_ne2p, out_rel2p = self.head(feat2p) 93 | 94 | return out_ne1p, out_rel1p, out_ne2p, out_rel2p 95 | 96 | if __name__=='__main__': 97 | NLP = spacy.load('en_core_web_lg') 98 | ds_tr, ds_vl, ds_ts = [DS(NLP, 'nyt', typ, 120) for typ in ['train', 'val', 'test']] 99 | dl = T.utils.data.DataLoader(ds_tr, batch_size=64, 100 | shuffle=True, num_workers=32, pin_memory=True) 101 | 102 | model = GraphRel(len(ds_tr.POS)+1, 5, 25, 103 | 256, 2, 2, 0.5, 104 | '2p').cuda() 105 | 106 | for s, inp_sent, inp_pos, dep_fw, dep_bw, ans_ne, ans_rel in tqdm(dl, ascii=True): 107 | out = model(inp_sent.cuda(), inp_pos.cuda(), dep_fw.cuda(), dep_bw.cuda()) 108 | print([o.shape for o in out]) 109 | --------------------------------------------------------------------------------