├── .gitignore
├── README.md
├── assets
└── group_qr_code_20240621.jpg
├── bert.py
├── chatglm
├── arguments.py
├── data
│ ├── ID_num_dic.json
│ ├── calc_map.py
│ ├── format_data.py
│ ├── positive_negetive_balance.py
│ ├── train.json
│ ├── train_balance.json
│ ├── valid.json
│ └── valid_balance.json
├── ds_test_finetune.sh
├── ds_test_ptv2.sh
├── finetune_test.py
├── ft_ds_config.json
├── output
│ └── test
│ │ └── trans.py
├── test.py
├── trainer.py
└── trainer_seq2seq.py
├── claude
└── claude.py
├── galactica
├── calc_map.py
├── gala_base.py
├── gala_finetune.py
└── process.py
├── glm
├── dataset.py
├── ds_config_glm_10b.json
├── ds_config_glm_2b.json
├── finetune_glm_10b_ds.py
├── finetune_glm_ds.py
├── process.py
├── run_finetune_ds.sh
├── run_finetune_ds_10b.sh
└── test_glm.py
├── gpt-api
├── gpt.py
└── map.py
├── net_emb.py
├── requirements.txt
├── result.txt
├── rf
├── model_main.py
├── model_rf.py
├── process_data.py
├── process_kddcup_data.py
└── set_param.py
├── rule.py
├── settings.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 |
3 | data/
4 | out/
5 | glm/saved/
6 | .idea/
7 | chatglm/output/
8 | galactica/lightning_logs/
9 | galactica/result/
10 | galactica/saved_model/
11 | processed_data/
12 | saved_model/
13 |
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | pip-wheel-metadata/
36 | share/python-wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | MANIFEST
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .nox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | *.py,cover
63 | .hypothesis/
64 | .pytest_cache/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | *.log
72 | local_settings.py
73 | db.sqlite3
74 | db.sqlite3-journal
75 |
76 | # Flask stuff:
77 | instance/
78 | .webassets-cache
79 |
80 | # Scrapy stuff:
81 | .scrapy
82 |
83 | # Sphinx documentation
84 | docs/_build/
85 |
86 | # PyBuilder
87 | target/
88 |
89 | # Jupyter Notebook
90 | .ipynb_checkpoints
91 |
92 | # IPython
93 | profile_default/
94 | ipython_config.py
95 |
96 | # pyenv
97 | .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107 | __pypackages__/
108 |
109 | # Celery stuff
110 | celerybeat-schedule
111 | celerybeat.pid
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # paper-source-trace
2 |
3 | ## Prerequisites
4 | - Linux
5 | - Python 3.9
6 | - PyTorch 1.10.0+cu111
7 |
8 | ## Getting Started
9 |
10 | ### Installation
11 |
12 | Clone this repo.
13 |
14 | ```bash
15 | git clone https://github.com/THUDM/paper-source-trace.git
16 | cd paper-source-trace
17 | ```
18 |
19 | Please install dependencies by
20 |
21 | ```bash
22 | pip install -r requirements.txt
23 | ```
24 |
25 | ## PST Dataset
26 | The dataset can be downloaded from [BaiduPan](https://pan.baidu.com/s/1I_HZXBx7U0UsRHJL5JJagw?pwd=bft3) with password bft3, [Aliyun](https://open-data-set.oss-cn-beijing.aliyuncs.com/oag-benchmark/kddcup-2024/PST/PST.zip) or [DropBox](https://www.dropbox.com/scl/fi/namx1n55xzqil4zbkd5sv/PST.zip?rlkey=impcbm2acqmqhurv2oj0xxysx&dl=1).
27 | The paper XML files are generated by [Grobid](https://grobid.readthedocs.io/en/latest/Introduction/) APIs from paper pdfs.
28 |
29 | ## Run Baselines for [KDD Cup 2024](https://www.biendata.xyz/competition/pst_kdd_2024/)
30 | First, download DBLP dataset from [AMiner](https://opendata.aminer.cn/dataset/DBLP-Citation-network-V16.zip).
31 | Put the unzipped PST directory into ``data/`` and unzipped DBLP dataset into ``data/PST/``.
32 |
33 | ```bash
34 | cd $project_path
35 | export CUDA_VISIBLE_DEVICES='?' # specify which GPU(s) to be used
36 | export PYTHONPATH="`pwd`:$PYTHONPATH"
37 |
38 | # Method 1: Random Forest
39 | python rf/process_kddcup_data.py
40 | python rf/model_rf.py # output at out/kddcup/rf/
41 |
42 | # Method 2: Network Embedding
43 | python net_emb.py # output at out/kddcup/prone/
44 |
45 | # Method 3: SciBERT
46 | python bert.py # output at out/kddcup/scibert/
47 | ```
48 |
49 | ## Results on Valiation Set
50 |
51 | | Method | MAP |
52 | |-------|-------|
53 | | Random Forest | 0.21420 |
54 | | ProNE | 0.21668 |
55 | | SciBERT | 0.29489 |
56 |
57 | ## Citation
58 |
59 | If you find this repo useful in your research, please cite the following papers:
60 |
61 | ```
62 | @article{zhang2024pst,
63 | title={PST-Bench: Tracing and Benchmarking the Source of Publications},
64 | author={Fanjin Zhang and Kun Cao and Yukuo Cen and Jifan Yu and Da Yin and Jie Tang},
65 | journal={arXiv preprint arXiv:2402.16009},
66 | year={2024}
67 | }
68 |
69 | @inproceedings{zhang2024oag,
70 | title={OAG-bench: a human-curated benchmark for academic graph mining},
71 | author={Fanjin Zhang and Shijie Shi and Yifan Zhu and Bo Chen and Yukuo Cen and Jifan Yu and Yelin Chen and Lulu Wang and Qingfei Zhao and Yuqing Cheng and Tianyi Han and Yuwei An and Dan Zhang and Weng Lam Tam and Kun Cao and Yunhe Pang and Xinyu Guan and Huihui Yuan and Jian Song and Xiaoyan Li and Yuxiao Dong and Jie Tang},
72 | booktitle={Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
73 | pages={6214--6225},
74 | year={2024}
75 | }
76 | ```
77 |
78 | ## Paper Sharing Group
79 |
80 | Hello everyone,
81 |
82 | We've created an online WeChat paper-sharing group where each member is required to share 2 computer science papers every week. We have established mechanisms of rewards and penalties for members who do and do not share papers as required. You are free to join or leave at any time. Welcome to join us! (You can receive the up-to-date QR code from this channel. https://t.me/+apOrPEOLGixiNjdl)
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/assets/group_qr_code_20240621.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/paper-source-trace/d0896c62508ad4bb00d28b28b8411c92034a409d/assets/group_qr_code_20240621.jpg
--------------------------------------------------------------------------------
/bert.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 | from tqdm import tqdm
4 | from collections import defaultdict as dd
5 | from bs4 import BeautifulSoup
6 | from fuzzywuzzy import fuzz
7 | import numpy as np
8 | import torch
9 | from transformers import AutoTokenizer
10 | from transformers import BertForSequenceClassification, get_linear_schedule_with_warmup
11 | from transformers.optimization import AdamW
12 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
13 | from tqdm import trange
14 | from sklearn.metrics import classification_report, precision_recall_fscore_support, average_precision_score
15 | import logging
16 |
17 | import utils
18 | import settings
19 |
20 |
21 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
22 | datefmt = '%m/%d/%Y %H:%M:%S',
23 | level = logging.INFO)
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | MAX_SEQ_LENGTH=512
28 |
29 |
30 | def prepare_train_test_data_for_bert(year=2023):
31 | x_train = []
32 | y_train = []
33 | x_valid = []
34 | y_valid = []
35 | x_test = []
36 | y_test = []
37 |
38 | truths = utils.load_json(settings.DATA_TRACE_DIR, "paper_source_trace_{}_final_filtered.json".format(year))
39 | pid_to_source_titles = dd(list)
40 | for paper in tqdm(truths):
41 | pid = paper["_id"]
42 | for ref in paper["refs_trace"]:
43 | pid_to_source_titles[pid].append(ref["title"].lower())
44 |
45 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year))
46 | papers_train = utils.load_json(data_year_dir, "paper_source_trace_train.json")
47 | papers_valid = utils.load_json(data_year_dir, "paper_source_trace_valid.json")
48 | papers_test = utils.load_json(data_year_dir, "paper_source_trace_test.json")
49 |
50 | pids_train = {p["_id"] for p in papers_train}
51 | pids_valid = {p["_id"] for p in papers_valid}
52 | pids_test = {p["_id"] for p in papers_test}
53 |
54 | in_dir = join(settings.DATA_TRACE_DIR, "paper-xml")
55 | files = []
56 | for f in os.listdir(in_dir):
57 | if f.endswith(".xml"):
58 | files.append(f)
59 |
60 | files = sorted(files)
61 | for file in tqdm(files):
62 | f = open(join(in_dir, file), encoding='utf-8')
63 | cur_pid = file.split(".")[0]
64 | if cur_pid not in pids_train and cur_pid not in pids_valid and cur_pid not in pids_test:
65 | continue
66 | xml = f.read()
67 | bs = BeautifulSoup(xml, "xml")
68 |
69 | source_titles = pid_to_source_titles[cur_pid]
70 | if len(source_titles) == 0:
71 | continue
72 |
73 | references = bs.find_all("biblStruct")
74 | bid_to_title = {}
75 | n_refs = 0
76 | for ref in references:
77 | if "xml:id" not in ref.attrs:
78 | continue
79 | bid = ref.attrs["xml:id"]
80 | if ref.analytic is None:
81 | continue
82 | if ref.analytic.title is None:
83 | continue
84 | bid_to_title[bid] = ref.analytic.title.text.lower()
85 | b_idx = int(bid[1:]) + 1
86 | if b_idx > n_refs:
87 | n_refs = b_idx
88 |
89 | flag = False
90 |
91 | cur_pos_bib = set()
92 |
93 | for bid in bid_to_title:
94 | cur_ref_title = bid_to_title[bid]
95 | for label_title in source_titles:
96 | if fuzz.ratio(cur_ref_title, label_title) >= 80:
97 | flag = True
98 | cur_pos_bib.add(bid)
99 |
100 | cur_neg_bib = set(bid_to_title.keys()) - cur_pos_bib
101 |
102 | if not flag:
103 | continue
104 |
105 | if len(cur_pos_bib) == 0 or len(cur_neg_bib) == 0:
106 | continue
107 |
108 | bib_to_contexts = utils.find_bib_context(xml)
109 |
110 | n_pos = len(cur_pos_bib)
111 | n_neg = n_pos * 10
112 | cur_neg_bib_sample = np.random.choice(list(cur_neg_bib), n_neg, replace=True)
113 |
114 | if cur_pid in pids_train:
115 | cur_x = x_train
116 | cur_y = y_train
117 | elif cur_pid in pids_valid:
118 | cur_x = x_valid
119 | cur_y = y_valid
120 | elif cur_pid in pids_test:
121 | cur_x = x_test
122 | cur_y = y_test
123 | else:
124 | continue
125 | # raise Exception("cur_pid not in train/valid/test")
126 |
127 | for bib in cur_pos_bib:
128 | cur_context = " ".join(bib_to_contexts[bib])
129 | cur_x.append(cur_context)
130 | cur_y.append(1)
131 |
132 | for bib in cur_neg_bib_sample:
133 | cur_context = " ".join(bib_to_contexts[bib])
134 | cur_x.append(cur_context)
135 | cur_y.append(0)
136 |
137 | print("len(x_train)", len(x_train), "len(x_valid)", len(x_valid), "len(x_test)", len(x_test))
138 |
139 | with open(join(data_year_dir, "bib_context_train.txt"), "w", encoding="utf-8") as f:
140 | for line in x_train:
141 | f.write(line + "\n")
142 |
143 | with open(join(data_year_dir, "bib_context_valid.txt"), "w", encoding="utf-8") as f:
144 | for line in x_valid:
145 | f.write(line + "\n")
146 |
147 | with open(join(data_year_dir, "bib_context_test.txt"), "w", encoding="utf-8") as f:
148 | for line in x_test:
149 | f.write(line + "\n")
150 |
151 | with open(join(data_year_dir, "bib_context_train_label.txt"), "w", encoding="utf-8") as f:
152 | for line in y_train:
153 | f.write(str(line) + "\n")
154 |
155 | with open(join(data_year_dir, "bib_context_valid_label.txt"), "w", encoding="utf-8") as f:
156 | for line in y_valid:
157 | f.write(str(line) + "\n")
158 |
159 | with open(join(data_year_dir, "bib_context_test_label.txt"), "w", encoding="utf-8") as f:
160 | for line in y_test:
161 | f.write(str(line) + "\n")
162 |
163 |
164 |
165 | def prepare_bert_input():
166 | x_train = []
167 | y_train = []
168 | x_valid = []
169 | y_valid = []
170 |
171 | data_dir = join(settings.DATA_TRACE_DIR, "PST")
172 | papers = utils.load_json(data_dir, "paper_source_trace_train_ans.json")
173 | n_papers = len(papers)
174 | papers = sorted(papers, key=lambda x: x["_id"])
175 | n_train = int(n_papers * 2 / 3)
176 | # n_valid = n_papers - n_train
177 |
178 | papers_train = papers[:n_train]
179 | papers_valid = papers[n_train:]
180 |
181 | pids_train = {p["_id"] for p in papers_train}
182 | pids_valid = {p["_id"] for p in papers_valid}
183 |
184 | in_dir = join(data_dir, "paper-xml")
185 | files = []
186 | for f in os.listdir(in_dir):
187 | if f.endswith(".xml"):
188 | files.append(f)
189 |
190 | pid_to_source_titles = dd(list)
191 | for paper in tqdm(papers):
192 | pid = paper["_id"]
193 | for ref in paper["refs_trace"]:
194 | pid_to_source_titles[pid].append(ref["title"].lower())
195 |
196 | # files = sorted(files)
197 | # for file in tqdm(files):
198 | for cur_pid in tqdm(pids_train | pids_valid):
199 | # cur_pid = file.split(".")[0]
200 | # if cur_pid not in pids_train and cur_pid not in pids_valid:
201 | # continue
202 | f = open(join(in_dir, cur_pid + ".xml"), encoding='utf-8')
203 | xml = f.read()
204 | bs = BeautifulSoup(xml, "xml")
205 |
206 | source_titles = pid_to_source_titles[cur_pid]
207 | if len(source_titles) == 0:
208 | continue
209 |
210 | references = bs.find_all("biblStruct")
211 | bid_to_title = {}
212 | n_refs = 0
213 | for ref in references:
214 | if "xml:id" not in ref.attrs:
215 | continue
216 | bid = ref.attrs["xml:id"]
217 | if ref.analytic is None:
218 | continue
219 | if ref.analytic.title is None:
220 | continue
221 | bid_to_title[bid] = ref.analytic.title.text.lower()
222 | b_idx = int(bid[1:]) + 1
223 | if b_idx > n_refs:
224 | n_refs = b_idx
225 |
226 | flag = False
227 |
228 | cur_pos_bib = set()
229 |
230 | for bid in bid_to_title:
231 | cur_ref_title = bid_to_title[bid]
232 | for label_title in source_titles:
233 | if fuzz.ratio(cur_ref_title, label_title) >= 80:
234 | flag = True
235 | cur_pos_bib.add(bid)
236 |
237 | cur_neg_bib = set(bid_to_title.keys()) - cur_pos_bib
238 |
239 | if not flag:
240 | continue
241 |
242 | if len(cur_pos_bib) == 0 or len(cur_neg_bib) == 0:
243 | continue
244 |
245 | bib_to_contexts = utils.find_bib_context(xml)
246 |
247 | n_pos = len(cur_pos_bib)
248 | n_neg = n_pos * 10
249 | cur_neg_bib_sample = np.random.choice(list(cur_neg_bib), n_neg, replace=True)
250 |
251 | if cur_pid in pids_train:
252 | cur_x = x_train
253 | cur_y = y_train
254 | elif cur_pid in pids_valid:
255 | cur_x = x_valid
256 | cur_y = y_valid
257 | else:
258 | continue
259 | # raise Exception("cur_pid not in train/valid/test")
260 |
261 | for bib in cur_pos_bib:
262 | cur_context = " ".join(bib_to_contexts[bib])
263 | cur_x.append(cur_context)
264 | cur_y.append(1)
265 |
266 | for bib in cur_neg_bib_sample:
267 | cur_context = " ".join(bib_to_contexts[bib])
268 | cur_x.append(cur_context)
269 | cur_y.append(0)
270 |
271 | print("len(x_train)", len(x_train), "len(x_valid)", len(x_valid))
272 |
273 |
274 | with open(join(data_dir, "bib_context_train.txt"), "w", encoding="utf-8") as f:
275 | for line in x_train:
276 | f.write(line + "\n")
277 |
278 | with open(join(data_dir, "bib_context_valid.txt"), "w", encoding="utf-8") as f:
279 | for line in x_valid:
280 | f.write(line + "\n")
281 |
282 | with open(join(data_dir, "bib_context_train_label.txt"), "w", encoding="utf-8") as f:
283 | for line in y_train:
284 | f.write(str(line) + "\n")
285 |
286 | with open(join(data_dir, "bib_context_valid_label.txt"), "w", encoding="utf-8") as f:
287 | for line in y_valid:
288 | f.write(str(line) + "\n")
289 |
290 |
291 | class BertInputItem(object):
292 | """An item with all the necessary attributes for finetuning BERT."""
293 |
294 | def __init__(self, text, input_ids, input_mask, segment_ids, label_id):
295 | self.text = text
296 | self.input_ids = input_ids
297 | self.input_mask = input_mask
298 | self.segment_ids = segment_ids
299 | self.label_id = label_id
300 |
301 |
302 | def convert_examples_to_inputs(example_texts, example_labels, max_seq_length, tokenizer, verbose=0):
303 | """Loads a data file into a list of `InputBatch`s."""
304 |
305 | input_items = []
306 | examples = zip(example_texts, example_labels)
307 | for (ex_index, (text, label)) in enumerate(examples):
308 |
309 | # Create a list of token ids
310 | input_ids = tokenizer.encode(f"[CLS] {text} [SEP]")
311 | if len(input_ids) > max_seq_length:
312 | input_ids = input_ids[:max_seq_length]
313 |
314 | # All our tokens are in the first input segment (id 0).
315 | segment_ids = [0] * len(input_ids)
316 |
317 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
318 | # tokens are attended to.
319 | input_mask = [1] * len(input_ids)
320 |
321 | # Zero-pad up to the sequence length.
322 | padding = [0] * (max_seq_length - len(input_ids))
323 | input_ids += padding
324 | input_mask += padding
325 | segment_ids += padding
326 |
327 | assert len(input_ids) == max_seq_length
328 | assert len(input_mask) == max_seq_length
329 | assert len(segment_ids) == max_seq_length
330 |
331 | label_id = label
332 |
333 | input_items.append(
334 | BertInputItem(text=text,
335 | input_ids=input_ids,
336 | input_mask=input_mask,
337 | segment_ids=segment_ids,
338 | label_id=label_id))
339 |
340 | return input_items
341 |
342 |
343 | def get_data_loader(features, max_seq_length, batch_size, shuffle=True):
344 |
345 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
346 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
347 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
348 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
349 | data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
350 |
351 | dataloader = DataLoader(data, shuffle=shuffle, batch_size=batch_size)
352 | return dataloader
353 |
354 |
355 | def evaluate(model, dataloader, device, criterion):
356 | model.eval()
357 |
358 | eval_loss = 0
359 | nb_eval_steps = 0
360 | predicted_labels, correct_labels = [], []
361 |
362 | for step, batch in enumerate(tqdm(dataloader, desc="Evaluation iteration")):
363 | batch = tuple(t.to(device) for t in batch)
364 | input_ids, input_mask, segment_ids, label_ids = batch
365 |
366 | with torch.no_grad():
367 | r = model(input_ids, attention_mask=input_mask,
368 | token_type_ids=segment_ids, labels=label_ids)
369 | # tmp_eval_loss = r[0]
370 | logits = r[1]
371 | # print("logits", logits)
372 | tmp_eval_loss = criterion(logits, label_ids)
373 |
374 | outputs = np.argmax(logits.to('cpu'), axis=1)
375 | label_ids = label_ids.to('cpu').numpy()
376 |
377 | predicted_labels += list(outputs)
378 | correct_labels += list(label_ids)
379 |
380 | eval_loss += tmp_eval_loss.mean().item()
381 | nb_eval_steps += 1
382 |
383 | eval_loss = eval_loss / nb_eval_steps
384 |
385 | correct_labels = np.array(correct_labels)
386 | predicted_labels = np.array(predicted_labels)
387 |
388 | return eval_loss, correct_labels, predicted_labels
389 |
390 |
391 | def train(year=2023, model_name="scibert"):
392 | print("model name", model_name)
393 | train_texts = []
394 | dev_texts = []
395 | train_labels = []
396 | dev_labels = []
397 | data_year_dir = join(settings.DATA_TRACE_DIR, "PST")
398 | print("data_year_dir", data_year_dir)
399 |
400 | with open(join(data_year_dir, "bib_context_train.txt"), "r", encoding="utf-8") as f:
401 | for line in f:
402 | train_texts.append(line.strip())
403 | with open(join(data_year_dir, "bib_context_valid.txt"), "r", encoding="utf-8") as f:
404 | for line in f:
405 | dev_texts.append(line.strip())
406 |
407 | with open(join(data_year_dir, "bib_context_train_label.txt"), "r", encoding="utf-8") as f:
408 | for line in f:
409 | train_labels.append(int(line.strip()))
410 | with open(join(data_year_dir, "bib_context_valid_label.txt"), "r", encoding="utf-8") as f:
411 | for line in f:
412 | dev_labels.append(int(line.strip()))
413 |
414 |
415 | print("Train size:", len(train_texts))
416 | print("Dev size:", len(dev_texts))
417 |
418 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
419 |
420 | class_weight = len(train_labels) / (2 * np.bincount(train_labels))
421 | class_weight = torch.Tensor(class_weight).to(device)
422 | print("Class weight:", class_weight)
423 |
424 | if model_name == "bert":
425 | BERT_MODEL = "bert-base-uncased"
426 | elif model_name == "scibert":
427 | BERT_MODEL = "allenai/scibert_scivocab_uncased"
428 | else:
429 | raise NotImplementedError
430 | tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
431 |
432 | model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels = 2)
433 | model.to(device)
434 |
435 | criterion = torch.nn.CrossEntropyLoss(weight=class_weight)
436 |
437 | train_features = convert_examples_to_inputs(train_texts, train_labels, MAX_SEQ_LENGTH, tokenizer, verbose=0)
438 | dev_features = convert_examples_to_inputs(dev_texts, dev_labels, MAX_SEQ_LENGTH, tokenizer)
439 |
440 | BATCH_SIZE = 16
441 | train_dataloader = get_data_loader(train_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=True)
442 | dev_dataloader = get_data_loader(dev_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)
443 |
444 | GRADIENT_ACCUMULATION_STEPS = 1
445 | NUM_TRAIN_EPOCHS = 20
446 | LEARNING_RATE = 5e-5
447 | WARMUP_PROPORTION = 0.1
448 | MAX_GRAD_NORM = 5
449 |
450 | num_train_steps = int(len(train_dataloader.dataset) / BATCH_SIZE / GRADIENT_ACCUMULATION_STEPS * NUM_TRAIN_EPOCHS)
451 | num_warmup_steps = int(WARMUP_PROPORTION * num_train_steps)
452 |
453 | param_optimizer = list(model.named_parameters())
454 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
455 | optimizer_grouped_parameters = [
456 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
457 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
458 | ]
459 |
460 | optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE, correct_bias=False)
461 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)
462 |
463 | OUTPUT_DIR = join(settings.OUT_DIR, "kddcup", model_name)
464 | os.makedirs(OUTPUT_DIR, exist_ok=True)
465 |
466 | MODEL_FILE_NAME = "pytorch_model.bin"
467 | PATIENCE = 5
468 |
469 | loss_history = []
470 | no_improvement = 0
471 | for _ in trange(int(NUM_TRAIN_EPOCHS), desc="Epoch"):
472 | model.train()
473 | tr_loss = 0
474 | nb_tr_examples, nb_tr_steps = 0, 0
475 | for step, batch in enumerate(tqdm(train_dataloader, desc="Training iteration")):
476 | batch = tuple(t.to(device) for t in batch)
477 | input_ids, input_mask, segment_ids, label_ids = batch
478 |
479 | outputs = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids, labels=label_ids)
480 | # loss = outputs[0]
481 | logits = outputs[1]
482 |
483 | loss = criterion(logits, label_ids)
484 |
485 | if GRADIENT_ACCUMULATION_STEPS > 1:
486 | loss = loss / GRADIENT_ACCUMULATION_STEPS
487 |
488 | loss.backward()
489 | tr_loss += loss.item()
490 |
491 | if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
492 | torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
493 |
494 | optimizer.step()
495 | optimizer.zero_grad()
496 | scheduler.step()
497 |
498 | dev_loss, _, _ = evaluate(model, dev_dataloader, device, criterion)
499 |
500 | print("Loss history:", loss_history)
501 | print("Dev loss:", dev_loss)
502 |
503 | if len(loss_history) == 0 or dev_loss < min(loss_history):
504 | no_improvement = 0
505 | model_to_save = model.module if hasattr(model, 'module') else model
506 | output_model_file = os.path.join(OUTPUT_DIR, MODEL_FILE_NAME)
507 | torch.save(model_to_save.state_dict(), output_model_file)
508 | else:
509 | no_improvement += 1
510 |
511 | if no_improvement >= PATIENCE:
512 | print("No improvement on development set. Finish training.")
513 | break
514 |
515 | loss_history.append(dev_loss)
516 |
517 |
518 | def eval_test_papers_bert(year=2023, model_name="scibert"):
519 | print("model name", model_name)
520 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year))
521 | papers_test = utils.load_json(data_year_dir, "paper_source_trace_test.json")
522 | pids_test = {p["_id"] for p in papers_test}
523 |
524 | in_dir = join(settings.DATA_TRACE_DIR, "paper-xml")
525 | files = []
526 | for f in os.listdir(in_dir):
527 | cur_pid = f.split(".")[0]
528 | if f.endswith(".xml") and cur_pid in pids_test:
529 | files.append(f)
530 |
531 | truths = papers_test
532 | pid_to_source_titles = dd(list)
533 | for paper in tqdm(truths):
534 | pid = paper["_id"]
535 | for ref in paper["refs_trace"]:
536 | pid_to_source_titles[pid].append(ref["title"].lower())
537 |
538 | if model_name == "bert":
539 | BERT_MODEL = "bert-base-uncased"
540 | elif model_name == "scibert":
541 | BERT_MODEL = "allenai/scibert_scivocab_uncased"
542 | else:
543 | raise NotImplementedError
544 | tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
545 |
546 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
547 | print("device", device)
548 | model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels = 2)
549 | # model.load_state_dict(torch.load(join(settings.OUT_DIR, model_name, "pytorch_model.bin")))
550 | # model.load_state_dict(torch.load(join(settings.OUT_DIR, "bert", "pytorch_model.bin")))
551 | model.to(device)
552 | model.eval()
553 |
554 | BATCH_SIZE = 16
555 | metrics = []
556 | f_idx = 0
557 |
558 | xml_dir = join(settings.DATA_TRACE_DIR, "paper-xml")
559 |
560 | for paper in tqdm(papers_test):
561 | cur_pid = paper["_id"]
562 | file = join(xml_dir, cur_pid + ".tei.xml")
563 | f = open(file, encoding='utf-8')
564 |
565 | xml = f.read()
566 | bs = BeautifulSoup(xml, "xml")
567 | f.close()
568 |
569 | source_titles = pid_to_source_titles[cur_pid]
570 | if len(source_titles) == 0:
571 | continue
572 |
573 | references = bs.find_all("biblStruct")
574 | bid_to_title = {}
575 | n_refs = 0
576 | for ref in references:
577 | if "xml:id" not in ref.attrs:
578 | continue
579 | bid = ref.attrs["xml:id"]
580 | if ref.analytic is None:
581 | continue
582 | if ref.analytic.title is None:
583 | continue
584 | bid_to_title[bid] = ref.analytic.title.text.lower()
585 | b_idx = int(bid[1:]) + 1
586 | if b_idx > n_refs:
587 | n_refs = b_idx
588 |
589 | bib_to_contexts = utils.find_bib_context(xml)
590 | bib_sorted = sorted(bib_to_contexts.keys())
591 |
592 | for bib in bib_sorted:
593 | cur_bib_idx = int(bib[1:])
594 | if cur_bib_idx + 1 > n_refs:
595 | n_refs = cur_bib_idx + 1
596 |
597 | y_true = [0] * n_refs
598 | y_score = [0] * n_refs
599 |
600 | flag = False
601 | for bid in bid_to_title:
602 | cur_ref_title = bid_to_title[bid]
603 | for label_title in source_titles:
604 | if fuzz.ratio(cur_ref_title, label_title) >= 80:
605 | flag = True
606 | b_idx = int(bid[1:])
607 | y_true[b_idx] = 1
608 |
609 | if not flag:
610 | continue
611 |
612 | contexts_sorted = [" ".join(bib_to_contexts[bib]) for bib in bib_sorted]
613 |
614 | test_features = convert_examples_to_inputs(contexts_sorted, y_score, MAX_SEQ_LENGTH, tokenizer)
615 | test_dataloader = get_data_loader(test_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)
616 |
617 | predicted_scores = []
618 | for step, batch in enumerate(test_dataloader):
619 | batch = tuple(t.to(device) for t in batch)
620 | input_ids, input_mask, segment_ids, label_ids = batch
621 |
622 | with torch.no_grad():
623 | r = model(input_ids, attention_mask=input_mask,
624 | token_type_ids=segment_ids, labels=label_ids)
625 | tmp_eval_loss = r[0]
626 | logits = r[1]
627 |
628 | cur_pred_scores = logits[:, 1].to('cpu').numpy()
629 | predicted_scores.extend(cur_pred_scores)
630 |
631 | try:
632 | for ii in range(len(predicted_scores)):
633 | bib_idx = int(bib_sorted[ii][1:])
634 | # print("bib_idx", bib_idx)
635 | y_score[bib_idx] = predicted_scores[ii]
636 | except IndexError as e:
637 | metrics.append(0)
638 | continue
639 |
640 | cur_map = average_precision_score(y_true, y_score)
641 | metrics.append(cur_map)
642 | f_idx += 1
643 | if f_idx % 20 == 0:
644 | print("map until now", np.mean(metrics), len(metrics), cur_map)
645 |
646 | print("bert average map", np.mean(metrics), len(metrics))
647 |
648 |
649 | def gen_kddcup_valid_submission_bert(model_name="scibert"):
650 | print("model name", model_name)
651 | data_dir = join(settings.DATA_TRACE_DIR, "PST")
652 | papers = utils.load_json(data_dir, "paper_source_trace_valid_wo_ans.json")
653 |
654 | if model_name == "bert":
655 | BERT_MODEL = "bert-base-uncased"
656 | elif model_name == "scibert":
657 | BERT_MODEL = "allenai/scibert_scivocab_uncased"
658 | else:
659 | raise NotImplementedError
660 | tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
661 |
662 | sub_example_dict = utils.load_json(data_dir, "submission_example_valid.json")
663 |
664 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
665 | print("device", device)
666 | model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels = 2)
667 | model.load_state_dict(torch.load(join(settings.OUT_DIR, "kddcup", model_name, "pytorch_model.bin")))
668 |
669 | model.to(device)
670 | model.eval()
671 |
672 | BATCH_SIZE = 16
673 | # metrics = []
674 | # f_idx = 0
675 |
676 | xml_dir = join(data_dir, "paper-xml")
677 | sub_dict = {}
678 |
679 | for paper in tqdm(papers):
680 | cur_pid = paper["_id"]
681 | file = join(xml_dir, cur_pid + ".xml")
682 | f = open(file, encoding='utf-8')
683 | xml = f.read()
684 | bs = BeautifulSoup(xml, "xml")
685 | f.close()
686 |
687 | references = bs.find_all("biblStruct")
688 | bid_to_title = {}
689 | n_refs = 0
690 | for ref in references:
691 | if "xml:id" not in ref.attrs:
692 | continue
693 | bid = ref.attrs["xml:id"]
694 | if ref.analytic is None:
695 | continue
696 | if ref.analytic.title is None:
697 | continue
698 | bid_to_title[bid] = ref.analytic.title.text.lower()
699 | b_idx = int(bid[1:]) + 1
700 | if b_idx > n_refs:
701 | n_refs = b_idx
702 |
703 | bib_to_contexts = utils.find_bib_context(xml)
704 | # bib_sorted = sorted(bib_to_contexts.keys())
705 | bib_sorted = ["b" + str(ii) for ii in range(n_refs)]
706 |
707 | y_score = [0] * n_refs
708 |
709 | assert len(sub_example_dict[cur_pid]) == n_refs
710 | # continue
711 |
712 | contexts_sorted = [" ".join(bib_to_contexts[bib]) for bib in bib_sorted]
713 |
714 | test_features = convert_examples_to_inputs(contexts_sorted, y_score, MAX_SEQ_LENGTH, tokenizer)
715 | test_dataloader = get_data_loader(test_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)
716 |
717 | predicted_scores = []
718 | for step, batch in enumerate(test_dataloader):
719 | batch = tuple(t.to(device) for t in batch)
720 | input_ids, input_mask, segment_ids, label_ids = batch
721 |
722 | with torch.no_grad():
723 | r = model(input_ids, attention_mask=input_mask,
724 | token_type_ids=segment_ids, labels=label_ids)
725 | tmp_eval_loss = r[0]
726 | logits = r[1]
727 |
728 | cur_pred_scores = logits[:, 1].to('cpu').numpy()
729 | predicted_scores.extend(cur_pred_scores)
730 |
731 | for ii in range(len(predicted_scores)):
732 | bib_idx = int(bib_sorted[ii][1:])
733 | # print("bib_idx", bib_idx)
734 | y_score[bib_idx] = float(utils.sigmoid(predicted_scores[ii]))
735 |
736 | sub_dict[cur_pid] = y_score
737 |
738 | utils.dump_json(sub_dict, join(settings.OUT_DIR, "kddcup", model_name), "valid_submission_scibert.json")
739 |
740 |
741 | if __name__ == "__main__":
742 | # prepare_train_test_data_for_bert()
743 | prepare_bert_input()
744 | train(model_name="scibert")
745 | # eval_test_papers_bert(model_name="scibert")
746 | gen_kddcup_valid_submission_bert(model_name="scibert")
747 |
--------------------------------------------------------------------------------
/chatglm/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 |
5 | @dataclass
6 | class ModelArguments:
7 | """
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9 | """
10 |
11 | model_name_or_path: str = field(
12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
13 | )
14 | ptuning_checkpoint: str = field(
15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
16 | )
17 | config_name: Optional[str] = field(
18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
19 | )
20 | tokenizer_name: Optional[str] = field(
21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
22 | )
23 | cache_dir: Optional[str] = field(
24 | default=None,
25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
26 | )
27 | use_fast_tokenizer: bool = field(
28 | default=True,
29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
30 | )
31 | model_revision: str = field(
32 | default="main",
33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
34 | )
35 | use_auth_token: bool = field(
36 | default=False,
37 | metadata={
38 | "help": (
39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
40 | "with private models)."
41 | )
42 | },
43 | )
44 | resize_position_embeddings: Optional[bool] = field(
45 | default=None,
46 | metadata={
47 | "help": (
48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
49 | "the model's position embeddings."
50 | )
51 | },
52 | )
53 | quantization_bit: Optional[int] = field(
54 | default=None
55 | )
56 | pre_seq_len: Optional[int] = field(
57 | default=None
58 | )
59 | prefix_projection: bool = field(
60 | default=False
61 | )
62 |
63 |
64 | @dataclass
65 | class DataTrainingArguments:
66 | """
67 | Arguments pertaining to what data we are going to input our model for training and eval.
68 | """
69 |
70 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
71 |
72 | dataset_name: Optional[str] = field(
73 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
74 | )
75 | dataset_config_name: Optional[str] = field(
76 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
77 | )
78 | prompt_column: Optional[str] = field(
79 | default=None,
80 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
81 | )
82 | response_column: Optional[str] = field(
83 | default=None,
84 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
85 | )
86 | history_column: Optional[str] = field(
87 | default=None,
88 | metadata={"help": "The name of the column in the datasets containing the history of chat."},
89 | )
90 | train_file: Optional[str] = field(
91 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
92 | )
93 | validation_file: Optional[str] = field(
94 | default=None,
95 | metadata={
96 | "help": (
97 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
98 | )
99 | },
100 | )
101 | test_file: Optional[str] = field(
102 | default=None,
103 | metadata={
104 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
105 | },
106 | )
107 | overwrite_cache: bool = field(
108 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
109 | )
110 | preprocessing_num_workers: Optional[int] = field(
111 | default=None,
112 | metadata={"help": "The number of processes to use for the preprocessing."},
113 | )
114 | max_source_length: Optional[int] = field(
115 | default=1024,
116 | metadata={
117 | "help": (
118 | "The maximum total input sequence length after tokenization. Sequences longer "
119 | "than this will be truncated, sequences shorter will be padded."
120 | )
121 | },
122 | )
123 | max_target_length: Optional[int] = field(
124 | default=128,
125 | metadata={
126 | "help": (
127 | "The maximum total sequence length for target text after tokenization. Sequences longer "
128 | "than this will be truncated, sequences shorter will be padded."
129 | )
130 | },
131 | )
132 | val_max_target_length: Optional[int] = field(
133 | default=None,
134 | metadata={
135 | "help": (
136 | "The maximum total sequence length for validation target text after tokenization. Sequences longer "
137 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
138 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
139 | "during ``evaluate`` and ``predict``."
140 | )
141 | },
142 | )
143 | pad_to_max_length: bool = field(
144 | default=False,
145 | metadata={
146 | "help": (
147 | "Whether to pad all samples to model maximum sentence length. "
148 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
149 | "efficient on GPU but very bad for TPU."
150 | )
151 | },
152 | )
153 | max_train_samples: Optional[int] = field(
154 | default=None,
155 | metadata={
156 | "help": (
157 | "For debugging purposes or quicker training, truncate the number of training examples to this "
158 | "value if set."
159 | )
160 | },
161 | )
162 | max_eval_samples: Optional[int] = field(
163 | default=None,
164 | metadata={
165 | "help": (
166 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
167 | "value if set."
168 | )
169 | },
170 | )
171 | max_predict_samples: Optional[int] = field(
172 | default=None,
173 | metadata={
174 | "help": (
175 | "For debugging purposes or quicker training, truncate the number of prediction examples to this "
176 | "value if set."
177 | )
178 | },
179 | )
180 | num_beams: Optional[int] = field(
181 | default=None,
182 | metadata={
183 | "help": (
184 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
185 | "which is used during ``evaluate`` and ``predict``."
186 | )
187 | },
188 | )
189 | ignore_pad_token_for_loss: bool = field(
190 | default=True,
191 | metadata={
192 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
193 | },
194 | )
195 | source_prefix: Optional[str] = field(
196 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
197 | )
198 |
199 | forced_bos_token: Optional[str] = field(
200 | default=None,
201 | metadata={
202 | "help": (
203 | "The token to force as the first generated token after the decoder_start_token_id."
204 | "Useful for multilingual models like mBART where the first generated token"
205 | "needs to be the target language token (Usually it is the target language token)"
206 | )
207 | },
208 | )
209 |
210 |
211 |
212 | def __post_init__(self):
213 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None:
214 | raise ValueError("Need either a dataset name or a training/validation/test file.")
215 | else:
216 | if self.train_file is not None:
217 | extension = self.train_file.split(".")[-1]
218 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
219 | if self.validation_file is not None:
220 | extension = self.validation_file.split(".")[-1]
221 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
222 | if self.val_max_target_length is None:
223 | self.val_max_target_length = self.max_target_length
224 |
225 |
--------------------------------------------------------------------------------
/chatglm/data/calc_map.py:
--------------------------------------------------------------------------------
1 | import json
2 | import warnings
3 | warnings.filterwarnings('ignore')
4 | from sklearn.metrics import average_precision_score
5 | # base_path = "/data/caokun/huge_model/chatGLM/ChatGLM-6B-main/ptuning/output/test/v2/"
6 | # base_path = "output/ptuning/"
7 | base_path = "output/finetune/"
8 |
9 | num = 5
10 | dic_list = [base_path + "generated_predictions.txt" for i in range(1, 2)]
11 | # dic_list = [base_path + str(i) + "000/generated_predictions.txt" for i in range(1, 6)]
12 | for i in range(len(dic_list)):
13 | result_list = []
14 | with open(dic_list[i], "r") as read_file:
15 | result_dic = json.load(read_file)
16 | for key in result_dic.keys():
17 | pre = []
18 | res = []
19 | for jtem in result_dic[key]:
20 | data = json.loads(jtem.strip())
21 | if data["labels"] == "Yes":
22 | res.append(1)
23 | else:
24 | res.append(0)
25 | if "yes" in data["predict"][0] or "Yes" in data["predict"][0] or "YES" in data["predict"][0]:
26 | pre.append(1)
27 | else:
28 | pre.append(0)
29 | result_list.append(average_precision_score(res, pre))
30 | print(f"{i}:map={sum(result_list)/len(result_list)}")
--------------------------------------------------------------------------------
/chatglm/data/format_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | with open("test.json", "r") as read_file:
3 | data_dic = json.load(read_file)
4 | write_data = open("test2.json", "w")
5 | write_ID_num = open("ID_num_dic.json", "w")
6 | ID_num_dic = {}
7 | n = 0
8 | for item in data_dic:
9 | this_data = data_dic[item]
10 | ID_num_dic[item] = []
11 | for jtem in this_data:
12 | ID_num_dic[item].append(n)
13 | n += 1
14 | write_data.write(str(json.dumps(jtem)) + '\n')
15 | json.dumps(ID_num_dic, write_ID_num, indent=2)
16 |
--------------------------------------------------------------------------------
/chatglm/data/positive_negetive_balance.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | target_list = ["train", "valid"]
5 | yes_list, no_list = [], []
6 | for item in target_list:
7 | with open("chatglm/data/" + item + '.json', "r") as read_file:
8 | all_lines = read_file.readlines()
9 | for data in all_lines:
10 | data_dic = json.loads(data.strip())
11 | if data_dic["summary"] == "Yes":
12 | yes_list.append(data.strip())
13 | else:
14 | no_list.append(data.strip())
15 | no_list = random.sample(no_list, len(yes_list))
16 | all_list = yes_list+no_list
17 | random.shuffle(all_list)
18 | with open("chatglm/data/" + item+"_balance.json", "w") as write_file:
19 | for jtem in all_list:
20 | write_file.write(jtem+"\n")
21 |
22 |
--------------------------------------------------------------------------------
/chatglm/ds_test_finetune.sh:
--------------------------------------------------------------------------------
1 |
2 | LR=1e-4
3 |
4 | MASTER_PORT=$(shuf -n 1 -i 10000-65535)
5 |
6 | deepspeed --include localhost:1,2,6,7 --master_port $MASTER_PORT finetune_test.py \
7 | --deepspeed ft_ds_config.json \
8 | --do_predict \
9 | --test_file data/test2.json \
10 | --prompt_column content \
11 | --response_column summary \
12 | --overwrite_cache \
13 | --model_name_or_path /data/caokun/huge_model/chatGLM/ChatGLM-6B-main/ptuning/output/chatglm-ft/checkpoint-5000/ \
14 | --output_dir output/finetune/ \
15 | --overwrite_output_dir \
16 | --max_source_length 64 \
17 | --max_target_length 64 \
18 | --per_device_train_batch_size 4 \
19 | --per_device_eval_batch_size 1 \
20 | --gradient_accumulation_steps 1 \
21 | --predict_with_generate \
22 | --max_steps 5000 \
23 | --logging_steps 10 \
24 | --save_steps 1000 \
25 | --learning_rate $LR \
26 | --fp16
27 |
28 |
--------------------------------------------------------------------------------
/chatglm/ds_test_ptv2.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | LR=2e-2
3 |
4 | CUDA_VISIBLE_DEVICES=2 python3 test.py \
5 | --do_predict \
6 | --test_file data/test2.json \
7 | --prompt_column content \
8 | --response_column summary \
9 | --overwrite_cache \
10 | --model_name_or_path /data/caokun/huge_model/chatGLM/model/ \
11 | --output_dir output/ptuning/ \
12 | --overwrite_output_dir \
13 | --max_source_length 64 \
14 | --max_target_length 64 \
15 | --per_device_train_batch_size 1 \
16 | --per_device_eval_batch_size 1 \
17 | --gradient_accumulation_steps 16 \
18 | --predict_with_generate \
19 | --max_steps 3000 \
20 | --logging_steps 10 \
21 | --save_steps 1000 \
22 | --learning_rate $LR \
23 | --pre_seq_len $PRE_SEQ_LEN \
24 | --quantization_bit 4 \
25 | # --ptuning_checkpoint /data/caokun/huge_model/chatGLM/ChatGLM-6B-main/ptuning/output/v2/checkpoint-1000/
26 |
27 |
--------------------------------------------------------------------------------
/chatglm/finetune_test.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/paper-source-trace/d0896c62508ad4bb00d28b28b8411c92034a409d/chatglm/finetune_test.py
--------------------------------------------------------------------------------
/chatglm/ft_ds_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": "auto",
3 | "zero_allow_untested_optimizer": true,
4 | "fp16": {
5 | "enabled": "auto",
6 | "loss_scale": 0,
7 | "initial_scale_power": 16,
8 | "loss_scale_window": 1000,
9 | "hysteresis": 2,
10 | "min_loss_scale": 1
11 | },
12 | "zero_optimization": {
13 | "stage": 3,
14 | "allgather_partitions": true,
15 | "allgather_bucket_size": 5e8,
16 | "overlap_comm": false,
17 | "reduce_scatter": true,
18 | "reduce_bucket_size": 5e8,
19 | "contiguous_gradients" : true
20 | }
21 | }
--------------------------------------------------------------------------------
/chatglm/output/test/trans.py:
--------------------------------------------------------------------------------
1 | # -*-coding:gbk -*-
2 | import os
3 | base_path = "v2/"
4 | all_file = os.listdir(base_path)
5 | for ktem in all_file:
6 | result = [0, 0, 0, 0] #"YesYes, YesNo, NoYes, NoNo"
7 | with open(base_path + ktem + "/generated_predictions.txt", "r") as read_file:
8 | all_lines = read_file.readlines()
9 | for item in all_lines:
10 | data = item.strip()
11 | data_list = data.split(",")
12 | if data_list[0][-2] == "s":
13 | if data_list[1][-3] == "s":
14 | result[0] += 1
15 | else:
16 | result[1] += 1
17 | else:
18 | if data_list[1][-3] == "s":
19 | result[2] += 1
20 | else:
21 | result[3] += 1
22 | Accuracy = (result[0]+result[3])/(sum(result))
23 | Precision = result[0]/(result[0]+result[2])
24 | Recall = result[0]/(result[0]+result[1])
25 | print(ktem+":")
26 | print("Accuracy:" + str(Accuracy))
27 | print("Precision:"+ str(Precision))
28 | print("Recall:"+ str(Recall))
29 | print("F1:"+ str((2*Precision*Recall)/(Precision+Recall)))
30 | print(result)
31 |
32 |
33 |
--------------------------------------------------------------------------------
/chatglm/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 | import json
25 |
26 | import numpy as np
27 | from datasets import load_dataset
28 | import jieba
29 | from rouge_chinese import Rouge
30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
31 | import torch
32 |
33 | import transformers
34 | from transformers import (
35 | AutoConfig,
36 | AutoModel,
37 | AutoTokenizer,
38 | DataCollatorForSeq2Seq,
39 | HfArgumentParser,
40 | Seq2SeqTrainingArguments,
41 | set_seed,
42 | )
43 | from trainer_seq2seq import Seq2SeqTrainer
44 |
45 | from arguments import ModelArguments, DataTrainingArguments
46 | logger = logging.getLogger(__name__)
47 |
48 | def main():
49 |
50 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
51 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
52 | # If we pass only one argument to the script and it's the path to a json file,
53 | # let's parse it to get our arguments.
54 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
55 | else:
56 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
57 |
58 | # Setup logging
59 | logging.basicConfig(
60 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
61 | datefmt="%m/%d/%Y %H:%M:%S",
62 | handlers=[logging.StreamHandler(sys.stdout)],
63 | )
64 |
65 | if training_args.should_log:
66 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
67 | transformers.utils.logging.set_verbosity_info()
68 |
69 | log_level = training_args.get_process_log_level()
70 | logger.setLevel(log_level)
71 | # datasets.utils.logging.set_verbosity(log_level)
72 | transformers.utils.logging.set_verbosity(log_level)
73 | transformers.utils.logging.enable_default_handler()
74 | transformers.utils.logging.enable_explicit_format()
75 |
76 | # Log on each process the small summary:
77 | logger.warning(
78 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
79 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
80 | )
81 | logger.info(f"Training/evaluation parameters {training_args}")
82 |
83 | # Set seed before initializing model.
84 | set_seed(training_args.seed)
85 |
86 | # Load dataset
87 | data_files = {}
88 | if data_args.train_file is not None:
89 | data_files["train"] = data_args.train_file
90 | extension = data_args.train_file.split(".")[-1]
91 | if data_args.validation_file is not None:
92 | data_files["validation"] = data_args.validation_file
93 | extension = data_args.validation_file.split(".")[-1]
94 | if data_args.test_file is not None:
95 | data_files["test"] = data_args.test_file
96 | extension = data_args.test_file.split(".")[-1]
97 |
98 | raw_datasets = load_dataset(
99 | extension,
100 | data_files=data_files,
101 | cache_dir=model_args.cache_dir,
102 | use_auth_token=True if model_args.use_auth_token else None,
103 | )
104 | # Load pretrained model and tokenizer
105 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
106 | config.pre_seq_len = model_args.pre_seq_len
107 | config.prefix_projection = model_args.prefix_projection
108 |
109 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
110 |
111 | if model_args.ptuning_checkpoint is not None:
112 | # Evaluation
113 | # Loading extra state dict of prefix encoder
114 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^load model~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
115 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
116 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
117 | new_prefix_state_dict = {}
118 | for k, v in prefix_state_dict.items():
119 | if k.startswith("transformer.prefix_encoder."):
120 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
121 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
122 | else:
123 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
124 |
125 | if model_args.quantization_bit is not None:
126 | print(f"Quantized to {model_args.quantization_bit} bit")
127 | model = model.quantize(model_args.quantization_bit)
128 | if model_args.pre_seq_len is not None:
129 | # P-tuning v2
130 | model = model.half()
131 | model.transformer.prefix_encoder.float()
132 | else:
133 | # Finetune
134 | model = model.float()
135 |
136 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
137 |
138 | # Preprocessing the datasets.
139 | # We need to tokenize inputs and targets.
140 | if training_args.do_train:
141 | column_names = raw_datasets["train"].column_names
142 | elif training_args.do_eval:
143 | column_names = raw_datasets["validation"].column_names
144 | elif training_args.do_predict:
145 | column_names = raw_datasets["test"].column_names
146 | else:
147 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
148 | return
149 |
150 | # Get the column names for input/target.
151 | prompt_column = data_args.prompt_column
152 | response_column = data_args.response_column
153 | history_column = data_args.history_column
154 |
155 | # Temporarily set max_target_length for training.
156 | max_target_length = data_args.max_target_length
157 |
158 | def preprocess_function_eval(examples):
159 | inputs, targets = [], []
160 | for i in range(len(examples[prompt_column])):
161 | if examples[prompt_column][i] and examples[response_column][i]:
162 | query = examples[prompt_column][i]
163 | if history_column is None or len(examples[history_column][i]) == 0:
164 | prompt = query
165 | else:
166 | prompt = ""
167 | history = examples[history_column][i]
168 | for turn_idx, (old_query, response) in enumerate(history):
169 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
170 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
171 | inputs.append(prompt)
172 | targets.append(examples[response_column][i])
173 |
174 | inputs = [prefix + inp for inp in inputs]
175 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
176 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
177 |
178 | if data_args.ignore_pad_token_for_loss:
179 | labels["input_ids"] = [
180 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
181 | ]
182 | model_inputs["labels"] = labels["input_ids"]
183 |
184 | return model_inputs
185 |
186 | def preprocess_function_train(examples):
187 | max_seq_length = data_args.max_source_length + data_args.max_target_length
188 |
189 | model_inputs = {
190 | "input_ids": [],
191 | "labels": [],
192 | }
193 | for i in range(len(examples[prompt_column])):
194 | if examples[prompt_column][i] and examples[response_column][i]:
195 | query, answer = examples[prompt_column][i], examples[response_column][i]
196 |
197 | if history_column is None:
198 | prompt = query
199 | else:
200 | prompt = ""
201 | history = examples[history_column][i]
202 | for turn_idx, (old_query, response) in enumerate(history):
203 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
204 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
205 |
206 | prompt = prefix + prompt
207 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
208 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
209 |
210 | if len(a_ids) > data_args.max_source_length - 1:
211 | a_ids = a_ids[: data_args.max_source_length - 1]
212 |
213 | if len(b_ids) > data_args.max_target_length - 2:
214 | b_ids = b_ids[: data_args.max_target_length - 2]
215 |
216 | input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
217 |
218 | context_length = input_ids.index(tokenizer.bos_token_id)
219 | mask_position = context_length - 1
220 | labels = [-100] * context_length + input_ids[mask_position+1:]
221 |
222 | pad_len = max_seq_length - len(input_ids)
223 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
224 | labels = labels + [tokenizer.pad_token_id] * pad_len
225 | if data_args.ignore_pad_token_for_loss:
226 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
227 |
228 | model_inputs["input_ids"].append(input_ids)
229 | model_inputs["labels"].append(labels)
230 |
231 | return model_inputs
232 |
233 | def print_dataset_example(example):
234 | print("input_ids",example["input_ids"])
235 | print("inputs", tokenizer.decode(example["input_ids"]))
236 | print("label_ids", example["labels"])
237 | print("labels", tokenizer.decode(example["labels"]))
238 |
239 | if training_args.do_train:
240 | if "train" not in raw_datasets:
241 | raise ValueError("--do_train requires a train dataset")
242 | train_dataset = raw_datasets["train"]
243 | if data_args.max_train_samples is not None:
244 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
245 | train_dataset = train_dataset.select(range(max_train_samples))
246 | with training_args.main_process_first(desc="train dataset map pre-processing"):
247 | train_dataset = train_dataset.map(
248 | preprocess_function_train,
249 | batched=True,
250 | num_proc=data_args.preprocessing_num_workers,
251 | remove_columns=column_names,
252 | load_from_cache_file=not data_args.overwrite_cache,
253 | desc="Running tokenizer on train dataset",
254 | )
255 | print_dataset_example(train_dataset[0])
256 |
257 | if training_args.do_eval:
258 | max_target_length = data_args.val_max_target_length
259 | if "validation" not in raw_datasets:
260 | raise ValueError("--do_eval requires a validation dataset")
261 | eval_dataset = raw_datasets["validation"]
262 | if data_args.max_eval_samples is not None:
263 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
264 | eval_dataset = eval_dataset.select(range(max_eval_samples))
265 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
266 | eval_dataset = eval_dataset.map(
267 | preprocess_function_eval,
268 | batched=True,
269 | num_proc=data_args.preprocessing_num_workers,
270 | remove_columns=column_names,
271 | load_from_cache_file=not data_args.overwrite_cache,
272 | desc="Running tokenizer on validation dataset",
273 | )
274 | print_dataset_example(eval_dataset[0])
275 |
276 | if training_args.do_predict:
277 | max_target_length = data_args.val_max_target_length
278 | if "test" not in raw_datasets:
279 | raise ValueError("--do_predict requires a test dataset")
280 | predict_dataset = raw_datasets["test"]
281 | if data_args.max_predict_samples is not None:
282 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
283 | predict_dataset = predict_dataset.select(range(max_predict_samples))
284 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
285 | predict_dataset = predict_dataset.map(
286 | preprocess_function_eval,
287 | batched=True,
288 | num_proc=data_args.preprocessing_num_workers,
289 | remove_columns=column_names,
290 | load_from_cache_file=not data_args.overwrite_cache,
291 | desc="Running tokenizer on prediction dataset",
292 | )
293 | print_dataset_example(predict_dataset[0])
294 |
295 | # Data collator
296 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
297 | data_collator = DataCollatorForSeq2Seq(
298 | tokenizer,
299 | model=model,
300 | label_pad_token_id=label_pad_token_id,
301 | pad_to_multiple_of=None,
302 | padding=False
303 | )
304 |
305 | # Metric
306 | def compute_metrics(eval_preds):
307 | preds, labels = eval_preds
308 | if isinstance(preds, tuple):
309 | preds = preds[0]
310 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
311 | if data_args.ignore_pad_token_for_loss:
312 | # Replace -100 in the labels as we can't decode them.
313 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
314 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
315 |
316 | score_dict = {
317 | "rouge-1": [],
318 | "rouge-2": [],
319 | "rouge-l": [],
320 | "bleu-4": []
321 | }
322 | for pred, label in zip(decoded_preds, decoded_labels):
323 | hypothesis = list(jieba.cut(pred))
324 | reference = list(jieba.cut(label))
325 | rouge = Rouge()
326 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
327 | result = scores[0]
328 |
329 | for k, v in result.items():
330 | score_dict[k].append(round(v["f"] * 100, 4))
331 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
332 | score_dict["bleu-4"].append(round(bleu_score * 100, 4))
333 |
334 | for k, v in score_dict.items():
335 | score_dict[k] = float(np.mean(v))
336 | return score_dict
337 |
338 | # Override the decoding parameters of Seq2SeqTrainer
339 | training_args.generation_max_length = (
340 | training_args.generation_max_length
341 | if training_args.generation_max_length is not None
342 | else data_args.val_max_target_length
343 | )
344 | training_args.generation_num_beams = (
345 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
346 | )
347 | # Initialize our Trainer
348 | trainer = Seq2SeqTrainer(
349 | model=model,
350 | args=training_args,
351 | train_dataset=train_dataset if training_args.do_train else None,
352 | eval_dataset=eval_dataset if training_args.do_eval else None,
353 | tokenizer=tokenizer,
354 | data_collator=data_collator,
355 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
356 | save_prefixencoder=model_args.pre_seq_len is not None
357 | )
358 |
359 | # Training
360 | if training_args.do_train:
361 | checkpoint = None
362 | if training_args.resume_from_checkpoint is not None:
363 | checkpoint = training_args.resume_from_checkpoint
364 | # elif last_checkpoint is not None:
365 | # checkpoint = last_checkpoint
366 | model.gradient_checkpointing_enable()
367 | model.enable_input_require_grads()
368 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
369 | # trainer.save_model() # Saves the tokenizer too for easy upload
370 |
371 | metrics = train_result.metrics
372 | max_train_samples = (
373 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
374 | )
375 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
376 |
377 | trainer.log_metrics("train", metrics)
378 | trainer.save_metrics("train", metrics)
379 | trainer.save_state()
380 |
381 | # Evaluation
382 | results = {}
383 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
384 | if training_args.do_eval:
385 | logger.info("*** Evaluate ***")
386 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
387 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
388 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
389 |
390 | trainer.log_metrics("eval", metrics)
391 | trainer.save_metrics("eval", metrics)
392 |
393 | if training_args.do_predict:
394 | logger.info("*** Predict ***")
395 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
396 | metrics = predict_results.metrics
397 | max_predict_samples = (
398 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
399 | )
400 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
401 |
402 | trainer.log_metrics("predict", metrics)
403 | trainer.save_metrics("predict", metrics)
404 |
405 | if trainer.is_world_process_zero():
406 | if training_args.predict_with_generate:
407 | predictions = tokenizer.batch_decode(
408 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
409 | )
410 | predictions = [pred.strip() for pred in predictions]
411 | labels = tokenizer.batch_decode(
412 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
413 | )
414 | labels = [label.strip() for label in labels]
415 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
416 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
417 | for p, l in zip(predictions, labels):
418 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
419 | writer.write(f"{res}\n")
420 | return results
421 |
422 |
423 | def _mp_fn(index):
424 | # For xla_spawn (TPUs)
425 | main()
426 |
427 |
428 | if __name__ == "__main__":
429 | main()
430 |
--------------------------------------------------------------------------------
/chatglm/trainer_seq2seq.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Dict, List, Optional, Tuple, Union
16 |
17 | import torch
18 | from torch import nn
19 | from torch.utils.data import Dataset
20 |
21 | from transformers.deepspeed import is_deepspeed_zero3_enabled
22 | from trainer import Trainer
23 | from transformers.trainer_utils import PredictionOutput
24 | from transformers.utils import logging
25 |
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 |
30 | class Seq2SeqTrainer(Trainer):
31 | def evaluate(
32 | self,
33 | eval_dataset: Optional[Dataset] = None,
34 | ignore_keys: Optional[List[str]] = None,
35 | metric_key_prefix: str = "eval",
36 | **gen_kwargs
37 | ) -> Dict[str, float]:
38 | """
39 | Run evaluation and returns metrics.
40 |
41 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
42 | (pass it to the init `compute_metrics` argument).
43 |
44 | You can also subclass and override this method to inject custom behavior.
45 |
46 | Args:
47 | eval_dataset (`Dataset`, *optional*):
48 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
49 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
50 | method.
51 | ignore_keys (`List[str]`, *optional*):
52 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
53 | gathering predictions.
54 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
55 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
56 | "eval_bleu" if the prefix is `"eval"` (default)
57 | max_length (`int`, *optional*):
58 | The maximum target length to use when predicting with the generate method.
59 | num_beams (`int`, *optional*):
60 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
61 | beam search.
62 | gen_kwargs:
63 | Additional `generate` specific kwargs.
64 |
65 | Returns:
66 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
67 | dictionary also contains the epoch number which comes from the training state.
68 | """
69 |
70 | gen_kwargs = gen_kwargs.copy()
71 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
72 | gen_kwargs["max_length"] = self.args.generation_max_length
73 | gen_kwargs["num_beams"] = (
74 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
75 | )
76 | self._gen_kwargs = gen_kwargs
77 |
78 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
79 |
80 | def predict(
81 | self,
82 | test_dataset: Dataset,
83 | ignore_keys: Optional[List[str]] = None,
84 | metric_key_prefix: str = "test",
85 | **gen_kwargs
86 | ) -> PredictionOutput:
87 | """
88 | Run prediction and returns predictions and potential metrics.
89 |
90 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
91 | will also return metrics, like in `evaluate()`.
92 |
93 | Args:
94 | test_dataset (`Dataset`):
95 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
96 | `model.forward()` method are automatically removed. Has to implement the method `__len__`
97 | ignore_keys (`List[str]`, *optional*):
98 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
99 | gathering predictions.
100 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
101 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
102 | "eval_bleu" if the prefix is `"eval"` (default)
103 | max_length (`int`, *optional*):
104 | The maximum target length to use when predicting with the generate method.
105 | num_beams (`int`, *optional*):
106 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
107 | beam search.
108 | gen_kwargs:
109 | Additional `generate` specific kwargs.
110 |
111 |
112 |
113 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
114 | padding in a token classification task) the predictions will be padded (on the right) to allow for
115 | concatenation into one array. The padding index is -100.
116 |
117 |
118 |
119 | Returns: *NamedTuple* A namedtuple with the following keys:
120 |
121 | - predictions (`np.ndarray`): The predictions on `test_dataset`.
122 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
123 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
124 | labels).
125 | """
126 |
127 | gen_kwargs = gen_kwargs.copy()
128 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
129 | gen_kwargs["max_length"] = self.args.generation_max_length
130 | gen_kwargs["num_beams"] = (
131 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
132 | )
133 | self._gen_kwargs = gen_kwargs
134 |
135 |
136 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
137 |
138 | def prediction_step(
139 | self,
140 | model: nn.Module,
141 | inputs: Dict[str, Union[torch.Tensor, Any]],
142 | prediction_loss_only: bool,
143 | ignore_keys: Optional[List[str]] = None,
144 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
145 | """
146 | Perform an evaluation step on `model` using `inputs`.
147 |
148 | Subclass and override to inject custom behavior.
149 |
150 | Args:
151 | model (`nn.Module`):
152 | The model to evaluate.
153 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
154 | The inputs and targets of the model.
155 |
156 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
157 | argument `labels`. Check your model's documentation for all accepted arguments.
158 | prediction_loss_only (`bool`):
159 | Whether or not to return the loss only.
160 |
161 | Return:
162 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
163 | labels (each being optional).
164 | """
165 |
166 | if not self.args.predict_with_generate or prediction_loss_only:
167 | return super().prediction_step(
168 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
169 | )
170 |
171 | has_labels = "labels" in inputs
172 | inputs = self._prepare_inputs(inputs)
173 |
174 | # XXX: adapt synced_gpus for fairscale as well
175 | gen_kwargs = self._gen_kwargs.copy()
176 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
177 | gen_kwargs["max_length"] = self.model.config.max_length
178 | gen_kwargs["num_beams"] = (
179 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
180 | )
181 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
182 | gen_kwargs["synced_gpus"] = (
183 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
184 | )
185 |
186 | if "attention_mask" in inputs:
187 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
188 | if "position_ids" in inputs:
189 | gen_kwargs["position_ids"] = inputs.get("position_ids", None)
190 | if "global_attention_mask" in inputs:
191 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
192 |
193 | # prepare generation inputs
194 | # some encoder-decoder models can have varying encoder's and thus
195 | # varying model input names
196 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
197 | generation_inputs = inputs[self.model.encoder.main_input_name]
198 | else:
199 | generation_inputs = inputs[self.model.main_input_name]
200 |
201 | gen_kwargs["input_ids"] = generation_inputs
202 | generated_tokens = self.model.generate(**gen_kwargs)
203 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:]
204 |
205 | # in case the batch is shorter than max length, the output should be padded
206 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
207 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
208 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
209 | gen_kwargs["max_new_tokens"] + 1
210 | ):
211 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
212 |
213 | loss = None
214 |
215 | if self.args.prediction_loss_only:
216 | return (loss, None, None)
217 |
218 | if has_labels:
219 | labels = inputs["labels"]
220 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
221 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
222 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
223 | gen_kwargs["max_new_tokens"] + 1
224 | ):
225 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
226 | else:
227 | labels = None
228 |
229 | return (loss, generated_tokens, labels)
230 |
231 | def _pad_tensors_to_max_len(self, tensor, max_length):
232 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
233 | # If PAD token is not defined at least EOS token has to be defined
234 | pad_token_id = (
235 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
236 | )
237 | else:
238 | if self.model.config.pad_token_id is not None:
239 | pad_token_id = self.model.config.pad_token_id
240 | else:
241 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
242 |
243 | padded_tensor = pad_token_id * torch.ones(
244 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
245 | )
246 | padded_tensor[:, : tensor.shape[-1]] = tensor
247 | return padded_tensor
248 |
--------------------------------------------------------------------------------
/claude/claude.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import json
3 |
4 | url = "" # Your url
5 |
6 | headers = {
7 | "Authorization": "", # Your Authorization
8 | "content-type": "application/json"
9 | }
10 | with open("result/claude.json", "r") as read_file:
11 | finished_dic = json.load(read_file).keys()
12 | result_dic = {}
13 | write_file = open("result/claude.json", "a")
14 | try:
15 | with open("test.json", "r") as read_file:
16 | data_dic = json.load(read_file)
17 | for item in data_dic.keys():
18 | if item in finished_dic:
19 | continue
20 | result_list = []
21 | for jtem in data_dic[item]:
22 | this_data = json.loads(jtem.strip())
23 | content, summery = this_data["content"], this_data["summary"]
24 | data = {
25 | "messages": [
26 | {
27 | "role": "user",
28 | "content": content,
29 | }
30 | ],
31 | "model": "claude-instant-1-100k",
32 | "max_tokens_to_sample": 300,
33 | }
34 | response = requests.post(url, headers=headers, json=data)
35 | answer = json.loads(response.text)["choices"][0]["message"]["content"]
36 | result_list.append({"labels": summery, "predict":answer})
37 | result_dic[item] = result_list
38 | except:
39 | pass
40 | finally:
41 | json.dump(result_dic, write_file, indent=2)
--------------------------------------------------------------------------------
/galactica/calc_map.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from collections import defaultdict as dd
4 | from tqdm import tqdm
5 | import warnings
6 | warnings.filterwarnings('ignore')
7 | from sklearn.metrics import average_precision_score
8 |
9 |
10 | result_file = "galactica_lora_result_2.json"
11 | pid_to_labels = dd(list)
12 | pid_to_predict = dd(list)
13 | with open(os.path.join("result", result_file), "r") as rf:
14 | for i, line in tqdm(enumerate(rf)):
15 | cur_item = json.loads(line.strip())
16 | pid = cur_item["pid"]
17 | cur_label = cur_item["label"]
18 | pid_to_labels[pid].append(cur_label)
19 | cur_predict = cur_item["predict"]
20 | try:
21 | s_idx = cur_predict.index("The answer is")
22 | cur_predict = cur_predict[s_idx + 13:]
23 | except:
24 | pass
25 | if "no" in cur_predict or "No" in cur_predict or "NO" in cur_predict:
26 | pid_to_predict[pid].append(0)
27 | else:
28 | pid_to_predict[pid].append(1)
29 |
30 | maps = []
31 | for pid in tqdm(pid_to_labels):
32 | cur_labels = pid_to_labels[pid]
33 | if sum(cur_labels) == 0:
34 | continue
35 | cur_predict = pid_to_predict[pid]
36 | cur_map = average_precision_score(cur_labels, cur_predict)
37 | maps.append(cur_map)
38 |
39 | print(maps)
40 | print(f"{i}:map={sum(maps)/len(maps)}")
41 |
42 |
43 | """
44 | dic_list = ["result/gala_standard.json"]
45 | for i in range(len(dic_list)):
46 | result_list = []
47 | with open(dic_list[i], "r") as read_file:
48 | result_dic = json.load(read_file)
49 | for key in result_dic.keys():
50 | pre = []
51 | res = []
52 | for jtem in result_dic[key]:
53 | # data = json.loads(jtem.strip())
54 | data = jtem
55 | if data["labels"] == "Yes":
56 | res.append(1)
57 | else:
58 | res.append(0)
59 | if "no" in data["predict"] or "No" in data["predict"] or "NO" in data["predict"]:
60 | pre.append(0)
61 | elif "yes" in data["predict"] or "Yes" in data["predict"] or "YES" in data["predict"]:
62 | pre.append(1)
63 | else:
64 | pre.append(0)
65 | if sum(res) == 0:
66 | continue
67 | cur_map = average_precision_score(res, pre)
68 | print(cur_map)
69 | result_list.append(cur_map)
70 | print(f"{i}:map={sum(result_list)/len(result_list)}")
71 | """
--------------------------------------------------------------------------------
/galactica/gala_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
3 | import warnings
4 | warnings.filterwarnings('ignore')
5 | import galai as gal
6 | import json
7 | from tqdm import tqdm
8 | from galai.notebook_utils import *
9 |
10 | model = gal.load_model("standard")
11 | # with open("../chatglm/data/test.json", "r") as read_file:
12 | # all_lines = read_file.readlines()
13 | os.makedirs("result", exist_ok=True)
14 |
15 | write_file_2 = open("result/galactica_standard_4.json", "a")
16 | # with open("../chatglm/data/test.json", "r") as read_file:
17 | # data_dic = json.load(read_file)
18 | result_dic = {}
19 | with open("data/test.jsonl", "r") as rf:
20 | for i, line in tqdm(enumerate(rf)):
21 | item = json.loads(line.strip())
22 | context = item["context"]
23 | answer = item["label"]
24 | item_new = item
25 | if len(context) <= 3:
26 | item_new["predict"] = 0
27 | else:
28 | result = model.generate([" ".join(context.split()[-200:])])
29 | item_new["predict"] = result[0]
30 | write_file_2.write(json.dumps(item_new, ensure_ascii=False) + "\n")
31 | write_file_2.flush()
32 |
33 | write_file_2.close()
34 |
35 |
36 | """
37 | write_file = open("result/gala_standard.json", "w")
38 | write_file_2 = open("result/gala_standard_2.json", "a")
39 | with open("../chatglm/data/test.json", "r") as read_file:
40 | data_dic = json.load(read_file)
41 | result_dic = {}
42 | for item in tqdm(data_dic.keys()):
43 | data_list = data_dic[item]
44 | result_dic[item] = []
45 | for jtem in data_list:
46 | # dic = json.loads(jtem.strip())
47 | dic = jtem
48 | content = dic["content"]
49 | answer = dic["summary"]
50 | if len(content) <= 3:
51 | result = {"labels": answer, "predict": "No"}
52 | else:
53 | result = model.generate([" ".join(content.split()[-200:])])
54 | result = {"labels": answer, "predict":result[0]}
55 | result_with_id = {"id": item, "labels": answer, "predict":result["predict"]}
56 | write_file_2.write(json.dumps(result_with_id, ensure_ascii=False) + "\n")
57 | write_file_2.flush()
58 | result_dic[item].append(result)
59 | json.dump(result_dic, write_file, indent=2)
60 | write_file.close()
61 | write_file_2.close()
62 | """
63 |
--------------------------------------------------------------------------------
/galactica/gala_finetune.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"
3 | from xturing.datasets.instruction_dataset import InstructionDataset
4 | from xturing.models import BaseModel
5 | import json
6 | from tqdm import tqdm
7 |
8 | data_dic = {}
9 | # with open("data/train_balance_2.json", "r") as read_file:
10 | # all_lines = read_file.readlines()
11 | instruction_dataset = InstructionDataset("data/pst_data")
12 | # Initializes the model
13 | #model = BaseModel.load("/root/huge_model/galactica/galactica")
14 | # model = BaseModel.create("galactica_lora")
15 | model = BaseModel.create("galactica")
16 |
17 | # Finetuned the model
18 | # model.finetune(dataset=instruction_dataset)
19 | # model.load("saved_model/")
20 |
21 | # Once the model has been finetuned, you can start doing inferences
22 | output = model.generate(texts=["Why LLM models are becoming so important?"])
23 | print("Generated output by the model: {}".format(output))
24 |
25 | # Save the model
26 | # model.save("saved_model/")
27 |
28 | # write_file = open("result/galactica_lora_result.json", "w")
29 | # write_file_2 = open("result/galactica_lora_result_2.json", "a")
30 | write_file_2 = open("result/galactica_standard_3.json", "a")
31 | # with open("../chatglm/data/test.json", "r") as read_file:
32 | # data_dic = json.load(read_file)
33 | result_dic = {}
34 | with open("data/test.jsonl", "r") as rf:
35 | for i, line in tqdm(enumerate(rf)):
36 | item = json.loads(line.strip())
37 | context = item["context"]
38 | answer = item["label"]
39 | item_new = item
40 | if len(context) <= 3:
41 | item_new["predict"] = 0
42 | else:
43 | result = model.generate(texts=[" ".join(context.split()[-200:])])
44 | item_new["predict"] = result[0]
45 | write_file_2.write(json.dumps(item_new, ensure_ascii=False) + "\n")
46 | write_file_2.flush()
47 |
48 | write_file_2.close()
49 |
50 | """
51 | for item in tqdm(data_dic.keys()):
52 | data_list = data_dic[item]
53 | result_dic[item] = []
54 | for jtem in data_list:
55 | dic = jtem
56 | content = dic["content"]
57 | answer = dic["summary"]
58 | if len(content) <= 3:
59 | result = {"labels": answer, "predict": "No"}
60 | else:
61 | result = model.generate(texts=[content])
62 | result = {"labels": answer, "predict":result[0]}
63 | result_with_id = {"id": item, "labels": answer, "predict":result["predict"]}
64 | write_file_2.write(json.dumps(result_with_id, ensure_ascii=False) + "\n")
65 | result_dic[item].append(result)
66 | json.dump(result_dic, write_file, indent=2)
67 | write_file.close()
68 | write_file_2.close()
69 | """
--------------------------------------------------------------------------------
/galactica/process.py:
--------------------------------------------------------------------------------
1 | #import json
2 | #import jsonlines
3 | #with open("datasets/train_balance.json", "r") as read_file:
4 | # all_lines = read_file.readlines()
5 | #write_file = jsonlines.open("datasets/train_balance_3.jsonl", "w")
6 | #for i, item in enumerate(all_lines):
7 | # data_dic = json.loads(item.strip())
8 | # new_dic = {"id":f"seed_task_{i}", "name":f"{i}", "instruction":data_dic["content"], "instances":[{"input":"", "output":data_dic["summary"], "is_classification": False}]}
9 | # write_file.write(json.dumps(new_dic))
10 | import json
11 | # import random
12 | import numpy as np
13 | from tqdm import tqdm
14 | from collections import defaultdict as dd
15 | from os.path import join
16 | import os
17 | from bs4 import BeautifulSoup
18 | from fuzzywuzzy import fuzz
19 |
20 | import utils
21 | import settings
22 |
23 | from datasets import Dataset, DatasetDict
24 |
25 | # Convert the alpaca JSON dataset to HF format
26 |
27 |
28 | # Right now only the HuggingFace datasets are supported, that's why the JSON Alpaca dataset
29 | # needs to be converted to the HuggingFace format. In addition, this HF dataset should have 3 columns for instruction finetuning: instruction, text and target.
30 | def preprocess_alpaca_json_data(alpaca_dataset_path: str):
31 | """Creates a dataset given the alpaca JSON dataset. You can download it here: https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json
32 |
33 | :param alpaca_dataset_path: path of the Alpaca dataset
34 | """
35 | read_file = open(alpaca_dataset_path, "r")
36 | instructions = []
37 | inputs = []
38 | outputs = []
39 | for item in read_file:
40 | dic = json.loads(item.strip())
41 | instructions.append(dic["instruction"])
42 | inputs.append(dic["input"])
43 | outputs.append(dic["output"])
44 |
45 | data_dict = {
46 | "train": {"instruction": instructions, "text": inputs, "target": outputs}
47 | }
48 |
49 | dataset = DatasetDict()
50 | # using your `Dict` object
51 | for k, v in data_dict.items():
52 | dataset[k] = Dataset.from_dict(v)
53 |
54 | dataset.save_to_disk(str("galactica/data/pst_data"))
55 |
56 |
57 | def make_balance_data_for_galactica():
58 | target_list = ["train", "valid"]
59 | yes_list, no_list = [], []
60 | for item in target_list:
61 | with open("glm/data/" + item + '.json', "r") as read_file:
62 | all_lines = read_file.readlines()
63 | for data in all_lines:
64 | data_dic = json.loads(data.strip())
65 | if data_dic["label"] == 1:
66 | yes_list.append(data_dic)
67 | else:
68 | no_list.append(data_dic)
69 | np.random.seed(42)
70 | no_list = np.random.choice(no_list, len(yes_list), replace=False).tolist()
71 | all_list = yes_list+no_list
72 | print(len(all_list))
73 | np.random.shuffle(all_list)
74 | with open("galactica/data/" + item+"_balance_gala.json", "w") as write_file:
75 | for jtem in all_list:
76 | new_item = {}
77 | new_item["output"] = "Yes" if jtem["label"] else "No"
78 | new_item["instruction"] = jtem["inputs_pretokenized"][:-7]
79 | new_item["input"] = ""
80 | write_file.write(json.dumps(new_item)+"\n")
81 |
82 |
83 | def gen_test_data_json_lines(year=2023):
84 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year))
85 | papers_test = utils.load_json(data_year_dir, "paper_source_trace_test.json")
86 | pids_test = {p["_id"] for p in papers_test}
87 |
88 | in_dir = join(settings.DATA_TRACE_DIR, "paper-xml")
89 | files = []
90 | for f in os.listdir(in_dir):
91 | cur_pid = f.split(".")[0]
92 | if f.endswith(".xml") and cur_pid in pids_test:
93 | files.append(f)
94 |
95 | truths = papers_test
96 | pid_to_source_titles = dd(list)
97 | for paper in tqdm(truths):
98 | pid = paper["_id"]
99 | for ref in paper["refs_trace"]:
100 | pid_to_source_titles[pid].append(ref["title"].lower())
101 |
102 |
103 | xml_dir = join(settings.DATA_TRACE_DIR, "paper-xml")
104 | wf = open("galactica/data/test.jsonl", "w")
105 |
106 | for paper in tqdm(papers_test):
107 | cur_pid = paper["_id"]
108 | file = join(xml_dir, cur_pid + ".tei.xml")
109 | f = open(file, encoding='utf-8')
110 |
111 | xml = f.read()
112 | bs = BeautifulSoup(xml, "xml")
113 | f.close()
114 |
115 | source_titles = pid_to_source_titles[cur_pid]
116 | if len(source_titles) == 0:
117 | continue
118 |
119 | references = bs.find_all("biblStruct")
120 | bid_to_title = {}
121 | n_refs = 0
122 | for ref in references:
123 | if "xml:id" not in ref.attrs:
124 | continue
125 | bid = ref.attrs["xml:id"]
126 | if ref.analytic is None:
127 | continue
128 | if ref.analytic.title is None:
129 | continue
130 | bid_to_title[bid] = ref.analytic.title.text.lower()
131 | b_idx = int(bid[1:]) + 1
132 | if b_idx > n_refs:
133 | n_refs = b_idx
134 |
135 | bib_to_contexts = utils.find_bib_context(xml)
136 | bib_sorted = sorted(bib_to_contexts.keys())
137 |
138 | for bib in bib_sorted:
139 | cur_bib_idx = int(bib[1:])
140 | if cur_bib_idx + 1 > n_refs:
141 | n_refs = cur_bib_idx + 1
142 |
143 | y_true = [0] * n_refs
144 | y_score = [0] * n_refs
145 |
146 | flag = False
147 | for bid in bid_to_title:
148 | cur_ref_title = bid_to_title[bid]
149 | for label_title in source_titles:
150 | if fuzz.ratio(cur_ref_title, label_title) >= 80:
151 | flag = True
152 | b_idx = int(bid[1:])
153 | y_true[b_idx] = 1
154 |
155 | if not flag:
156 | continue
157 |
158 | contexts_sorted = [" ".join(bib_to_contexts[bib]) for bib in bib_sorted]
159 | # print(bid_to_title)
160 |
161 | for bi, cur_bib in enumerate(bib_sorted):
162 | new_item = {"pid": cur_pid, "bib_id": cur_bib}
163 | cur_context = contexts_sorted[bi]
164 | cur_context = cur_context + ". Is the current reference important? Please answer Yes or No. The answer is "
165 | cur_label = y_true[int(cur_bib[1:])]
166 | new_item["label"] = cur_label
167 | new_item["context"] = cur_context
168 | try:
169 | new_item["title"]= bid_to_title[cur_bib]
170 | except:
171 | pass
172 | wf.write(json.dumps(new_item) + "\n")
173 | wf.flush()
174 |
175 | wf.close()
176 |
177 |
178 | # preprocess_alpaca_json_data("data/train_balance_2.json")
179 |
180 | # make_balance_data_for_galactica()
181 | # preprocess_alpaca_json_data("galactica/data/train_balance_gala.json")
182 | gen_test_data_json_lines(2023)
--------------------------------------------------------------------------------
/glm/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class TensorDataset(Dataset):
7 | def __init__(self, path, tokenizer):
8 | self.variable_num_choices = True
9 | self.example_list = []
10 | with open(path, "r", encoding="utf-8") as file:
11 | for idx, line in enumerate(file):
12 | item = json.loads(line)
13 | item["idx"] = str(idx)
14 | item["answer"]= item["choices_pretokenized"][item["label"]]
15 | self.example_list.append(item)
16 | # self.example_list = self.example_list[:200] # debug
17 | self.examples = {example["idx"]: example for example in self.example_list}
18 | print(f"Creating {len(self.example_list)} examples")
19 | self.dataset_name = "multichoice-" + os.path.basename(path).split(".")[0]
20 |
21 | contexts = [x["inputs_pretokenized"][-500:] for x in self.example_list]
22 | candidates = [x["choices_pretokenized"] for x in self.example_list]
23 | self.labels = [x["label"] for x in self.example_list]
24 | answers = [x["answer"] for x in self.example_list]
25 | self.input_dict = {"contexts": contexts, "candidates": candidates, "labels": self.labels, "answers": answers}
26 |
27 |
28 | def __len__(self):
29 | return len(self.example_list)
30 |
31 | def __getitem__(self, idx):
32 | return {k: v[idx] for k, v in self.input_dict.items()}
33 |
--------------------------------------------------------------------------------
/glm/ds_config_glm_10b.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 4,
3 | "gradient_accumulation_steps": 2,
4 | "steps_per_print": 50,
5 | "gradient_clipping": 1.0,
6 | "zero_optimization": {
7 | "stage": 1,
8 | "contiguous_gradients": false,
9 | "overlap_comm": true,
10 | "reduce_scatter": true,
11 | "reduce_bucket_size": 5e7,
12 | "allgather_bucket_size": 5e7,
13 | "cpu_offload": true
14 | },
15 | "zero_allow_untested_optimizer": true,
16 | "fp16": {
17 | "enabled": true,
18 | "loss_scale": 0,
19 | "loss_scale_window": 1000,
20 | "hysteresis": 2,
21 | "min_loss_scale": 1
22 | },
23 | "optimizer": {
24 | "type": "Adam",
25 | "params": {
26 | "lr": 1e-7,
27 | "betas": [
28 | 0.9,
29 | 0.95
30 | ],
31 | "eps": 1e-8,
32 | "weight_decay": 1e-2
33 | }
34 | },
35 | "activation_checkpointing": {
36 | "partition_activations": false,
37 | "contiguous_memory_optimization": false
38 | },
39 | "wall_clock_breakdown": false
40 | }
41 |
--------------------------------------------------------------------------------
/glm/ds_config_glm_2b.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": 6,
3 | "steps_per_print": 2000,
4 | "optimizer": {
5 | "type": "Adam",
6 | "params": {
7 | "lr": 0.0001,
8 | "betas": [
9 | 0.8,
10 | 0.999
11 | ],
12 | "eps": 1e-8,
13 | "weight_decay": 3e-7
14 | }
15 | },
16 | "gradient_clipping": 1.0,
17 | "prescale_gradients": false,
18 | "fp16": {
19 | "enabled": true,
20 | "fp16_master_weights_and_grads": false,
21 | "loss_scale": 0,
22 | "loss_scale_window": 500,
23 | "hysteresis": 2,
24 | "min_loss_scale": 1,
25 | "initial_scale_power": 15
26 | },
27 | "wall_clock_breakdown": false,
28 | "zero_optimization": {
29 | "stage": 0,
30 | "allgather_partitions": true,
31 | "reduce_scatter": true,
32 | "allgather_bucket_size": 50000000,
33 | "reduce_bucket_size": 50000000,
34 | "overlap_comm": true,
35 | "contiguous_gradients": true,
36 | "cpu_offload": false
37 | }
38 | }
--------------------------------------------------------------------------------
/glm/finetune_glm_10b_ds.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from datetime import datetime
4 | import numpy as np
5 | from sklearn.metrics import top_k_accuracy_score
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMultipleChoice
9 | import deepspeed
10 | from glm.dataset import TensorDataset
11 | from utils import Log
12 |
13 | BATCH_SIZE = 8
14 |
15 | def add_argument():
16 |
17 | parser = argparse.ArgumentParser(description='CIFAR')
18 |
19 | #data
20 | # cuda
21 | parser.add_argument('--with_cuda',
22 | default=False,
23 | action='store_true',
24 | help='use CPU in case there\'s no GPU support')
25 | parser.add_argument('--use_ema',
26 | default=False,
27 | action='store_true',
28 | help='whether use exponential moving average')
29 |
30 | # train
31 | parser.add_argument('-b',
32 | '--batch_size',
33 | default=8,
34 | type=int,
35 | help='mini-batch size (default: 32)')
36 | parser.add_argument('-e',
37 | '--epochs',
38 | default=50,
39 | type=int,
40 | help='number of total epochs (default: 30)')
41 | parser.add_argument('--local_rank',
42 | type=int,
43 | default=-1,
44 | help='local rank passed from distributed launcher')
45 |
46 | parser.add_argument('--log-interval',
47 | type=int,
48 | default=2000,
49 | help="output logging information at a given interval")
50 |
51 | parser.add_argument('--moe',
52 | default=False,
53 | action='store_true',
54 | help='use deepspeed mixture of experts (moe)')
55 |
56 | parser.add_argument('--ep-world-size',
57 | default=1,
58 | type=int,
59 | help='(moe) expert parallel world size')
60 | parser.add_argument('--num-experts',
61 | type=int,
62 | nargs='+',
63 | default=[
64 | 1,
65 | ],
66 | help='number of experts list, MoE related.')
67 | parser.add_argument(
68 | '--mlp-type',
69 | type=str,
70 | default='standard',
71 | help=
72 | 'Only applicable when num-experts > 1, accepts [standard, residual]')
73 | parser.add_argument('--top-k',
74 | default=1,
75 | type=int,
76 | help='(moe) gating top 1 and 2 supported')
77 | parser.add_argument(
78 | '--min-capacity',
79 | default=0,
80 | type=int,
81 | help=
82 | '(moe) minimum capacity of an expert regardless of the capacity_factor'
83 | )
84 | parser.add_argument(
85 | '--noisy-gate-policy',
86 | default=None,
87 | type=str,
88 | help=
89 | '(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter'
90 | )
91 | parser.add_argument(
92 | '--moe-param-group',
93 | default=False,
94 | action='store_true',
95 | help=
96 | '(moe) create separate moe param groups, required when using ZeRO w. MoE'
97 | )
98 |
99 | # Include DeepSpeed configuration arguments
100 | parser = deepspeed.add_config_arguments(parser)
101 |
102 | args = parser.parse_args()
103 |
104 | return args
105 |
106 |
107 | def glm_get_params_for_weight_decay_optimization(module):
108 | weight_decay_params = {'params': []}
109 | no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
110 | for module_ in module.modules():
111 | if isinstance(module_, (torch.nn.LayerNorm)):
112 | no_weight_decay_params['params'].extend(
113 | [p for p in list(module_._parameters.values())
114 | if p is not None and p.requires_grad])
115 | else:
116 | weight_decay_params['params'].extend(
117 | [p for n, p in list(module_._parameters.items())
118 | if p is not None and p.requires_grad and n != 'bias'])
119 | no_weight_decay_params['params'].extend(
120 | [p for n, p in list(module_._parameters.items())
121 | if p is not None and p.requires_grad and n == 'bias'])
122 |
123 | return weight_decay_params,
124 |
125 |
126 | def get_optimizer_param_groups(model):
127 | # Build parameter groups (weight decay and non-decay).
128 | # while isinstance(model, (LocalDDP, TorchDDP, FP16_Module)):
129 | # model = model.module
130 | param_groups = glm_get_params_for_weight_decay_optimization(model)
131 |
132 | # Add model parallel attribute if it is not set.
133 | for param_group in param_groups:
134 | # print('## param_group', len(param_group['params']))
135 | for param in param_group['params']:
136 | if not hasattr(param, 'model_parallel'):
137 | param.model_parallel = False
138 |
139 | return param_groups
140 |
141 |
142 | def train(model_engine, dataloader, tokenizer, fp16):
143 | loss_total = 0
144 | n_item = 0
145 | for data in dataloader:
146 | # print("here3")
147 | # data = {k: v.to(model_engine.local_rank) for k, v in batch.items()}
148 | # if fp16:
149 | # data = {k: v.half() for k, v in data.items()}
150 | # outputs = model_engine(**data)
151 | # loss = outputs.loss
152 | # model_engine.backward(loss)
153 | # model_engine.step()
154 |
155 | # loss_total += loss.item()
156 | # cur_n_item = outputs.logits.shape[0]
157 | # n_item += cur_n_item
158 | inputs = tokenizer(data["contexts"], return_tensors="pt", padding=True)
159 | inputs = tokenizer.build_inputs_for_generation(inputs, targets=data["answers"], max_gen_length=2, padding=False)
160 | inputs = inputs.to(model_engine.local_rank)
161 |
162 | outputs = model_engine(**inputs)
163 | loss = outputs.loss
164 | model_engine.backward(loss)
165 | model_engine.step()
166 |
167 | loss_total += loss.item()
168 | cur_n_item = outputs.logits.shape[0]
169 | n_item += cur_n_item
170 |
171 | return loss_total / n_item
172 |
173 |
174 | def eval(dataloader, model_infer_ds, tokenizer, fp16):
175 | # model_infer = AutoModelForMultipleChoice.from_pretrained(model_name, trust_remote_code=True)
176 | # model_infer.load_state_dict(model.state_dict())
177 | # model_infer = model_infer.half()
178 | # model_infer = model_infer.half().cuda() # half?
179 | # model_infer.eval()
180 |
181 | # ds_engine = deepspeed.init_inference(model_infer, mp_size=1, dtype=torch.half, replace_with_kernel_inject=True)
182 | # model_ds = ds_engine.module
183 |
184 | pred_scores = np.empty((0, 6))
185 | labels = []
186 |
187 | with torch.no_grad():
188 | for data in dataloader:
189 | label = data["labels"]
190 | inputs = tokenizer(data["contexts"], return_tensors="pt", padding=True)
191 | inputs = tokenizer.build_inputs_for_multiple_choice(inputs, data["candidates"])
192 | inputs = inputs.to('cuda')
193 | # if fp16:
194 | # inputs = {k: v.half() if torch.is_tensor(v) else v for k, v in inputs.items()}
195 | # outputs = model_infer(**inputs)
196 | outputs = model_infer_ds(**inputs)
197 | logits = outputs.logits
198 | score = logits.detach().cpu().numpy()
199 | if score.shape[0] > 6:
200 | score = score[:6]
201 | elif score.shape[0] < 6:
202 | score = np.concatenate((score, np.zeros((6 - score.shape[0], score.shape[1]))), axis=0)
203 | pred_scores = np.concatenate((pred_scores, np.transpose(score)), axis=0)
204 | labels.extend(label)
205 |
206 | hit_1 = top_k_accuracy_score(labels, pred_scores, k=1)
207 | hit_3 = top_k_accuracy_score(labels, pred_scores, k=3)
208 | hit_5 = top_k_accuracy_score(labels, pred_scores, k=5)
209 |
210 | # del model_infer
211 |
212 | return [hit_1, hit_3, hit_5]
213 |
214 |
215 | def finetune_ds(model_name="THUDM/glm-2b"):
216 | deepspeed.init_distributed()
217 |
218 | # if torch.distributed.get_rank() != 0:
219 | # # might be downloading data, let rank 0 download first
220 | # torch.distributed.barrier()
221 |
222 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
223 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
224 | # tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
225 | # model = AutoModelForSeq2SeqLM.from_pretrained(model_name, local_files_only=True)
226 | model = model.half()
227 |
228 | model_infer = AutoModelForMultipleChoice.from_pretrained(model_name, trust_remote_code=True)
229 | model_infer = model_infer.half()
230 |
231 | args = add_argument()
232 |
233 | exp_name = "finetune-" + model_name.split('/')[-1] + '-ds-'
234 | exp_name = exp_name + str(datetime.now())
235 |
236 | # Dataset
237 | train_dataset_path = "glm/data/train.json"
238 | valid_dataset_path = "glm/data/valid.json"
239 | test_dataset_path = "glm/data/test.json"
240 | log_path = "./log/" + exp_name + ".log"
241 | os.makedirs("./log", exist_ok=True)
242 |
243 | train_dataset = TensorDataset(train_dataset_path, tokenizer)
244 | valid_dataset = TensorDataset(valid_dataset_path, tokenizer)
245 | test_dataset = TensorDataset(test_dataset_path, tokenizer)
246 |
247 | # Log
248 | log = Log(file_path= log_path)
249 |
250 | # initialize the dataloader
251 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
252 | valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
253 | test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
254 |
255 | parameters = get_optimizer_param_groups(model)
256 |
257 | print("here1")
258 | model_engine, optimizer, trainloader, __ = deepspeed.initialize(
259 | args=args, model=model, model_parameters=parameters, training_data=train_dataset)
260 | print("here2")
261 |
262 | ds_engine_infer = deepspeed.init_inference(model_infer, mp_size=1, dtype=torch.half, replace_with_kernel_inject=True)
263 |
264 | fp16 = model_engine.fp16_enabled()
265 | print(f'fp16={fp16}')
266 |
267 | # ds_engine_infer._load_from_state_dict(model.state_dict())
268 | ds_engine_infer.module.load_state_dict(model.state_dict())
269 | # model_infer_ds = ds_engine_infer.module
270 | # res_valid = eval(valid_loader, model_infer_ds, tokenizer, fp16)
271 | # log.log('valid epoch {} result: {}'.format(-1, str(res_valid)))
272 | # res_test = eval(test_loader, model_infer_ds, tokenizer, fp16)
273 | # log.log('test epoch {} result: {}'.format(-1, str(res_test)))
274 |
275 | valid_best = 0
276 | test_best = []
277 | best_epoch = 0
278 | out_dir = "glm/saved"
279 | os.makedirs(out_dir, exist_ok=True)
280 | for epoch in range(args.epochs):
281 | log.log('begin epoch {}'.format(epoch))
282 | loss_train = train(model_engine, trainloader, tokenizer, fp16)
283 | log.log('train epoch {} end, loss {}'.format(epoch, str(loss_train)))
284 |
285 | # ds_engine_infer._load_from_state_dict(model.state_dict())
286 |
287 | # ds_engine_infer.module.load_state_dict(model.state_dict())
288 | # model_infer_ds = ds_engine_infer.module
289 | # res_valid = eval(valid_loader, model_infer_ds, tokenizer, fp16)
290 | # log.log('valid epoch {} result: {}'.format(epoch, str(res_valid)))
291 | # res_test = eval(test_loader, model_infer_ds, tokenizer, fp16)
292 | # log.log('test epoch {} result: {}'.format(epoch, str(res_test)))
293 |
294 | # if res_valid[0] > valid_best:
295 | # valid_best = res_valid[0]
296 | # test_best = res_test
297 | # best_epoch = epoch
298 | # log.log('best epoch {} result: {}'.format(best_epoch, str(test_best)))
299 |
300 | ds_engine_infer.module.load_state_dict(model.state_dict())
301 | torch.save(ds_engine_infer.module, os.path.join(out_dir, "{}-epoch-{}.pt".format("glm-10b", epoch)))
302 |
303 |
304 |
305 | if __name__ == "__main__":
306 | model_name = "THUDM/glm-10b"
307 | finetune_ds(model_name)
308 |
--------------------------------------------------------------------------------
/glm/finetune_glm_ds.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from datetime import datetime
4 | import numpy as np
5 | from sklearn.metrics import top_k_accuracy_score
6 |
7 | import torch
8 | from torch.utils.data import DataLoader
9 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMultipleChoice
10 | import deepspeed
11 |
12 | from glm.dataset import TensorDataset
13 | from utils import Log
14 |
15 | BATCH_SIZE = 8
16 |
17 |
18 | def add_argument():
19 |
20 | parser = argparse.ArgumentParser(description='CIFAR')
21 |
22 | #data
23 | # cuda
24 | parser.add_argument('--with_cuda',
25 | default=False,
26 | action='store_true',
27 | help='use CPU in case there\'s no GPU support')
28 | parser.add_argument('--use_ema',
29 | default=False,
30 | action='store_true',
31 | help='whether use exponential moving average')
32 |
33 | # train
34 | parser.add_argument('-b',
35 | '--batch_size',
36 | default=8,
37 | type=int,
38 | help='mini-batch size (default: 32)')
39 | parser.add_argument('-e',
40 | '--epochs',
41 | default=10,
42 | type=int,
43 | help='number of total epochs (default: 30)')
44 | parser.add_argument('--local_rank',
45 | type=int,
46 | default=-1,
47 | help='local rank passed from distributed launcher')
48 |
49 | parser.add_argument('--log-interval',
50 | type=int,
51 | default=2000,
52 | help="output logging information at a given interval")
53 |
54 | parser.add_argument('--moe',
55 | default=False,
56 | action='store_true',
57 | help='use deepspeed mixture of experts (moe)')
58 |
59 | parser.add_argument('--ep-world-size',
60 | default=1,
61 | type=int,
62 | help='(moe) expert parallel world size')
63 | parser.add_argument('--num-experts',
64 | type=int,
65 | nargs='+',
66 | default=[
67 | 1,
68 | ],
69 | help='number of experts list, MoE related.')
70 | parser.add_argument(
71 | '--mlp-type',
72 | type=str,
73 | default='standard',
74 | help=
75 | 'Only applicable when num-experts > 1, accepts [standard, residual]')
76 | parser.add_argument('--top-k',
77 | default=1,
78 | type=int,
79 | help='(moe) gating top 1 and 2 supported')
80 | parser.add_argument(
81 | '--min-capacity',
82 | default=0,
83 | type=int,
84 | help=
85 | '(moe) minimum capacity of an expert regardless of the capacity_factor'
86 | )
87 | parser.add_argument(
88 | '--noisy-gate-policy',
89 | default=None,
90 | type=str,
91 | help=
92 | '(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter'
93 | )
94 | parser.add_argument(
95 | '--moe-param-group',
96 | default=False,
97 | action='store_true',
98 | help=
99 | '(moe) create separate moe param groups, required when using ZeRO w. MoE'
100 | )
101 |
102 | # Include DeepSpeed configuration arguments
103 | parser = deepspeed.add_config_arguments(parser)
104 |
105 | args = parser.parse_args()
106 |
107 | return args
108 |
109 |
110 | def glm_get_params_for_weight_decay_optimization(module):
111 | weight_decay_params = {'params': []}
112 | no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
113 | for module_ in module.modules():
114 | if isinstance(module_, (torch.nn.LayerNorm)):
115 | no_weight_decay_params['params'].extend(
116 | [p for p in list(module_._parameters.values())
117 | if p is not None and p.requires_grad])
118 | else:
119 | weight_decay_params['params'].extend(
120 | [p for n, p in list(module_._parameters.items())
121 | if p is not None and p.requires_grad and n != 'bias'])
122 | no_weight_decay_params['params'].extend(
123 | [p for n, p in list(module_._parameters.items())
124 | if p is not None and p.requires_grad and n == 'bias'])
125 |
126 | return weight_decay_params,
127 |
128 |
129 | def get_optimizer_param_groups(model):
130 | # Build parameter groups (weight decay and non-decay).
131 | # while isinstance(model, (LocalDDP, TorchDDP, FP16_Module)):
132 | # model = model.module
133 | param_groups = glm_get_params_for_weight_decay_optimization(model)
134 |
135 | # Add model parallel attribute if it is not set.
136 | for param_group in param_groups:
137 | # print('## param_group', len(param_group['params']))
138 | for param in param_group['params']:
139 | if not hasattr(param, 'model_parallel'):
140 | param.model_parallel = False
141 |
142 | return param_groups
143 |
144 |
145 | def train(model_engine, dataloader, tokenizer, fp16):
146 | loss_total = 0
147 | n_item = 0
148 | for data in dataloader:
149 | # data = {k: v.to(model_engine.local_rank) for k, v in batch.items()}
150 | # if fp16:
151 | # data = {k: v.half() for k, v in data.items()}
152 | inputs = tokenizer(data["contexts"], return_tensors="pt", padding=True)
153 | inputs = tokenizer.build_inputs_for_generation(inputs, targets=data["answers"], max_gen_length=2, padding=False)
154 | inputs = inputs.to(model_engine.local_rank)
155 |
156 | outputs = model_engine(**inputs)
157 | loss = outputs.loss
158 | model_engine.backward(loss)
159 | model_engine.step()
160 |
161 | loss_total += loss.item()
162 | cur_n_item = outputs.logits.shape[0]
163 | n_item += cur_n_item
164 |
165 | return loss_total / n_item
166 |
167 |
168 | def eval(dataloader, model_infer_ds, tokenizer, fp16):
169 | pred_scores = np.empty((0, 6))
170 | labels = []
171 |
172 | with torch.no_grad():
173 | for data in dataloader:
174 | label = data["labels"]
175 | inputs = tokenizer(data["contexts"], return_tensors="pt", padding=True)
176 | inputs = tokenizer.build_inputs_for_multiple_choice(inputs, data["candidates"])
177 | inputs = inputs.to('cuda')
178 | # if fp16:
179 | # inputs = {k: v.half() if torch.is_tensor(v) else v for k, v in inputs.items()}
180 | # outputs = model_infer(**inputs)
181 | outputs = model_infer_ds(**inputs)
182 | logits = outputs.logits
183 | score = logits.detach().cpu().numpy()
184 | if score.shape[0] > 6:
185 | score = score[:6]
186 | elif score.shape[0] < 6:
187 | score = np.concatenate((score, np.zeros((6 - score.shape[0], score.shape[1]))), axis=0)
188 | pred_scores = np.concatenate((pred_scores, np.transpose(score)), axis=0)
189 | labels.extend(label)
190 |
191 | hit_1 = top_k_accuracy_score(labels, pred_scores, k=1)
192 | hit_3 = top_k_accuracy_score(labels, pred_scores, k=3)
193 | hit_5 = top_k_accuracy_score(labels, pred_scores, k=5)
194 |
195 | # del model_infer
196 |
197 | return [hit_1, hit_3, hit_5]
198 |
199 |
200 | def finetune_ds(model_name="THUDM/glm-2b"):
201 | deepspeed.init_distributed()
202 |
203 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
204 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
205 | model = model.half()
206 |
207 | model_infer = AutoModelForMultipleChoice.from_pretrained(model_name, trust_remote_code=True)
208 | model_infer = model_infer.half()
209 |
210 | args = add_argument()
211 |
212 | exp_name = "finetune-" + model_name.split('/')[-1] + '-ds-'
213 | exp_name = exp_name + str(datetime.now())
214 |
215 | # Dataset
216 | train_dataset_path = "glm/data/train.json"
217 | valid_dataset_path = "glm/data/valid.json"
218 | test_dataset_path = "glm/data/test.json"
219 | log_path = "glm/log/" + exp_name + ".log"
220 | os.makedirs("glm/log", exist_ok=True)
221 |
222 | train_dataset = TensorDataset(train_dataset_path, tokenizer)
223 | valid_dataset = TensorDataset(valid_dataset_path, tokenizer)
224 | test_dataset = TensorDataset(test_dataset_path, tokenizer)
225 |
226 | # Log
227 | log = Log(file_path= log_path)
228 |
229 | # initialize the dataloader
230 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
231 | valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
232 | test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
233 |
234 | parameters = get_optimizer_param_groups(model)
235 |
236 | print("here1")
237 | model_engine, optimizer, trainloader, __ = deepspeed.initialize(
238 | args=args, model=model, model_parameters=parameters, training_data=train_dataset)
239 | print("here2")
240 |
241 | ds_engine_infer = deepspeed.init_inference(model_infer, mp_size=1, dtype=torch.half, replace_with_kernel_inject=True)
242 |
243 | fp16 = model_engine.fp16_enabled()
244 | print(f'fp16={fp16}')
245 |
246 | valid_best = 0
247 | test_best = []
248 | best_epoch = 0
249 | out_dir = "glm/saved"
250 | os.makedirs(out_dir, exist_ok=True)
251 | for epoch in range(args.epochs):
252 | log.log('begin epoch {}'.format(epoch))
253 | loss_train = train(model_engine, trainloader, tokenizer, fp16)
254 | log.log('train epoch {} end, loss {}'.format(epoch, str(loss_train)))
255 |
256 | ds_engine_infer.module.load_state_dict(model.state_dict())
257 |
258 | torch.save(ds_engine_infer.module, os.path.join(out_dir, "{}-epoch-{}.pt".format("glm-2b", epoch)))
259 |
260 |
261 | if __name__ == "__main__":
262 | model_name = "THUDM/glm-2b"
263 | # model_name = "~/.cache/huggingface/hub/models--THUDM--glm-10b/snapshots/696788d4f82ac96b90823555f547d1e754839ff4"
264 | # model_name = "~/.cache/huggingface/hub/models--THUDM--glm-2b/snapshots/774fda883d7ad028b8effc3c65afec510fce9634"
265 | finetune_ds(model_name)
266 |
--------------------------------------------------------------------------------
/glm/process.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 | import json
4 | from tqdm import tqdm
5 | from collections import defaultdict as dd
6 | from bs4 import BeautifulSoup
7 | import numpy as np
8 | from fuzzywuzzy import fuzz
9 |
10 | import utils
11 | import settings
12 |
13 |
14 | def prepare_train_test_data_for_glm(year=2023):
15 | x_train = []
16 | y_train = []
17 | x_valid = []
18 | y_valid = []
19 | x_test = []
20 | y_test = []
21 |
22 | truths = utils.load_json(settings.DATA_TRACE_DIR, "paper_source_trace_{}_final_filtered.json".format(year))
23 | pid_to_source_titles = dd(list)
24 | for paper in tqdm(truths):
25 | pid = paper["_id"]
26 | for ref in paper["refs_trace"]:
27 | pid_to_source_titles[pid].append(ref["title"].lower())
28 |
29 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year))
30 | papers_train = utils.load_json(data_year_dir, "paper_source_trace_train.json")
31 | papers_valid = utils.load_json(data_year_dir, "paper_source_trace_valid.json")
32 | papers_test = utils.load_json(data_year_dir, "paper_source_trace_test.json")
33 |
34 | pids_train = {p["_id"] for p in papers_train}
35 | pids_valid = {p["_id"] for p in papers_valid}
36 | pids_test = {p["_id"] for p in papers_test}
37 |
38 | in_dir = join(settings.DATA_TRACE_DIR, "paper-xml")
39 | files = []
40 | for f in os.listdir(in_dir):
41 | if f.endswith(".xml"):
42 | files.append(f)
43 |
44 | files = sorted(files)
45 | for file in tqdm(files):
46 | f = open(join(in_dir, file), encoding='utf-8')
47 | cur_pid = file.split(".")[0]
48 | if cur_pid not in pids_train and cur_pid not in pids_valid and cur_pid not in pids_test:
49 | continue
50 | xml = f.read()
51 | bs = BeautifulSoup(xml, "xml")
52 |
53 | source_titles = pid_to_source_titles[cur_pid]
54 | if len(source_titles) == 0:
55 | continue
56 |
57 | references = bs.find_all("biblStruct")
58 | bid_to_title = {}
59 | n_refs = 0
60 | for ref in references:
61 | if "xml:id" not in ref.attrs:
62 | continue
63 | bid = ref.attrs["xml:id"]
64 | if ref.analytic is None:
65 | continue
66 | if ref.analytic.title is None:
67 | continue
68 | bid_to_title[bid] = ref.analytic.title.text.lower()
69 | b_idx = int(bid[1:]) + 1
70 | if b_idx > n_refs:
71 | n_refs = b_idx
72 |
73 | flag = False
74 |
75 | cur_pos_bib = set()
76 |
77 | for bid in bid_to_title:
78 | cur_ref_title = bid_to_title[bid]
79 | for label_title in source_titles:
80 | if fuzz.ratio(cur_ref_title, label_title) >= 80:
81 | flag = True
82 | cur_pos_bib.add(bid)
83 |
84 | cur_neg_bib = set(bid_to_title.keys()) - cur_pos_bib
85 |
86 | if not flag:
87 | continue
88 |
89 | if len(cur_pos_bib) == 0 or len(cur_neg_bib) == 0:
90 | continue
91 |
92 | bib_to_contexts = utils.find_bib_context(xml)
93 |
94 | n_pos = len(cur_pos_bib)
95 | n_neg = n_pos * 10
96 | cur_neg_bib_sample = np.random.choice(list(cur_neg_bib), n_neg, replace=True)
97 |
98 | if cur_pid in pids_train:
99 | cur_x = x_train
100 | cur_y = y_train
101 | elif cur_pid in pids_valid:
102 | cur_x = x_valid
103 | cur_y = y_valid
104 | elif cur_pid in pids_test:
105 | cur_x = x_test
106 | cur_y = y_test
107 | else:
108 | continue
109 | # raise Exception("cur_pid not in train/valid/test")
110 |
111 | for bib in cur_pos_bib:
112 | cur_context = "The context is: " + " ".join(bib_to_contexts[bib]) + ". Is the current reference important? Please answer Yes or No. The answer is [MASK]."
113 | cur_x.append(cur_context)
114 | cur_y.append(1)
115 |
116 | for bib in cur_neg_bib_sample:
117 | cur_context = "The context is: " + " ".join(bib_to_contexts[bib]) + ". Is the current reference important? Please answer Yes or No. The answer is [MASK]."
118 | cur_x.append(cur_context)
119 | cur_y.append(0)
120 |
121 | print("len(x_train)", len(x_train), "len(x_valid)", len(x_valid), "len(x_test)", len(x_test))
122 |
123 | out_dir = "glm/data/"
124 | os.makedirs(out_dir, exist_ok=True)
125 |
126 | with open(join(out_dir, "train.json"), "w") as f:
127 | for i in range(len(x_train)):
128 | f.write(json.dumps({"inputs_pretokenized": x_train[i], "choices_pretokenized": ["No", "Yes"], "label": y_train[i]}) + "\n")
129 |
130 |
131 | with open(join(out_dir, "valid.json"), "w") as f:
132 | for i in range(len(x_valid)):
133 | f.write(json.dumps({"inputs_pretokenized": x_valid[i], "choices_pretokenized": ["No", "Yes"], "label": y_valid[i]}) + "\n")
134 |
135 | with open(join(out_dir, "test.json"), "w") as f:
136 | for i in range(len(x_test)):
137 | f.write(json.dumps({"inputs_pretokenized": x_test[i], "choices_pretokenized": ["No", "Yes"], "label": y_test[i]}) + "\n")
138 |
139 |
140 | if __name__ == "__main__":
141 | prepare_train_test_data_for_glm()
142 |
--------------------------------------------------------------------------------
/glm/run_finetune_ds.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | deepspeed --master_port 23619 glm/finetune_glm_ds.py --deepspeed --deepspeed_config glm/ds_config_glm_2b.json $@
4 |
--------------------------------------------------------------------------------
/glm/run_finetune_ds_10b.sh:
--------------------------------------------------------------------------------
1 | deepspeed --master_port 23620 --include localhost:2,6,7 glm/finetune_glm_10b_ds.py --deepspeed --deepspeed_config glm/ds_config_glm_10b.json $@
2 |
--------------------------------------------------------------------------------
/glm/test_glm.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 | from collections import defaultdict as dd
4 | import numpy as np
5 | import torch
6 | from fuzzywuzzy import fuzz
7 | from bs4 import BeautifulSoup
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForMultipleChoice
10 | from sklearn.metrics import average_precision_score
11 |
12 | import utils
13 | import settings
14 |
15 |
16 | def eval_test(model_name, ckpt_epoch=1, year=2023, role="test", ft=True):
17 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
18 |
19 | out_dir = "glm/saved"
20 | prefix = model_name.split("/")[1].lower()
21 | model_path = os.path.join(out_dir, "{}-epoch-{}.pt".format(prefix, ckpt_epoch))
22 | if not ft:
23 | model_infer = AutoModelForMultipleChoice.from_pretrained(model_name, trust_remote_code=True)
24 | # model_infer = AutoModelForMultipleChoice.from_pretrained("/home/zhangfanjin/.cache/huggingface/hub/models--THUDM--glm-10b/snapshots/696788d4f82ac96b90823555f547d1e754839ff4", trust_remote_code=True)
25 | model_infer.to('cuda')
26 | elif prefix == "glm-10b":
27 | model_infer = torch.load(model_path, map_location=torch.device('cuda'))
28 | elif prefix == "glm-2b":
29 | model_infer = torch.load(model_path, map_location=torch.device('cuda'))
30 | # model_infer.to('cuda')
31 | model_infer.eval()
32 | print("model load successfully")
33 |
34 | papers_test = utils.load_json(join(settings.DATA_TRACE_DIR, str(year)), "paper_source_trace_{}.json".format(role))
35 | pids_test = {p["_id"] for p in papers_test}
36 |
37 | truths = papers_test
38 | pid_to_source_titles = dd(list)
39 | for paper in tqdm(truths):
40 | pid = paper["_id"]
41 | for ref in paper["refs_trace"]:
42 | pid_to_source_titles[pid].append(ref["title"].lower())
43 |
44 | xml_dir = join(settings.DATA_TRACE_DIR, "paper-xml")
45 | candidates = ["No", "Yes"]
46 | metrics = []
47 | f_idx = 0
48 |
49 | for paper in tqdm(papers_test):
50 | cur_pid = paper["_id"]
51 | file = join(xml_dir, cur_pid + ".tei.xml")
52 | f = open(file, encoding='utf-8')
53 |
54 | xml = f.read()
55 | bs = BeautifulSoup(xml, "xml")
56 | f.close()
57 |
58 | source_titles = pid_to_source_titles[cur_pid]
59 | if len(source_titles) == 0:
60 | continue
61 |
62 | references = bs.find_all("biblStruct")
63 | bid_to_title = {}
64 | n_refs = 0
65 | for ref in references:
66 | if "xml:id" not in ref.attrs:
67 | continue
68 | bid = ref.attrs["xml:id"]
69 | if ref.analytic is None:
70 | continue
71 | if ref.analytic.title is None:
72 | continue
73 | bid_to_title[bid] = ref.analytic.title.text.lower()
74 | b_idx = int(bid[1:]) + 1
75 | if b_idx > n_refs:
76 | n_refs = b_idx
77 |
78 | bib_to_contexts = utils.find_bib_context(xml)
79 | bib_sorted = sorted(bib_to_contexts.keys())
80 |
81 | for bib in bib_sorted:
82 | cur_bib_idx = int(bib[1:])
83 | if cur_bib_idx + 1 > n_refs:
84 | n_refs = cur_bib_idx + 1
85 |
86 | y_true = [0] * n_refs
87 | y_score = [0] * n_refs
88 |
89 | flag = False
90 | for bid in bid_to_title:
91 | cur_ref_title = bid_to_title[bid]
92 | for label_title in source_titles:
93 | if fuzz.ratio(cur_ref_title, label_title) >= 80:
94 | flag = True
95 | b_idx = int(bid[1:])
96 | y_true[b_idx] = 1
97 |
98 | if not flag:
99 | continue
100 |
101 | contexts_sorted = ["The context is: " + " ".join(bib_to_contexts[bib])
102 | + ". Is the current reference important? Please answer Yes or No. The answer is [MASK]."
103 | for bib in bib_sorted]
104 | contexts_sorted = [x[-500:] for x in contexts_sorted]
105 |
106 | predicted_scores = []
107 | for cur_context in contexts_sorted:
108 | token = tokenizer([cur_context], return_tensors="pt", padding=True)
109 | inputs = tokenizer.build_inputs_for_multiple_choice(token, [candidates])
110 | inputs = inputs.to('cuda')
111 | outputs = model_infer(**inputs)
112 | logits = outputs.logits
113 | # print("logits", logits.shape)
114 | score = logits.detach().cpu().numpy().tolist()
115 | # print("score", score)
116 | predicted_scores.append(score[0][1])
117 |
118 | try:
119 | for ii in range(len(predicted_scores)):
120 | bib_idx = int(bib_sorted[ii][1:])
121 | y_score[bib_idx] = predicted_scores[ii]
122 | except IndexError as e:
123 | metrics.append(0)
124 | continue
125 |
126 | cur_map = average_precision_score(y_true, y_score)
127 | metrics.append(cur_map)
128 | f_idx += 1
129 | if f_idx % 20 == 0:
130 | print("map until now", np.mean(metrics), len(metrics), cur_map)
131 |
132 | map_avg = np.mean(metrics)
133 | print("epoch {} average map".format(ckpt_epoch), map_avg, len(metrics))
134 | return np.mean(metrics)
135 |
136 |
137 | if __name__ == "__main__":
138 | # eval_test(model_name="THUDM/GLM-2b", ckpt_epoch=1)
139 |
140 | """
141 | model_name = "THUDM/GLM-2b"
142 | prefix = model_name.split("/")[1].lower()
143 | wf = open("glm/saved/valid_map_{}.txt".format(prefix), "w")
144 | for i in range(10):
145 | cur_map = eval_test(model_name=model_name, ckpt_epoch=i, role="valid")
146 | wf.write("{}\t{}\n".format(i, cur_map))
147 | wf.flush()
148 | wf.close()
149 | """
150 |
151 | # eval_test(model_name="THUDM/GLM-2b", ckpt_epoch=3)
152 | eval_test(model_name="THUDM/GLM-10b", ckpt_epoch=1)
153 |
--------------------------------------------------------------------------------
/gpt-api/gpt.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import json
3 | openai.api_key = "" # your api key
4 | openai.api_base = "" # your api base
5 | model_list = ["gpt-3.5-turbo","gpt-4"]
6 | for model in model_list:
7 | result_dic = {}
8 | with open("result/" + model + "2.json", "r") as read_dic:
9 | finished_dic = json.load(read_dic).keys()
10 | write_dic = open("result/" + model + "2.json", "a")
11 | try:
12 | with open("test.json", "r") as read_file:
13 | data_dic = json.load(read_file)
14 | for this_data in data_dic.keys():
15 | if this_data in finished_dic:
16 | continue
17 | result_list = []
18 | for jtem in data_dic[this_data]:
19 | question = jtem["content"]
20 | result = jtem["summary"]
21 | flag = True
22 | while flag:
23 | try:
24 | chat_completion = openai.ChatCompletion.create(model=model, messages=[{"role": "user", "content":question}], stream=True)
25 | completion_text = ""
26 | for event in chat_completion:
27 | if len(event["choices"]) > 0:
28 | completion_text += event["choices"][0]["delta"].get("content", "")
29 | result_list.append(json.dumps({"labels":result, "predict":completion_text})+"\n")
30 | flag = False
31 | except openai.error.APIError:
32 | continue
33 | result_dic[this_data] = result_list
34 | except:
35 | pass
36 | finally:
37 | json.dump(result_dic, write_dic, indent=2)
--------------------------------------------------------------------------------
/gpt-api/map.py:
--------------------------------------------------------------------------------
1 | import json
2 | import warnings
3 | warnings.filterwarnings('ignore')
4 | from sklearn.metrics import average_precision_score
5 | # base_path = "/data/caokun/huge_model/chatGLM/ChatGLM-6B-main/ptuning/output/test/v2/"
6 | base_path = "/data/caokun/huge_model/chatGLM/ChatGLM-6B-main/ptuning/output/test/finetune/"
7 | num = 5
8 | dic_list = ["result/gpt-3.5-turbo2.json", "result/gpt-4.json"]
9 | for i in range(len(dic_list)):
10 | result_list = []
11 | with open(dic_list[i], "r") as read_file:
12 | result_dic = json.load(read_file)
13 | for item in result_dic.keys():
14 | this_list = result_dic[item]
15 | pre = []
16 | res = []
17 | for jtem in this_list:
18 | if jtem["labels"] == "Yes":
19 | res.append(1)
20 | else:
21 | res.append(0)
22 | #if data["predict"] == "Yes":
23 | # pre.append(1)
24 | #else:
25 | # pre.append(0)
26 | if "yes" in jtem["predict"] or "Yes" in jtem["predict"]:
27 | pre.append(1)
28 | elif "not important" in jtem["predict"]:
29 | pre.append(0)
30 | elif "important" in jtem["predict"]:
31 | pre.append(1)
32 | else:
33 | pre.append(0)
34 | result_list.append(average_precision_score(res, pre))
35 | print(f"{i}:map={sum(result_list)/len(result_list)}")
--------------------------------------------------------------------------------
/net_emb.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | import os
3 | import json
4 | from tqdm import tqdm
5 | import numpy as np
6 | from fuzzywuzzy import fuzz
7 | from cogdl import pipeline
8 | from bs4 import BeautifulSoup
9 | from sklearn.metrics import average_precision_score
10 |
11 | import utils
12 | import settings
13 |
14 | import logging
15 |
16 | logger = logging.getLogger(__name__)
17 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') # include timestamp
18 |
19 |
20 |
21 | def extract_paper_citation_graph():
22 | data_dir = join(settings.DATA_TRACE_DIR, "PST")
23 | papers = []
24 | dblp_fname = "DBLP-Citation-network-V15.1.json"
25 | parse_err_cnt = 0
26 |
27 | wf = open(join(data_dir, "dblp_pids.txt"), "w")
28 | with open(join(data_dir, dblp_fname), "r", encoding="utf-8") as myFile:
29 | for i, line in enumerate(myFile):
30 | if len(line) <= 2:
31 | continue
32 | if i % 10000 == 0:
33 | logger.info("reading papers %d, parse err cnt %d", i, parse_err_cnt)
34 | try:
35 | paper_tmp = json.loads(line.strip())
36 | wf.write(paper_tmp["id"] + "\n")
37 | wf.flush()
38 | except:
39 | parse_err_cnt += 1
40 | papers.append(paper_tmp)
41 | wf.close()
42 |
43 | paper_dict_filter = {}
44 | for paper in tqdm(papers):
45 | paper_dict_filter[paper["id"]] = paper.get("references", [])
46 |
47 | logger.info("number of papers after filtering %d", len(paper_dict_filter))
48 | utils.dump_json(paper_dict_filter, data_dir, "dblp_papers_refs.json")
49 |
50 |
51 | def merge_paper_references():
52 | data_dir = join(settings.DATA_TRACE_DIR, "PST")
53 | paper_dict_open = utils.load_json(data_dir, "dblp_papers_refs.json")
54 | papers_train = utils.load_json(data_dir, "paper_source_trace_train_ans.json")
55 | papers_valid = utils.load_json(data_dir, "paper_source_trace_valid_wo_ans.json")
56 |
57 | for paper in tqdm(papers_train + papers_valid):
58 | pid = paper["_id"]
59 | cur_refs = paper.get("references", [])
60 | if len(cur_refs) == 0:
61 | continue
62 | refs_open = paper_dict_open.get(pid, [])
63 | refs_update = list(set(cur_refs + refs_open))
64 | paper_dict_open[pid] = refs_update
65 |
66 | utils.dump_json(paper_dict_open, data_dir, "dblp_papers_refs_merged.json")
67 |
68 |
69 | def gen_paper_emb(year=2023, method="prone"):
70 | print("method", method)
71 | paper_dict = utils.load_json(settings.DATA_TRACE_DIR, "dblp_papers_refs_merged_{}.json".format(year))
72 | pids_set = set()
73 | edges = []
74 | for pid in tqdm(paper_dict):
75 | pids_set.add(pid)
76 | for ref_id in paper_dict[pid]:
77 | pids_set.add(ref_id)
78 | edges.append([pid, ref_id])
79 | edges.append([ref_id, pid])
80 |
81 | pids_sorted = sorted(list(pids_set))
82 | pid_to_idx = {pid: idx for idx, pid in enumerate(pids_sorted)}
83 | edges = [[pid_to_idx[pid], pid_to_idx[ref_id]] for pid, ref_id in edges]
84 |
85 | if method == "prone":
86 | generator = pipeline("generate-emb", model="prone")
87 | elif method == "line":
88 | generator = pipeline("generate-emb", model="line", walk_length=5, walk_num=5)
89 | elif method == "netsmf":
90 | generator = pipeline("generate-emb", model="netsmf", window_size=5, num_round=5)
91 | else:
92 | raise NotImplementedError
93 |
94 | # generate embedding by an unweighted graph
95 | edge_index = np.array(edges)
96 | print("genreate_emb...", edge_index.shape)
97 | outputs = generator(edge_index)
98 | print("outputs", outputs.shape)
99 |
100 | out_dir = join(settings.OUT_DIR, method)
101 | os.makedirs(out_dir, exist_ok=True)
102 | with open(join(out_dir, "paper_id.txt"), "w") as f:
103 | for pid in pids_sorted:
104 | f.write(pid + "\n")
105 | f.flush()
106 | np.savez(join(out_dir, "paper_emb_{}.npz".format(method)), emb=outputs)
107 |
108 |
109 | def gen_paper_emb_kddcup(method="prone"):
110 | data_dir = join(settings.DATA_TRACE_DIR, "PST")
111 | print("method", method)
112 | paper_dict = utils.load_json(data_dir, "dblp_papers_refs_merged.json")
113 | pids_set = set()
114 | edges = []
115 | for pid in tqdm(paper_dict):
116 | pids_set.add(pid)
117 | for ref_id in paper_dict[pid]:
118 | pids_set.add(ref_id)
119 | edges.append([pid, ref_id])
120 | edges.append([ref_id, pid])
121 |
122 |
123 | pids_sorted = sorted(list(pids_set))
124 | pid_to_idx = {pid: idx for idx, pid in enumerate(pids_sorted)}
125 | edges = [[pid_to_idx[pid], pid_to_idx[ref_id]] for pid, ref_id in edges]
126 |
127 | if method == "prone":
128 | generator = pipeline("generate-emb", model="prone")
129 | elif method == "line":
130 | generator = pipeline("generate-emb", model="line", walk_length=5, walk_num=5)
131 | elif method == "netsmf":
132 | generator = pipeline("generate-emb", model="netsmf", window_size=5, num_round=5)
133 | else:
134 | raise NotImplementedError
135 |
136 | # generate embedding by an unweighted graph
137 | edge_index = np.array(edges)
138 | print("genreate_emb...", edge_index.shape)
139 | outputs = generator(edge_index)
140 | print("outputs", outputs.shape)
141 |
142 | out_dir = join(settings.OUT_DIR, "kddcup", method)
143 | os.makedirs(out_dir, exist_ok=True)
144 | with open(join(out_dir, "paper_id.txt"), "w") as f:
145 | for pid in pids_sorted:
146 | f.write(pid + "\n")
147 | f.flush()
148 | np.savez(join(out_dir, "paper_emb_{}.npz".format(method)), emb=outputs)
149 |
150 |
151 | def eval_node_sim(year=2023, method="prone"):
152 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year))
153 | test_papers = utils.load_json(data_year_dir, "paper_source_trace_test.json")
154 | pids = []
155 | with open(join(settings.OUT_DIR, method, "paper_id.txt"), "r") as f:
156 | for line in f:
157 | pids.append(line.strip())
158 | pid_to_idx = {pid: idx for idx, pid in enumerate(pids)}
159 | emb = np.load(join(settings.OUT_DIR, method, "paper_emb_{}.npz".format(method)))["emb"]
160 |
161 | xml_dir = join(settings.DATA_TRACE_DIR, "paper-xml")
162 | metrics = []
163 | f_idx = 0
164 | for paper in tqdm(test_papers):
165 | pid = paper["_id"]
166 | file = join(xml_dir, pid + ".tei.xml")
167 | f = open(file, encoding='utf-8')
168 |
169 | xml = f.read()
170 | bs = BeautifulSoup(xml, "xml")
171 |
172 | references = bs.find_all("biblStruct")
173 | bid_to_title = {}
174 | n_refs = 0
175 | for ref in references:
176 | if "xml:id" not in ref.attrs:
177 | continue
178 | bid = ref.attrs["xml:id"]
179 | if ref.analytic is None:
180 | continue
181 | if ref.analytic.title is None:
182 | continue
183 | bid_to_title[bid] = ref.analytic.title.text.lower()
184 | b_idx = int(bid[1:]) + 1
185 | if b_idx > n_refs:
186 | n_refs = b_idx
187 |
188 | bib_to_contexts = utils.find_bib_context(xml)
189 | bib_sorted = sorted(bib_to_contexts.keys())
190 |
191 | for bib in bib_sorted:
192 | cur_bib_idx = int(bib[1:])
193 | if cur_bib_idx + 1 > n_refs:
194 | n_refs = cur_bib_idx + 1
195 |
196 | f.close()
197 |
198 | ref_id_to_score = {}
199 | ref_id_to_label = {}
200 | cur_emb = emb[pid_to_idx[pid]]
201 | cur_refs = paper.get("references", [])
202 | ref_truths = set([x["_id"] for x in paper.get("refs_trace", [])])
203 | for ref in cur_refs:
204 | ref_emb = emb[pid_to_idx[ref]]
205 | cur_sim = np.dot(cur_emb, ref_emb)
206 | cur_sim = 1/(1 + np.exp(-cur_sim))
207 | ref_id_to_score[ref] = cur_sim
208 | if ref in ref_truths:
209 | ref_id_to_label[ref] = 1
210 | else:
211 | ref_id_to_label[ref] = 0
212 |
213 | ref_id_to_score_sorted = sorted(ref_id_to_score.items(), key=lambda x: x[1], reverse=True)
214 | ref_labels = [ref_id_to_label[x[0]] for x in ref_id_to_score_sorted]
215 | truth_id_not_in = ref_truths - set(cur_refs)
216 | n_limit = n_refs - len(truth_id_not_in)
217 | scores = [x[1] for x in ref_id_to_score_sorted][:n_limit]
218 | labels = ref_labels[:n_limit]
219 | if len(truth_id_not_in) > 0:
220 | scores += [0] * len(truth_id_not_in)
221 | labels += [1] * len(truth_id_not_in)
222 | if len(scores) < n_refs:
223 | scores += [0] * (n_refs - len(scores))
224 | labels += [0] * (n_refs - len(labels))
225 |
226 | cur_map = average_precision_score(labels, scores)
227 | metrics.append(cur_map)
228 | f_idx += 1
229 | if f_idx % 20 == 0:
230 | print("map until now", np.mean(metrics), len(metrics), cur_map)
231 |
232 | print("prone map", np.mean(metrics), len(metrics))
233 |
234 |
235 | def eval_node_sim_kddcup(method="prone", role="valid"):
236 | data_dir = join(settings.DATA_TRACE_DIR, "PST")
237 | papers = utils.load_json(data_dir, "paper_source_trace_{}_wo_ans.json".format(role))
238 | out_dir = join(settings.OUT_DIR, "kddcup", method)
239 | paper_info_more = utils.load_json(data_dir, "paper_info_hit_from_dblp.json")
240 |
241 | pids = []
242 | with open(join(out_dir, "paper_id.txt"), "r") as f:
243 | for line in f:
244 | pids.append(line.strip())
245 | pid_to_idx = {pid: idx for idx, pid in enumerate(pids)}
246 | emb = np.load(join(out_dir, "paper_emb_{}.npz".format(method)))["emb"]
247 |
248 | xml_dir = join(data_dir, "paper-xml")
249 | sub_dict = {}
250 | sub_example_dict = utils.load_json(data_dir, "submission_example_valid.json")
251 |
252 | for paper in tqdm(papers):
253 | cur_pid = paper["_id"]
254 | file = join(xml_dir, cur_pid + ".xml")
255 | f = open(file, encoding='utf-8')
256 | xml = f.read()
257 | bs = BeautifulSoup(xml, "xml")
258 | f.close()
259 |
260 | ref_ids = paper.get("references", [])
261 | cur_title_to_pid = {}
262 | for ref_id in ref_ids:
263 | if ref_id in paper_info_more:
264 | cur_title_to_pid[paper_info_more[ref_id]["title"].lower()] = ref_id
265 |
266 | references = bs.find_all("biblStruct")
267 | bid_to_title = {}
268 | n_refs = 0
269 | cur_title_to_b_idx = {}
270 | for ref in references:
271 | if "xml:id" not in ref.attrs:
272 | continue
273 | bid = ref.attrs["xml:id"]
274 | if ref.analytic is None:
275 | continue
276 | if ref.analytic.title is None:
277 | continue
278 | bid_to_title[bid] = ref.analytic.title.text.lower()
279 | b_idx = int(bid[1:]) + 1
280 | cur_title_to_b_idx[ref.analytic.title.text.lower()] = b_idx - 1
281 | if b_idx > n_refs:
282 | n_refs = b_idx
283 |
284 | assert len(sub_example_dict[cur_pid]) == n_refs
285 | y_score = [0] * n_refs
286 |
287 | cur_emb = emb[pid_to_idx[cur_pid]]
288 |
289 | for r_idx, ref_id in enumerate(ref_ids):
290 | if ref_id not in paper_info_more:
291 | continue
292 | cur_title = paper_info_more[ref_id].get("title", "").lower()
293 | if len(cur_title) == 0:
294 | continue
295 | cur_b_idx = None
296 | for b_title in cur_title_to_b_idx:
297 | cur_sim = fuzz.ratio(cur_title, b_title)
298 | if cur_sim >= 80:
299 | cur_b_idx = cur_title_to_b_idx[b_title]
300 | break
301 | if cur_b_idx is None:
302 | continue
303 | ref_emb = emb[pid_to_idx[ref_id]]
304 | cur_sim = np.dot(cur_emb, ref_emb)
305 | cur_sim = utils.sigmoid(cur_sim)
306 | y_score[cur_b_idx] = float(cur_sim)
307 |
308 | print(y_score)
309 | sub_dict[cur_pid] = y_score
310 |
311 | utils.dump_json(sub_dict, out_dir, f"{role}_sub_{method}.json")
312 |
313 |
314 | if __name__ == "__main__":
315 | method = "prone"
316 | extract_paper_citation_graph()
317 | merge_paper_references()
318 | # gen_paper_emb(method=method)
319 | gen_paper_emb_kddcup(method=method)
320 | # eval_node_sim(method=method)
321 | eval_node_sim_kddcup(method=method)
322 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | beautifulsoup4==4.11.1
2 | cogdl==0.5.3
3 | fuzzywuzzy==0.18.0
4 | numpy==1.21.6
5 | scikit_learn==1.2.1
6 | tqdm==4.36.1
7 | transformers==4.22.2
8 | python-Levenshtein==0.20.9
9 | lxml==4.8.0
10 | numba==0.56.4
11 |
--------------------------------------------------------------------------------
/result.txt:
--------------------------------------------------------------------------------
1 | fine-tune glm-2b epoch 1 average map 0.15643463634349988 280
2 |
--------------------------------------------------------------------------------
/rf/model_main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import set_param
4 | import pickle
5 | from sklearn import svm
6 | from sklearn.ensemble import RandomForestClassifier
7 | import json
8 | from sklearn.metrics import average_precision_score
9 | import warnings
10 | from sklearn.linear_model import LogisticRegression
11 |
12 | warnings.filterwarnings("ignore")
13 |
14 |
15 | def calculate_TPFN(pre_result, label):
16 | TP, TN, FP, FN = 0, 0, 0, 0
17 | total = len(label)
18 | for item, jtem in zip(pre_result, label):
19 | if item == 1 and jtem == 1:
20 | TP += 1
21 | elif item == 0 and jtem == 1:
22 | FP += 1
23 | elif item == 1 and jtem == 0:
24 | FN += 1
25 | else:
26 | TN += 1
27 | print("Accuracy:", (TP + TN) / total)
28 | print("Precision:", TP / (TP + FP))
29 | print("Recall:", TP / (TP + FN))
30 |
31 |
32 | # mode = "valid"
33 | # mode = "train"
34 | mode = "test"
35 | # model_type = "SVM"
36 | model_type = "RandomForest"
37 | # model_type = "LR"
38 |
39 |
40 | params = set_param.Args(model_type)
41 | if mode == "train":
42 | data = np.loadtxt(open(f"processed_data/train_data.csv", "rb"), delimiter=",")
43 | label = np.loadtxt(open(f"processed_data/train_label.csv", "rb"), delimiter=",")
44 | if model_type == "SVM":
45 | model = svm.SVC(C=params.C, kernel=params.kernel, verbose=params.verbose, max_iter=params.max_iter,
46 | tol=params.tol, probability=True)
47 | elif model_type == "RandomForest":
48 | model = RandomForestClassifier(n_estimators=params.n_estimators)
49 | elif model_type == "LR":
50 | model = LogisticRegression(solver=params.solver, multi_class=params.multi_class)
51 | model.fit(data, label)
52 | save_model = pickle.dumps(model)
53 | os.makedirs("saved_model", exist_ok=True)
54 | write_file = open(f"saved_model/{model_type}.pkl", 'wb')
55 | write_file.write(save_model)
56 | write_file.close()
57 | pre_result = model.predict(data)
58 | calculate_TPFN(pre_result, label)
59 | elif mode == "test":
60 | with open('processed_data/test_data.json') as read_file:
61 | data_dic = json.load(read_file)
62 | with open('processed_data/test_label.json') as read_file:
63 | label_dic = json.load(read_file)
64 | with open(f'saved_model/{model_type}.pkl', 'rb') as read_file:
65 | model = pickle.load(read_file)
66 |
67 | map_list = []
68 | total_pre_list = []
69 | total_label = []
70 | for item in data_dic.keys():
71 | this_data = data_dic[item]
72 | this_label = [jtem[0] for jtem in label_dic[item]]
73 | pre_result = model.predict_proba(this_data)
74 | pre_result = [jtem[1] for jtem in pre_result]
75 | total_pre_list += [(1 if jtem >= 0.5 else 0) for jtem in pre_result]
76 | total_label += this_label
77 | map_list.append(average_precision_score(this_label, pre_result))
78 | calculate_TPFN(total_pre_list, total_label)
79 | print("MAP:", sum(map_list) / len(map_list))
80 | # 显示特征权重
81 | # feature_importances = model.feature_importances_
82 | # print(feature_importances)
83 |
84 |
85 | elif mode == "valid":
86 | data = np.loadtxt(open(f"processed_data/valid/valid_data.csv", "rb"), delimiter=",")
87 | label = np.loadtxt(open(f"processed_data/valid/valid_label.csv", "rb"), delimiter=",")
88 | if model_type == "SVM":
89 | model = svm.SVC(C=params.C, kernel=params.kernel, verbose=params.verbose, max_iter=params.max_iter,
90 | tol=params.tol, probability=True)
91 | elif model_type == "RandomForest":
92 | model = RandomForestClassifier(n_estimators=params.n_estimators)
93 | elif model_type == "LR":
94 | model = LogisticRegression(solver=params.solver, multi_class=params.multi_class)
95 | model.fit(data, label)
96 | predict = model.predict(data)
97 | calculate_TPFN(predict, label)
98 | else:
99 | print("mode input error!")
100 |
--------------------------------------------------------------------------------
/rf/model_rf.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 | import numpy as np
4 | from tqdm import tqdm
5 | from bs4 import BeautifulSoup
6 | import set_param
7 | import pickle
8 | from sklearn import svm
9 | from sklearn.ensemble import RandomForestClassifier
10 | import json
11 | from sklearn.metrics import average_precision_score
12 | import warnings
13 | from sklearn.linear_model import LogisticRegression
14 | from fuzzywuzzy import fuzz
15 |
16 | warnings.filterwarnings("ignore")
17 |
18 | import utils
19 | import settings
20 |
21 |
22 | def calculate_TPFN(pre_result, label):
23 | TP, TN, FP, FN = 0, 0, 0, 0
24 | total = len(label)
25 | for item, jtem in zip(pre_result, label):
26 | if item == 1 and jtem == 1:
27 | TP += 1
28 | elif item == 0 and jtem == 1:
29 | FP += 1
30 | elif item == 1 and jtem == 0:
31 | FN += 1
32 | else:
33 | TN += 1
34 | print("Accuracy:", (TP + TN) / total)
35 | print("Precision:", TP / (TP + FP))
36 | print("Recall:", TP / (TP + FN))
37 |
38 |
39 | def train_classifier(model_type = "RandomForest"):
40 | params = set_param.Args(model_type)
41 | feature_dir = join(settings.OUT_DIR, "kddcup", "rf")
42 | data = np.loadtxt(open(join(feature_dir, "train_data.csv"), "rb"), delimiter=",")
43 | label = np.loadtxt(open(join(feature_dir, "train_label.csv"), "rb"), delimiter=",")
44 | if model_type == "SVM":
45 | model = svm.SVC(C=params.C, kernel=params.kernel, verbose=params.verbose, max_iter=params.max_iter,
46 | tol=params.tol, probability=True)
47 | elif model_type == "RandomForest":
48 | model = RandomForestClassifier(n_estimators=params.n_estimators)
49 | elif model_type == "LR":
50 | model = LogisticRegression(solver=params.solver, multi_class=params.multi_class)
51 | model.fit(data, label)
52 | save_model = pickle.dumps(model)
53 | # os.makedirs("saved_model", exist_ok=True)
54 | write_file = open(join(feature_dir, "{}.pkl".format(model_type)), 'wb')
55 | write_file.write(save_model)
56 | write_file.close()
57 | pre_result = model.predict(data)
58 | calculate_TPFN(pre_result, label)
59 |
60 |
61 | def eval_classifier(model_type="RandomForest", role="valid"):
62 | feature_dir = join(settings.OUT_DIR, "kddcup", "rf")
63 | with open(join(feature_dir, f"{role}_data.json")) as read_file:
64 | eval_features = json.load(read_file)
65 |
66 | data_dir = join(settings.DATA_TRACE_DIR, "PST")
67 | papers = utils.load_json(data_dir, "paper_source_trace_valid_wo_ans.json")
68 | paper_info_more = utils.load_json(data_dir, "paper_info_hit_from_dblp.json")
69 |
70 |
71 | with open(join(feature_dir, "{}.pkl".format(model_type)), 'rb') as read_file:
72 | model = pickle.loads(read_file.read())
73 |
74 | xml_dir = join(data_dir, "paper-xml")
75 | sub_dict = {}
76 | sub_example_dict = utils.load_json(data_dir, "submission_example_valid.json")
77 |
78 | for paper in tqdm(papers):
79 | cur_pid = paper["_id"]
80 | file = join(xml_dir, cur_pid + ".xml")
81 | f = open(file, encoding='utf-8')
82 | xml = f.read()
83 | bs = BeautifulSoup(xml, "xml")
84 | f.close()
85 |
86 | ref_ids = paper.get("references", [])
87 | cur_title_to_pid = {}
88 | for ref_id in ref_ids:
89 | if ref_id in paper_info_more:
90 | cur_title_to_pid[paper_info_more[ref_id]["title"].lower()] = ref_id
91 |
92 | references = bs.find_all("biblStruct")
93 | bid_to_title = {}
94 | n_refs = 0
95 | cur_title_to_b_idx = {}
96 | for ref in references:
97 | if "xml:id" not in ref.attrs:
98 | continue
99 | bid = ref.attrs["xml:id"]
100 | if ref.analytic is None:
101 | continue
102 | if ref.analytic.title is None:
103 | continue
104 | bid_to_title[bid] = ref.analytic.title.text.lower()
105 | b_idx = int(bid[1:]) + 1
106 | cur_title_to_b_idx[ref.analytic.title.text.lower()] = b_idx - 1
107 | if b_idx > n_refs:
108 | n_refs = b_idx
109 |
110 | assert len(sub_example_dict[cur_pid]) == n_refs
111 | y_score = [0] * n_refs
112 | cur_feature = eval_features[cur_pid]
113 |
114 | for r_idx, ref_id in enumerate(ref_ids):
115 | if ref_id not in paper_info_more:
116 | continue
117 | cur_title = paper_info_more[ref_id].get("title", "").lower()
118 | if len(cur_title) == 0:
119 | continue
120 | cur_ref_feature = cur_feature[r_idx]
121 | if len(cur_ref_feature) == 0:
122 | continue
123 | cur_b_idx = None
124 | for b_title in cur_title_to_b_idx:
125 | cur_sim = fuzz.ratio(cur_title, b_title)
126 | if cur_sim >= 80:
127 | cur_b_idx = cur_title_to_b_idx[b_title]
128 | break
129 | if cur_b_idx is None:
130 | continue
131 | cur_prob = model.predict_proba([cur_ref_feature])[0][1]
132 | y_score[cur_b_idx] = float(cur_prob)
133 |
134 | # print(y_score)
135 | sub_dict[cur_pid] = y_score
136 |
137 | utils.dump_json(sub_dict, feature_dir, f"{role}_sub_{model_type}.json")
138 |
139 |
140 | if __name__ == "__main__":
141 | train_classifier()
142 | eval_classifier("RandomForest", "valid")
143 |
--------------------------------------------------------------------------------
/rf/process_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import csv
4 | import pymongo
5 | from bson import ObjectId
6 | import random
7 | from lxml import etree
8 | from fuzzywuzzy import fuzz
9 | import re
10 | from tqdm import tqdm
11 |
12 | random.seed(1) # 随机数种子,用来确定训练和验证集中用抽中哪些论文。如果改动,需要重新从数据库中查询数据。(查询语句留着,但是地址,用户名和密码删掉了,需要的话再填一下)
13 |
14 |
15 | class Process_data(object):
16 | def __init__(self, paper_dic, mode):
17 | self.paper_dic = paper_dic
18 | paper_id = paper_dic['_id']
19 | paper_positive_id = [item['_id'] for item in paper_dic['refs_trace']]
20 | self.authors_set = set([item.get('name') for item in paper_dic.get('authors', {})])
21 | # 通过xml获取tree和listBibl
22 | try:
23 | path = f'data/paper-xml/{paper_id}.tei.xml'
24 | self.tree = etree.parse(path)
25 | root = self.tree.getroot()
26 | listBibl = root.xpath("//*[local-name()='listBibl']")[0]
27 | self.biblStruct = listBibl.getchildren()
28 | self.num_ref = len(self.biblStruct)
29 | except OSError:
30 | self.tree = None
31 | self.num_ref = 0
32 | print('not exits xml ' + paper_id)
33 | # 获取论文引用数
34 | self.reference_num = self.get_reciprocal_of_reference_num()
35 | if mode == 'test': # train和valid需要随机选择部分论文构造正例,而test则可以直接把所有例子都放入query_list
36 | query_list = paper_dic.get('references', [])
37 | else:
38 | references = paper_dic.get('references', [])
39 | for item in paper_positive_id:
40 | try:
41 | references.remove(item)
42 | except ValueError:
43 | continue
44 | query_list = random.sample(references, min(max(len(paper_positive_id), 1),
45 | len(references))) # 最少选一个负例,最多references个(存在正例数多于reference的情况!)
46 | query_list += paper_positive_id
47 | self.data_list = []
48 | self.label_list = []
49 | for i, item in enumerate(query_list):
50 | this_data = []
51 | # query = {'_id': ObjectId(item)}
52 | # self.query_result = list(collection.find(query))
53 | self.query_result = saved_data.get(item, {})
54 | reference_place_list = self.get_referenced_place_num(item)
55 | if len(reference_place_list) == 0:
56 | continue # 如果返回长度为0说明文章标题在xml的reference中没有找到自己的序号,这里暂时先不管它。
57 | this_data.append(self.get_referenced_num())
58 | this_data.append(self.get_common_authors(item))
59 | this_data.append(self.reference_num)
60 | this_data.append(self.key_words())
61 | this_data += reference_place_list
62 | self.data_list.append(this_data)
63 | self.label_list.append([1] if item in paper_positive_id else [0])
64 |
65 | # ONE 被引用次数
66 | def get_referenced_num(self):
67 | return self.query_result.get('n_citation', 0)
68 |
69 | # TWO,SIX,EIGHT 引用位置, 是否出现在图表中, 引用次数/引用总数
70 | # 0 abstract
71 | # 1 introduction
72 | # 2 related work
73 | # 3 method
74 | # 4 graph and figure
75 | # 5 result
76 | # 6 others
77 | def get_referenced_place_num(self, paper_id):
78 | # 从数据库中检索到title
79 | # query = {"_id": ObjectId(paper_id)}
80 | # title = list(collection.find(query))[0]['title']
81 | title = self.query_result.get('title', '')
82 | # 从xml中检索到序号
83 | if self.tree is None:
84 | return [0 * 8]
85 |
86 | paper_number = -1
87 | for i, item in enumerate(self.biblStruct):
88 | this_test = item.xpath('.//*[local-name()="title"]')
89 | this_text = this_test[0].text
90 | if this_text is None:
91 | try:
92 | this_text = this_test[1].text
93 | except IndexError:
94 | this_text = ''
95 | try:
96 | score = fuzz.partial_ratio(title, this_text)
97 | except ValueError:
98 | score = 0
99 | if score >= 80:
100 | paper_number = i + 1
101 | break
102 | place_num = [0 for i in range(8)]
103 | self.paper_number = paper_number
104 | if paper_number == -1:
105 | return place_num
106 | # 使用序号,在xml文件中检索位置
107 | nodes = self.tree.xpath(f"//*[contains(text(), '[{paper_number}]')]")
108 | reference_times = len(nodes)
109 |
110 | for item in nodes:
111 | found_text = ''
112 | this_node = item
113 | while found_text == '':
114 | this_node = this_node.getparent()
115 | if this_node is None:
116 | break
117 | if this_node.xpath("local-name()") == 'figure':
118 | place_num[4] = 1
119 | it_children = this_node.iterchildren()
120 | for jtem in it_children:
121 | node = this_node
122 | if jtem.xpath("local-name()") == 'head':
123 | found_text = node.text
124 | n_num = jtem.attrib.get('n')
125 | node = this_node
126 | if n_num is None:
127 | break
128 | while not n_num.isdigit():
129 | node = node.getprevious()
130 | if node is None:
131 | break
132 | node_children = node.iterchildren()
133 | for ktem in node_children:
134 | if ktem.xpath("local-name()") == 'head':
135 | n = ktem.attrib.get('n')
136 | if n is not None and n.isdigit():
137 | n_num = ktem.attrib.get('n')
138 | found_text = ktem.text
139 | break
140 | break
141 |
142 | if this_node is None or found_text == '':
143 | place_num[6] = 1
144 | continue
145 | if found_text is not None:
146 | found_text = found_text.lower()
147 | if fuzz.partial_ratio('abstract', found_text) >= 60:
148 | place_num[0] = 1
149 | elif fuzz.partial_ratio('introduction', found_text) >= 60:
150 | place_num[1] = 1
151 | elif fuzz.partial_ratio('related work', found_text) >= 60:
152 | place_num[2] = 1
153 | elif fuzz.partial_ratio('method', found_text) >= 60:
154 | place_num[3] = 1
155 | elif fuzz.partial_ratio('result', found_text) >= 60 or fuzz.partial_ratio('experiment', found_text) >= 60:
156 | place_num[5] = 1
157 | else:
158 | place_num[6] = 1
159 | pattern = re.compile(r'[\d+]')
160 | nodes = self.tree.xpath("//*[re:match(text(), $pattern)]",
161 | namespaces={"re": "http://exslt.org/regular-expressions"},
162 | pattern=pattern.pattern)
163 | total_ref_num = len(nodes)
164 | if not total_ref_num == 0:
165 | place_num[7] = reference_times / total_ref_num
166 | return place_num
167 |
168 | # FOUR 重叠作者
169 | def get_common_authors(self, paper_id):
170 | # ref_authors_set = set([item.get('name') for item in self.query_result.get('authors', {})])
171 | ref_authors_set = set(self.query_result.get('authors', []))
172 | if not len(self.authors_set & ref_authors_set) == 0:
173 | return 1
174 | else:
175 | return 0
176 |
177 | # FIVE 关键词
178 | def key_words(self):
179 | if self.paper_number == -1:
180 | return 0
181 | pattern = re.compile(r'[\d+]')
182 | nodes = self.tree.xpath("//*[re:match(text(), $pattern)]",
183 | namespaces={"re": "http://exslt.org/regular-expressions"},
184 | pattern=pattern.pattern)
185 | key_words_list = ['motivated by', 'inspired by']
186 | for item in nodes:
187 | if item.xpath('local-name()') == 'ref':
188 | node_text = item.getparent().text
189 | else:
190 | node_text = item.text
191 | if node_text is None:
192 | return 0
193 | node_text = node_text.lower()
194 | for jtem in key_words_list:
195 | pattern = re.compile(fr"{jtem}")
196 | match = pattern.search(node_text)
197 | if match is not None:
198 | return 1
199 | return 0
200 |
201 | # SEVEN
202 |
203 | def get_reciprocal_of_reference_num(self):
204 | if self.num_ref == 0:
205 | return 0
206 | else:
207 | return 1 / self.num_ref
208 |
209 |
210 | # mode = 'valid'
211 | # mode = 'train'
212 | mode = 'test'
213 | year = 2023
214 | with open(f"data/{year}/paper_source_trace_{mode}.json", 'r', encoding='utf-8') as read_file:
215 | data_dic = json.load(read_file)
216 | all_id = [item['_id'] for item in data_dic]
217 | data_list = []
218 | label_list = []
219 |
220 | # 连接数据库
221 | # client = pymongo.MongoClient(host='', port=-1, authSource='', username='', password='')
222 | # db = client.get_database('')
223 | # collection = db[""]
224 | with open('processed_data/saved_data2.json', 'r') as read_file:
225 | saved_data = json.load(read_file)
226 | n = 0
227 | # 调用方法获取数据,然后将数据写入csv或json文件
228 | if mode == 'test':
229 | total_data_dic = {}
230 | total_label_dic = {}
231 | for i, item in tqdm(enumerate(all_id), total=len(all_id)):
232 | process_data = Process_data(data_dic[i], mode)
233 | this_data, this_label = process_data.data_list, process_data.label_list
234 | total_data_dic[item] = this_data
235 | total_label_dic[item] = this_label
236 | with open(f'processed_data/{mode}_data.json', 'w') as write_file:
237 | write_file.write(json.dumps(total_data_dic))
238 | with open(f'processed_data/{mode}_label.json', 'w') as write_file:
239 | write_file.write(json.dumps(total_label_dic))
240 | else:
241 | # train/valid:
242 | for i, item in tqdm(enumerate(all_id), total=len(all_id)):
243 | process_data = Process_data(data_dic[i], mode)
244 | this_data, this_label = process_data.data_list, process_data.label_list
245 | data_list += this_data
246 | label_list += this_label
247 | with open(f'processed_data/{mode}_label.csv', 'w', newline='') as write_file:
248 | writer = csv.writer(write_file)
249 | writer.writerows(label_list)
250 | with open(f'processed_data/{mode}_data.csv', 'w', newline='') as write_file:
251 | writer = csv.writer(write_file)
252 | writer.writerows(data_list)
253 |
--------------------------------------------------------------------------------
/rf/process_kddcup_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from os.path import join
4 | import csv
5 | import random
6 | from lxml import etree
7 | from fuzzywuzzy import fuzz
8 | import re
9 | from collections import defaultdict as dd
10 | from tqdm import tqdm
11 |
12 | import utils
13 | import settings
14 |
15 | import logging
16 |
17 | logger = logging.getLogger(__name__)
18 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') # include timestamp
19 |
20 |
21 | random.seed(1)
22 |
23 |
24 | class Process_data(object):
25 | def __init__(self, paper_dic, mode, paper_info_more):
26 | self.paper_dic = paper_dic
27 | paper_id = paper_dic['_id']
28 | if mode == "train":
29 | paper_positive_id = [item['_id'] for item in paper_dic['refs_trace']]
30 | self.authors_set = set([item.get('name') for item in paper_dic.get('authors', {})])
31 | # 通过xml获取tree和listBibl
32 | try:
33 | path = f'data/PST/paper-xml/{paper_id}.xml'
34 | self.tree = etree.parse(path)
35 | root = self.tree.getroot()
36 | listBibl = root.xpath("//*[local-name()='listBibl']")[0]
37 | self.biblStruct = listBibl.getchildren()
38 | self.num_ref = len(self.biblStruct)
39 | except OSError:
40 | self.tree = None
41 | self.num_ref = 0
42 | print('not exits xml ' + paper_id)
43 | # 获取论文引用数
44 | self.reference_num = self.get_reciprocal_of_reference_num()
45 | if mode == 'test': # train和valid需要随机选择部分论文构造正例,而test则可以直接把所有例子都放入query_list
46 | query_list = paper_dic.get('references', [])
47 | else:
48 | references = paper_dic.get('references', [])
49 | for item in paper_positive_id:
50 | try:
51 | references.remove(item)
52 | except ValueError:
53 | continue
54 | query_list = random.sample(references, min(max(len(paper_positive_id), 1),
55 | len(references))) # 最少选一个负例,最多references个(存在正例数多于reference的情况!)
56 | query_list += paper_positive_id
57 | self.data_list = []
58 | self.label_list = []
59 | for i, item in enumerate(query_list):
60 | this_data = []
61 | self.query_result = paper_info_more.get(item, {})
62 | reference_place_list = self.get_referenced_place_num(item)
63 | if len(reference_place_list) == 0:
64 | self.data_list.append([])
65 | continue # 如果返回长度为0说明文章标题在xml的reference中没有找到自己的序号,这里暂时先不管它。
66 | this_data.append(self.get_referenced_num())
67 | this_data.append(self.get_common_authors(item))
68 | this_data.append(self.reference_num)
69 | this_data.append(self.key_words())
70 | this_data += reference_place_list
71 | self.data_list.append(this_data)
72 | if mode == "train":
73 | self.label_list.append([1] if item in paper_positive_id else [0])
74 |
75 | # ONE 被引用次数
76 | def get_referenced_num(self):
77 | return self.query_result.get('n_citation', 0)
78 |
79 | # TWO,SIX,EIGHT 引用位置, 是否出现在图表中, 引用次数/引用总数
80 | # 0 abstract
81 | # 1 introduction
82 | # 2 related work
83 | # 3 method
84 | # 4 graph and figure
85 | # 5 result
86 | # 6 others
87 | def get_referenced_place_num(self, paper_id):
88 | title = self.query_result.get('title', '')
89 | # 从xml中检索到序号
90 | if self.tree is None:
91 | return [0 * 8]
92 |
93 | paper_number = -1
94 | for i, item in enumerate(self.biblStruct):
95 | this_test = item.xpath('.//*[local-name()="title"]')
96 | this_text = this_test[0].text
97 | if this_text is None:
98 | try:
99 | this_text = this_test[1].text
100 | except IndexError:
101 | this_text = ''
102 | try:
103 | score = fuzz.partial_ratio(title, this_text)
104 | except ValueError:
105 | score = 0
106 | if score >= 80:
107 | paper_number = i + 1
108 | break
109 | place_num = [0 for i in range(8)]
110 | self.paper_number = paper_number
111 | if paper_number == -1:
112 | return place_num
113 | # 使用序号,在xml文件中检索位置
114 | nodes = self.tree.xpath(f"//*[contains(text(), '[{paper_number}]')]")
115 | reference_times = len(nodes)
116 |
117 | for item in nodes:
118 | found_text = ''
119 | this_node = item
120 | while found_text == '':
121 | this_node = this_node.getparent()
122 | if this_node is None:
123 | break
124 | if this_node.xpath("local-name()") == 'figure':
125 | place_num[4] = 1
126 | it_children = this_node.iterchildren()
127 | for jtem in it_children:
128 | node = this_node
129 | if jtem.xpath("local-name()") == 'head':
130 | found_text = node.text
131 | n_num = jtem.attrib.get('n')
132 | node = this_node
133 | if n_num is None:
134 | break
135 | while not n_num.isdigit():
136 | node = node.getprevious()
137 | if node is None:
138 | break
139 | node_children = node.iterchildren()
140 | for ktem in node_children:
141 | if ktem.xpath("local-name()") == 'head':
142 | n = ktem.attrib.get('n')
143 | if n is not None and n.isdigit():
144 | n_num = ktem.attrib.get('n')
145 | found_text = ktem.text
146 | break
147 | break
148 |
149 | if this_node is None or found_text == '':
150 | place_num[6] = 1
151 | continue
152 | if found_text is not None:
153 | found_text = found_text.lower()
154 | if fuzz.partial_ratio('abstract', found_text) >= 60:
155 | place_num[0] = 1
156 | elif fuzz.partial_ratio('introduction', found_text) >= 60:
157 | place_num[1] = 1
158 | elif fuzz.partial_ratio('related work', found_text) >= 60:
159 | place_num[2] = 1
160 | elif fuzz.partial_ratio('method', found_text) >= 60:
161 | place_num[3] = 1
162 | elif fuzz.partial_ratio('result', found_text) >= 60 or fuzz.partial_ratio('experiment', found_text) >= 60:
163 | place_num[5] = 1
164 | else:
165 | place_num[6] = 1
166 | pattern = re.compile(r'[\d+]')
167 | nodes = self.tree.xpath("//*[re:match(text(), $pattern)]",
168 | namespaces={"re": "http://exslt.org/regular-expressions"},
169 | pattern=pattern.pattern)
170 | total_ref_num = len(nodes)
171 | if not total_ref_num == 0:
172 | place_num[7] = reference_times / total_ref_num
173 | return place_num
174 |
175 | # FOUR 重叠作者
176 | def get_common_authors(self, paper_id):
177 | # ref_authors_set = set([item.get('name') for item in self.query_result.get('authors', {})])
178 | ref_authors_set = set(self.query_result.get('authors', []))
179 | if not len(self.authors_set & ref_authors_set) == 0:
180 | return 1
181 | else:
182 | return 0
183 |
184 | # FIVE 关键词
185 | def key_words(self):
186 | if self.paper_number == -1:
187 | return 0
188 | pattern = re.compile(r'[\d+]')
189 | nodes = self.tree.xpath("//*[re:match(text(), $pattern)]",
190 | namespaces={"re": "http://exslt.org/regular-expressions"},
191 | pattern=pattern.pattern)
192 | key_words_list = ['motivated by', 'inspired by']
193 | for item in nodes:
194 | if item.xpath('local-name()') == 'ref':
195 | node_text = item.getparent().text
196 | else:
197 | node_text = item.text
198 | if node_text is None:
199 | return 0
200 | node_text = node_text.lower()
201 | for jtem in key_words_list:
202 | pattern = re.compile(fr"{jtem}")
203 | match = pattern.search(node_text)
204 | if match is not None:
205 | return 1
206 | return 0
207 |
208 | # SEVEN
209 |
210 | def get_reciprocal_of_reference_num(self):
211 | if self.num_ref == 0:
212 | return 0
213 | else:
214 | return 1 / self.num_ref
215 |
216 |
217 | def extract_paper_info_from_dblp():
218 | data_dir = join(settings.DATA_TRACE_DIR, "PST")
219 | papers_train = utils.load_json(data_dir, "paper_source_trace_train_ans.json")
220 | papers_valid = utils.load_json(data_dir, "paper_source_trace_valid_wo_ans.json")
221 |
222 | paper_dict_open = {}
223 | dblp_fname = "DBLP-Citation-network-V15.1.json"
224 | with open(join(data_dir, dblp_fname), "r", encoding="utf-8") as myFile:
225 | for i, line in enumerate(myFile):
226 | if len(line) <= 2:
227 | continue
228 | if i % 10000 == 0:
229 | logger.info("reading papers %d", i)
230 | paper_tmp = json.loads(line.strip())
231 | paper_dict_open[paper_tmp["id"]] = paper_tmp
232 |
233 | paper_dict_hit = dd(dict)
234 | for paper in tqdm(papers_train + papers_valid):
235 | cur_pid = paper["_id"]
236 | ref_ids = paper.get("references", [])
237 | pids = [cur_pid] + ref_ids
238 | for pid in pids:
239 | if pid not in paper_dict_open:
240 | continue
241 | cur_paper_info = paper_dict_open[pid]
242 | cur_authors = [a.get("name", "") for a in cur_paper_info.get("authors", [])]
243 | n_citation = cur_paper_info.get("n_citation", 0)
244 | title = cur_paper_info.get("title", "")
245 | paper_dict_hit[pid] = {"authors": cur_authors, "n_citation": n_citation, "title": title}
246 |
247 | print("number of papers after filtering", len(paper_dict_hit))
248 | utils.dump_json(paper_dict_hit, data_dir, "paper_info_hit_from_dblp.json")
249 |
250 |
251 | def extract_train_features():
252 | data_dir = join(settings.DATA_TRACE_DIR, "PST")
253 | with open(join(data_dir, "paper_source_trace_train_ans.json"), 'r', encoding='utf-8') as read_file:
254 | data_dic = json.load(read_file)
255 | all_id = [item['_id'] for item in data_dic]
256 | data_list = []
257 | label_list = []
258 |
259 | paper_info_more = utils.load_json(data_dir, "paper_info_hit_from_dblp.json")
260 |
261 | for i, item in tqdm(enumerate(all_id), total=len(all_id)):
262 | process_data = Process_data(data_dic[i], "train", paper_info_more)
263 | this_data, this_label = process_data.data_list, process_data.label_list
264 | data_list += this_data
265 | label_list += this_label
266 |
267 | out_dir = join(settings.OUT_DIR, "kddcup", "rf")
268 | os.makedirs(out_dir, exist_ok=True)
269 |
270 | with open(join(out_dir, "train_label.csv"), 'w', newline='') as f:
271 | writer = csv.writer(f)
272 | writer.writerows(label_list)
273 |
274 | with open(join(out_dir, "train_data.csv"), 'w', newline='') as write_file:
275 | writer = csv.writer(write_file)
276 | writer.writerows(data_list)
277 |
278 |
279 | def extract_valid_features():
280 | data_dir = join(settings.DATA_TRACE_DIR, "PST")
281 | with open(join(data_dir, "paper_source_trace_valid_wo_ans.json"), 'r', encoding='utf-8') as read_file:
282 | data_dic = json.load(read_file)
283 | all_id = [item['_id'] for item in data_dic]
284 | paper_info_more = utils.load_json(data_dir, "paper_info_hit_from_dblp.json")
285 |
286 | total_data_dic = {}
287 | for i, item in tqdm(enumerate(all_id), total=len(all_id)):
288 | process_data = Process_data(data_dic[i], "test", paper_info_more)
289 | n_refs = len(data_dic[i].get('references', []))
290 | this_data, this_label = process_data.data_list, process_data.label_list
291 | total_data_dic[item] = this_data
292 | assert len(this_data) == n_refs
293 |
294 | out_dir = join(settings.OUT_DIR, "kddcup", "rf")
295 | os.makedirs(out_dir, exist_ok=True)
296 | utils.dump_json(total_data_dic, out_dir, "valid_data.json")
297 |
298 |
299 | if __name__ == "__main__":
300 | extract_paper_info_from_dblp()
301 | extract_train_features()
302 | extract_valid_features()
303 |
--------------------------------------------------------------------------------
/rf/set_param.py:
--------------------------------------------------------------------------------
1 | class Args(object):
2 | def __init__(self, model_type):
3 | if model_type == "SVM":
4 | self.C = 1
5 | self.kernel = "sigmoid"
6 | self.verbose = 2
7 | self.max_iter = 100
8 | self.tol = 0.1
9 | elif model_type == "RandomForest":
10 | self.n_estimators = 100
11 | elif model_type == "LR":
12 | self.solver = 'saga'
13 | self.multi_class = 'ovr'
--------------------------------------------------------------------------------
/rule.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | from os.path import join
4 | from tqdm import tqdm
5 | from collections import defaultdict as dd
6 | from bs4 import BeautifulSoup
7 | from fuzzywuzzy import fuzz
8 | import numpy as np
9 | from sklearn.metrics import average_precision_score
10 |
11 | import utils
12 | import settings
13 |
14 |
15 | def extract_one_paper_via_rule(xml):
16 | bs = BeautifulSoup(xml, "xml")
17 | ref = []
18 | importantlist = []
19 | for item in bs.find_all(type='bibr'):
20 | if "target" not in item.attrs:
21 | continue
22 | item_str = "[{}](\"{}\")".format(item.attrs["target"], item.get_text())
23 | try:
24 | refer = item.attrs["target"][1:]
25 | ref.append((item_str, refer)) # 找到上下文
26 | # print(refer)
27 | pass
28 | except IndexError as e:
29 | continue
30 | xml = xml.lower()
31 | s2 = [ii for ii in range(len(xml)) if xml.startswith('motivated by', ii)]
32 | s3 = [ii for ii in range(len(xml)) if xml.startswith('inspired by', ii)]
33 | s = s2 + s3
34 | pos_to_signal = {}
35 | for i in s2:
36 | pos_to_signal[i] = "motivated by"
37 | for i in s3:
38 | pos_to_signal[i] = "inspired by"
39 |
40 | for i in ref:
41 | cur_bibr, idx = i
42 | p_ref = [ii for ii in range(len(xml)) if xml.startswith(cur_bibr, ii)]
43 | # print("p_ref", p_ref)
44 | for j in p_ref:
45 | for k in s:
46 | if abs(j-k) < 100:
47 | importantlist.append(idx)
48 | # print("hit***************************", j, k, i, pos_to_signal[k])
49 | break
50 | return importantlist
51 |
52 |
53 | def find_paper_source_by_rule(year=2023):
54 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year))
55 | truths = utils.load_json(data_year_dir, "paper_source_trace_test.json")
56 | pid_to_source_titles = dd(list)
57 | for paper in tqdm(truths):
58 | pid = paper["_id"]
59 | for ref in paper["refs_trace"]:
60 | pid_to_source_titles[pid].append(ref["title"].lower())
61 | xml_dir = join(settings.DATA_TRACE_DIR, "paper-xml")
62 | metrics = []
63 | p_idx = 0
64 |
65 | for paper in tqdm(truths):
66 | cur_pid = paper["_id"]
67 | file = join(xml_dir, cur_pid + ".tei.xml")
68 | f = open(file, encoding='utf-8')
69 | xml = f.read()
70 | bs = BeautifulSoup(xml, "xml")
71 | f.close()
72 |
73 | references = bs.find_all("biblStruct")
74 | bid_to_title = {}
75 | n_refs = 0
76 | for ref in references:
77 | if "xml:id" not in ref.attrs:
78 | continue
79 | bid = ref.attrs["xml:id"]
80 | if ref.analytic is None:
81 | continue
82 | if ref.analytic.title is None:
83 | continue
84 | bid_to_title[bid] = ref.analytic.title.text.lower()
85 | b_idx = int(bid[1:]) + 1
86 | if b_idx > n_refs:
87 | n_refs = b_idx
88 |
89 | bib_to_contexts = utils.find_bib_context(xml)
90 | bib_sorted = sorted(bib_to_contexts.keys())
91 |
92 | for bib in bib_sorted:
93 | cur_bib_idx = int(bib[1:])
94 | if cur_bib_idx + 1 > n_refs:
95 | n_refs = cur_bib_idx + 1
96 |
97 | y_true = []
98 | y_score = []
99 | try:
100 | source_titles = pid_to_source_titles[cur_pid]
101 | if len(source_titles) == 0:
102 | print("hit1")
103 | raise
104 | continue
105 | pred_sources = extract_one_paper_via_rule(xml)
106 | y_true = [0]* n_refs
107 | y_score = [0]* n_refs
108 |
109 | for bid in bid_to_title:
110 | cur_ref_title = bid_to_title[bid]
111 | for label_title in source_titles:
112 | if fuzz.ratio(cur_ref_title, label_title) >= 80:
113 | b_idx = int(bid[1:])
114 | y_true[b_idx] = 1
115 | break
116 | for ii in pred_sources:
117 | y_score[int(ii[1:])] = 1
118 | if sum(y_true) == 0:
119 | metrics.append(0)
120 | continue
121 | cur_map = average_precision_score(y_true, y_score)
122 | # print("cur_map", cur_map)
123 | metrics.append(cur_map)
124 | except IndexError as e:
125 | metrics.append(0)
126 | continue
127 | p_idx += 1
128 | if p_idx % 20 == 0:
129 | print("map until now", np.mean(metrics), len(metrics), cur_map)
130 |
131 | print("average map", np.mean(metrics), len(metrics))
132 |
133 |
134 | if __name__ == '__main__':
135 | find_paper_source_by_rule()
136 |
--------------------------------------------------------------------------------
/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import abspath, dirname, join
3 |
4 |
5 | PROJ_DIR = join(abspath(dirname(__file__)))
6 | DATA_DIR = join(PROJ_DIR, "data")
7 | OUT_DIR = join(PROJ_DIR, "out")
8 |
9 | os.makedirs(DATA_DIR, exist_ok=True)
10 | os.makedirs(OUT_DIR, exist_ok=True)
11 | DATA_TRACE_DIR = DATA_DIR
12 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | import json
3 | import numpy as np
4 | import pickle
5 | from collections import defaultdict as dd
6 | from bs4 import BeautifulSoup
7 | from datetime import datetime
8 |
9 | import logging
10 |
11 | logger = logging.getLogger(__name__)
12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') # include timestamp
13 |
14 |
15 | def load_json(rfdir, rfname):
16 | logger.info('loading %s ...', rfname)
17 | with open(join(rfdir, rfname), 'r', encoding='utf-8') as rf:
18 | data = json.load(rf)
19 | logger.info('%s loaded', rfname)
20 | return data
21 |
22 |
23 | def dump_json(obj, wfdir, wfname):
24 | logger.info('dumping %s ...', wfname)
25 | with open(join(wfdir, wfname), 'w', encoding='utf-8') as wf:
26 | json.dump(obj, wf, indent=4, ensure_ascii=False)
27 | logger.info('%s dumped.', wfname)
28 |
29 |
30 | def serialize_embedding(embedding):
31 | return pickle.dumps(embedding)
32 |
33 |
34 | def deserialize_embedding(s):
35 | return pickle.loads(s)
36 |
37 |
38 | def find_bib_context(xml, dist=100):
39 | bs = BeautifulSoup(xml, "xml")
40 | bib_to_context = dd(list)
41 | bibr_strs_to_bid_id = {}
42 | for item in bs.find_all(type='bibr'):
43 | if "target" not in item.attrs:
44 | continue
45 | bib_id = item.attrs["target"][1:]
46 | item_str = "[{}](\"{}\")".format(item.attrs["target"], item.get_text())
47 | bibr_strs_to_bid_id[item_str] = bib_id
48 |
49 | for item_str in bibr_strs_to_bid_id:
50 | bib_id = bibr_strs_to_bid_id[item_str]
51 | cur_bib_context_pos_start = [ii for ii in range(len(xml)) if xml.startswith(item_str, ii)]
52 | for pos in cur_bib_context_pos_start:
53 | bib_to_context[bib_id].append(xml[pos - dist: pos + dist].replace("\n", " ").replace("\r", " ").strip())
54 | return bib_to_context
55 |
56 |
57 | def sigmoid(x):
58 | return 1 / (1 + np.exp(-x))
59 |
60 |
61 | class Log:
62 | def __init__(self, file_path):
63 | self.file_path = file_path
64 | self.f = open(file_path, 'w+')
65 |
66 | def log(self, s):
67 | self.f.write(str(datetime.now()) + "\t" + s + '\n')
68 | self.f.flush()
--------------------------------------------------------------------------------