├── .gitignore
├── LICENSE
├── README.md
├── _data
├── dummy
└── nyt.json
├── _snapshot
├── dummy
├── nyt_1p.pt
└── nyt_2p.pt
├── dataset.py
├── imgs
├── experiment.png
├── overview.png
├── poster.png
└── result.png
├── lib.py
├── main.py
└── model.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | __pycache__
3 | *.pyc
4 | .ipynb_checkpoints
5 |
6 | _snapshot
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Tsu-Jui Fu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## [**2021/12/20 Update**] Reimplementation by Author
2 | Thanks to everyone's interest in this project and sorry for missing the original preprocessed data.
3 | It got lost in my previous lab, and I finally had time to reimplement it 😂.
4 | I also want to appreciate @LuoXukun for his [nice reply about reproducing](https://github.com/tsujuifu/pytorch_graph-rel/issues/6#issuecomment-615836064).
5 |
6 | | [NYT](https://github.com/tsujuifu/pytorch_graph-rel/blob/master/_data/nyt.json) | Precision | Recall | F1 |
7 | | :-: | :-: | :-: | :-: |
8 | | GraphRel1p ([Paper](https://tsujuifu.github.io/pubs/acl19_graph-rel.pdf)) | 62.9 | 57.3 | 60.0 |
9 | | GraphRel1p ([Reimplementation](https://github.com/tsujuifu/pytorch_graph-rel/blob/master/_snapshot/nyt_1p.pt)) | 60.9 | 59.2 | 60.1 |
10 | | GraphRel2p ([Paper](https://tsujuifu.github.io/pubs/acl19_graph-rel.pdf)) | 63.9 | 60.0 | 61.9 |
11 | | GraphRel2p ([Reimplementation](https://github.com/tsujuifu/pytorch_graph-rel/blob/master/_snapshot/nyt_2p.pt)) | 63.1 | 60.2 | 61.6 |
12 |
13 | # [ACL'19 (Long)] GraphRel: Modeling Text as Relational Graphs for Joint Entity and Relation Extraction
14 | A **PyTorch** implementation of GraphRel
15 |
16 | [Paper](https://tsujuifu.github.io/pubs/acl19_graph-rel.pdf) | [Slide](https://tsujuifu.github.io/slides/acl19_graph-rel.pdf) | [Poster](https://github.com/tsujuifu/pytorch_graph-rel/raw/master/imgs/poster.png)
17 |
18 |
19 |
20 | ## Overview
21 | GraphRel is an implementation of
22 | "[GraphRel: Modeling Text as Relational Graphs for Joint Entity and Relation Extraction](https://tsujuifu.github.io/pubs/acl19_graph-rel.pdf)"
23 | [Tsu-Jui Fu](https://scholar.google.com/citations?user=mwFy9kkAAAAJ), [Peng-Hsuan Li](https://scholar.google.com/citations?user=sqYoxbsAAAAJ), and [Wei-Yun Ma](https://scholar.google.com/citations?user=AHG3DncAAAAJ)
24 | in Annual Meeting of the Association for Computational Linguistics (**ACL**) 2019 (Long)
25 |
26 |
27 |
28 | In the 1st-phase, we **adopt bi-RNN and GCN to extract both sequential and regional dependency** word features. Given the word features, we **predict relations for each word pair** and the entities for all words. Then, in 2nd-phase, based on the predicted 1st-phase relations, we build complete relational graphs for each relation, to which we **apply GCN on each graph to integrate each relation’s information** and further consider the interaction between entities and relations.
29 |
30 | ## Requirements
31 | This code is implemented under **Python3.8** and [PyTorch 1.7](https://pypi.org/project/torch/1.7.0).
32 | + [tqdm](https://pypi.org/project/tqdm), [spaCy](https://spacy.io)
33 |
34 | ## Usage
35 | ```
36 | python -m spacy download en_core_web_lg
37 | python main.py --arch=2p
38 | ```
39 | We also provide the [trained checkpoints](https://github.com/tsujuifu/pytorch_graph-rel/tree/master/_snapshot).
40 |
41 | ## Citation
42 | ```
43 | @inproceedings{fu2019graph-rel,
44 | author = {Tsu-Jui Fu and Peng-Hsuan Li and Wei-Yun Ma},
45 | title = {{GraphRel: Modeling Text as Relational Graphs for Joint Entity and Relation Extractionn}},
46 | booktitle = {Annual Meeting of the Association for Computational Linguistics (ACL)},
47 | year = {2019}
48 | }
49 | ```
50 |
51 | ## Acknowledgement
52 | + [copy_re](https://github.com/xiangrongzeng/copy_re)
53 |
--------------------------------------------------------------------------------
/_data/dummy:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/_snapshot/dummy:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/_snapshot/nyt_1p.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/_snapshot/nyt_1p.pt
--------------------------------------------------------------------------------
/_snapshot/nyt_2p.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/_snapshot/nyt_2p.pt
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from lib import *
3 |
4 | class DS(T.utils.data.Dataset):
5 | def __init__(self, NLP, path, typ, max_len):
6 | super().__init__()
7 |
8 | self.NLP = NLP
9 | self.dat = json.load(open('./_data/%s.json'%(path), 'r'))[typ]
10 | self.max_len = max_len
11 |
12 | self.POS = {}
13 | for p in self.NLP.pipe_labels['tagger']:
14 | self.POS[p] = len(self.POS)+1
15 |
16 | def __len__(self):
17 | return len(self.dat)
18 |
19 | def __getitem__(self, idx):
20 | item = self.dat[idx]
21 | sent, label = item['sentence'], item['label']
22 |
23 | s = ' '.join(sent)
24 | inp_sent, inp_pos = np.zeros((self.max_len, 300), dtype=np.float32), np.zeros((self.max_len, ), dtype=np.int64)
25 | dep_fw, dep_bw = np.zeros((self.max_len, self.max_len), dtype=np.float32), np.zeros((self.max_len, self.max_len), dtype=np.float32)
26 | ans_ne, ans_rel = np.ones((self.max_len, ), dtype=np.int64)*-1, np.ones((self.max_len, self.max_len), dtype=np.int64)*-1
27 |
28 | res = self.NLP(s)
29 | for i in range(len(res)):
30 | ans_ne[i] = 0
31 | for j in range(len(res)):
32 | ans_rel[i][j] = 0
33 |
34 | for i, w in enumerate(res):
35 | inp_sent[i], inp_pos[i] = w.vector, self.POS[w.tag_]
36 |
37 | dep_fw[i][i], dep_bw[i][i] = 1, 1
38 | for c in res[i].children:
39 | for j, t in enumerate(res):
40 | if c==t:
41 | dep_fw[i][j], dep_bw[j][i] = 1, 1
42 | L = len(res)
43 | dep_fw[:L], dep_bw[:L] = [dep_fw[:L]/dep_fw[:L].sum(axis=1, keepdims=True),
44 | dep_bw[:L]/dep_bw[:L].sum(axis=1, keepdims=True)]
45 |
46 | for ne1, ne2, rel in label:
47 | def set_ne(ne):
48 | b, e = ne
49 | if b==e:
50 | ans_ne[b] = 4 # 'S'
51 | else:
52 | ans_ne[b], ans_ne[e] = 1, 3 # 'B', 'E'
53 | ans_ne[b+1:e] = 2 # 'I'
54 |
55 | set_ne(ne1), set_ne(ne2)
56 | ans_rel[ne1[0]:ne1[1]+1, ne2[0]:ne2[1]+1] = rel
57 |
58 | return s, inp_sent, inp_pos, dep_fw, dep_bw, ans_ne, ans_rel
59 |
60 | if __name__=='__main__':
61 | NLP = spacy.load('en_core_web_lg')
62 | ds_tr, ds_vl, ds_ts = [DS(NLP, 'nyt', typ, 120) for typ in ['train', 'val', 'test']]
63 |
64 | dl = T.utils.data.DataLoader(ds_tr, batch_size=64, shuffle=True, num_workers=32)
65 | for s, inp_sent, inp_pos, dep_fw, dep_bw, ans_ne, ans_rel in tqdm(dl, ascii=True):
66 | print(len(s), inp_sent.shape, inp_pos.shape, dep_fw.shape, dep_bw.shape, ans_ne.shape, ans_rel.shape)
67 |
68 |
--------------------------------------------------------------------------------
/imgs/experiment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/imgs/experiment.png
--------------------------------------------------------------------------------
/imgs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/imgs/overview.png
--------------------------------------------------------------------------------
/imgs/poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/imgs/poster.png
--------------------------------------------------------------------------------
/imgs/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsujuifu/pytorch_graph-rel/cec8a712fb3db1b609eb107a207b33baaaa9c019/imgs/result.png
--------------------------------------------------------------------------------
/lib.py:
--------------------------------------------------------------------------------
1 |
2 | import os, argparse, json, math
3 |
4 | from datetime import datetime
5 | from tqdm import tqdm
6 |
7 | import numpy as np
8 | import torch as T
9 |
10 | import spacy
11 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 |
2 | from lib import *
3 | from dataset import *
4 | from model import *
5 |
6 | def get_args():
7 | parser = argparse.ArgumentParser()
8 |
9 | parser.add_argument("--path", default='nyt', type=str)
10 |
11 | parser.add_argument("--max_len", default=120, type=int)
12 | parser.add_argument("--num_ne", default=5, type=int)
13 | parser.add_argument("--num_rel", default=25, type=int)
14 |
15 | parser.add_argument("--size_hid", default=256, type=int)
16 | parser.add_argument("--layer_rnn", default=2, type=int)
17 | parser.add_argument("--layer_gcn", default=2, type=int)
18 | parser.add_argument("--dropout", default=0.5, type=float)
19 | parser.add_argument("--arch", default='2p', type=str)
20 |
21 | parser.add_argument("--size_epoch", default=40, type=int)
22 | parser.add_argument("--size_batch", default=64, type=int)
23 | parser.add_argument("--lr", default=8e-4, type=float)
24 | parser.add_argument("--lr_decay", default=0.9, type=float)
25 | parser.add_argument("--weight_loss", default=2.0, type=float)
26 | parser.add_argument("--weight_alpha", default=3.0, type=float)
27 |
28 | args = parser.parse_args()
29 | args.path_output = '_snapshot/_%s_%s_%s'%(args.path, args.arch, datetime.now().strftime('%Y%m%d%H%M%S'))
30 |
31 | return args
32 |
33 | def train_dl(args, model, dl, optzr):
34 | def get_loss(weight_loss, out, ans):
35 | out, ans = out.flatten(0, len(out.shape)-2), ans.flatten(0, len(ans.shape)-1).cuda()
36 | ls = T.nn.functional.cross_entropy(out, ans, ignore_index=-1, reduction='none')
37 | weight = 1.0-(ans==-1).float()
38 | weight.masked_fill_(ans>0, weight_loss)
39 | ls = (ls*weight).sum() / (weight>0).sum()
40 | return ls
41 |
42 | ret = {'ls_ne': [], 'ls_rel': []}
43 | for s, inp_sent, inp_pos, dep_fw, dep_bw, ans_ne, ans_rel in tqdm(dl, ascii=True):
44 | if args.arch=='1p':
45 | out_ne, out_rel = model(inp_sent.cuda(), inp_pos.cuda(), dep_fw.cuda(), dep_bw.cuda())
46 | ls_ne, ls_rel = get_loss(args.weight_loss, out_ne, ans_ne), get_loss(args.weight_loss, out_rel, ans_rel)
47 | ls = ls_ne + args.weight_alpha*ls_rel
48 |
49 | elif args.arch=='2p':
50 | out_ne1p, out_rel1p, out_ne2p, out_rel2p = model(inp_sent.cuda(), inp_pos.cuda(), dep_fw.cuda(), dep_bw.cuda())
51 | ls_ne1p, ls_rel1p = get_loss(args.weight_loss, out_ne1p, ans_ne), get_loss(args.weight_loss, out_rel1p, ans_rel)
52 | ls_ne2p, ls_rel2p = get_loss(args.weight_loss, out_ne2p, ans_ne), get_loss(args.weight_loss, out_rel2p, ans_rel)
53 | ls_ne, ls_rel = ls_ne2p, ls_rel2p
54 | ls = (ls_ne1p+ls_ne2p) + args.weight_alpha*(ls_rel1p+ls_rel2p)
55 |
56 | optzr.zero_grad()
57 | ls.backward()
58 | optzr.step()
59 | ret['ls_ne'].append(ls_ne.item()), ret['ls_rel'].append(ls_rel.item())
60 | ret = {k: float(np.average(l)) for k, l in ret.items()}
61 |
62 | return ret
63 |
64 | def eval_dl(model, dl):
65 | ret = {'precision': [0, 0], 'recall': [0, 0], 'f1': 0}
66 |
67 | I = 0
68 | for s, inp_sent, inp_pos, dep_fw, dep_bw, ans_ne, ans_rel in tqdm(dl, ascii=True):
69 | if args.arch=='1p':
70 | out_ne, out_rel = model(inp_sent.cuda(), inp_pos.cuda(), dep_fw.cuda(), dep_bw.cuda())
71 | elif args.arch=='2p':
72 | _, _, out_ne, out_rel = model(inp_sent.cuda(), inp_pos.cuda(), dep_fw.cuda(), dep_bw.cuda())
73 |
74 | out_ne, out_rel = [T.argmax(out, dim=-1).data.cpu().numpy() for out in [out_ne, out_rel]]
75 | for o_ne, o_rel in zip(out_ne, out_rel):
76 | l = len(dl.dataset.dat[I]['sentence'])+1
77 |
78 | ne, pos = {}, -1
79 | for i in range(l):
80 | v = o_ne[i]
81 | if v==4:
82 | ne[i] = [i, i]
83 | pos = -1
84 | elif v==1:
85 | pos = i
86 | elif v==2:
87 | pass
88 | elif v==3:
89 | if pos!=-1:
90 | for p in range(pos, i+1):
91 | ne[p] = [pos, i]
92 | elif v==0:
93 | pos = -1
94 |
95 | pd = set()
96 | for i in range(l):
97 | for j in range(l):
98 | if o_rel[i][j]!=0 and i in ne and j in ne:
99 | pd.add((ne[i][1], ne[j][1], o_rel[i][j]))
100 |
101 | gt = set()
102 | for ne1, ne2, rel in dl.dataset.dat[I]['label']:
103 | gt.add((ne1[1], ne2[1], rel))
104 |
105 | ret['precision'][0] += len(pd.intersection(gt))
106 | ret['precision'][1] += len(pd)
107 | ret['recall'][0] += len(pd.intersection(gt))
108 | ret['recall'][1] += len(gt)
109 |
110 | I += 1
111 |
112 | ret['precision'] = ret['precision'][0]/ret['precision'][1] if ret['precision'][1]>0 else 0
113 | ret['recall'] = ret['recall'][0]/ret['recall'][1] if ret['recall'][1]>0 else 0
114 | ret['f1'] = 2*ret['precision']*ret['recall']/(ret['precision']+ret['recall']) if (ret['precision']+ret['recall'])>0 else 0
115 |
116 | return ret
117 |
118 | if __name__=='__main__':
119 | args = get_args()
120 | os.makedirs(args.path_output, exist_ok=True)
121 | json.dump(vars(args), open('%s/args.json'%(args.path_output), 'w'), indent=2)
122 | print(args)
123 |
124 | NLP = spacy.load('en_core_web_lg')
125 | ds_tr, ds_vl, ds_ts = [DS(NLP, args.path, typ, args.max_len) for typ in ['train', 'val', 'test']]
126 | dl_tr, dl_vl, dl_ts = [T.utils.data.DataLoader(ds, batch_size=args.size_batch,
127 | shuffle=(ds is ds_tr), num_workers=32, pin_memory=True) \
128 | for ds in [ds_tr, ds_vl, ds_ts]]
129 |
130 | log = {'ls_tr': [], 'f1_vl': [], 'f1_ts': []}
131 | json.dump(log, open('%s/log.json'%(args.path_output), 'w'), indent=2)
132 |
133 | model = GraphRel(len(ds_tr.POS)+1, args.num_ne, args.num_rel,
134 | args.size_hid, args.layer_rnn, args.layer_gcn, args.dropout,
135 | args.arch).cuda()
136 | T.save(model.state_dict(), '%s/model_0.pt'%(args.path_output))
137 |
138 | optzr = T.optim.AdamW(model.parameters(), lr=args.lr)
139 | for e in tqdm(range(args.size_epoch), ascii=True):
140 | model.train()
141 | ls_tr = train_dl(args, model, dl_tr, optzr)
142 |
143 | model.eval()
144 | f1_vl = eval_dl(model, dl_vl)
145 | f1_ts = eval_dl(model, dl_ts)
146 |
147 | log['ls_tr'].append(ls_tr), log['f1_vl'].append(f1_vl), log['f1_ts'].append(f1_ts)
148 | json.dump(log, open('%s/log.json'%(args.path_output), 'w'), indent=2)
149 | T.save(model.state_dict(), '%s/model_%d.pt'%(args.path_output, e+1))
150 | print('Ep %d:'%(e+1), ls_tr, f1_vl, f1_ts)
151 |
152 | for pg in optzr.param_groups:
153 | pg['lr'] *= args.lr_decay
154 |
155 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 |
2 | from lib import *
3 | from dataset import *
4 |
5 | class GCN(T.nn.Module):
6 | def __init__(self, size_hid):
7 | super().__init__()
8 |
9 | self.size_hid = size_hid
10 |
11 | self.W = T.nn.Parameter(T.FloatTensor(self.size_hid, self.size_hid//2))
12 | self.b = T.nn.Parameter(T.FloatTensor(self.size_hid//2, ))
13 |
14 | stdv = 1.0/math.sqrt(self.size_hid//2)
15 | self.W.data.uniform_(-stdv, stdv)
16 | self.b.data.uniform_(-stdv, stdv)
17 |
18 | def forward(self, inp, adj):
19 | out = T.matmul(inp, self.W) + self.b
20 | out = T.matmul(adj, out)
21 | out = T.nn.functional.relu(out)
22 |
23 | return out
24 |
25 | def __repr__(self):
26 | return self.__class__.__name__+'(size_hid=%d)'%(self.size_hid)
27 |
28 | class GraphRel(T.nn.Module):
29 | def __init__(self, num_pos, num_ne, num_rel,
30 | size_hid, layer_rnn, layer_gcn, dp,
31 | arch='2p'):
32 | super().__init__()
33 |
34 | self.arch = arch
35 |
36 | self.emb_pos = T.nn.Embedding(num_pos, 15)
37 |
38 | self.rnn = T.nn.GRU(300+15, size_hid, num_layers=layer_rnn, dropout=dp,
39 | batch_first=True, bidirectional=True)
40 | self.gcn_fw, self.gcn_bw = [T.nn.ModuleList([GCN(size_hid*2) for _ in range(layer_gcn)]),
41 | T.nn.ModuleList([GCN(size_hid*2) for _ in range(layer_gcn)])]
42 |
43 | self.rnn_ne = T.nn.GRU(size_hid*2, size_hid, batch_first=True)
44 | self.fc_ne = T.nn.Linear(size_hid, num_ne)
45 |
46 | self.fc_rf, self.fc_rb = [T.nn.Sequential(*[T.nn.Linear(size_hid*2, size_hid), T.nn.ReLU()]),
47 | T.nn.Sequential(*[T.nn.Linear(size_hid*2, size_hid), T.nn.ReLU()])]
48 | self.fc_rel = T.nn.Linear(size_hid*2, num_rel)
49 |
50 | if self.arch=='2p':
51 | self.gcn2p_fw, self.gcn2p_bw = [T.nn.ModuleList([GCN(size_hid*2) for _ in range(num_rel)]),
52 | T.nn.ModuleList([GCN(size_hid*2) for _ in range(num_rel)])]
53 |
54 | self.dp = T.nn.Dropout(dp)
55 |
56 | def head(self, feat):
57 | feat_ne, _ = self.rnn_ne(feat)
58 | out_ne = self.fc_ne(feat_ne)
59 |
60 | rf, rb = self.fc_rf(feat), self.fc_rb(feat)
61 | rf, rb = [rf.unsqueeze(2).expand([-1, -1, rf.shape[1], -1]),
62 | rb.unsqueeze(1).expand([-1, rb.shape[1], -1, -1])]
63 | out_rel = self.fc_rel(T.cat([rf, rb], dim=3))
64 |
65 | return out_ne, out_rel
66 |
67 | def forward(self, inp_sent, inp_pos, dep_fw, dep_bw):
68 | inp = T.cat([inp_sent, self.emb_pos(inp_pos)], dim=2)
69 | inp = self.dp(inp)
70 |
71 | feat, _ = self.rnn(inp)
72 | for gf, gb in zip(self.gcn_fw, self.gcn_bw):
73 | of, ob = gf(feat, dep_fw), gb(feat, dep_bw)
74 | feat = self.dp(T.cat([of, ob], dim=2))
75 |
76 | out_ne, out_rel = self.head(feat)
77 |
78 | if self.arch=='1p':
79 | return out_ne, out_rel
80 |
81 | # 2p
82 | feat1p, out_ne1p, out_rel1p = feat, out_ne, out_rel
83 |
84 | dep_fw = T.nn.functional.softmax(out_rel1p, dim=3)
85 | dep_bw = dep_fw.transpose(1, 2)
86 |
87 | feat2p = feat1p.clone()
88 | for i, (gf, gb) in enumerate(zip(self.gcn2p_fw, self.gcn2p_bw)):
89 | of, ob = gf(feat1p, dep_fw[:, :, :, i]), gb(feat1p, dep_bw[:, :, :, i])
90 | feat2p += self.dp(T.cat([of, ob], dim=2))
91 |
92 | out_ne2p, out_rel2p = self.head(feat2p)
93 |
94 | return out_ne1p, out_rel1p, out_ne2p, out_rel2p
95 |
96 | if __name__=='__main__':
97 | NLP = spacy.load('en_core_web_lg')
98 | ds_tr, ds_vl, ds_ts = [DS(NLP, 'nyt', typ, 120) for typ in ['train', 'val', 'test']]
99 | dl = T.utils.data.DataLoader(ds_tr, batch_size=64,
100 | shuffle=True, num_workers=32, pin_memory=True)
101 |
102 | model = GraphRel(len(ds_tr.POS)+1, 5, 25,
103 | 256, 2, 2, 0.5,
104 | '2p').cuda()
105 |
106 | for s, inp_sent, inp_pos, dep_fw, dep_bw, ans_ne, ans_rel in tqdm(dl, ascii=True):
107 | out = model(inp_sent.cuda(), inp_pos.cuda(), dep_fw.cuda(), dep_bw.cuda())
108 | print([o.shape for o in out])
109 |
--------------------------------------------------------------------------------