├── .gitignore ├── README.md ├── assets └── group_qr_code_20240621.jpg ├── bert.py ├── chatglm ├── arguments.py ├── data │ ├── ID_num_dic.json │ ├── calc_map.py │ ├── format_data.py │ ├── positive_negetive_balance.py │ ├── train.json │ ├── train_balance.json │ ├── valid.json │ └── valid_balance.json ├── ds_test_finetune.sh ├── ds_test_ptv2.sh ├── finetune_test.py ├── ft_ds_config.json ├── output │ └── test │ │ └── trans.py ├── test.py ├── trainer.py └── trainer_seq2seq.py ├── claude └── claude.py ├── galactica ├── calc_map.py ├── gala_base.py ├── gala_finetune.py └── process.py ├── glm ├── dataset.py ├── ds_config_glm_10b.json ├── ds_config_glm_2b.json ├── finetune_glm_10b_ds.py ├── finetune_glm_ds.py ├── process.py ├── run_finetune_ds.sh ├── run_finetune_ds_10b.sh └── test_glm.py ├── gpt-api ├── gpt.py └── map.py ├── net_emb.py ├── requirements.txt ├── result.txt ├── rf ├── model_main.py ├── model_rf.py ├── process_data.py ├── process_kddcup_data.py └── set_param.py ├── rule.py ├── settings.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | 3 | data/ 4 | out/ 5 | glm/saved/ 6 | .idea/ 7 | chatglm/output/ 8 | galactica/lightning_logs/ 9 | galactica/result/ 10 | galactica/saved_model/ 11 | processed_data/ 12 | saved_model/ 13 | 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # paper-source-trace 2 | 3 | ## Prerequisites 4 | - Linux 5 | - Python 3.9 6 | - PyTorch 1.10.0+cu111 7 | 8 | ## Getting Started 9 | 10 | ### Installation 11 | 12 | Clone this repo. 13 | 14 | ```bash 15 | git clone https://github.com/THUDM/paper-source-trace.git 16 | cd paper-source-trace 17 | ``` 18 | 19 | Please install dependencies by 20 | 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## PST Dataset 26 | The dataset can be downloaded from [BaiduPan](https://pan.baidu.com/s/1I_HZXBx7U0UsRHJL5JJagw?pwd=bft3) with password bft3, [Aliyun](https://open-data-set.oss-cn-beijing.aliyuncs.com/oag-benchmark/kddcup-2024/PST/PST.zip) or [DropBox](https://www.dropbox.com/scl/fi/namx1n55xzqil4zbkd5sv/PST.zip?rlkey=impcbm2acqmqhurv2oj0xxysx&dl=1). 27 | The paper XML files are generated by [Grobid](https://grobid.readthedocs.io/en/latest/Introduction/) APIs from paper pdfs. 28 | 29 | ## Run Baselines for [KDD Cup 2024](https://www.biendata.xyz/competition/pst_kdd_2024/) 30 | First, download DBLP dataset from [AMiner](https://opendata.aminer.cn/dataset/DBLP-Citation-network-V16.zip). 31 | Put the unzipped PST directory into ``data/`` and unzipped DBLP dataset into ``data/PST/``. 32 | 33 | ```bash 34 | cd $project_path 35 | export CUDA_VISIBLE_DEVICES='?' # specify which GPU(s) to be used 36 | export PYTHONPATH="`pwd`:$PYTHONPATH" 37 | 38 | # Method 1: Random Forest 39 | python rf/process_kddcup_data.py 40 | python rf/model_rf.py # output at out/kddcup/rf/ 41 | 42 | # Method 2: Network Embedding 43 | python net_emb.py # output at out/kddcup/prone/ 44 | 45 | # Method 3: SciBERT 46 | python bert.py # output at out/kddcup/scibert/ 47 | ``` 48 | 49 | ## Results on Valiation Set 50 | 51 | | Method | MAP | 52 | |-------|-------| 53 | | Random Forest | 0.21420 | 54 | | ProNE | 0.21668 | 55 | | SciBERT | 0.29489 | 56 | 57 | ## Citation 58 | 59 | If you find this repo useful in your research, please cite the following papers: 60 | 61 | ``` 62 | @article{zhang2024pst, 63 | title={PST-Bench: Tracing and Benchmarking the Source of Publications}, 64 | author={Fanjin Zhang and Kun Cao and Yukuo Cen and Jifan Yu and Da Yin and Jie Tang}, 65 | journal={arXiv preprint arXiv:2402.16009}, 66 | year={2024} 67 | } 68 | 69 | @inproceedings{zhang2024oag, 70 | title={OAG-bench: a human-curated benchmark for academic graph mining}, 71 | author={Fanjin Zhang and Shijie Shi and Yifan Zhu and Bo Chen and Yukuo Cen and Jifan Yu and Yelin Chen and Lulu Wang and Qingfei Zhao and Yuqing Cheng and Tianyi Han and Yuwei An and Dan Zhang and Weng Lam Tam and Kun Cao and Yunhe Pang and Xinyu Guan and Huihui Yuan and Jian Song and Xiaoyan Li and Yuxiao Dong and Jie Tang}, 72 | booktitle={Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, 73 | pages={6214--6225}, 74 | year={2024} 75 | } 76 | ``` 77 | 78 | ## Paper Sharing Group 79 | 80 | Hello everyone, 81 | 82 | We've created an online WeChat paper-sharing group where each member is required to share 2 computer science papers every week. We have established mechanisms of rewards and penalties for members who do and do not share papers as required. You are free to join or leave at any time. Welcome to join us! (You can receive the up-to-date QR code from this channel. https://t.me/+apOrPEOLGixiNjdl) 83 | 84 |

85 | 描述文字 86 |

87 | -------------------------------------------------------------------------------- /assets/group_qr_code_20240621.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/paper-source-trace/d0896c62508ad4bb00d28b28b8411c92034a409d/assets/group_qr_code_20240621.jpg -------------------------------------------------------------------------------- /bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | from tqdm import tqdm 4 | from collections import defaultdict as dd 5 | from bs4 import BeautifulSoup 6 | from fuzzywuzzy import fuzz 7 | import numpy as np 8 | import torch 9 | from transformers import AutoTokenizer 10 | from transformers import BertForSequenceClassification, get_linear_schedule_with_warmup 11 | from transformers.optimization import AdamW 12 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 13 | from tqdm import trange 14 | from sklearn.metrics import classification_report, precision_recall_fscore_support, average_precision_score 15 | import logging 16 | 17 | import utils 18 | import settings 19 | 20 | 21 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 22 | datefmt = '%m/%d/%Y %H:%M:%S', 23 | level = logging.INFO) 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | MAX_SEQ_LENGTH=512 28 | 29 | 30 | def prepare_train_test_data_for_bert(year=2023): 31 | x_train = [] 32 | y_train = [] 33 | x_valid = [] 34 | y_valid = [] 35 | x_test = [] 36 | y_test = [] 37 | 38 | truths = utils.load_json(settings.DATA_TRACE_DIR, "paper_source_trace_{}_final_filtered.json".format(year)) 39 | pid_to_source_titles = dd(list) 40 | for paper in tqdm(truths): 41 | pid = paper["_id"] 42 | for ref in paper["refs_trace"]: 43 | pid_to_source_titles[pid].append(ref["title"].lower()) 44 | 45 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year)) 46 | papers_train = utils.load_json(data_year_dir, "paper_source_trace_train.json") 47 | papers_valid = utils.load_json(data_year_dir, "paper_source_trace_valid.json") 48 | papers_test = utils.load_json(data_year_dir, "paper_source_trace_test.json") 49 | 50 | pids_train = {p["_id"] for p in papers_train} 51 | pids_valid = {p["_id"] for p in papers_valid} 52 | pids_test = {p["_id"] for p in papers_test} 53 | 54 | in_dir = join(settings.DATA_TRACE_DIR, "paper-xml") 55 | files = [] 56 | for f in os.listdir(in_dir): 57 | if f.endswith(".xml"): 58 | files.append(f) 59 | 60 | files = sorted(files) 61 | for file in tqdm(files): 62 | f = open(join(in_dir, file), encoding='utf-8') 63 | cur_pid = file.split(".")[0] 64 | if cur_pid not in pids_train and cur_pid not in pids_valid and cur_pid not in pids_test: 65 | continue 66 | xml = f.read() 67 | bs = BeautifulSoup(xml, "xml") 68 | 69 | source_titles = pid_to_source_titles[cur_pid] 70 | if len(source_titles) == 0: 71 | continue 72 | 73 | references = bs.find_all("biblStruct") 74 | bid_to_title = {} 75 | n_refs = 0 76 | for ref in references: 77 | if "xml:id" not in ref.attrs: 78 | continue 79 | bid = ref.attrs["xml:id"] 80 | if ref.analytic is None: 81 | continue 82 | if ref.analytic.title is None: 83 | continue 84 | bid_to_title[bid] = ref.analytic.title.text.lower() 85 | b_idx = int(bid[1:]) + 1 86 | if b_idx > n_refs: 87 | n_refs = b_idx 88 | 89 | flag = False 90 | 91 | cur_pos_bib = set() 92 | 93 | for bid in bid_to_title: 94 | cur_ref_title = bid_to_title[bid] 95 | for label_title in source_titles: 96 | if fuzz.ratio(cur_ref_title, label_title) >= 80: 97 | flag = True 98 | cur_pos_bib.add(bid) 99 | 100 | cur_neg_bib = set(bid_to_title.keys()) - cur_pos_bib 101 | 102 | if not flag: 103 | continue 104 | 105 | if len(cur_pos_bib) == 0 or len(cur_neg_bib) == 0: 106 | continue 107 | 108 | bib_to_contexts = utils.find_bib_context(xml) 109 | 110 | n_pos = len(cur_pos_bib) 111 | n_neg = n_pos * 10 112 | cur_neg_bib_sample = np.random.choice(list(cur_neg_bib), n_neg, replace=True) 113 | 114 | if cur_pid in pids_train: 115 | cur_x = x_train 116 | cur_y = y_train 117 | elif cur_pid in pids_valid: 118 | cur_x = x_valid 119 | cur_y = y_valid 120 | elif cur_pid in pids_test: 121 | cur_x = x_test 122 | cur_y = y_test 123 | else: 124 | continue 125 | # raise Exception("cur_pid not in train/valid/test") 126 | 127 | for bib in cur_pos_bib: 128 | cur_context = " ".join(bib_to_contexts[bib]) 129 | cur_x.append(cur_context) 130 | cur_y.append(1) 131 | 132 | for bib in cur_neg_bib_sample: 133 | cur_context = " ".join(bib_to_contexts[bib]) 134 | cur_x.append(cur_context) 135 | cur_y.append(0) 136 | 137 | print("len(x_train)", len(x_train), "len(x_valid)", len(x_valid), "len(x_test)", len(x_test)) 138 | 139 | with open(join(data_year_dir, "bib_context_train.txt"), "w", encoding="utf-8") as f: 140 | for line in x_train: 141 | f.write(line + "\n") 142 | 143 | with open(join(data_year_dir, "bib_context_valid.txt"), "w", encoding="utf-8") as f: 144 | for line in x_valid: 145 | f.write(line + "\n") 146 | 147 | with open(join(data_year_dir, "bib_context_test.txt"), "w", encoding="utf-8") as f: 148 | for line in x_test: 149 | f.write(line + "\n") 150 | 151 | with open(join(data_year_dir, "bib_context_train_label.txt"), "w", encoding="utf-8") as f: 152 | for line in y_train: 153 | f.write(str(line) + "\n") 154 | 155 | with open(join(data_year_dir, "bib_context_valid_label.txt"), "w", encoding="utf-8") as f: 156 | for line in y_valid: 157 | f.write(str(line) + "\n") 158 | 159 | with open(join(data_year_dir, "bib_context_test_label.txt"), "w", encoding="utf-8") as f: 160 | for line in y_test: 161 | f.write(str(line) + "\n") 162 | 163 | 164 | 165 | def prepare_bert_input(): 166 | x_train = [] 167 | y_train = [] 168 | x_valid = [] 169 | y_valid = [] 170 | 171 | data_dir = join(settings.DATA_TRACE_DIR, "PST") 172 | papers = utils.load_json(data_dir, "paper_source_trace_train_ans.json") 173 | n_papers = len(papers) 174 | papers = sorted(papers, key=lambda x: x["_id"]) 175 | n_train = int(n_papers * 2 / 3) 176 | # n_valid = n_papers - n_train 177 | 178 | papers_train = papers[:n_train] 179 | papers_valid = papers[n_train:] 180 | 181 | pids_train = {p["_id"] for p in papers_train} 182 | pids_valid = {p["_id"] for p in papers_valid} 183 | 184 | in_dir = join(data_dir, "paper-xml") 185 | files = [] 186 | for f in os.listdir(in_dir): 187 | if f.endswith(".xml"): 188 | files.append(f) 189 | 190 | pid_to_source_titles = dd(list) 191 | for paper in tqdm(papers): 192 | pid = paper["_id"] 193 | for ref in paper["refs_trace"]: 194 | pid_to_source_titles[pid].append(ref["title"].lower()) 195 | 196 | # files = sorted(files) 197 | # for file in tqdm(files): 198 | for cur_pid in tqdm(pids_train | pids_valid): 199 | # cur_pid = file.split(".")[0] 200 | # if cur_pid not in pids_train and cur_pid not in pids_valid: 201 | # continue 202 | f = open(join(in_dir, cur_pid + ".xml"), encoding='utf-8') 203 | xml = f.read() 204 | bs = BeautifulSoup(xml, "xml") 205 | 206 | source_titles = pid_to_source_titles[cur_pid] 207 | if len(source_titles) == 0: 208 | continue 209 | 210 | references = bs.find_all("biblStruct") 211 | bid_to_title = {} 212 | n_refs = 0 213 | for ref in references: 214 | if "xml:id" not in ref.attrs: 215 | continue 216 | bid = ref.attrs["xml:id"] 217 | if ref.analytic is None: 218 | continue 219 | if ref.analytic.title is None: 220 | continue 221 | bid_to_title[bid] = ref.analytic.title.text.lower() 222 | b_idx = int(bid[1:]) + 1 223 | if b_idx > n_refs: 224 | n_refs = b_idx 225 | 226 | flag = False 227 | 228 | cur_pos_bib = set() 229 | 230 | for bid in bid_to_title: 231 | cur_ref_title = bid_to_title[bid] 232 | for label_title in source_titles: 233 | if fuzz.ratio(cur_ref_title, label_title) >= 80: 234 | flag = True 235 | cur_pos_bib.add(bid) 236 | 237 | cur_neg_bib = set(bid_to_title.keys()) - cur_pos_bib 238 | 239 | if not flag: 240 | continue 241 | 242 | if len(cur_pos_bib) == 0 or len(cur_neg_bib) == 0: 243 | continue 244 | 245 | bib_to_contexts = utils.find_bib_context(xml) 246 | 247 | n_pos = len(cur_pos_bib) 248 | n_neg = n_pos * 10 249 | cur_neg_bib_sample = np.random.choice(list(cur_neg_bib), n_neg, replace=True) 250 | 251 | if cur_pid in pids_train: 252 | cur_x = x_train 253 | cur_y = y_train 254 | elif cur_pid in pids_valid: 255 | cur_x = x_valid 256 | cur_y = y_valid 257 | else: 258 | continue 259 | # raise Exception("cur_pid not in train/valid/test") 260 | 261 | for bib in cur_pos_bib: 262 | cur_context = " ".join(bib_to_contexts[bib]) 263 | cur_x.append(cur_context) 264 | cur_y.append(1) 265 | 266 | for bib in cur_neg_bib_sample: 267 | cur_context = " ".join(bib_to_contexts[bib]) 268 | cur_x.append(cur_context) 269 | cur_y.append(0) 270 | 271 | print("len(x_train)", len(x_train), "len(x_valid)", len(x_valid)) 272 | 273 | 274 | with open(join(data_dir, "bib_context_train.txt"), "w", encoding="utf-8") as f: 275 | for line in x_train: 276 | f.write(line + "\n") 277 | 278 | with open(join(data_dir, "bib_context_valid.txt"), "w", encoding="utf-8") as f: 279 | for line in x_valid: 280 | f.write(line + "\n") 281 | 282 | with open(join(data_dir, "bib_context_train_label.txt"), "w", encoding="utf-8") as f: 283 | for line in y_train: 284 | f.write(str(line) + "\n") 285 | 286 | with open(join(data_dir, "bib_context_valid_label.txt"), "w", encoding="utf-8") as f: 287 | for line in y_valid: 288 | f.write(str(line) + "\n") 289 | 290 | 291 | class BertInputItem(object): 292 | """An item with all the necessary attributes for finetuning BERT.""" 293 | 294 | def __init__(self, text, input_ids, input_mask, segment_ids, label_id): 295 | self.text = text 296 | self.input_ids = input_ids 297 | self.input_mask = input_mask 298 | self.segment_ids = segment_ids 299 | self.label_id = label_id 300 | 301 | 302 | def convert_examples_to_inputs(example_texts, example_labels, max_seq_length, tokenizer, verbose=0): 303 | """Loads a data file into a list of `InputBatch`s.""" 304 | 305 | input_items = [] 306 | examples = zip(example_texts, example_labels) 307 | for (ex_index, (text, label)) in enumerate(examples): 308 | 309 | # Create a list of token ids 310 | input_ids = tokenizer.encode(f"[CLS] {text} [SEP]") 311 | if len(input_ids) > max_seq_length: 312 | input_ids = input_ids[:max_seq_length] 313 | 314 | # All our tokens are in the first input segment (id 0). 315 | segment_ids = [0] * len(input_ids) 316 | 317 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 318 | # tokens are attended to. 319 | input_mask = [1] * len(input_ids) 320 | 321 | # Zero-pad up to the sequence length. 322 | padding = [0] * (max_seq_length - len(input_ids)) 323 | input_ids += padding 324 | input_mask += padding 325 | segment_ids += padding 326 | 327 | assert len(input_ids) == max_seq_length 328 | assert len(input_mask) == max_seq_length 329 | assert len(segment_ids) == max_seq_length 330 | 331 | label_id = label 332 | 333 | input_items.append( 334 | BertInputItem(text=text, 335 | input_ids=input_ids, 336 | input_mask=input_mask, 337 | segment_ids=segment_ids, 338 | label_id=label_id)) 339 | 340 | return input_items 341 | 342 | 343 | def get_data_loader(features, max_seq_length, batch_size, shuffle=True): 344 | 345 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 346 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 347 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 348 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) 349 | data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 350 | 351 | dataloader = DataLoader(data, shuffle=shuffle, batch_size=batch_size) 352 | return dataloader 353 | 354 | 355 | def evaluate(model, dataloader, device, criterion): 356 | model.eval() 357 | 358 | eval_loss = 0 359 | nb_eval_steps = 0 360 | predicted_labels, correct_labels = [], [] 361 | 362 | for step, batch in enumerate(tqdm(dataloader, desc="Evaluation iteration")): 363 | batch = tuple(t.to(device) for t in batch) 364 | input_ids, input_mask, segment_ids, label_ids = batch 365 | 366 | with torch.no_grad(): 367 | r = model(input_ids, attention_mask=input_mask, 368 | token_type_ids=segment_ids, labels=label_ids) 369 | # tmp_eval_loss = r[0] 370 | logits = r[1] 371 | # print("logits", logits) 372 | tmp_eval_loss = criterion(logits, label_ids) 373 | 374 | outputs = np.argmax(logits.to('cpu'), axis=1) 375 | label_ids = label_ids.to('cpu').numpy() 376 | 377 | predicted_labels += list(outputs) 378 | correct_labels += list(label_ids) 379 | 380 | eval_loss += tmp_eval_loss.mean().item() 381 | nb_eval_steps += 1 382 | 383 | eval_loss = eval_loss / nb_eval_steps 384 | 385 | correct_labels = np.array(correct_labels) 386 | predicted_labels = np.array(predicted_labels) 387 | 388 | return eval_loss, correct_labels, predicted_labels 389 | 390 | 391 | def train(year=2023, model_name="scibert"): 392 | print("model name", model_name) 393 | train_texts = [] 394 | dev_texts = [] 395 | train_labels = [] 396 | dev_labels = [] 397 | data_year_dir = join(settings.DATA_TRACE_DIR, "PST") 398 | print("data_year_dir", data_year_dir) 399 | 400 | with open(join(data_year_dir, "bib_context_train.txt"), "r", encoding="utf-8") as f: 401 | for line in f: 402 | train_texts.append(line.strip()) 403 | with open(join(data_year_dir, "bib_context_valid.txt"), "r", encoding="utf-8") as f: 404 | for line in f: 405 | dev_texts.append(line.strip()) 406 | 407 | with open(join(data_year_dir, "bib_context_train_label.txt"), "r", encoding="utf-8") as f: 408 | for line in f: 409 | train_labels.append(int(line.strip())) 410 | with open(join(data_year_dir, "bib_context_valid_label.txt"), "r", encoding="utf-8") as f: 411 | for line in f: 412 | dev_labels.append(int(line.strip())) 413 | 414 | 415 | print("Train size:", len(train_texts)) 416 | print("Dev size:", len(dev_texts)) 417 | 418 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 419 | 420 | class_weight = len(train_labels) / (2 * np.bincount(train_labels)) 421 | class_weight = torch.Tensor(class_weight).to(device) 422 | print("Class weight:", class_weight) 423 | 424 | if model_name == "bert": 425 | BERT_MODEL = "bert-base-uncased" 426 | elif model_name == "scibert": 427 | BERT_MODEL = "allenai/scibert_scivocab_uncased" 428 | else: 429 | raise NotImplementedError 430 | tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL) 431 | 432 | model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels = 2) 433 | model.to(device) 434 | 435 | criterion = torch.nn.CrossEntropyLoss(weight=class_weight) 436 | 437 | train_features = convert_examples_to_inputs(train_texts, train_labels, MAX_SEQ_LENGTH, tokenizer, verbose=0) 438 | dev_features = convert_examples_to_inputs(dev_texts, dev_labels, MAX_SEQ_LENGTH, tokenizer) 439 | 440 | BATCH_SIZE = 16 441 | train_dataloader = get_data_loader(train_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=True) 442 | dev_dataloader = get_data_loader(dev_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False) 443 | 444 | GRADIENT_ACCUMULATION_STEPS = 1 445 | NUM_TRAIN_EPOCHS = 20 446 | LEARNING_RATE = 5e-5 447 | WARMUP_PROPORTION = 0.1 448 | MAX_GRAD_NORM = 5 449 | 450 | num_train_steps = int(len(train_dataloader.dataset) / BATCH_SIZE / GRADIENT_ACCUMULATION_STEPS * NUM_TRAIN_EPOCHS) 451 | num_warmup_steps = int(WARMUP_PROPORTION * num_train_steps) 452 | 453 | param_optimizer = list(model.named_parameters()) 454 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 455 | optimizer_grouped_parameters = [ 456 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 457 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 458 | ] 459 | 460 | optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE, correct_bias=False) 461 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps) 462 | 463 | OUTPUT_DIR = join(settings.OUT_DIR, "kddcup", model_name) 464 | os.makedirs(OUTPUT_DIR, exist_ok=True) 465 | 466 | MODEL_FILE_NAME = "pytorch_model.bin" 467 | PATIENCE = 5 468 | 469 | loss_history = [] 470 | no_improvement = 0 471 | for _ in trange(int(NUM_TRAIN_EPOCHS), desc="Epoch"): 472 | model.train() 473 | tr_loss = 0 474 | nb_tr_examples, nb_tr_steps = 0, 0 475 | for step, batch in enumerate(tqdm(train_dataloader, desc="Training iteration")): 476 | batch = tuple(t.to(device) for t in batch) 477 | input_ids, input_mask, segment_ids, label_ids = batch 478 | 479 | outputs = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids, labels=label_ids) 480 | # loss = outputs[0] 481 | logits = outputs[1] 482 | 483 | loss = criterion(logits, label_ids) 484 | 485 | if GRADIENT_ACCUMULATION_STEPS > 1: 486 | loss = loss / GRADIENT_ACCUMULATION_STEPS 487 | 488 | loss.backward() 489 | tr_loss += loss.item() 490 | 491 | if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0: 492 | torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM) 493 | 494 | optimizer.step() 495 | optimizer.zero_grad() 496 | scheduler.step() 497 | 498 | dev_loss, _, _ = evaluate(model, dev_dataloader, device, criterion) 499 | 500 | print("Loss history:", loss_history) 501 | print("Dev loss:", dev_loss) 502 | 503 | if len(loss_history) == 0 or dev_loss < min(loss_history): 504 | no_improvement = 0 505 | model_to_save = model.module if hasattr(model, 'module') else model 506 | output_model_file = os.path.join(OUTPUT_DIR, MODEL_FILE_NAME) 507 | torch.save(model_to_save.state_dict(), output_model_file) 508 | else: 509 | no_improvement += 1 510 | 511 | if no_improvement >= PATIENCE: 512 | print("No improvement on development set. Finish training.") 513 | break 514 | 515 | loss_history.append(dev_loss) 516 | 517 | 518 | def eval_test_papers_bert(year=2023, model_name="scibert"): 519 | print("model name", model_name) 520 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year)) 521 | papers_test = utils.load_json(data_year_dir, "paper_source_trace_test.json") 522 | pids_test = {p["_id"] for p in papers_test} 523 | 524 | in_dir = join(settings.DATA_TRACE_DIR, "paper-xml") 525 | files = [] 526 | for f in os.listdir(in_dir): 527 | cur_pid = f.split(".")[0] 528 | if f.endswith(".xml") and cur_pid in pids_test: 529 | files.append(f) 530 | 531 | truths = papers_test 532 | pid_to_source_titles = dd(list) 533 | for paper in tqdm(truths): 534 | pid = paper["_id"] 535 | for ref in paper["refs_trace"]: 536 | pid_to_source_titles[pid].append(ref["title"].lower()) 537 | 538 | if model_name == "bert": 539 | BERT_MODEL = "bert-base-uncased" 540 | elif model_name == "scibert": 541 | BERT_MODEL = "allenai/scibert_scivocab_uncased" 542 | else: 543 | raise NotImplementedError 544 | tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL) 545 | 546 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 547 | print("device", device) 548 | model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels = 2) 549 | # model.load_state_dict(torch.load(join(settings.OUT_DIR, model_name, "pytorch_model.bin"))) 550 | # model.load_state_dict(torch.load(join(settings.OUT_DIR, "bert", "pytorch_model.bin"))) 551 | model.to(device) 552 | model.eval() 553 | 554 | BATCH_SIZE = 16 555 | metrics = [] 556 | f_idx = 0 557 | 558 | xml_dir = join(settings.DATA_TRACE_DIR, "paper-xml") 559 | 560 | for paper in tqdm(papers_test): 561 | cur_pid = paper["_id"] 562 | file = join(xml_dir, cur_pid + ".tei.xml") 563 | f = open(file, encoding='utf-8') 564 | 565 | xml = f.read() 566 | bs = BeautifulSoup(xml, "xml") 567 | f.close() 568 | 569 | source_titles = pid_to_source_titles[cur_pid] 570 | if len(source_titles) == 0: 571 | continue 572 | 573 | references = bs.find_all("biblStruct") 574 | bid_to_title = {} 575 | n_refs = 0 576 | for ref in references: 577 | if "xml:id" not in ref.attrs: 578 | continue 579 | bid = ref.attrs["xml:id"] 580 | if ref.analytic is None: 581 | continue 582 | if ref.analytic.title is None: 583 | continue 584 | bid_to_title[bid] = ref.analytic.title.text.lower() 585 | b_idx = int(bid[1:]) + 1 586 | if b_idx > n_refs: 587 | n_refs = b_idx 588 | 589 | bib_to_contexts = utils.find_bib_context(xml) 590 | bib_sorted = sorted(bib_to_contexts.keys()) 591 | 592 | for bib in bib_sorted: 593 | cur_bib_idx = int(bib[1:]) 594 | if cur_bib_idx + 1 > n_refs: 595 | n_refs = cur_bib_idx + 1 596 | 597 | y_true = [0] * n_refs 598 | y_score = [0] * n_refs 599 | 600 | flag = False 601 | for bid in bid_to_title: 602 | cur_ref_title = bid_to_title[bid] 603 | for label_title in source_titles: 604 | if fuzz.ratio(cur_ref_title, label_title) >= 80: 605 | flag = True 606 | b_idx = int(bid[1:]) 607 | y_true[b_idx] = 1 608 | 609 | if not flag: 610 | continue 611 | 612 | contexts_sorted = [" ".join(bib_to_contexts[bib]) for bib in bib_sorted] 613 | 614 | test_features = convert_examples_to_inputs(contexts_sorted, y_score, MAX_SEQ_LENGTH, tokenizer) 615 | test_dataloader = get_data_loader(test_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False) 616 | 617 | predicted_scores = [] 618 | for step, batch in enumerate(test_dataloader): 619 | batch = tuple(t.to(device) for t in batch) 620 | input_ids, input_mask, segment_ids, label_ids = batch 621 | 622 | with torch.no_grad(): 623 | r = model(input_ids, attention_mask=input_mask, 624 | token_type_ids=segment_ids, labels=label_ids) 625 | tmp_eval_loss = r[0] 626 | logits = r[1] 627 | 628 | cur_pred_scores = logits[:, 1].to('cpu').numpy() 629 | predicted_scores.extend(cur_pred_scores) 630 | 631 | try: 632 | for ii in range(len(predicted_scores)): 633 | bib_idx = int(bib_sorted[ii][1:]) 634 | # print("bib_idx", bib_idx) 635 | y_score[bib_idx] = predicted_scores[ii] 636 | except IndexError as e: 637 | metrics.append(0) 638 | continue 639 | 640 | cur_map = average_precision_score(y_true, y_score) 641 | metrics.append(cur_map) 642 | f_idx += 1 643 | if f_idx % 20 == 0: 644 | print("map until now", np.mean(metrics), len(metrics), cur_map) 645 | 646 | print("bert average map", np.mean(metrics), len(metrics)) 647 | 648 | 649 | def gen_kddcup_valid_submission_bert(model_name="scibert"): 650 | print("model name", model_name) 651 | data_dir = join(settings.DATA_TRACE_DIR, "PST") 652 | papers = utils.load_json(data_dir, "paper_source_trace_valid_wo_ans.json") 653 | 654 | if model_name == "bert": 655 | BERT_MODEL = "bert-base-uncased" 656 | elif model_name == "scibert": 657 | BERT_MODEL = "allenai/scibert_scivocab_uncased" 658 | else: 659 | raise NotImplementedError 660 | tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL) 661 | 662 | sub_example_dict = utils.load_json(data_dir, "submission_example_valid.json") 663 | 664 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 665 | print("device", device) 666 | model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels = 2) 667 | model.load_state_dict(torch.load(join(settings.OUT_DIR, "kddcup", model_name, "pytorch_model.bin"))) 668 | 669 | model.to(device) 670 | model.eval() 671 | 672 | BATCH_SIZE = 16 673 | # metrics = [] 674 | # f_idx = 0 675 | 676 | xml_dir = join(data_dir, "paper-xml") 677 | sub_dict = {} 678 | 679 | for paper in tqdm(papers): 680 | cur_pid = paper["_id"] 681 | file = join(xml_dir, cur_pid + ".xml") 682 | f = open(file, encoding='utf-8') 683 | xml = f.read() 684 | bs = BeautifulSoup(xml, "xml") 685 | f.close() 686 | 687 | references = bs.find_all("biblStruct") 688 | bid_to_title = {} 689 | n_refs = 0 690 | for ref in references: 691 | if "xml:id" not in ref.attrs: 692 | continue 693 | bid = ref.attrs["xml:id"] 694 | if ref.analytic is None: 695 | continue 696 | if ref.analytic.title is None: 697 | continue 698 | bid_to_title[bid] = ref.analytic.title.text.lower() 699 | b_idx = int(bid[1:]) + 1 700 | if b_idx > n_refs: 701 | n_refs = b_idx 702 | 703 | bib_to_contexts = utils.find_bib_context(xml) 704 | # bib_sorted = sorted(bib_to_contexts.keys()) 705 | bib_sorted = ["b" + str(ii) for ii in range(n_refs)] 706 | 707 | y_score = [0] * n_refs 708 | 709 | assert len(sub_example_dict[cur_pid]) == n_refs 710 | # continue 711 | 712 | contexts_sorted = [" ".join(bib_to_contexts[bib]) for bib in bib_sorted] 713 | 714 | test_features = convert_examples_to_inputs(contexts_sorted, y_score, MAX_SEQ_LENGTH, tokenizer) 715 | test_dataloader = get_data_loader(test_features, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False) 716 | 717 | predicted_scores = [] 718 | for step, batch in enumerate(test_dataloader): 719 | batch = tuple(t.to(device) for t in batch) 720 | input_ids, input_mask, segment_ids, label_ids = batch 721 | 722 | with torch.no_grad(): 723 | r = model(input_ids, attention_mask=input_mask, 724 | token_type_ids=segment_ids, labels=label_ids) 725 | tmp_eval_loss = r[0] 726 | logits = r[1] 727 | 728 | cur_pred_scores = logits[:, 1].to('cpu').numpy() 729 | predicted_scores.extend(cur_pred_scores) 730 | 731 | for ii in range(len(predicted_scores)): 732 | bib_idx = int(bib_sorted[ii][1:]) 733 | # print("bib_idx", bib_idx) 734 | y_score[bib_idx] = float(utils.sigmoid(predicted_scores[ii])) 735 | 736 | sub_dict[cur_pid] = y_score 737 | 738 | utils.dump_json(sub_dict, join(settings.OUT_DIR, "kddcup", model_name), "valid_submission_scibert.json") 739 | 740 | 741 | if __name__ == "__main__": 742 | # prepare_train_test_data_for_bert() 743 | prepare_bert_input() 744 | train(model_name="scibert") 745 | # eval_test_papers_bert(model_name="scibert") 746 | gen_kddcup_valid_submission_bert(model_name="scibert") 747 | -------------------------------------------------------------------------------- /chatglm/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | 11 | model_name_or_path: str = field( 12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 13 | ) 14 | ptuning_checkpoint: str = field( 15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"} 16 | ) 17 | config_name: Optional[str] = field( 18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 19 | ) 20 | tokenizer_name: Optional[str] = field( 21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 22 | ) 23 | cache_dir: Optional[str] = field( 24 | default=None, 25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 26 | ) 27 | use_fast_tokenizer: bool = field( 28 | default=True, 29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 30 | ) 31 | model_revision: str = field( 32 | default="main", 33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 34 | ) 35 | use_auth_token: bool = field( 36 | default=False, 37 | metadata={ 38 | "help": ( 39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 40 | "with private models)." 41 | ) 42 | }, 43 | ) 44 | resize_position_embeddings: Optional[bool] = field( 45 | default=None, 46 | metadata={ 47 | "help": ( 48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 49 | "the model's position embeddings." 50 | ) 51 | }, 52 | ) 53 | quantization_bit: Optional[int] = field( 54 | default=None 55 | ) 56 | pre_seq_len: Optional[int] = field( 57 | default=None 58 | ) 59 | prefix_projection: bool = field( 60 | default=False 61 | ) 62 | 63 | 64 | @dataclass 65 | class DataTrainingArguments: 66 | """ 67 | Arguments pertaining to what data we are going to input our model for training and eval. 68 | """ 69 | 70 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 71 | 72 | dataset_name: Optional[str] = field( 73 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 74 | ) 75 | dataset_config_name: Optional[str] = field( 76 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 77 | ) 78 | prompt_column: Optional[str] = field( 79 | default=None, 80 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 81 | ) 82 | response_column: Optional[str] = field( 83 | default=None, 84 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 85 | ) 86 | history_column: Optional[str] = field( 87 | default=None, 88 | metadata={"help": "The name of the column in the datasets containing the history of chat."}, 89 | ) 90 | train_file: Optional[str] = field( 91 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 92 | ) 93 | validation_file: Optional[str] = field( 94 | default=None, 95 | metadata={ 96 | "help": ( 97 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 98 | ) 99 | }, 100 | ) 101 | test_file: Optional[str] = field( 102 | default=None, 103 | metadata={ 104 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 105 | }, 106 | ) 107 | overwrite_cache: bool = field( 108 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 109 | ) 110 | preprocessing_num_workers: Optional[int] = field( 111 | default=None, 112 | metadata={"help": "The number of processes to use for the preprocessing."}, 113 | ) 114 | max_source_length: Optional[int] = field( 115 | default=1024, 116 | metadata={ 117 | "help": ( 118 | "The maximum total input sequence length after tokenization. Sequences longer " 119 | "than this will be truncated, sequences shorter will be padded." 120 | ) 121 | }, 122 | ) 123 | max_target_length: Optional[int] = field( 124 | default=128, 125 | metadata={ 126 | "help": ( 127 | "The maximum total sequence length for target text after tokenization. Sequences longer " 128 | "than this will be truncated, sequences shorter will be padded." 129 | ) 130 | }, 131 | ) 132 | val_max_target_length: Optional[int] = field( 133 | default=None, 134 | metadata={ 135 | "help": ( 136 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 137 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 138 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 139 | "during ``evaluate`` and ``predict``." 140 | ) 141 | }, 142 | ) 143 | pad_to_max_length: bool = field( 144 | default=False, 145 | metadata={ 146 | "help": ( 147 | "Whether to pad all samples to model maximum sentence length. " 148 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 149 | "efficient on GPU but very bad for TPU." 150 | ) 151 | }, 152 | ) 153 | max_train_samples: Optional[int] = field( 154 | default=None, 155 | metadata={ 156 | "help": ( 157 | "For debugging purposes or quicker training, truncate the number of training examples to this " 158 | "value if set." 159 | ) 160 | }, 161 | ) 162 | max_eval_samples: Optional[int] = field( 163 | default=None, 164 | metadata={ 165 | "help": ( 166 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 167 | "value if set." 168 | ) 169 | }, 170 | ) 171 | max_predict_samples: Optional[int] = field( 172 | default=None, 173 | metadata={ 174 | "help": ( 175 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 176 | "value if set." 177 | ) 178 | }, 179 | ) 180 | num_beams: Optional[int] = field( 181 | default=None, 182 | metadata={ 183 | "help": ( 184 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 185 | "which is used during ``evaluate`` and ``predict``." 186 | ) 187 | }, 188 | ) 189 | ignore_pad_token_for_loss: bool = field( 190 | default=True, 191 | metadata={ 192 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 193 | }, 194 | ) 195 | source_prefix: Optional[str] = field( 196 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 197 | ) 198 | 199 | forced_bos_token: Optional[str] = field( 200 | default=None, 201 | metadata={ 202 | "help": ( 203 | "The token to force as the first generated token after the decoder_start_token_id." 204 | "Useful for multilingual models like mBART where the first generated token" 205 | "needs to be the target language token (Usually it is the target language token)" 206 | ) 207 | }, 208 | ) 209 | 210 | 211 | 212 | def __post_init__(self): 213 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: 214 | raise ValueError("Need either a dataset name or a training/validation/test file.") 215 | else: 216 | if self.train_file is not None: 217 | extension = self.train_file.split(".")[-1] 218 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 219 | if self.validation_file is not None: 220 | extension = self.validation_file.split(".")[-1] 221 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 222 | if self.val_max_target_length is None: 223 | self.val_max_target_length = self.max_target_length 224 | 225 | -------------------------------------------------------------------------------- /chatglm/data/calc_map.py: -------------------------------------------------------------------------------- 1 | import json 2 | import warnings 3 | warnings.filterwarnings('ignore') 4 | from sklearn.metrics import average_precision_score 5 | # base_path = "/data/caokun/huge_model/chatGLM/ChatGLM-6B-main/ptuning/output/test/v2/" 6 | # base_path = "output/ptuning/" 7 | base_path = "output/finetune/" 8 | 9 | num = 5 10 | dic_list = [base_path + "generated_predictions.txt" for i in range(1, 2)] 11 | # dic_list = [base_path + str(i) + "000/generated_predictions.txt" for i in range(1, 6)] 12 | for i in range(len(dic_list)): 13 | result_list = [] 14 | with open(dic_list[i], "r") as read_file: 15 | result_dic = json.load(read_file) 16 | for key in result_dic.keys(): 17 | pre = [] 18 | res = [] 19 | for jtem in result_dic[key]: 20 | data = json.loads(jtem.strip()) 21 | if data["labels"] == "Yes": 22 | res.append(1) 23 | else: 24 | res.append(0) 25 | if "yes" in data["predict"][0] or "Yes" in data["predict"][0] or "YES" in data["predict"][0]: 26 | pre.append(1) 27 | else: 28 | pre.append(0) 29 | result_list.append(average_precision_score(res, pre)) 30 | print(f"{i}:map={sum(result_list)/len(result_list)}") -------------------------------------------------------------------------------- /chatglm/data/format_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | with open("test.json", "r") as read_file: 3 | data_dic = json.load(read_file) 4 | write_data = open("test2.json", "w") 5 | write_ID_num = open("ID_num_dic.json", "w") 6 | ID_num_dic = {} 7 | n = 0 8 | for item in data_dic: 9 | this_data = data_dic[item] 10 | ID_num_dic[item] = [] 11 | for jtem in this_data: 12 | ID_num_dic[item].append(n) 13 | n += 1 14 | write_data.write(str(json.dumps(jtem)) + '\n') 15 | json.dumps(ID_num_dic, write_ID_num, indent=2) 16 | -------------------------------------------------------------------------------- /chatglm/data/positive_negetive_balance.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | target_list = ["train", "valid"] 5 | yes_list, no_list = [], [] 6 | for item in target_list: 7 | with open("chatglm/data/" + item + '.json', "r") as read_file: 8 | all_lines = read_file.readlines() 9 | for data in all_lines: 10 | data_dic = json.loads(data.strip()) 11 | if data_dic["summary"] == "Yes": 12 | yes_list.append(data.strip()) 13 | else: 14 | no_list.append(data.strip()) 15 | no_list = random.sample(no_list, len(yes_list)) 16 | all_list = yes_list+no_list 17 | random.shuffle(all_list) 18 | with open("chatglm/data/" + item+"_balance.json", "w") as write_file: 19 | for jtem in all_list: 20 | write_file.write(jtem+"\n") 21 | 22 | -------------------------------------------------------------------------------- /chatglm/ds_test_finetune.sh: -------------------------------------------------------------------------------- 1 | 2 | LR=1e-4 3 | 4 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 5 | 6 | deepspeed --include localhost:1,2,6,7 --master_port $MASTER_PORT finetune_test.py \ 7 | --deepspeed ft_ds_config.json \ 8 | --do_predict \ 9 | --test_file data/test2.json \ 10 | --prompt_column content \ 11 | --response_column summary \ 12 | --overwrite_cache \ 13 | --model_name_or_path /data/caokun/huge_model/chatGLM/ChatGLM-6B-main/ptuning/output/chatglm-ft/checkpoint-5000/ \ 14 | --output_dir output/finetune/ \ 15 | --overwrite_output_dir \ 16 | --max_source_length 64 \ 17 | --max_target_length 64 \ 18 | --per_device_train_batch_size 4 \ 19 | --per_device_eval_batch_size 1 \ 20 | --gradient_accumulation_steps 1 \ 21 | --predict_with_generate \ 22 | --max_steps 5000 \ 23 | --logging_steps 10 \ 24 | --save_steps 1000 \ 25 | --learning_rate $LR \ 26 | --fp16 27 | 28 | -------------------------------------------------------------------------------- /chatglm/ds_test_ptv2.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | LR=2e-2 3 | 4 | CUDA_VISIBLE_DEVICES=2 python3 test.py \ 5 | --do_predict \ 6 | --test_file data/test2.json \ 7 | --prompt_column content \ 8 | --response_column summary \ 9 | --overwrite_cache \ 10 | --model_name_or_path /data/caokun/huge_model/chatGLM/model/ \ 11 | --output_dir output/ptuning/ \ 12 | --overwrite_output_dir \ 13 | --max_source_length 64 \ 14 | --max_target_length 64 \ 15 | --per_device_train_batch_size 1 \ 16 | --per_device_eval_batch_size 1 \ 17 | --gradient_accumulation_steps 16 \ 18 | --predict_with_generate \ 19 | --max_steps 3000 \ 20 | --logging_steps 10 \ 21 | --save_steps 1000 \ 22 | --learning_rate $LR \ 23 | --pre_seq_len $PRE_SEQ_LEN \ 24 | --quantization_bit 4 \ 25 | # --ptuning_checkpoint /data/caokun/huge_model/chatGLM/ChatGLM-6B-main/ptuning/output/v2/checkpoint-1000/ 26 | 27 | -------------------------------------------------------------------------------- /chatglm/finetune_test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/paper-source-trace/d0896c62508ad4bb00d28b28b8411c92034a409d/chatglm/finetune_test.py -------------------------------------------------------------------------------- /chatglm/ft_ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "zero_allow_untested_optimizer": true, 4 | "fp16": { 5 | "enabled": "auto", 6 | "loss_scale": 0, 7 | "initial_scale_power": 16, 8 | "loss_scale_window": 1000, 9 | "hysteresis": 2, 10 | "min_loss_scale": 1 11 | }, 12 | "zero_optimization": { 13 | "stage": 3, 14 | "allgather_partitions": true, 15 | "allgather_bucket_size": 5e8, 16 | "overlap_comm": false, 17 | "reduce_scatter": true, 18 | "reduce_bucket_size": 5e8, 19 | "contiguous_gradients" : true 20 | } 21 | } -------------------------------------------------------------------------------- /chatglm/output/test/trans.py: -------------------------------------------------------------------------------- 1 | # -*-coding:gbk -*- 2 | import os 3 | base_path = "v2/" 4 | all_file = os.listdir(base_path) 5 | for ktem in all_file: 6 | result = [0, 0, 0, 0] #"YesYes, YesNo, NoYes, NoNo" 7 | with open(base_path + ktem + "/generated_predictions.txt", "r") as read_file: 8 | all_lines = read_file.readlines() 9 | for item in all_lines: 10 | data = item.strip() 11 | data_list = data.split(",") 12 | if data_list[0][-2] == "s": 13 | if data_list[1][-3] == "s": 14 | result[0] += 1 15 | else: 16 | result[1] += 1 17 | else: 18 | if data_list[1][-3] == "s": 19 | result[2] += 1 20 | else: 21 | result[3] += 1 22 | Accuracy = (result[0]+result[3])/(sum(result)) 23 | Precision = result[0]/(result[0]+result[2]) 24 | Recall = result[0]/(result[0]+result[1]) 25 | print(ktem+":") 26 | print("Accuracy:" + str(Accuracy)) 27 | print("Precision:"+ str(Precision)) 28 | print("Recall:"+ str(Recall)) 29 | print("F1:"+ str((2*Precision*Recall)/(Precision+Recall))) 30 | print(result) 31 | 32 | 33 | -------------------------------------------------------------------------------- /chatglm/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | import json 25 | 26 | import numpy as np 27 | from datasets import load_dataset 28 | import jieba 29 | from rouge_chinese import Rouge 30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 31 | import torch 32 | 33 | import transformers 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModel, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | Seq2SeqTrainingArguments, 41 | set_seed, 42 | ) 43 | from trainer_seq2seq import Seq2SeqTrainer 44 | 45 | from arguments import ModelArguments, DataTrainingArguments 46 | logger = logging.getLogger(__name__) 47 | 48 | def main(): 49 | 50 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 51 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 52 | # If we pass only one argument to the script and it's the path to a json file, 53 | # let's parse it to get our arguments. 54 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 55 | else: 56 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 57 | 58 | # Setup logging 59 | logging.basicConfig( 60 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 61 | datefmt="%m/%d/%Y %H:%M:%S", 62 | handlers=[logging.StreamHandler(sys.stdout)], 63 | ) 64 | 65 | if training_args.should_log: 66 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 67 | transformers.utils.logging.set_verbosity_info() 68 | 69 | log_level = training_args.get_process_log_level() 70 | logger.setLevel(log_level) 71 | # datasets.utils.logging.set_verbosity(log_level) 72 | transformers.utils.logging.set_verbosity(log_level) 73 | transformers.utils.logging.enable_default_handler() 74 | transformers.utils.logging.enable_explicit_format() 75 | 76 | # Log on each process the small summary: 77 | logger.warning( 78 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 79 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 80 | ) 81 | logger.info(f"Training/evaluation parameters {training_args}") 82 | 83 | # Set seed before initializing model. 84 | set_seed(training_args.seed) 85 | 86 | # Load dataset 87 | data_files = {} 88 | if data_args.train_file is not None: 89 | data_files["train"] = data_args.train_file 90 | extension = data_args.train_file.split(".")[-1] 91 | if data_args.validation_file is not None: 92 | data_files["validation"] = data_args.validation_file 93 | extension = data_args.validation_file.split(".")[-1] 94 | if data_args.test_file is not None: 95 | data_files["test"] = data_args.test_file 96 | extension = data_args.test_file.split(".")[-1] 97 | 98 | raw_datasets = load_dataset( 99 | extension, 100 | data_files=data_files, 101 | cache_dir=model_args.cache_dir, 102 | use_auth_token=True if model_args.use_auth_token else None, 103 | ) 104 | # Load pretrained model and tokenizer 105 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 106 | config.pre_seq_len = model_args.pre_seq_len 107 | config.prefix_projection = model_args.prefix_projection 108 | 109 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 110 | 111 | if model_args.ptuning_checkpoint is not None: 112 | # Evaluation 113 | # Loading extra state dict of prefix encoder 114 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^load model~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^") 115 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 116 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) 117 | new_prefix_state_dict = {} 118 | for k, v in prefix_state_dict.items(): 119 | if k.startswith("transformer.prefix_encoder."): 120 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 121 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 122 | else: 123 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 124 | 125 | if model_args.quantization_bit is not None: 126 | print(f"Quantized to {model_args.quantization_bit} bit") 127 | model = model.quantize(model_args.quantization_bit) 128 | if model_args.pre_seq_len is not None: 129 | # P-tuning v2 130 | model = model.half() 131 | model.transformer.prefix_encoder.float() 132 | else: 133 | # Finetune 134 | model = model.float() 135 | 136 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 137 | 138 | # Preprocessing the datasets. 139 | # We need to tokenize inputs and targets. 140 | if training_args.do_train: 141 | column_names = raw_datasets["train"].column_names 142 | elif training_args.do_eval: 143 | column_names = raw_datasets["validation"].column_names 144 | elif training_args.do_predict: 145 | column_names = raw_datasets["test"].column_names 146 | else: 147 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 148 | return 149 | 150 | # Get the column names for input/target. 151 | prompt_column = data_args.prompt_column 152 | response_column = data_args.response_column 153 | history_column = data_args.history_column 154 | 155 | # Temporarily set max_target_length for training. 156 | max_target_length = data_args.max_target_length 157 | 158 | def preprocess_function_eval(examples): 159 | inputs, targets = [], [] 160 | for i in range(len(examples[prompt_column])): 161 | if examples[prompt_column][i] and examples[response_column][i]: 162 | query = examples[prompt_column][i] 163 | if history_column is None or len(examples[history_column][i]) == 0: 164 | prompt = query 165 | else: 166 | prompt = "" 167 | history = examples[history_column][i] 168 | for turn_idx, (old_query, response) in enumerate(history): 169 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) 170 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 171 | inputs.append(prompt) 172 | targets.append(examples[response_column][i]) 173 | 174 | inputs = [prefix + inp for inp in inputs] 175 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) 176 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) 177 | 178 | if data_args.ignore_pad_token_for_loss: 179 | labels["input_ids"] = [ 180 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 181 | ] 182 | model_inputs["labels"] = labels["input_ids"] 183 | 184 | return model_inputs 185 | 186 | def preprocess_function_train(examples): 187 | max_seq_length = data_args.max_source_length + data_args.max_target_length 188 | 189 | model_inputs = { 190 | "input_ids": [], 191 | "labels": [], 192 | } 193 | for i in range(len(examples[prompt_column])): 194 | if examples[prompt_column][i] and examples[response_column][i]: 195 | query, answer = examples[prompt_column][i], examples[response_column][i] 196 | 197 | if history_column is None: 198 | prompt = query 199 | else: 200 | prompt = "" 201 | history = examples[history_column][i] 202 | for turn_idx, (old_query, response) in enumerate(history): 203 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) 204 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 205 | 206 | prompt = prefix + prompt 207 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 208 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False) 209 | 210 | if len(a_ids) > data_args.max_source_length - 1: 211 | a_ids = a_ids[: data_args.max_source_length - 1] 212 | 213 | if len(b_ids) > data_args.max_target_length - 2: 214 | b_ids = b_ids[: data_args.max_target_length - 2] 215 | 216 | input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) 217 | 218 | context_length = input_ids.index(tokenizer.bos_token_id) 219 | mask_position = context_length - 1 220 | labels = [-100] * context_length + input_ids[mask_position+1:] 221 | 222 | pad_len = max_seq_length - len(input_ids) 223 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len 224 | labels = labels + [tokenizer.pad_token_id] * pad_len 225 | if data_args.ignore_pad_token_for_loss: 226 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels] 227 | 228 | model_inputs["input_ids"].append(input_ids) 229 | model_inputs["labels"].append(labels) 230 | 231 | return model_inputs 232 | 233 | def print_dataset_example(example): 234 | print("input_ids",example["input_ids"]) 235 | print("inputs", tokenizer.decode(example["input_ids"])) 236 | print("label_ids", example["labels"]) 237 | print("labels", tokenizer.decode(example["labels"])) 238 | 239 | if training_args.do_train: 240 | if "train" not in raw_datasets: 241 | raise ValueError("--do_train requires a train dataset") 242 | train_dataset = raw_datasets["train"] 243 | if data_args.max_train_samples is not None: 244 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 245 | train_dataset = train_dataset.select(range(max_train_samples)) 246 | with training_args.main_process_first(desc="train dataset map pre-processing"): 247 | train_dataset = train_dataset.map( 248 | preprocess_function_train, 249 | batched=True, 250 | num_proc=data_args.preprocessing_num_workers, 251 | remove_columns=column_names, 252 | load_from_cache_file=not data_args.overwrite_cache, 253 | desc="Running tokenizer on train dataset", 254 | ) 255 | print_dataset_example(train_dataset[0]) 256 | 257 | if training_args.do_eval: 258 | max_target_length = data_args.val_max_target_length 259 | if "validation" not in raw_datasets: 260 | raise ValueError("--do_eval requires a validation dataset") 261 | eval_dataset = raw_datasets["validation"] 262 | if data_args.max_eval_samples is not None: 263 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 264 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 265 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 266 | eval_dataset = eval_dataset.map( 267 | preprocess_function_eval, 268 | batched=True, 269 | num_proc=data_args.preprocessing_num_workers, 270 | remove_columns=column_names, 271 | load_from_cache_file=not data_args.overwrite_cache, 272 | desc="Running tokenizer on validation dataset", 273 | ) 274 | print_dataset_example(eval_dataset[0]) 275 | 276 | if training_args.do_predict: 277 | max_target_length = data_args.val_max_target_length 278 | if "test" not in raw_datasets: 279 | raise ValueError("--do_predict requires a test dataset") 280 | predict_dataset = raw_datasets["test"] 281 | if data_args.max_predict_samples is not None: 282 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 283 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 284 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 285 | predict_dataset = predict_dataset.map( 286 | preprocess_function_eval, 287 | batched=True, 288 | num_proc=data_args.preprocessing_num_workers, 289 | remove_columns=column_names, 290 | load_from_cache_file=not data_args.overwrite_cache, 291 | desc="Running tokenizer on prediction dataset", 292 | ) 293 | print_dataset_example(predict_dataset[0]) 294 | 295 | # Data collator 296 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 297 | data_collator = DataCollatorForSeq2Seq( 298 | tokenizer, 299 | model=model, 300 | label_pad_token_id=label_pad_token_id, 301 | pad_to_multiple_of=None, 302 | padding=False 303 | ) 304 | 305 | # Metric 306 | def compute_metrics(eval_preds): 307 | preds, labels = eval_preds 308 | if isinstance(preds, tuple): 309 | preds = preds[0] 310 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 311 | if data_args.ignore_pad_token_for_loss: 312 | # Replace -100 in the labels as we can't decode them. 313 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 314 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 315 | 316 | score_dict = { 317 | "rouge-1": [], 318 | "rouge-2": [], 319 | "rouge-l": [], 320 | "bleu-4": [] 321 | } 322 | for pred, label in zip(decoded_preds, decoded_labels): 323 | hypothesis = list(jieba.cut(pred)) 324 | reference = list(jieba.cut(label)) 325 | rouge = Rouge() 326 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) 327 | result = scores[0] 328 | 329 | for k, v in result.items(): 330 | score_dict[k].append(round(v["f"] * 100, 4)) 331 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 332 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 333 | 334 | for k, v in score_dict.items(): 335 | score_dict[k] = float(np.mean(v)) 336 | return score_dict 337 | 338 | # Override the decoding parameters of Seq2SeqTrainer 339 | training_args.generation_max_length = ( 340 | training_args.generation_max_length 341 | if training_args.generation_max_length is not None 342 | else data_args.val_max_target_length 343 | ) 344 | training_args.generation_num_beams = ( 345 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 346 | ) 347 | # Initialize our Trainer 348 | trainer = Seq2SeqTrainer( 349 | model=model, 350 | args=training_args, 351 | train_dataset=train_dataset if training_args.do_train else None, 352 | eval_dataset=eval_dataset if training_args.do_eval else None, 353 | tokenizer=tokenizer, 354 | data_collator=data_collator, 355 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 356 | save_prefixencoder=model_args.pre_seq_len is not None 357 | ) 358 | 359 | # Training 360 | if training_args.do_train: 361 | checkpoint = None 362 | if training_args.resume_from_checkpoint is not None: 363 | checkpoint = training_args.resume_from_checkpoint 364 | # elif last_checkpoint is not None: 365 | # checkpoint = last_checkpoint 366 | model.gradient_checkpointing_enable() 367 | model.enable_input_require_grads() 368 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 369 | # trainer.save_model() # Saves the tokenizer too for easy upload 370 | 371 | metrics = train_result.metrics 372 | max_train_samples = ( 373 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 374 | ) 375 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 376 | 377 | trainer.log_metrics("train", metrics) 378 | trainer.save_metrics("train", metrics) 379 | trainer.save_state() 380 | 381 | # Evaluation 382 | results = {} 383 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1 384 | if training_args.do_eval: 385 | logger.info("*** Evaluate ***") 386 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95) 387 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 388 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 389 | 390 | trainer.log_metrics("eval", metrics) 391 | trainer.save_metrics("eval", metrics) 392 | 393 | if training_args.do_predict: 394 | logger.info("*** Predict ***") 395 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95) 396 | metrics = predict_results.metrics 397 | max_predict_samples = ( 398 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 399 | ) 400 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 401 | 402 | trainer.log_metrics("predict", metrics) 403 | trainer.save_metrics("predict", metrics) 404 | 405 | if trainer.is_world_process_zero(): 406 | if training_args.predict_with_generate: 407 | predictions = tokenizer.batch_decode( 408 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 409 | ) 410 | predictions = [pred.strip() for pred in predictions] 411 | labels = tokenizer.batch_decode( 412 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True 413 | ) 414 | labels = [label.strip() for label in labels] 415 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 416 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 417 | for p, l in zip(predictions, labels): 418 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False) 419 | writer.write(f"{res}\n") 420 | return results 421 | 422 | 423 | def _mp_fn(index): 424 | # For xla_spawn (TPUs) 425 | main() 426 | 427 | 428 | if __name__ == "__main__": 429 | main() 430 | -------------------------------------------------------------------------------- /chatglm/trainer_seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch import nn 19 | from torch.utils.data import Dataset 20 | 21 | from transformers.deepspeed import is_deepspeed_zero3_enabled 22 | from trainer import Trainer 23 | from transformers.trainer_utils import PredictionOutput 24 | from transformers.utils import logging 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | class Seq2SeqTrainer(Trainer): 31 | def evaluate( 32 | self, 33 | eval_dataset: Optional[Dataset] = None, 34 | ignore_keys: Optional[List[str]] = None, 35 | metric_key_prefix: str = "eval", 36 | **gen_kwargs 37 | ) -> Dict[str, float]: 38 | """ 39 | Run evaluation and returns metrics. 40 | 41 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 42 | (pass it to the init `compute_metrics` argument). 43 | 44 | You can also subclass and override this method to inject custom behavior. 45 | 46 | Args: 47 | eval_dataset (`Dataset`, *optional*): 48 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns 49 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` 50 | method. 51 | ignore_keys (`List[str]`, *optional*): 52 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 53 | gathering predictions. 54 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 55 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 56 | "eval_bleu" if the prefix is `"eval"` (default) 57 | max_length (`int`, *optional*): 58 | The maximum target length to use when predicting with the generate method. 59 | num_beams (`int`, *optional*): 60 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 61 | beam search. 62 | gen_kwargs: 63 | Additional `generate` specific kwargs. 64 | 65 | Returns: 66 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 67 | dictionary also contains the epoch number which comes from the training state. 68 | """ 69 | 70 | gen_kwargs = gen_kwargs.copy() 71 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 72 | gen_kwargs["max_length"] = self.args.generation_max_length 73 | gen_kwargs["num_beams"] = ( 74 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 75 | ) 76 | self._gen_kwargs = gen_kwargs 77 | 78 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 79 | 80 | def predict( 81 | self, 82 | test_dataset: Dataset, 83 | ignore_keys: Optional[List[str]] = None, 84 | metric_key_prefix: str = "test", 85 | **gen_kwargs 86 | ) -> PredictionOutput: 87 | """ 88 | Run prediction and returns predictions and potential metrics. 89 | 90 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method 91 | will also return metrics, like in `evaluate()`. 92 | 93 | Args: 94 | test_dataset (`Dataset`): 95 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the 96 | `model.forward()` method are automatically removed. Has to implement the method `__len__` 97 | ignore_keys (`List[str]`, *optional*): 98 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 99 | gathering predictions. 100 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 101 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 102 | "eval_bleu" if the prefix is `"eval"` (default) 103 | max_length (`int`, *optional*): 104 | The maximum target length to use when predicting with the generate method. 105 | num_beams (`int`, *optional*): 106 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 107 | beam search. 108 | gen_kwargs: 109 | Additional `generate` specific kwargs. 110 | 111 | 112 | 113 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic 114 | padding in a token classification task) the predictions will be padded (on the right) to allow for 115 | concatenation into one array. The padding index is -100. 116 | 117 | 118 | 119 | Returns: *NamedTuple* A namedtuple with the following keys: 120 | 121 | - predictions (`np.ndarray`): The predictions on `test_dataset`. 122 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). 123 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained 124 | labels). 125 | """ 126 | 127 | gen_kwargs = gen_kwargs.copy() 128 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 129 | gen_kwargs["max_length"] = self.args.generation_max_length 130 | gen_kwargs["num_beams"] = ( 131 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 132 | ) 133 | self._gen_kwargs = gen_kwargs 134 | 135 | 136 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 137 | 138 | def prediction_step( 139 | self, 140 | model: nn.Module, 141 | inputs: Dict[str, Union[torch.Tensor, Any]], 142 | prediction_loss_only: bool, 143 | ignore_keys: Optional[List[str]] = None, 144 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 145 | """ 146 | Perform an evaluation step on `model` using `inputs`. 147 | 148 | Subclass and override to inject custom behavior. 149 | 150 | Args: 151 | model (`nn.Module`): 152 | The model to evaluate. 153 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 154 | The inputs and targets of the model. 155 | 156 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 157 | argument `labels`. Check your model's documentation for all accepted arguments. 158 | prediction_loss_only (`bool`): 159 | Whether or not to return the loss only. 160 | 161 | Return: 162 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 163 | labels (each being optional). 164 | """ 165 | 166 | if not self.args.predict_with_generate or prediction_loss_only: 167 | return super().prediction_step( 168 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 169 | ) 170 | 171 | has_labels = "labels" in inputs 172 | inputs = self._prepare_inputs(inputs) 173 | 174 | # XXX: adapt synced_gpus for fairscale as well 175 | gen_kwargs = self._gen_kwargs.copy() 176 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 177 | gen_kwargs["max_length"] = self.model.config.max_length 178 | gen_kwargs["num_beams"] = ( 179 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 180 | ) 181 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 182 | gen_kwargs["synced_gpus"] = ( 183 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 184 | ) 185 | 186 | if "attention_mask" in inputs: 187 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 188 | if "position_ids" in inputs: 189 | gen_kwargs["position_ids"] = inputs.get("position_ids", None) 190 | if "global_attention_mask" in inputs: 191 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 192 | 193 | # prepare generation inputs 194 | # some encoder-decoder models can have varying encoder's and thus 195 | # varying model input names 196 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 197 | generation_inputs = inputs[self.model.encoder.main_input_name] 198 | else: 199 | generation_inputs = inputs[self.model.main_input_name] 200 | 201 | gen_kwargs["input_ids"] = generation_inputs 202 | generated_tokens = self.model.generate(**gen_kwargs) 203 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] 204 | 205 | # in case the batch is shorter than max length, the output should be padded 206 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: 207 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 208 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( 209 | gen_kwargs["max_new_tokens"] + 1 210 | ): 211 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) 212 | 213 | loss = None 214 | 215 | if self.args.prediction_loss_only: 216 | return (loss, None, None) 217 | 218 | if has_labels: 219 | labels = inputs["labels"] 220 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: 221 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 222 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( 223 | gen_kwargs["max_new_tokens"] + 1 224 | ): 225 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) 226 | else: 227 | labels = None 228 | 229 | return (loss, generated_tokens, labels) 230 | 231 | def _pad_tensors_to_max_len(self, tensor, max_length): 232 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): 233 | # If PAD token is not defined at least EOS token has to be defined 234 | pad_token_id = ( 235 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id 236 | ) 237 | else: 238 | if self.model.config.pad_token_id is not None: 239 | pad_token_id = self.model.config.pad_token_id 240 | else: 241 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") 242 | 243 | padded_tensor = pad_token_id * torch.ones( 244 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 245 | ) 246 | padded_tensor[:, : tensor.shape[-1]] = tensor 247 | return padded_tensor 248 | -------------------------------------------------------------------------------- /claude/claude.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | 4 | url = "" # Your url 5 | 6 | headers = { 7 | "Authorization": "", # Your Authorization 8 | "content-type": "application/json" 9 | } 10 | with open("result/claude.json", "r") as read_file: 11 | finished_dic = json.load(read_file).keys() 12 | result_dic = {} 13 | write_file = open("result/claude.json", "a") 14 | try: 15 | with open("test.json", "r") as read_file: 16 | data_dic = json.load(read_file) 17 | for item in data_dic.keys(): 18 | if item in finished_dic: 19 | continue 20 | result_list = [] 21 | for jtem in data_dic[item]: 22 | this_data = json.loads(jtem.strip()) 23 | content, summery = this_data["content"], this_data["summary"] 24 | data = { 25 | "messages": [ 26 | { 27 | "role": "user", 28 | "content": content, 29 | } 30 | ], 31 | "model": "claude-instant-1-100k", 32 | "max_tokens_to_sample": 300, 33 | } 34 | response = requests.post(url, headers=headers, json=data) 35 | answer = json.loads(response.text)["choices"][0]["message"]["content"] 36 | result_list.append({"labels": summery, "predict":answer}) 37 | result_dic[item] = result_list 38 | except: 39 | pass 40 | finally: 41 | json.dump(result_dic, write_file, indent=2) -------------------------------------------------------------------------------- /galactica/calc_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import defaultdict as dd 4 | from tqdm import tqdm 5 | import warnings 6 | warnings.filterwarnings('ignore') 7 | from sklearn.metrics import average_precision_score 8 | 9 | 10 | result_file = "galactica_lora_result_2.json" 11 | pid_to_labels = dd(list) 12 | pid_to_predict = dd(list) 13 | with open(os.path.join("result", result_file), "r") as rf: 14 | for i, line in tqdm(enumerate(rf)): 15 | cur_item = json.loads(line.strip()) 16 | pid = cur_item["pid"] 17 | cur_label = cur_item["label"] 18 | pid_to_labels[pid].append(cur_label) 19 | cur_predict = cur_item["predict"] 20 | try: 21 | s_idx = cur_predict.index("The answer is") 22 | cur_predict = cur_predict[s_idx + 13:] 23 | except: 24 | pass 25 | if "no" in cur_predict or "No" in cur_predict or "NO" in cur_predict: 26 | pid_to_predict[pid].append(0) 27 | else: 28 | pid_to_predict[pid].append(1) 29 | 30 | maps = [] 31 | for pid in tqdm(pid_to_labels): 32 | cur_labels = pid_to_labels[pid] 33 | if sum(cur_labels) == 0: 34 | continue 35 | cur_predict = pid_to_predict[pid] 36 | cur_map = average_precision_score(cur_labels, cur_predict) 37 | maps.append(cur_map) 38 | 39 | print(maps) 40 | print(f"{i}:map={sum(maps)/len(maps)}") 41 | 42 | 43 | """ 44 | dic_list = ["result/gala_standard.json"] 45 | for i in range(len(dic_list)): 46 | result_list = [] 47 | with open(dic_list[i], "r") as read_file: 48 | result_dic = json.load(read_file) 49 | for key in result_dic.keys(): 50 | pre = [] 51 | res = [] 52 | for jtem in result_dic[key]: 53 | # data = json.loads(jtem.strip()) 54 | data = jtem 55 | if data["labels"] == "Yes": 56 | res.append(1) 57 | else: 58 | res.append(0) 59 | if "no" in data["predict"] or "No" in data["predict"] or "NO" in data["predict"]: 60 | pre.append(0) 61 | elif "yes" in data["predict"] or "Yes" in data["predict"] or "YES" in data["predict"]: 62 | pre.append(1) 63 | else: 64 | pre.append(0) 65 | if sum(res) == 0: 66 | continue 67 | cur_map = average_precision_score(res, pre) 68 | print(cur_map) 69 | result_list.append(cur_map) 70 | print(f"{i}:map={sum(result_list)/len(result_list)}") 71 | """ -------------------------------------------------------------------------------- /galactica/gala_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 3 | import warnings 4 | warnings.filterwarnings('ignore') 5 | import galai as gal 6 | import json 7 | from tqdm import tqdm 8 | from galai.notebook_utils import * 9 | 10 | model = gal.load_model("standard") 11 | # with open("../chatglm/data/test.json", "r") as read_file: 12 | # all_lines = read_file.readlines() 13 | os.makedirs("result", exist_ok=True) 14 | 15 | write_file_2 = open("result/galactica_standard_4.json", "a") 16 | # with open("../chatglm/data/test.json", "r") as read_file: 17 | # data_dic = json.load(read_file) 18 | result_dic = {} 19 | with open("data/test.jsonl", "r") as rf: 20 | for i, line in tqdm(enumerate(rf)): 21 | item = json.loads(line.strip()) 22 | context = item["context"] 23 | answer = item["label"] 24 | item_new = item 25 | if len(context) <= 3: 26 | item_new["predict"] = 0 27 | else: 28 | result = model.generate([" ".join(context.split()[-200:])]) 29 | item_new["predict"] = result[0] 30 | write_file_2.write(json.dumps(item_new, ensure_ascii=False) + "\n") 31 | write_file_2.flush() 32 | 33 | write_file_2.close() 34 | 35 | 36 | """ 37 | write_file = open("result/gala_standard.json", "w") 38 | write_file_2 = open("result/gala_standard_2.json", "a") 39 | with open("../chatglm/data/test.json", "r") as read_file: 40 | data_dic = json.load(read_file) 41 | result_dic = {} 42 | for item in tqdm(data_dic.keys()): 43 | data_list = data_dic[item] 44 | result_dic[item] = [] 45 | for jtem in data_list: 46 | # dic = json.loads(jtem.strip()) 47 | dic = jtem 48 | content = dic["content"] 49 | answer = dic["summary"] 50 | if len(content) <= 3: 51 | result = {"labels": answer, "predict": "No"} 52 | else: 53 | result = model.generate([" ".join(content.split()[-200:])]) 54 | result = {"labels": answer, "predict":result[0]} 55 | result_with_id = {"id": item, "labels": answer, "predict":result["predict"]} 56 | write_file_2.write(json.dumps(result_with_id, ensure_ascii=False) + "\n") 57 | write_file_2.flush() 58 | result_dic[item].append(result) 59 | json.dump(result_dic, write_file, indent=2) 60 | write_file.close() 61 | write_file_2.close() 62 | """ 63 | -------------------------------------------------------------------------------- /galactica/gala_finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" 3 | from xturing.datasets.instruction_dataset import InstructionDataset 4 | from xturing.models import BaseModel 5 | import json 6 | from tqdm import tqdm 7 | 8 | data_dic = {} 9 | # with open("data/train_balance_2.json", "r") as read_file: 10 | # all_lines = read_file.readlines() 11 | instruction_dataset = InstructionDataset("data/pst_data") 12 | # Initializes the model 13 | #model = BaseModel.load("/root/huge_model/galactica/galactica") 14 | # model = BaseModel.create("galactica_lora") 15 | model = BaseModel.create("galactica") 16 | 17 | # Finetuned the model 18 | # model.finetune(dataset=instruction_dataset) 19 | # model.load("saved_model/") 20 | 21 | # Once the model has been finetuned, you can start doing inferences 22 | output = model.generate(texts=["Why LLM models are becoming so important?"]) 23 | print("Generated output by the model: {}".format(output)) 24 | 25 | # Save the model 26 | # model.save("saved_model/") 27 | 28 | # write_file = open("result/galactica_lora_result.json", "w") 29 | # write_file_2 = open("result/galactica_lora_result_2.json", "a") 30 | write_file_2 = open("result/galactica_standard_3.json", "a") 31 | # with open("../chatglm/data/test.json", "r") as read_file: 32 | # data_dic = json.load(read_file) 33 | result_dic = {} 34 | with open("data/test.jsonl", "r") as rf: 35 | for i, line in tqdm(enumerate(rf)): 36 | item = json.loads(line.strip()) 37 | context = item["context"] 38 | answer = item["label"] 39 | item_new = item 40 | if len(context) <= 3: 41 | item_new["predict"] = 0 42 | else: 43 | result = model.generate(texts=[" ".join(context.split()[-200:])]) 44 | item_new["predict"] = result[0] 45 | write_file_2.write(json.dumps(item_new, ensure_ascii=False) + "\n") 46 | write_file_2.flush() 47 | 48 | write_file_2.close() 49 | 50 | """ 51 | for item in tqdm(data_dic.keys()): 52 | data_list = data_dic[item] 53 | result_dic[item] = [] 54 | for jtem in data_list: 55 | dic = jtem 56 | content = dic["content"] 57 | answer = dic["summary"] 58 | if len(content) <= 3: 59 | result = {"labels": answer, "predict": "No"} 60 | else: 61 | result = model.generate(texts=[content]) 62 | result = {"labels": answer, "predict":result[0]} 63 | result_with_id = {"id": item, "labels": answer, "predict":result["predict"]} 64 | write_file_2.write(json.dumps(result_with_id, ensure_ascii=False) + "\n") 65 | result_dic[item].append(result) 66 | json.dump(result_dic, write_file, indent=2) 67 | write_file.close() 68 | write_file_2.close() 69 | """ -------------------------------------------------------------------------------- /galactica/process.py: -------------------------------------------------------------------------------- 1 | #import json 2 | #import jsonlines 3 | #with open("datasets/train_balance.json", "r") as read_file: 4 | # all_lines = read_file.readlines() 5 | #write_file = jsonlines.open("datasets/train_balance_3.jsonl", "w") 6 | #for i, item in enumerate(all_lines): 7 | # data_dic = json.loads(item.strip()) 8 | # new_dic = {"id":f"seed_task_{i}", "name":f"{i}", "instruction":data_dic["content"], "instances":[{"input":"", "output":data_dic["summary"], "is_classification": False}]} 9 | # write_file.write(json.dumps(new_dic)) 10 | import json 11 | # import random 12 | import numpy as np 13 | from tqdm import tqdm 14 | from collections import defaultdict as dd 15 | from os.path import join 16 | import os 17 | from bs4 import BeautifulSoup 18 | from fuzzywuzzy import fuzz 19 | 20 | import utils 21 | import settings 22 | 23 | from datasets import Dataset, DatasetDict 24 | 25 | # Convert the alpaca JSON dataset to HF format 26 | 27 | 28 | # Right now only the HuggingFace datasets are supported, that's why the JSON Alpaca dataset 29 | # needs to be converted to the HuggingFace format. In addition, this HF dataset should have 3 columns for instruction finetuning: instruction, text and target. 30 | def preprocess_alpaca_json_data(alpaca_dataset_path: str): 31 | """Creates a dataset given the alpaca JSON dataset. You can download it here: https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json 32 | 33 | :param alpaca_dataset_path: path of the Alpaca dataset 34 | """ 35 | read_file = open(alpaca_dataset_path, "r") 36 | instructions = [] 37 | inputs = [] 38 | outputs = [] 39 | for item in read_file: 40 | dic = json.loads(item.strip()) 41 | instructions.append(dic["instruction"]) 42 | inputs.append(dic["input"]) 43 | outputs.append(dic["output"]) 44 | 45 | data_dict = { 46 | "train": {"instruction": instructions, "text": inputs, "target": outputs} 47 | } 48 | 49 | dataset = DatasetDict() 50 | # using your `Dict` object 51 | for k, v in data_dict.items(): 52 | dataset[k] = Dataset.from_dict(v) 53 | 54 | dataset.save_to_disk(str("galactica/data/pst_data")) 55 | 56 | 57 | def make_balance_data_for_galactica(): 58 | target_list = ["train", "valid"] 59 | yes_list, no_list = [], [] 60 | for item in target_list: 61 | with open("glm/data/" + item + '.json', "r") as read_file: 62 | all_lines = read_file.readlines() 63 | for data in all_lines: 64 | data_dic = json.loads(data.strip()) 65 | if data_dic["label"] == 1: 66 | yes_list.append(data_dic) 67 | else: 68 | no_list.append(data_dic) 69 | np.random.seed(42) 70 | no_list = np.random.choice(no_list, len(yes_list), replace=False).tolist() 71 | all_list = yes_list+no_list 72 | print(len(all_list)) 73 | np.random.shuffle(all_list) 74 | with open("galactica/data/" + item+"_balance_gala.json", "w") as write_file: 75 | for jtem in all_list: 76 | new_item = {} 77 | new_item["output"] = "Yes" if jtem["label"] else "No" 78 | new_item["instruction"] = jtem["inputs_pretokenized"][:-7] 79 | new_item["input"] = "" 80 | write_file.write(json.dumps(new_item)+"\n") 81 | 82 | 83 | def gen_test_data_json_lines(year=2023): 84 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year)) 85 | papers_test = utils.load_json(data_year_dir, "paper_source_trace_test.json") 86 | pids_test = {p["_id"] for p in papers_test} 87 | 88 | in_dir = join(settings.DATA_TRACE_DIR, "paper-xml") 89 | files = [] 90 | for f in os.listdir(in_dir): 91 | cur_pid = f.split(".")[0] 92 | if f.endswith(".xml") and cur_pid in pids_test: 93 | files.append(f) 94 | 95 | truths = papers_test 96 | pid_to_source_titles = dd(list) 97 | for paper in tqdm(truths): 98 | pid = paper["_id"] 99 | for ref in paper["refs_trace"]: 100 | pid_to_source_titles[pid].append(ref["title"].lower()) 101 | 102 | 103 | xml_dir = join(settings.DATA_TRACE_DIR, "paper-xml") 104 | wf = open("galactica/data/test.jsonl", "w") 105 | 106 | for paper in tqdm(papers_test): 107 | cur_pid = paper["_id"] 108 | file = join(xml_dir, cur_pid + ".tei.xml") 109 | f = open(file, encoding='utf-8') 110 | 111 | xml = f.read() 112 | bs = BeautifulSoup(xml, "xml") 113 | f.close() 114 | 115 | source_titles = pid_to_source_titles[cur_pid] 116 | if len(source_titles) == 0: 117 | continue 118 | 119 | references = bs.find_all("biblStruct") 120 | bid_to_title = {} 121 | n_refs = 0 122 | for ref in references: 123 | if "xml:id" not in ref.attrs: 124 | continue 125 | bid = ref.attrs["xml:id"] 126 | if ref.analytic is None: 127 | continue 128 | if ref.analytic.title is None: 129 | continue 130 | bid_to_title[bid] = ref.analytic.title.text.lower() 131 | b_idx = int(bid[1:]) + 1 132 | if b_idx > n_refs: 133 | n_refs = b_idx 134 | 135 | bib_to_contexts = utils.find_bib_context(xml) 136 | bib_sorted = sorted(bib_to_contexts.keys()) 137 | 138 | for bib in bib_sorted: 139 | cur_bib_idx = int(bib[1:]) 140 | if cur_bib_idx + 1 > n_refs: 141 | n_refs = cur_bib_idx + 1 142 | 143 | y_true = [0] * n_refs 144 | y_score = [0] * n_refs 145 | 146 | flag = False 147 | for bid in bid_to_title: 148 | cur_ref_title = bid_to_title[bid] 149 | for label_title in source_titles: 150 | if fuzz.ratio(cur_ref_title, label_title) >= 80: 151 | flag = True 152 | b_idx = int(bid[1:]) 153 | y_true[b_idx] = 1 154 | 155 | if not flag: 156 | continue 157 | 158 | contexts_sorted = [" ".join(bib_to_contexts[bib]) for bib in bib_sorted] 159 | # print(bid_to_title) 160 | 161 | for bi, cur_bib in enumerate(bib_sorted): 162 | new_item = {"pid": cur_pid, "bib_id": cur_bib} 163 | cur_context = contexts_sorted[bi] 164 | cur_context = cur_context + ". Is the current reference important? Please answer Yes or No. The answer is " 165 | cur_label = y_true[int(cur_bib[1:])] 166 | new_item["label"] = cur_label 167 | new_item["context"] = cur_context 168 | try: 169 | new_item["title"]= bid_to_title[cur_bib] 170 | except: 171 | pass 172 | wf.write(json.dumps(new_item) + "\n") 173 | wf.flush() 174 | 175 | wf.close() 176 | 177 | 178 | # preprocess_alpaca_json_data("data/train_balance_2.json") 179 | 180 | # make_balance_data_for_galactica() 181 | # preprocess_alpaca_json_data("galactica/data/train_balance_gala.json") 182 | gen_test_data_json_lines(2023) -------------------------------------------------------------------------------- /glm/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class TensorDataset(Dataset): 7 | def __init__(self, path, tokenizer): 8 | self.variable_num_choices = True 9 | self.example_list = [] 10 | with open(path, "r", encoding="utf-8") as file: 11 | for idx, line in enumerate(file): 12 | item = json.loads(line) 13 | item["idx"] = str(idx) 14 | item["answer"]= item["choices_pretokenized"][item["label"]] 15 | self.example_list.append(item) 16 | # self.example_list = self.example_list[:200] # debug 17 | self.examples = {example["idx"]: example for example in self.example_list} 18 | print(f"Creating {len(self.example_list)} examples") 19 | self.dataset_name = "multichoice-" + os.path.basename(path).split(".")[0] 20 | 21 | contexts = [x["inputs_pretokenized"][-500:] for x in self.example_list] 22 | candidates = [x["choices_pretokenized"] for x in self.example_list] 23 | self.labels = [x["label"] for x in self.example_list] 24 | answers = [x["answer"] for x in self.example_list] 25 | self.input_dict = {"contexts": contexts, "candidates": candidates, "labels": self.labels, "answers": answers} 26 | 27 | 28 | def __len__(self): 29 | return len(self.example_list) 30 | 31 | def __getitem__(self, idx): 32 | return {k: v[idx] for k, v in self.input_dict.items()} 33 | -------------------------------------------------------------------------------- /glm/ds_config_glm_10b.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 4, 3 | "gradient_accumulation_steps": 2, 4 | "steps_per_print": 50, 5 | "gradient_clipping": 1.0, 6 | "zero_optimization": { 7 | "stage": 1, 8 | "contiguous_gradients": false, 9 | "overlap_comm": true, 10 | "reduce_scatter": true, 11 | "reduce_bucket_size": 5e7, 12 | "allgather_bucket_size": 5e7, 13 | "cpu_offload": true 14 | }, 15 | "zero_allow_untested_optimizer": true, 16 | "fp16": { 17 | "enabled": true, 18 | "loss_scale": 0, 19 | "loss_scale_window": 1000, 20 | "hysteresis": 2, 21 | "min_loss_scale": 1 22 | }, 23 | "optimizer": { 24 | "type": "Adam", 25 | "params": { 26 | "lr": 1e-7, 27 | "betas": [ 28 | 0.9, 29 | 0.95 30 | ], 31 | "eps": 1e-8, 32 | "weight_decay": 1e-2 33 | } 34 | }, 35 | "activation_checkpointing": { 36 | "partition_activations": false, 37 | "contiguous_memory_optimization": false 38 | }, 39 | "wall_clock_breakdown": false 40 | } 41 | -------------------------------------------------------------------------------- /glm/ds_config_glm_2b.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 6, 3 | "steps_per_print": 2000, 4 | "optimizer": { 5 | "type": "Adam", 6 | "params": { 7 | "lr": 0.0001, 8 | "betas": [ 9 | 0.8, 10 | 0.999 11 | ], 12 | "eps": 1e-8, 13 | "weight_decay": 3e-7 14 | } 15 | }, 16 | "gradient_clipping": 1.0, 17 | "prescale_gradients": false, 18 | "fp16": { 19 | "enabled": true, 20 | "fp16_master_weights_and_grads": false, 21 | "loss_scale": 0, 22 | "loss_scale_window": 500, 23 | "hysteresis": 2, 24 | "min_loss_scale": 1, 25 | "initial_scale_power": 15 26 | }, 27 | "wall_clock_breakdown": false, 28 | "zero_optimization": { 29 | "stage": 0, 30 | "allgather_partitions": true, 31 | "reduce_scatter": true, 32 | "allgather_bucket_size": 50000000, 33 | "reduce_bucket_size": 50000000, 34 | "overlap_comm": true, 35 | "contiguous_gradients": true, 36 | "cpu_offload": false 37 | } 38 | } -------------------------------------------------------------------------------- /glm/finetune_glm_10b_ds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datetime import datetime 4 | import numpy as np 5 | from sklearn.metrics import top_k_accuracy_score 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMultipleChoice 9 | import deepspeed 10 | from glm.dataset import TensorDataset 11 | from utils import Log 12 | 13 | BATCH_SIZE = 8 14 | 15 | def add_argument(): 16 | 17 | parser = argparse.ArgumentParser(description='CIFAR') 18 | 19 | #data 20 | # cuda 21 | parser.add_argument('--with_cuda', 22 | default=False, 23 | action='store_true', 24 | help='use CPU in case there\'s no GPU support') 25 | parser.add_argument('--use_ema', 26 | default=False, 27 | action='store_true', 28 | help='whether use exponential moving average') 29 | 30 | # train 31 | parser.add_argument('-b', 32 | '--batch_size', 33 | default=8, 34 | type=int, 35 | help='mini-batch size (default: 32)') 36 | parser.add_argument('-e', 37 | '--epochs', 38 | default=50, 39 | type=int, 40 | help='number of total epochs (default: 30)') 41 | parser.add_argument('--local_rank', 42 | type=int, 43 | default=-1, 44 | help='local rank passed from distributed launcher') 45 | 46 | parser.add_argument('--log-interval', 47 | type=int, 48 | default=2000, 49 | help="output logging information at a given interval") 50 | 51 | parser.add_argument('--moe', 52 | default=False, 53 | action='store_true', 54 | help='use deepspeed mixture of experts (moe)') 55 | 56 | parser.add_argument('--ep-world-size', 57 | default=1, 58 | type=int, 59 | help='(moe) expert parallel world size') 60 | parser.add_argument('--num-experts', 61 | type=int, 62 | nargs='+', 63 | default=[ 64 | 1, 65 | ], 66 | help='number of experts list, MoE related.') 67 | parser.add_argument( 68 | '--mlp-type', 69 | type=str, 70 | default='standard', 71 | help= 72 | 'Only applicable when num-experts > 1, accepts [standard, residual]') 73 | parser.add_argument('--top-k', 74 | default=1, 75 | type=int, 76 | help='(moe) gating top 1 and 2 supported') 77 | parser.add_argument( 78 | '--min-capacity', 79 | default=0, 80 | type=int, 81 | help= 82 | '(moe) minimum capacity of an expert regardless of the capacity_factor' 83 | ) 84 | parser.add_argument( 85 | '--noisy-gate-policy', 86 | default=None, 87 | type=str, 88 | help= 89 | '(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter' 90 | ) 91 | parser.add_argument( 92 | '--moe-param-group', 93 | default=False, 94 | action='store_true', 95 | help= 96 | '(moe) create separate moe param groups, required when using ZeRO w. MoE' 97 | ) 98 | 99 | # Include DeepSpeed configuration arguments 100 | parser = deepspeed.add_config_arguments(parser) 101 | 102 | args = parser.parse_args() 103 | 104 | return args 105 | 106 | 107 | def glm_get_params_for_weight_decay_optimization(module): 108 | weight_decay_params = {'params': []} 109 | no_weight_decay_params = {'params': [], 'weight_decay': 0.0} 110 | for module_ in module.modules(): 111 | if isinstance(module_, (torch.nn.LayerNorm)): 112 | no_weight_decay_params['params'].extend( 113 | [p for p in list(module_._parameters.values()) 114 | if p is not None and p.requires_grad]) 115 | else: 116 | weight_decay_params['params'].extend( 117 | [p for n, p in list(module_._parameters.items()) 118 | if p is not None and p.requires_grad and n != 'bias']) 119 | no_weight_decay_params['params'].extend( 120 | [p for n, p in list(module_._parameters.items()) 121 | if p is not None and p.requires_grad and n == 'bias']) 122 | 123 | return weight_decay_params, 124 | 125 | 126 | def get_optimizer_param_groups(model): 127 | # Build parameter groups (weight decay and non-decay). 128 | # while isinstance(model, (LocalDDP, TorchDDP, FP16_Module)): 129 | # model = model.module 130 | param_groups = glm_get_params_for_weight_decay_optimization(model) 131 | 132 | # Add model parallel attribute if it is not set. 133 | for param_group in param_groups: 134 | # print('## param_group', len(param_group['params'])) 135 | for param in param_group['params']: 136 | if not hasattr(param, 'model_parallel'): 137 | param.model_parallel = False 138 | 139 | return param_groups 140 | 141 | 142 | def train(model_engine, dataloader, tokenizer, fp16): 143 | loss_total = 0 144 | n_item = 0 145 | for data in dataloader: 146 | # print("here3") 147 | # data = {k: v.to(model_engine.local_rank) for k, v in batch.items()} 148 | # if fp16: 149 | # data = {k: v.half() for k, v in data.items()} 150 | # outputs = model_engine(**data) 151 | # loss = outputs.loss 152 | # model_engine.backward(loss) 153 | # model_engine.step() 154 | 155 | # loss_total += loss.item() 156 | # cur_n_item = outputs.logits.shape[0] 157 | # n_item += cur_n_item 158 | inputs = tokenizer(data["contexts"], return_tensors="pt", padding=True) 159 | inputs = tokenizer.build_inputs_for_generation(inputs, targets=data["answers"], max_gen_length=2, padding=False) 160 | inputs = inputs.to(model_engine.local_rank) 161 | 162 | outputs = model_engine(**inputs) 163 | loss = outputs.loss 164 | model_engine.backward(loss) 165 | model_engine.step() 166 | 167 | loss_total += loss.item() 168 | cur_n_item = outputs.logits.shape[0] 169 | n_item += cur_n_item 170 | 171 | return loss_total / n_item 172 | 173 | 174 | def eval(dataloader, model_infer_ds, tokenizer, fp16): 175 | # model_infer = AutoModelForMultipleChoice.from_pretrained(model_name, trust_remote_code=True) 176 | # model_infer.load_state_dict(model.state_dict()) 177 | # model_infer = model_infer.half() 178 | # model_infer = model_infer.half().cuda() # half? 179 | # model_infer.eval() 180 | 181 | # ds_engine = deepspeed.init_inference(model_infer, mp_size=1, dtype=torch.half, replace_with_kernel_inject=True) 182 | # model_ds = ds_engine.module 183 | 184 | pred_scores = np.empty((0, 6)) 185 | labels = [] 186 | 187 | with torch.no_grad(): 188 | for data in dataloader: 189 | label = data["labels"] 190 | inputs = tokenizer(data["contexts"], return_tensors="pt", padding=True) 191 | inputs = tokenizer.build_inputs_for_multiple_choice(inputs, data["candidates"]) 192 | inputs = inputs.to('cuda') 193 | # if fp16: 194 | # inputs = {k: v.half() if torch.is_tensor(v) else v for k, v in inputs.items()} 195 | # outputs = model_infer(**inputs) 196 | outputs = model_infer_ds(**inputs) 197 | logits = outputs.logits 198 | score = logits.detach().cpu().numpy() 199 | if score.shape[0] > 6: 200 | score = score[:6] 201 | elif score.shape[0] < 6: 202 | score = np.concatenate((score, np.zeros((6 - score.shape[0], score.shape[1]))), axis=0) 203 | pred_scores = np.concatenate((pred_scores, np.transpose(score)), axis=0) 204 | labels.extend(label) 205 | 206 | hit_1 = top_k_accuracy_score(labels, pred_scores, k=1) 207 | hit_3 = top_k_accuracy_score(labels, pred_scores, k=3) 208 | hit_5 = top_k_accuracy_score(labels, pred_scores, k=5) 209 | 210 | # del model_infer 211 | 212 | return [hit_1, hit_3, hit_5] 213 | 214 | 215 | def finetune_ds(model_name="THUDM/glm-2b"): 216 | deepspeed.init_distributed() 217 | 218 | # if torch.distributed.get_rank() != 0: 219 | # # might be downloading data, let rank 0 download first 220 | # torch.distributed.barrier() 221 | 222 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 223 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) 224 | # tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) 225 | # model = AutoModelForSeq2SeqLM.from_pretrained(model_name, local_files_only=True) 226 | model = model.half() 227 | 228 | model_infer = AutoModelForMultipleChoice.from_pretrained(model_name, trust_remote_code=True) 229 | model_infer = model_infer.half() 230 | 231 | args = add_argument() 232 | 233 | exp_name = "finetune-" + model_name.split('/')[-1] + '-ds-' 234 | exp_name = exp_name + str(datetime.now()) 235 | 236 | # Dataset 237 | train_dataset_path = "glm/data/train.json" 238 | valid_dataset_path = "glm/data/valid.json" 239 | test_dataset_path = "glm/data/test.json" 240 | log_path = "./log/" + exp_name + ".log" 241 | os.makedirs("./log", exist_ok=True) 242 | 243 | train_dataset = TensorDataset(train_dataset_path, tokenizer) 244 | valid_dataset = TensorDataset(valid_dataset_path, tokenizer) 245 | test_dataset = TensorDataset(test_dataset_path, tokenizer) 246 | 247 | # Log 248 | log = Log(file_path= log_path) 249 | 250 | # initialize the dataloader 251 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) 252 | valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 253 | test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 254 | 255 | parameters = get_optimizer_param_groups(model) 256 | 257 | print("here1") 258 | model_engine, optimizer, trainloader, __ = deepspeed.initialize( 259 | args=args, model=model, model_parameters=parameters, training_data=train_dataset) 260 | print("here2") 261 | 262 | ds_engine_infer = deepspeed.init_inference(model_infer, mp_size=1, dtype=torch.half, replace_with_kernel_inject=True) 263 | 264 | fp16 = model_engine.fp16_enabled() 265 | print(f'fp16={fp16}') 266 | 267 | # ds_engine_infer._load_from_state_dict(model.state_dict()) 268 | ds_engine_infer.module.load_state_dict(model.state_dict()) 269 | # model_infer_ds = ds_engine_infer.module 270 | # res_valid = eval(valid_loader, model_infer_ds, tokenizer, fp16) 271 | # log.log('valid epoch {} result: {}'.format(-1, str(res_valid))) 272 | # res_test = eval(test_loader, model_infer_ds, tokenizer, fp16) 273 | # log.log('test epoch {} result: {}'.format(-1, str(res_test))) 274 | 275 | valid_best = 0 276 | test_best = [] 277 | best_epoch = 0 278 | out_dir = "glm/saved" 279 | os.makedirs(out_dir, exist_ok=True) 280 | for epoch in range(args.epochs): 281 | log.log('begin epoch {}'.format(epoch)) 282 | loss_train = train(model_engine, trainloader, tokenizer, fp16) 283 | log.log('train epoch {} end, loss {}'.format(epoch, str(loss_train))) 284 | 285 | # ds_engine_infer._load_from_state_dict(model.state_dict()) 286 | 287 | # ds_engine_infer.module.load_state_dict(model.state_dict()) 288 | # model_infer_ds = ds_engine_infer.module 289 | # res_valid = eval(valid_loader, model_infer_ds, tokenizer, fp16) 290 | # log.log('valid epoch {} result: {}'.format(epoch, str(res_valid))) 291 | # res_test = eval(test_loader, model_infer_ds, tokenizer, fp16) 292 | # log.log('test epoch {} result: {}'.format(epoch, str(res_test))) 293 | 294 | # if res_valid[0] > valid_best: 295 | # valid_best = res_valid[0] 296 | # test_best = res_test 297 | # best_epoch = epoch 298 | # log.log('best epoch {} result: {}'.format(best_epoch, str(test_best))) 299 | 300 | ds_engine_infer.module.load_state_dict(model.state_dict()) 301 | torch.save(ds_engine_infer.module, os.path.join(out_dir, "{}-epoch-{}.pt".format("glm-10b", epoch))) 302 | 303 | 304 | 305 | if __name__ == "__main__": 306 | model_name = "THUDM/glm-10b" 307 | finetune_ds(model_name) 308 | -------------------------------------------------------------------------------- /glm/finetune_glm_ds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datetime import datetime 4 | import numpy as np 5 | from sklearn.metrics import top_k_accuracy_score 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMultipleChoice 10 | import deepspeed 11 | 12 | from glm.dataset import TensorDataset 13 | from utils import Log 14 | 15 | BATCH_SIZE = 8 16 | 17 | 18 | def add_argument(): 19 | 20 | parser = argparse.ArgumentParser(description='CIFAR') 21 | 22 | #data 23 | # cuda 24 | parser.add_argument('--with_cuda', 25 | default=False, 26 | action='store_true', 27 | help='use CPU in case there\'s no GPU support') 28 | parser.add_argument('--use_ema', 29 | default=False, 30 | action='store_true', 31 | help='whether use exponential moving average') 32 | 33 | # train 34 | parser.add_argument('-b', 35 | '--batch_size', 36 | default=8, 37 | type=int, 38 | help='mini-batch size (default: 32)') 39 | parser.add_argument('-e', 40 | '--epochs', 41 | default=10, 42 | type=int, 43 | help='number of total epochs (default: 30)') 44 | parser.add_argument('--local_rank', 45 | type=int, 46 | default=-1, 47 | help='local rank passed from distributed launcher') 48 | 49 | parser.add_argument('--log-interval', 50 | type=int, 51 | default=2000, 52 | help="output logging information at a given interval") 53 | 54 | parser.add_argument('--moe', 55 | default=False, 56 | action='store_true', 57 | help='use deepspeed mixture of experts (moe)') 58 | 59 | parser.add_argument('--ep-world-size', 60 | default=1, 61 | type=int, 62 | help='(moe) expert parallel world size') 63 | parser.add_argument('--num-experts', 64 | type=int, 65 | nargs='+', 66 | default=[ 67 | 1, 68 | ], 69 | help='number of experts list, MoE related.') 70 | parser.add_argument( 71 | '--mlp-type', 72 | type=str, 73 | default='standard', 74 | help= 75 | 'Only applicable when num-experts > 1, accepts [standard, residual]') 76 | parser.add_argument('--top-k', 77 | default=1, 78 | type=int, 79 | help='(moe) gating top 1 and 2 supported') 80 | parser.add_argument( 81 | '--min-capacity', 82 | default=0, 83 | type=int, 84 | help= 85 | '(moe) minimum capacity of an expert regardless of the capacity_factor' 86 | ) 87 | parser.add_argument( 88 | '--noisy-gate-policy', 89 | default=None, 90 | type=str, 91 | help= 92 | '(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter' 93 | ) 94 | parser.add_argument( 95 | '--moe-param-group', 96 | default=False, 97 | action='store_true', 98 | help= 99 | '(moe) create separate moe param groups, required when using ZeRO w. MoE' 100 | ) 101 | 102 | # Include DeepSpeed configuration arguments 103 | parser = deepspeed.add_config_arguments(parser) 104 | 105 | args = parser.parse_args() 106 | 107 | return args 108 | 109 | 110 | def glm_get_params_for_weight_decay_optimization(module): 111 | weight_decay_params = {'params': []} 112 | no_weight_decay_params = {'params': [], 'weight_decay': 0.0} 113 | for module_ in module.modules(): 114 | if isinstance(module_, (torch.nn.LayerNorm)): 115 | no_weight_decay_params['params'].extend( 116 | [p for p in list(module_._parameters.values()) 117 | if p is not None and p.requires_grad]) 118 | else: 119 | weight_decay_params['params'].extend( 120 | [p for n, p in list(module_._parameters.items()) 121 | if p is not None and p.requires_grad and n != 'bias']) 122 | no_weight_decay_params['params'].extend( 123 | [p for n, p in list(module_._parameters.items()) 124 | if p is not None and p.requires_grad and n == 'bias']) 125 | 126 | return weight_decay_params, 127 | 128 | 129 | def get_optimizer_param_groups(model): 130 | # Build parameter groups (weight decay and non-decay). 131 | # while isinstance(model, (LocalDDP, TorchDDP, FP16_Module)): 132 | # model = model.module 133 | param_groups = glm_get_params_for_weight_decay_optimization(model) 134 | 135 | # Add model parallel attribute if it is not set. 136 | for param_group in param_groups: 137 | # print('## param_group', len(param_group['params'])) 138 | for param in param_group['params']: 139 | if not hasattr(param, 'model_parallel'): 140 | param.model_parallel = False 141 | 142 | return param_groups 143 | 144 | 145 | def train(model_engine, dataloader, tokenizer, fp16): 146 | loss_total = 0 147 | n_item = 0 148 | for data in dataloader: 149 | # data = {k: v.to(model_engine.local_rank) for k, v in batch.items()} 150 | # if fp16: 151 | # data = {k: v.half() for k, v in data.items()} 152 | inputs = tokenizer(data["contexts"], return_tensors="pt", padding=True) 153 | inputs = tokenizer.build_inputs_for_generation(inputs, targets=data["answers"], max_gen_length=2, padding=False) 154 | inputs = inputs.to(model_engine.local_rank) 155 | 156 | outputs = model_engine(**inputs) 157 | loss = outputs.loss 158 | model_engine.backward(loss) 159 | model_engine.step() 160 | 161 | loss_total += loss.item() 162 | cur_n_item = outputs.logits.shape[0] 163 | n_item += cur_n_item 164 | 165 | return loss_total / n_item 166 | 167 | 168 | def eval(dataloader, model_infer_ds, tokenizer, fp16): 169 | pred_scores = np.empty((0, 6)) 170 | labels = [] 171 | 172 | with torch.no_grad(): 173 | for data in dataloader: 174 | label = data["labels"] 175 | inputs = tokenizer(data["contexts"], return_tensors="pt", padding=True) 176 | inputs = tokenizer.build_inputs_for_multiple_choice(inputs, data["candidates"]) 177 | inputs = inputs.to('cuda') 178 | # if fp16: 179 | # inputs = {k: v.half() if torch.is_tensor(v) else v for k, v in inputs.items()} 180 | # outputs = model_infer(**inputs) 181 | outputs = model_infer_ds(**inputs) 182 | logits = outputs.logits 183 | score = logits.detach().cpu().numpy() 184 | if score.shape[0] > 6: 185 | score = score[:6] 186 | elif score.shape[0] < 6: 187 | score = np.concatenate((score, np.zeros((6 - score.shape[0], score.shape[1]))), axis=0) 188 | pred_scores = np.concatenate((pred_scores, np.transpose(score)), axis=0) 189 | labels.extend(label) 190 | 191 | hit_1 = top_k_accuracy_score(labels, pred_scores, k=1) 192 | hit_3 = top_k_accuracy_score(labels, pred_scores, k=3) 193 | hit_5 = top_k_accuracy_score(labels, pred_scores, k=5) 194 | 195 | # del model_infer 196 | 197 | return [hit_1, hit_3, hit_5] 198 | 199 | 200 | def finetune_ds(model_name="THUDM/glm-2b"): 201 | deepspeed.init_distributed() 202 | 203 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 204 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) 205 | model = model.half() 206 | 207 | model_infer = AutoModelForMultipleChoice.from_pretrained(model_name, trust_remote_code=True) 208 | model_infer = model_infer.half() 209 | 210 | args = add_argument() 211 | 212 | exp_name = "finetune-" + model_name.split('/')[-1] + '-ds-' 213 | exp_name = exp_name + str(datetime.now()) 214 | 215 | # Dataset 216 | train_dataset_path = "glm/data/train.json" 217 | valid_dataset_path = "glm/data/valid.json" 218 | test_dataset_path = "glm/data/test.json" 219 | log_path = "glm/log/" + exp_name + ".log" 220 | os.makedirs("glm/log", exist_ok=True) 221 | 222 | train_dataset = TensorDataset(train_dataset_path, tokenizer) 223 | valid_dataset = TensorDataset(valid_dataset_path, tokenizer) 224 | test_dataset = TensorDataset(test_dataset_path, tokenizer) 225 | 226 | # Log 227 | log = Log(file_path= log_path) 228 | 229 | # initialize the dataloader 230 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) 231 | valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 232 | test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 233 | 234 | parameters = get_optimizer_param_groups(model) 235 | 236 | print("here1") 237 | model_engine, optimizer, trainloader, __ = deepspeed.initialize( 238 | args=args, model=model, model_parameters=parameters, training_data=train_dataset) 239 | print("here2") 240 | 241 | ds_engine_infer = deepspeed.init_inference(model_infer, mp_size=1, dtype=torch.half, replace_with_kernel_inject=True) 242 | 243 | fp16 = model_engine.fp16_enabled() 244 | print(f'fp16={fp16}') 245 | 246 | valid_best = 0 247 | test_best = [] 248 | best_epoch = 0 249 | out_dir = "glm/saved" 250 | os.makedirs(out_dir, exist_ok=True) 251 | for epoch in range(args.epochs): 252 | log.log('begin epoch {}'.format(epoch)) 253 | loss_train = train(model_engine, trainloader, tokenizer, fp16) 254 | log.log('train epoch {} end, loss {}'.format(epoch, str(loss_train))) 255 | 256 | ds_engine_infer.module.load_state_dict(model.state_dict()) 257 | 258 | torch.save(ds_engine_infer.module, os.path.join(out_dir, "{}-epoch-{}.pt".format("glm-2b", epoch))) 259 | 260 | 261 | if __name__ == "__main__": 262 | model_name = "THUDM/glm-2b" 263 | # model_name = "~/.cache/huggingface/hub/models--THUDM--glm-10b/snapshots/696788d4f82ac96b90823555f547d1e754839ff4" 264 | # model_name = "~/.cache/huggingface/hub/models--THUDM--glm-2b/snapshots/774fda883d7ad028b8effc3c65afec510fce9634" 265 | finetune_ds(model_name) 266 | -------------------------------------------------------------------------------- /glm/process.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | import json 4 | from tqdm import tqdm 5 | from collections import defaultdict as dd 6 | from bs4 import BeautifulSoup 7 | import numpy as np 8 | from fuzzywuzzy import fuzz 9 | 10 | import utils 11 | import settings 12 | 13 | 14 | def prepare_train_test_data_for_glm(year=2023): 15 | x_train = [] 16 | y_train = [] 17 | x_valid = [] 18 | y_valid = [] 19 | x_test = [] 20 | y_test = [] 21 | 22 | truths = utils.load_json(settings.DATA_TRACE_DIR, "paper_source_trace_{}_final_filtered.json".format(year)) 23 | pid_to_source_titles = dd(list) 24 | for paper in tqdm(truths): 25 | pid = paper["_id"] 26 | for ref in paper["refs_trace"]: 27 | pid_to_source_titles[pid].append(ref["title"].lower()) 28 | 29 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year)) 30 | papers_train = utils.load_json(data_year_dir, "paper_source_trace_train.json") 31 | papers_valid = utils.load_json(data_year_dir, "paper_source_trace_valid.json") 32 | papers_test = utils.load_json(data_year_dir, "paper_source_trace_test.json") 33 | 34 | pids_train = {p["_id"] for p in papers_train} 35 | pids_valid = {p["_id"] for p in papers_valid} 36 | pids_test = {p["_id"] for p in papers_test} 37 | 38 | in_dir = join(settings.DATA_TRACE_DIR, "paper-xml") 39 | files = [] 40 | for f in os.listdir(in_dir): 41 | if f.endswith(".xml"): 42 | files.append(f) 43 | 44 | files = sorted(files) 45 | for file in tqdm(files): 46 | f = open(join(in_dir, file), encoding='utf-8') 47 | cur_pid = file.split(".")[0] 48 | if cur_pid not in pids_train and cur_pid not in pids_valid and cur_pid not in pids_test: 49 | continue 50 | xml = f.read() 51 | bs = BeautifulSoup(xml, "xml") 52 | 53 | source_titles = pid_to_source_titles[cur_pid] 54 | if len(source_titles) == 0: 55 | continue 56 | 57 | references = bs.find_all("biblStruct") 58 | bid_to_title = {} 59 | n_refs = 0 60 | for ref in references: 61 | if "xml:id" not in ref.attrs: 62 | continue 63 | bid = ref.attrs["xml:id"] 64 | if ref.analytic is None: 65 | continue 66 | if ref.analytic.title is None: 67 | continue 68 | bid_to_title[bid] = ref.analytic.title.text.lower() 69 | b_idx = int(bid[1:]) + 1 70 | if b_idx > n_refs: 71 | n_refs = b_idx 72 | 73 | flag = False 74 | 75 | cur_pos_bib = set() 76 | 77 | for bid in bid_to_title: 78 | cur_ref_title = bid_to_title[bid] 79 | for label_title in source_titles: 80 | if fuzz.ratio(cur_ref_title, label_title) >= 80: 81 | flag = True 82 | cur_pos_bib.add(bid) 83 | 84 | cur_neg_bib = set(bid_to_title.keys()) - cur_pos_bib 85 | 86 | if not flag: 87 | continue 88 | 89 | if len(cur_pos_bib) == 0 or len(cur_neg_bib) == 0: 90 | continue 91 | 92 | bib_to_contexts = utils.find_bib_context(xml) 93 | 94 | n_pos = len(cur_pos_bib) 95 | n_neg = n_pos * 10 96 | cur_neg_bib_sample = np.random.choice(list(cur_neg_bib), n_neg, replace=True) 97 | 98 | if cur_pid in pids_train: 99 | cur_x = x_train 100 | cur_y = y_train 101 | elif cur_pid in pids_valid: 102 | cur_x = x_valid 103 | cur_y = y_valid 104 | elif cur_pid in pids_test: 105 | cur_x = x_test 106 | cur_y = y_test 107 | else: 108 | continue 109 | # raise Exception("cur_pid not in train/valid/test") 110 | 111 | for bib in cur_pos_bib: 112 | cur_context = "The context is: " + " ".join(bib_to_contexts[bib]) + ". Is the current reference important? Please answer Yes or No. The answer is [MASK]." 113 | cur_x.append(cur_context) 114 | cur_y.append(1) 115 | 116 | for bib in cur_neg_bib_sample: 117 | cur_context = "The context is: " + " ".join(bib_to_contexts[bib]) + ". Is the current reference important? Please answer Yes or No. The answer is [MASK]." 118 | cur_x.append(cur_context) 119 | cur_y.append(0) 120 | 121 | print("len(x_train)", len(x_train), "len(x_valid)", len(x_valid), "len(x_test)", len(x_test)) 122 | 123 | out_dir = "glm/data/" 124 | os.makedirs(out_dir, exist_ok=True) 125 | 126 | with open(join(out_dir, "train.json"), "w") as f: 127 | for i in range(len(x_train)): 128 | f.write(json.dumps({"inputs_pretokenized": x_train[i], "choices_pretokenized": ["No", "Yes"], "label": y_train[i]}) + "\n") 129 | 130 | 131 | with open(join(out_dir, "valid.json"), "w") as f: 132 | for i in range(len(x_valid)): 133 | f.write(json.dumps({"inputs_pretokenized": x_valid[i], "choices_pretokenized": ["No", "Yes"], "label": y_valid[i]}) + "\n") 134 | 135 | with open(join(out_dir, "test.json"), "w") as f: 136 | for i in range(len(x_test)): 137 | f.write(json.dumps({"inputs_pretokenized": x_test[i], "choices_pretokenized": ["No", "Yes"], "label": y_test[i]}) + "\n") 138 | 139 | 140 | if __name__ == "__main__": 141 | prepare_train_test_data_for_glm() 142 | -------------------------------------------------------------------------------- /glm/run_finetune_ds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed --master_port 23619 glm/finetune_glm_ds.py --deepspeed --deepspeed_config glm/ds_config_glm_2b.json $@ 4 | -------------------------------------------------------------------------------- /glm/run_finetune_ds_10b.sh: -------------------------------------------------------------------------------- 1 | deepspeed --master_port 23620 --include localhost:2,6,7 glm/finetune_glm_10b_ds.py --deepspeed --deepspeed_config glm/ds_config_glm_10b.json $@ 2 | -------------------------------------------------------------------------------- /glm/test_glm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | from collections import defaultdict as dd 4 | import numpy as np 5 | import torch 6 | from fuzzywuzzy import fuzz 7 | from bs4 import BeautifulSoup 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForMultipleChoice 10 | from sklearn.metrics import average_precision_score 11 | 12 | import utils 13 | import settings 14 | 15 | 16 | def eval_test(model_name, ckpt_epoch=1, year=2023, role="test", ft=True): 17 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 18 | 19 | out_dir = "glm/saved" 20 | prefix = model_name.split("/")[1].lower() 21 | model_path = os.path.join(out_dir, "{}-epoch-{}.pt".format(prefix, ckpt_epoch)) 22 | if not ft: 23 | model_infer = AutoModelForMultipleChoice.from_pretrained(model_name, trust_remote_code=True) 24 | # model_infer = AutoModelForMultipleChoice.from_pretrained("/home/zhangfanjin/.cache/huggingface/hub/models--THUDM--glm-10b/snapshots/696788d4f82ac96b90823555f547d1e754839ff4", trust_remote_code=True) 25 | model_infer.to('cuda') 26 | elif prefix == "glm-10b": 27 | model_infer = torch.load(model_path, map_location=torch.device('cuda')) 28 | elif prefix == "glm-2b": 29 | model_infer = torch.load(model_path, map_location=torch.device('cuda')) 30 | # model_infer.to('cuda') 31 | model_infer.eval() 32 | print("model load successfully") 33 | 34 | papers_test = utils.load_json(join(settings.DATA_TRACE_DIR, str(year)), "paper_source_trace_{}.json".format(role)) 35 | pids_test = {p["_id"] for p in papers_test} 36 | 37 | truths = papers_test 38 | pid_to_source_titles = dd(list) 39 | for paper in tqdm(truths): 40 | pid = paper["_id"] 41 | for ref in paper["refs_trace"]: 42 | pid_to_source_titles[pid].append(ref["title"].lower()) 43 | 44 | xml_dir = join(settings.DATA_TRACE_DIR, "paper-xml") 45 | candidates = ["No", "Yes"] 46 | metrics = [] 47 | f_idx = 0 48 | 49 | for paper in tqdm(papers_test): 50 | cur_pid = paper["_id"] 51 | file = join(xml_dir, cur_pid + ".tei.xml") 52 | f = open(file, encoding='utf-8') 53 | 54 | xml = f.read() 55 | bs = BeautifulSoup(xml, "xml") 56 | f.close() 57 | 58 | source_titles = pid_to_source_titles[cur_pid] 59 | if len(source_titles) == 0: 60 | continue 61 | 62 | references = bs.find_all("biblStruct") 63 | bid_to_title = {} 64 | n_refs = 0 65 | for ref in references: 66 | if "xml:id" not in ref.attrs: 67 | continue 68 | bid = ref.attrs["xml:id"] 69 | if ref.analytic is None: 70 | continue 71 | if ref.analytic.title is None: 72 | continue 73 | bid_to_title[bid] = ref.analytic.title.text.lower() 74 | b_idx = int(bid[1:]) + 1 75 | if b_idx > n_refs: 76 | n_refs = b_idx 77 | 78 | bib_to_contexts = utils.find_bib_context(xml) 79 | bib_sorted = sorted(bib_to_contexts.keys()) 80 | 81 | for bib in bib_sorted: 82 | cur_bib_idx = int(bib[1:]) 83 | if cur_bib_idx + 1 > n_refs: 84 | n_refs = cur_bib_idx + 1 85 | 86 | y_true = [0] * n_refs 87 | y_score = [0] * n_refs 88 | 89 | flag = False 90 | for bid in bid_to_title: 91 | cur_ref_title = bid_to_title[bid] 92 | for label_title in source_titles: 93 | if fuzz.ratio(cur_ref_title, label_title) >= 80: 94 | flag = True 95 | b_idx = int(bid[1:]) 96 | y_true[b_idx] = 1 97 | 98 | if not flag: 99 | continue 100 | 101 | contexts_sorted = ["The context is: " + " ".join(bib_to_contexts[bib]) 102 | + ". Is the current reference important? Please answer Yes or No. The answer is [MASK]." 103 | for bib in bib_sorted] 104 | contexts_sorted = [x[-500:] for x in contexts_sorted] 105 | 106 | predicted_scores = [] 107 | for cur_context in contexts_sorted: 108 | token = tokenizer([cur_context], return_tensors="pt", padding=True) 109 | inputs = tokenizer.build_inputs_for_multiple_choice(token, [candidates]) 110 | inputs = inputs.to('cuda') 111 | outputs = model_infer(**inputs) 112 | logits = outputs.logits 113 | # print("logits", logits.shape) 114 | score = logits.detach().cpu().numpy().tolist() 115 | # print("score", score) 116 | predicted_scores.append(score[0][1]) 117 | 118 | try: 119 | for ii in range(len(predicted_scores)): 120 | bib_idx = int(bib_sorted[ii][1:]) 121 | y_score[bib_idx] = predicted_scores[ii] 122 | except IndexError as e: 123 | metrics.append(0) 124 | continue 125 | 126 | cur_map = average_precision_score(y_true, y_score) 127 | metrics.append(cur_map) 128 | f_idx += 1 129 | if f_idx % 20 == 0: 130 | print("map until now", np.mean(metrics), len(metrics), cur_map) 131 | 132 | map_avg = np.mean(metrics) 133 | print("epoch {} average map".format(ckpt_epoch), map_avg, len(metrics)) 134 | return np.mean(metrics) 135 | 136 | 137 | if __name__ == "__main__": 138 | # eval_test(model_name="THUDM/GLM-2b", ckpt_epoch=1) 139 | 140 | """ 141 | model_name = "THUDM/GLM-2b" 142 | prefix = model_name.split("/")[1].lower() 143 | wf = open("glm/saved/valid_map_{}.txt".format(prefix), "w") 144 | for i in range(10): 145 | cur_map = eval_test(model_name=model_name, ckpt_epoch=i, role="valid") 146 | wf.write("{}\t{}\n".format(i, cur_map)) 147 | wf.flush() 148 | wf.close() 149 | """ 150 | 151 | # eval_test(model_name="THUDM/GLM-2b", ckpt_epoch=3) 152 | eval_test(model_name="THUDM/GLM-10b", ckpt_epoch=1) 153 | -------------------------------------------------------------------------------- /gpt-api/gpt.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import json 3 | openai.api_key = "" # your api key 4 | openai.api_base = "" # your api base 5 | model_list = ["gpt-3.5-turbo","gpt-4"] 6 | for model in model_list: 7 | result_dic = {} 8 | with open("result/" + model + "2.json", "r") as read_dic: 9 | finished_dic = json.load(read_dic).keys() 10 | write_dic = open("result/" + model + "2.json", "a") 11 | try: 12 | with open("test.json", "r") as read_file: 13 | data_dic = json.load(read_file) 14 | for this_data in data_dic.keys(): 15 | if this_data in finished_dic: 16 | continue 17 | result_list = [] 18 | for jtem in data_dic[this_data]: 19 | question = jtem["content"] 20 | result = jtem["summary"] 21 | flag = True 22 | while flag: 23 | try: 24 | chat_completion = openai.ChatCompletion.create(model=model, messages=[{"role": "user", "content":question}], stream=True) 25 | completion_text = "" 26 | for event in chat_completion: 27 | if len(event["choices"]) > 0: 28 | completion_text += event["choices"][0]["delta"].get("content", "") 29 | result_list.append(json.dumps({"labels":result, "predict":completion_text})+"\n") 30 | flag = False 31 | except openai.error.APIError: 32 | continue 33 | result_dic[this_data] = result_list 34 | except: 35 | pass 36 | finally: 37 | json.dump(result_dic, write_dic, indent=2) -------------------------------------------------------------------------------- /gpt-api/map.py: -------------------------------------------------------------------------------- 1 | import json 2 | import warnings 3 | warnings.filterwarnings('ignore') 4 | from sklearn.metrics import average_precision_score 5 | # base_path = "/data/caokun/huge_model/chatGLM/ChatGLM-6B-main/ptuning/output/test/v2/" 6 | base_path = "/data/caokun/huge_model/chatGLM/ChatGLM-6B-main/ptuning/output/test/finetune/" 7 | num = 5 8 | dic_list = ["result/gpt-3.5-turbo2.json", "result/gpt-4.json"] 9 | for i in range(len(dic_list)): 10 | result_list = [] 11 | with open(dic_list[i], "r") as read_file: 12 | result_dic = json.load(read_file) 13 | for item in result_dic.keys(): 14 | this_list = result_dic[item] 15 | pre = [] 16 | res = [] 17 | for jtem in this_list: 18 | if jtem["labels"] == "Yes": 19 | res.append(1) 20 | else: 21 | res.append(0) 22 | #if data["predict"] == "Yes": 23 | # pre.append(1) 24 | #else: 25 | # pre.append(0) 26 | if "yes" in jtem["predict"] or "Yes" in jtem["predict"]: 27 | pre.append(1) 28 | elif "not important" in jtem["predict"]: 29 | pre.append(0) 30 | elif "important" in jtem["predict"]: 31 | pre.append(1) 32 | else: 33 | pre.append(0) 34 | result_list.append(average_precision_score(res, pre)) 35 | print(f"{i}:map={sum(result_list)/len(result_list)}") -------------------------------------------------------------------------------- /net_emb.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import os 3 | import json 4 | from tqdm import tqdm 5 | import numpy as np 6 | from fuzzywuzzy import fuzz 7 | from cogdl import pipeline 8 | from bs4 import BeautifulSoup 9 | from sklearn.metrics import average_precision_score 10 | 11 | import utils 12 | import settings 13 | 14 | import logging 15 | 16 | logger = logging.getLogger(__name__) 17 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') # include timestamp 18 | 19 | 20 | 21 | def extract_paper_citation_graph(): 22 | data_dir = join(settings.DATA_TRACE_DIR, "PST") 23 | papers = [] 24 | dblp_fname = "DBLP-Citation-network-V15.1.json" 25 | parse_err_cnt = 0 26 | 27 | wf = open(join(data_dir, "dblp_pids.txt"), "w") 28 | with open(join(data_dir, dblp_fname), "r", encoding="utf-8") as myFile: 29 | for i, line in enumerate(myFile): 30 | if len(line) <= 2: 31 | continue 32 | if i % 10000 == 0: 33 | logger.info("reading papers %d, parse err cnt %d", i, parse_err_cnt) 34 | try: 35 | paper_tmp = json.loads(line.strip()) 36 | wf.write(paper_tmp["id"] + "\n") 37 | wf.flush() 38 | except: 39 | parse_err_cnt += 1 40 | papers.append(paper_tmp) 41 | wf.close() 42 | 43 | paper_dict_filter = {} 44 | for paper in tqdm(papers): 45 | paper_dict_filter[paper["id"]] = paper.get("references", []) 46 | 47 | logger.info("number of papers after filtering %d", len(paper_dict_filter)) 48 | utils.dump_json(paper_dict_filter, data_dir, "dblp_papers_refs.json") 49 | 50 | 51 | def merge_paper_references(): 52 | data_dir = join(settings.DATA_TRACE_DIR, "PST") 53 | paper_dict_open = utils.load_json(data_dir, "dblp_papers_refs.json") 54 | papers_train = utils.load_json(data_dir, "paper_source_trace_train_ans.json") 55 | papers_valid = utils.load_json(data_dir, "paper_source_trace_valid_wo_ans.json") 56 | 57 | for paper in tqdm(papers_train + papers_valid): 58 | pid = paper["_id"] 59 | cur_refs = paper.get("references", []) 60 | if len(cur_refs) == 0: 61 | continue 62 | refs_open = paper_dict_open.get(pid, []) 63 | refs_update = list(set(cur_refs + refs_open)) 64 | paper_dict_open[pid] = refs_update 65 | 66 | utils.dump_json(paper_dict_open, data_dir, "dblp_papers_refs_merged.json") 67 | 68 | 69 | def gen_paper_emb(year=2023, method="prone"): 70 | print("method", method) 71 | paper_dict = utils.load_json(settings.DATA_TRACE_DIR, "dblp_papers_refs_merged_{}.json".format(year)) 72 | pids_set = set() 73 | edges = [] 74 | for pid in tqdm(paper_dict): 75 | pids_set.add(pid) 76 | for ref_id in paper_dict[pid]: 77 | pids_set.add(ref_id) 78 | edges.append([pid, ref_id]) 79 | edges.append([ref_id, pid]) 80 | 81 | pids_sorted = sorted(list(pids_set)) 82 | pid_to_idx = {pid: idx for idx, pid in enumerate(pids_sorted)} 83 | edges = [[pid_to_idx[pid], pid_to_idx[ref_id]] for pid, ref_id in edges] 84 | 85 | if method == "prone": 86 | generator = pipeline("generate-emb", model="prone") 87 | elif method == "line": 88 | generator = pipeline("generate-emb", model="line", walk_length=5, walk_num=5) 89 | elif method == "netsmf": 90 | generator = pipeline("generate-emb", model="netsmf", window_size=5, num_round=5) 91 | else: 92 | raise NotImplementedError 93 | 94 | # generate embedding by an unweighted graph 95 | edge_index = np.array(edges) 96 | print("genreate_emb...", edge_index.shape) 97 | outputs = generator(edge_index) 98 | print("outputs", outputs.shape) 99 | 100 | out_dir = join(settings.OUT_DIR, method) 101 | os.makedirs(out_dir, exist_ok=True) 102 | with open(join(out_dir, "paper_id.txt"), "w") as f: 103 | for pid in pids_sorted: 104 | f.write(pid + "\n") 105 | f.flush() 106 | np.savez(join(out_dir, "paper_emb_{}.npz".format(method)), emb=outputs) 107 | 108 | 109 | def gen_paper_emb_kddcup(method="prone"): 110 | data_dir = join(settings.DATA_TRACE_DIR, "PST") 111 | print("method", method) 112 | paper_dict = utils.load_json(data_dir, "dblp_papers_refs_merged.json") 113 | pids_set = set() 114 | edges = [] 115 | for pid in tqdm(paper_dict): 116 | pids_set.add(pid) 117 | for ref_id in paper_dict[pid]: 118 | pids_set.add(ref_id) 119 | edges.append([pid, ref_id]) 120 | edges.append([ref_id, pid]) 121 | 122 | 123 | pids_sorted = sorted(list(pids_set)) 124 | pid_to_idx = {pid: idx for idx, pid in enumerate(pids_sorted)} 125 | edges = [[pid_to_idx[pid], pid_to_idx[ref_id]] for pid, ref_id in edges] 126 | 127 | if method == "prone": 128 | generator = pipeline("generate-emb", model="prone") 129 | elif method == "line": 130 | generator = pipeline("generate-emb", model="line", walk_length=5, walk_num=5) 131 | elif method == "netsmf": 132 | generator = pipeline("generate-emb", model="netsmf", window_size=5, num_round=5) 133 | else: 134 | raise NotImplementedError 135 | 136 | # generate embedding by an unweighted graph 137 | edge_index = np.array(edges) 138 | print("genreate_emb...", edge_index.shape) 139 | outputs = generator(edge_index) 140 | print("outputs", outputs.shape) 141 | 142 | out_dir = join(settings.OUT_DIR, "kddcup", method) 143 | os.makedirs(out_dir, exist_ok=True) 144 | with open(join(out_dir, "paper_id.txt"), "w") as f: 145 | for pid in pids_sorted: 146 | f.write(pid + "\n") 147 | f.flush() 148 | np.savez(join(out_dir, "paper_emb_{}.npz".format(method)), emb=outputs) 149 | 150 | 151 | def eval_node_sim(year=2023, method="prone"): 152 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year)) 153 | test_papers = utils.load_json(data_year_dir, "paper_source_trace_test.json") 154 | pids = [] 155 | with open(join(settings.OUT_DIR, method, "paper_id.txt"), "r") as f: 156 | for line in f: 157 | pids.append(line.strip()) 158 | pid_to_idx = {pid: idx for idx, pid in enumerate(pids)} 159 | emb = np.load(join(settings.OUT_DIR, method, "paper_emb_{}.npz".format(method)))["emb"] 160 | 161 | xml_dir = join(settings.DATA_TRACE_DIR, "paper-xml") 162 | metrics = [] 163 | f_idx = 0 164 | for paper in tqdm(test_papers): 165 | pid = paper["_id"] 166 | file = join(xml_dir, pid + ".tei.xml") 167 | f = open(file, encoding='utf-8') 168 | 169 | xml = f.read() 170 | bs = BeautifulSoup(xml, "xml") 171 | 172 | references = bs.find_all("biblStruct") 173 | bid_to_title = {} 174 | n_refs = 0 175 | for ref in references: 176 | if "xml:id" not in ref.attrs: 177 | continue 178 | bid = ref.attrs["xml:id"] 179 | if ref.analytic is None: 180 | continue 181 | if ref.analytic.title is None: 182 | continue 183 | bid_to_title[bid] = ref.analytic.title.text.lower() 184 | b_idx = int(bid[1:]) + 1 185 | if b_idx > n_refs: 186 | n_refs = b_idx 187 | 188 | bib_to_contexts = utils.find_bib_context(xml) 189 | bib_sorted = sorted(bib_to_contexts.keys()) 190 | 191 | for bib in bib_sorted: 192 | cur_bib_idx = int(bib[1:]) 193 | if cur_bib_idx + 1 > n_refs: 194 | n_refs = cur_bib_idx + 1 195 | 196 | f.close() 197 | 198 | ref_id_to_score = {} 199 | ref_id_to_label = {} 200 | cur_emb = emb[pid_to_idx[pid]] 201 | cur_refs = paper.get("references", []) 202 | ref_truths = set([x["_id"] for x in paper.get("refs_trace", [])]) 203 | for ref in cur_refs: 204 | ref_emb = emb[pid_to_idx[ref]] 205 | cur_sim = np.dot(cur_emb, ref_emb) 206 | cur_sim = 1/(1 + np.exp(-cur_sim)) 207 | ref_id_to_score[ref] = cur_sim 208 | if ref in ref_truths: 209 | ref_id_to_label[ref] = 1 210 | else: 211 | ref_id_to_label[ref] = 0 212 | 213 | ref_id_to_score_sorted = sorted(ref_id_to_score.items(), key=lambda x: x[1], reverse=True) 214 | ref_labels = [ref_id_to_label[x[0]] for x in ref_id_to_score_sorted] 215 | truth_id_not_in = ref_truths - set(cur_refs) 216 | n_limit = n_refs - len(truth_id_not_in) 217 | scores = [x[1] for x in ref_id_to_score_sorted][:n_limit] 218 | labels = ref_labels[:n_limit] 219 | if len(truth_id_not_in) > 0: 220 | scores += [0] * len(truth_id_not_in) 221 | labels += [1] * len(truth_id_not_in) 222 | if len(scores) < n_refs: 223 | scores += [0] * (n_refs - len(scores)) 224 | labels += [0] * (n_refs - len(labels)) 225 | 226 | cur_map = average_precision_score(labels, scores) 227 | metrics.append(cur_map) 228 | f_idx += 1 229 | if f_idx % 20 == 0: 230 | print("map until now", np.mean(metrics), len(metrics), cur_map) 231 | 232 | print("prone map", np.mean(metrics), len(metrics)) 233 | 234 | 235 | def eval_node_sim_kddcup(method="prone", role="valid"): 236 | data_dir = join(settings.DATA_TRACE_DIR, "PST") 237 | papers = utils.load_json(data_dir, "paper_source_trace_{}_wo_ans.json".format(role)) 238 | out_dir = join(settings.OUT_DIR, "kddcup", method) 239 | paper_info_more = utils.load_json(data_dir, "paper_info_hit_from_dblp.json") 240 | 241 | pids = [] 242 | with open(join(out_dir, "paper_id.txt"), "r") as f: 243 | for line in f: 244 | pids.append(line.strip()) 245 | pid_to_idx = {pid: idx for idx, pid in enumerate(pids)} 246 | emb = np.load(join(out_dir, "paper_emb_{}.npz".format(method)))["emb"] 247 | 248 | xml_dir = join(data_dir, "paper-xml") 249 | sub_dict = {} 250 | sub_example_dict = utils.load_json(data_dir, "submission_example_valid.json") 251 | 252 | for paper in tqdm(papers): 253 | cur_pid = paper["_id"] 254 | file = join(xml_dir, cur_pid + ".xml") 255 | f = open(file, encoding='utf-8') 256 | xml = f.read() 257 | bs = BeautifulSoup(xml, "xml") 258 | f.close() 259 | 260 | ref_ids = paper.get("references", []) 261 | cur_title_to_pid = {} 262 | for ref_id in ref_ids: 263 | if ref_id in paper_info_more: 264 | cur_title_to_pid[paper_info_more[ref_id]["title"].lower()] = ref_id 265 | 266 | references = bs.find_all("biblStruct") 267 | bid_to_title = {} 268 | n_refs = 0 269 | cur_title_to_b_idx = {} 270 | for ref in references: 271 | if "xml:id" not in ref.attrs: 272 | continue 273 | bid = ref.attrs["xml:id"] 274 | if ref.analytic is None: 275 | continue 276 | if ref.analytic.title is None: 277 | continue 278 | bid_to_title[bid] = ref.analytic.title.text.lower() 279 | b_idx = int(bid[1:]) + 1 280 | cur_title_to_b_idx[ref.analytic.title.text.lower()] = b_idx - 1 281 | if b_idx > n_refs: 282 | n_refs = b_idx 283 | 284 | assert len(sub_example_dict[cur_pid]) == n_refs 285 | y_score = [0] * n_refs 286 | 287 | cur_emb = emb[pid_to_idx[cur_pid]] 288 | 289 | for r_idx, ref_id in enumerate(ref_ids): 290 | if ref_id not in paper_info_more: 291 | continue 292 | cur_title = paper_info_more[ref_id].get("title", "").lower() 293 | if len(cur_title) == 0: 294 | continue 295 | cur_b_idx = None 296 | for b_title in cur_title_to_b_idx: 297 | cur_sim = fuzz.ratio(cur_title, b_title) 298 | if cur_sim >= 80: 299 | cur_b_idx = cur_title_to_b_idx[b_title] 300 | break 301 | if cur_b_idx is None: 302 | continue 303 | ref_emb = emb[pid_to_idx[ref_id]] 304 | cur_sim = np.dot(cur_emb, ref_emb) 305 | cur_sim = utils.sigmoid(cur_sim) 306 | y_score[cur_b_idx] = float(cur_sim) 307 | 308 | print(y_score) 309 | sub_dict[cur_pid] = y_score 310 | 311 | utils.dump_json(sub_dict, out_dir, f"{role}_sub_{method}.json") 312 | 313 | 314 | if __name__ == "__main__": 315 | method = "prone" 316 | extract_paper_citation_graph() 317 | merge_paper_references() 318 | # gen_paper_emb(method=method) 319 | gen_paper_emb_kddcup(method=method) 320 | # eval_node_sim(method=method) 321 | eval_node_sim_kddcup(method=method) 322 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4==4.11.1 2 | cogdl==0.5.3 3 | fuzzywuzzy==0.18.0 4 | numpy==1.21.6 5 | scikit_learn==1.2.1 6 | tqdm==4.36.1 7 | transformers==4.22.2 8 | python-Levenshtein==0.20.9 9 | lxml==4.8.0 10 | numba==0.56.4 11 | -------------------------------------------------------------------------------- /result.txt: -------------------------------------------------------------------------------- 1 | fine-tune glm-2b epoch 1 average map 0.15643463634349988 280 2 | -------------------------------------------------------------------------------- /rf/model_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import set_param 4 | import pickle 5 | from sklearn import svm 6 | from sklearn.ensemble import RandomForestClassifier 7 | import json 8 | from sklearn.metrics import average_precision_score 9 | import warnings 10 | from sklearn.linear_model import LogisticRegression 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | 15 | def calculate_TPFN(pre_result, label): 16 | TP, TN, FP, FN = 0, 0, 0, 0 17 | total = len(label) 18 | for item, jtem in zip(pre_result, label): 19 | if item == 1 and jtem == 1: 20 | TP += 1 21 | elif item == 0 and jtem == 1: 22 | FP += 1 23 | elif item == 1 and jtem == 0: 24 | FN += 1 25 | else: 26 | TN += 1 27 | print("Accuracy:", (TP + TN) / total) 28 | print("Precision:", TP / (TP + FP)) 29 | print("Recall:", TP / (TP + FN)) 30 | 31 | 32 | # mode = "valid" 33 | # mode = "train" 34 | mode = "test" 35 | # model_type = "SVM" 36 | model_type = "RandomForest" 37 | # model_type = "LR" 38 | 39 | 40 | params = set_param.Args(model_type) 41 | if mode == "train": 42 | data = np.loadtxt(open(f"processed_data/train_data.csv", "rb"), delimiter=",") 43 | label = np.loadtxt(open(f"processed_data/train_label.csv", "rb"), delimiter=",") 44 | if model_type == "SVM": 45 | model = svm.SVC(C=params.C, kernel=params.kernel, verbose=params.verbose, max_iter=params.max_iter, 46 | tol=params.tol, probability=True) 47 | elif model_type == "RandomForest": 48 | model = RandomForestClassifier(n_estimators=params.n_estimators) 49 | elif model_type == "LR": 50 | model = LogisticRegression(solver=params.solver, multi_class=params.multi_class) 51 | model.fit(data, label) 52 | save_model = pickle.dumps(model) 53 | os.makedirs("saved_model", exist_ok=True) 54 | write_file = open(f"saved_model/{model_type}.pkl", 'wb') 55 | write_file.write(save_model) 56 | write_file.close() 57 | pre_result = model.predict(data) 58 | calculate_TPFN(pre_result, label) 59 | elif mode == "test": 60 | with open('processed_data/test_data.json') as read_file: 61 | data_dic = json.load(read_file) 62 | with open('processed_data/test_label.json') as read_file: 63 | label_dic = json.load(read_file) 64 | with open(f'saved_model/{model_type}.pkl', 'rb') as read_file: 65 | model = pickle.load(read_file) 66 | 67 | map_list = [] 68 | total_pre_list = [] 69 | total_label = [] 70 | for item in data_dic.keys(): 71 | this_data = data_dic[item] 72 | this_label = [jtem[0] for jtem in label_dic[item]] 73 | pre_result = model.predict_proba(this_data) 74 | pre_result = [jtem[1] for jtem in pre_result] 75 | total_pre_list += [(1 if jtem >= 0.5 else 0) for jtem in pre_result] 76 | total_label += this_label 77 | map_list.append(average_precision_score(this_label, pre_result)) 78 | calculate_TPFN(total_pre_list, total_label) 79 | print("MAP:", sum(map_list) / len(map_list)) 80 | # 显示特征权重 81 | # feature_importances = model.feature_importances_ 82 | # print(feature_importances) 83 | 84 | 85 | elif mode == "valid": 86 | data = np.loadtxt(open(f"processed_data/valid/valid_data.csv", "rb"), delimiter=",") 87 | label = np.loadtxt(open(f"processed_data/valid/valid_label.csv", "rb"), delimiter=",") 88 | if model_type == "SVM": 89 | model = svm.SVC(C=params.C, kernel=params.kernel, verbose=params.verbose, max_iter=params.max_iter, 90 | tol=params.tol, probability=True) 91 | elif model_type == "RandomForest": 92 | model = RandomForestClassifier(n_estimators=params.n_estimators) 93 | elif model_type == "LR": 94 | model = LogisticRegression(solver=params.solver, multi_class=params.multi_class) 95 | model.fit(data, label) 96 | predict = model.predict(data) 97 | calculate_TPFN(predict, label) 98 | else: 99 | print("mode input error!") 100 | -------------------------------------------------------------------------------- /rf/model_rf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | import numpy as np 4 | from tqdm import tqdm 5 | from bs4 import BeautifulSoup 6 | import set_param 7 | import pickle 8 | from sklearn import svm 9 | from sklearn.ensemble import RandomForestClassifier 10 | import json 11 | from sklearn.metrics import average_precision_score 12 | import warnings 13 | from sklearn.linear_model import LogisticRegression 14 | from fuzzywuzzy import fuzz 15 | 16 | warnings.filterwarnings("ignore") 17 | 18 | import utils 19 | import settings 20 | 21 | 22 | def calculate_TPFN(pre_result, label): 23 | TP, TN, FP, FN = 0, 0, 0, 0 24 | total = len(label) 25 | for item, jtem in zip(pre_result, label): 26 | if item == 1 and jtem == 1: 27 | TP += 1 28 | elif item == 0 and jtem == 1: 29 | FP += 1 30 | elif item == 1 and jtem == 0: 31 | FN += 1 32 | else: 33 | TN += 1 34 | print("Accuracy:", (TP + TN) / total) 35 | print("Precision:", TP / (TP + FP)) 36 | print("Recall:", TP / (TP + FN)) 37 | 38 | 39 | def train_classifier(model_type = "RandomForest"): 40 | params = set_param.Args(model_type) 41 | feature_dir = join(settings.OUT_DIR, "kddcup", "rf") 42 | data = np.loadtxt(open(join(feature_dir, "train_data.csv"), "rb"), delimiter=",") 43 | label = np.loadtxt(open(join(feature_dir, "train_label.csv"), "rb"), delimiter=",") 44 | if model_type == "SVM": 45 | model = svm.SVC(C=params.C, kernel=params.kernel, verbose=params.verbose, max_iter=params.max_iter, 46 | tol=params.tol, probability=True) 47 | elif model_type == "RandomForest": 48 | model = RandomForestClassifier(n_estimators=params.n_estimators) 49 | elif model_type == "LR": 50 | model = LogisticRegression(solver=params.solver, multi_class=params.multi_class) 51 | model.fit(data, label) 52 | save_model = pickle.dumps(model) 53 | # os.makedirs("saved_model", exist_ok=True) 54 | write_file = open(join(feature_dir, "{}.pkl".format(model_type)), 'wb') 55 | write_file.write(save_model) 56 | write_file.close() 57 | pre_result = model.predict(data) 58 | calculate_TPFN(pre_result, label) 59 | 60 | 61 | def eval_classifier(model_type="RandomForest", role="valid"): 62 | feature_dir = join(settings.OUT_DIR, "kddcup", "rf") 63 | with open(join(feature_dir, f"{role}_data.json")) as read_file: 64 | eval_features = json.load(read_file) 65 | 66 | data_dir = join(settings.DATA_TRACE_DIR, "PST") 67 | papers = utils.load_json(data_dir, "paper_source_trace_valid_wo_ans.json") 68 | paper_info_more = utils.load_json(data_dir, "paper_info_hit_from_dblp.json") 69 | 70 | 71 | with open(join(feature_dir, "{}.pkl".format(model_type)), 'rb') as read_file: 72 | model = pickle.loads(read_file.read()) 73 | 74 | xml_dir = join(data_dir, "paper-xml") 75 | sub_dict = {} 76 | sub_example_dict = utils.load_json(data_dir, "submission_example_valid.json") 77 | 78 | for paper in tqdm(papers): 79 | cur_pid = paper["_id"] 80 | file = join(xml_dir, cur_pid + ".xml") 81 | f = open(file, encoding='utf-8') 82 | xml = f.read() 83 | bs = BeautifulSoup(xml, "xml") 84 | f.close() 85 | 86 | ref_ids = paper.get("references", []) 87 | cur_title_to_pid = {} 88 | for ref_id in ref_ids: 89 | if ref_id in paper_info_more: 90 | cur_title_to_pid[paper_info_more[ref_id]["title"].lower()] = ref_id 91 | 92 | references = bs.find_all("biblStruct") 93 | bid_to_title = {} 94 | n_refs = 0 95 | cur_title_to_b_idx = {} 96 | for ref in references: 97 | if "xml:id" not in ref.attrs: 98 | continue 99 | bid = ref.attrs["xml:id"] 100 | if ref.analytic is None: 101 | continue 102 | if ref.analytic.title is None: 103 | continue 104 | bid_to_title[bid] = ref.analytic.title.text.lower() 105 | b_idx = int(bid[1:]) + 1 106 | cur_title_to_b_idx[ref.analytic.title.text.lower()] = b_idx - 1 107 | if b_idx > n_refs: 108 | n_refs = b_idx 109 | 110 | assert len(sub_example_dict[cur_pid]) == n_refs 111 | y_score = [0] * n_refs 112 | cur_feature = eval_features[cur_pid] 113 | 114 | for r_idx, ref_id in enumerate(ref_ids): 115 | if ref_id not in paper_info_more: 116 | continue 117 | cur_title = paper_info_more[ref_id].get("title", "").lower() 118 | if len(cur_title) == 0: 119 | continue 120 | cur_ref_feature = cur_feature[r_idx] 121 | if len(cur_ref_feature) == 0: 122 | continue 123 | cur_b_idx = None 124 | for b_title in cur_title_to_b_idx: 125 | cur_sim = fuzz.ratio(cur_title, b_title) 126 | if cur_sim >= 80: 127 | cur_b_idx = cur_title_to_b_idx[b_title] 128 | break 129 | if cur_b_idx is None: 130 | continue 131 | cur_prob = model.predict_proba([cur_ref_feature])[0][1] 132 | y_score[cur_b_idx] = float(cur_prob) 133 | 134 | # print(y_score) 135 | sub_dict[cur_pid] = y_score 136 | 137 | utils.dump_json(sub_dict, feature_dir, f"{role}_sub_{model_type}.json") 138 | 139 | 140 | if __name__ == "__main__": 141 | train_classifier() 142 | eval_classifier("RandomForest", "valid") 143 | -------------------------------------------------------------------------------- /rf/process_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import csv 4 | import pymongo 5 | from bson import ObjectId 6 | import random 7 | from lxml import etree 8 | from fuzzywuzzy import fuzz 9 | import re 10 | from tqdm import tqdm 11 | 12 | random.seed(1) # 随机数种子,用来确定训练和验证集中用抽中哪些论文。如果改动,需要重新从数据库中查询数据。(查询语句留着,但是地址,用户名和密码删掉了,需要的话再填一下) 13 | 14 | 15 | class Process_data(object): 16 | def __init__(self, paper_dic, mode): 17 | self.paper_dic = paper_dic 18 | paper_id = paper_dic['_id'] 19 | paper_positive_id = [item['_id'] for item in paper_dic['refs_trace']] 20 | self.authors_set = set([item.get('name') for item in paper_dic.get('authors', {})]) 21 | # 通过xml获取tree和listBibl 22 | try: 23 | path = f'data/paper-xml/{paper_id}.tei.xml' 24 | self.tree = etree.parse(path) 25 | root = self.tree.getroot() 26 | listBibl = root.xpath("//*[local-name()='listBibl']")[0] 27 | self.biblStruct = listBibl.getchildren() 28 | self.num_ref = len(self.biblStruct) 29 | except OSError: 30 | self.tree = None 31 | self.num_ref = 0 32 | print('not exits xml ' + paper_id) 33 | # 获取论文引用数 34 | self.reference_num = self.get_reciprocal_of_reference_num() 35 | if mode == 'test': # train和valid需要随机选择部分论文构造正例,而test则可以直接把所有例子都放入query_list 36 | query_list = paper_dic.get('references', []) 37 | else: 38 | references = paper_dic.get('references', []) 39 | for item in paper_positive_id: 40 | try: 41 | references.remove(item) 42 | except ValueError: 43 | continue 44 | query_list = random.sample(references, min(max(len(paper_positive_id), 1), 45 | len(references))) # 最少选一个负例,最多references个(存在正例数多于reference的情况!) 46 | query_list += paper_positive_id 47 | self.data_list = [] 48 | self.label_list = [] 49 | for i, item in enumerate(query_list): 50 | this_data = [] 51 | # query = {'_id': ObjectId(item)} 52 | # self.query_result = list(collection.find(query)) 53 | self.query_result = saved_data.get(item, {}) 54 | reference_place_list = self.get_referenced_place_num(item) 55 | if len(reference_place_list) == 0: 56 | continue # 如果返回长度为0说明文章标题在xml的reference中没有找到自己的序号,这里暂时先不管它。 57 | this_data.append(self.get_referenced_num()) 58 | this_data.append(self.get_common_authors(item)) 59 | this_data.append(self.reference_num) 60 | this_data.append(self.key_words()) 61 | this_data += reference_place_list 62 | self.data_list.append(this_data) 63 | self.label_list.append([1] if item in paper_positive_id else [0]) 64 | 65 | # ONE 被引用次数 66 | def get_referenced_num(self): 67 | return self.query_result.get('n_citation', 0) 68 | 69 | # TWO,SIX,EIGHT 引用位置, 是否出现在图表中, 引用次数/引用总数 70 | # 0 abstract 71 | # 1 introduction 72 | # 2 related work 73 | # 3 method 74 | # 4 graph and figure 75 | # 5 result 76 | # 6 others 77 | def get_referenced_place_num(self, paper_id): 78 | # 从数据库中检索到title 79 | # query = {"_id": ObjectId(paper_id)} 80 | # title = list(collection.find(query))[0]['title'] 81 | title = self.query_result.get('title', '') 82 | # 从xml中检索到序号 83 | if self.tree is None: 84 | return [0 * 8] 85 | 86 | paper_number = -1 87 | for i, item in enumerate(self.biblStruct): 88 | this_test = item.xpath('.//*[local-name()="title"]') 89 | this_text = this_test[0].text 90 | if this_text is None: 91 | try: 92 | this_text = this_test[1].text 93 | except IndexError: 94 | this_text = '' 95 | try: 96 | score = fuzz.partial_ratio(title, this_text) 97 | except ValueError: 98 | score = 0 99 | if score >= 80: 100 | paper_number = i + 1 101 | break 102 | place_num = [0 for i in range(8)] 103 | self.paper_number = paper_number 104 | if paper_number == -1: 105 | return place_num 106 | # 使用序号,在xml文件中检索位置 107 | nodes = self.tree.xpath(f"//*[contains(text(), '[{paper_number}]')]") 108 | reference_times = len(nodes) 109 | 110 | for item in nodes: 111 | found_text = '' 112 | this_node = item 113 | while found_text == '': 114 | this_node = this_node.getparent() 115 | if this_node is None: 116 | break 117 | if this_node.xpath("local-name()") == 'figure': 118 | place_num[4] = 1 119 | it_children = this_node.iterchildren() 120 | for jtem in it_children: 121 | node = this_node 122 | if jtem.xpath("local-name()") == 'head': 123 | found_text = node.text 124 | n_num = jtem.attrib.get('n') 125 | node = this_node 126 | if n_num is None: 127 | break 128 | while not n_num.isdigit(): 129 | node = node.getprevious() 130 | if node is None: 131 | break 132 | node_children = node.iterchildren() 133 | for ktem in node_children: 134 | if ktem.xpath("local-name()") == 'head': 135 | n = ktem.attrib.get('n') 136 | if n is not None and n.isdigit(): 137 | n_num = ktem.attrib.get('n') 138 | found_text = ktem.text 139 | break 140 | break 141 | 142 | if this_node is None or found_text == '': 143 | place_num[6] = 1 144 | continue 145 | if found_text is not None: 146 | found_text = found_text.lower() 147 | if fuzz.partial_ratio('abstract', found_text) >= 60: 148 | place_num[0] = 1 149 | elif fuzz.partial_ratio('introduction', found_text) >= 60: 150 | place_num[1] = 1 151 | elif fuzz.partial_ratio('related work', found_text) >= 60: 152 | place_num[2] = 1 153 | elif fuzz.partial_ratio('method', found_text) >= 60: 154 | place_num[3] = 1 155 | elif fuzz.partial_ratio('result', found_text) >= 60 or fuzz.partial_ratio('experiment', found_text) >= 60: 156 | place_num[5] = 1 157 | else: 158 | place_num[6] = 1 159 | pattern = re.compile(r'[\d+]') 160 | nodes = self.tree.xpath("//*[re:match(text(), $pattern)]", 161 | namespaces={"re": "http://exslt.org/regular-expressions"}, 162 | pattern=pattern.pattern) 163 | total_ref_num = len(nodes) 164 | if not total_ref_num == 0: 165 | place_num[7] = reference_times / total_ref_num 166 | return place_num 167 | 168 | # FOUR 重叠作者 169 | def get_common_authors(self, paper_id): 170 | # ref_authors_set = set([item.get('name') for item in self.query_result.get('authors', {})]) 171 | ref_authors_set = set(self.query_result.get('authors', [])) 172 | if not len(self.authors_set & ref_authors_set) == 0: 173 | return 1 174 | else: 175 | return 0 176 | 177 | # FIVE 关键词 178 | def key_words(self): 179 | if self.paper_number == -1: 180 | return 0 181 | pattern = re.compile(r'[\d+]') 182 | nodes = self.tree.xpath("//*[re:match(text(), $pattern)]", 183 | namespaces={"re": "http://exslt.org/regular-expressions"}, 184 | pattern=pattern.pattern) 185 | key_words_list = ['motivated by', 'inspired by'] 186 | for item in nodes: 187 | if item.xpath('local-name()') == 'ref': 188 | node_text = item.getparent().text 189 | else: 190 | node_text = item.text 191 | if node_text is None: 192 | return 0 193 | node_text = node_text.lower() 194 | for jtem in key_words_list: 195 | pattern = re.compile(fr"{jtem}") 196 | match = pattern.search(node_text) 197 | if match is not None: 198 | return 1 199 | return 0 200 | 201 | # SEVEN 202 | 203 | def get_reciprocal_of_reference_num(self): 204 | if self.num_ref == 0: 205 | return 0 206 | else: 207 | return 1 / self.num_ref 208 | 209 | 210 | # mode = 'valid' 211 | # mode = 'train' 212 | mode = 'test' 213 | year = 2023 214 | with open(f"data/{year}/paper_source_trace_{mode}.json", 'r', encoding='utf-8') as read_file: 215 | data_dic = json.load(read_file) 216 | all_id = [item['_id'] for item in data_dic] 217 | data_list = [] 218 | label_list = [] 219 | 220 | # 连接数据库 221 | # client = pymongo.MongoClient(host='', port=-1, authSource='', username='', password='') 222 | # db = client.get_database('') 223 | # collection = db[""] 224 | with open('processed_data/saved_data2.json', 'r') as read_file: 225 | saved_data = json.load(read_file) 226 | n = 0 227 | # 调用方法获取数据,然后将数据写入csv或json文件 228 | if mode == 'test': 229 | total_data_dic = {} 230 | total_label_dic = {} 231 | for i, item in tqdm(enumerate(all_id), total=len(all_id)): 232 | process_data = Process_data(data_dic[i], mode) 233 | this_data, this_label = process_data.data_list, process_data.label_list 234 | total_data_dic[item] = this_data 235 | total_label_dic[item] = this_label 236 | with open(f'processed_data/{mode}_data.json', 'w') as write_file: 237 | write_file.write(json.dumps(total_data_dic)) 238 | with open(f'processed_data/{mode}_label.json', 'w') as write_file: 239 | write_file.write(json.dumps(total_label_dic)) 240 | else: 241 | # train/valid: 242 | for i, item in tqdm(enumerate(all_id), total=len(all_id)): 243 | process_data = Process_data(data_dic[i], mode) 244 | this_data, this_label = process_data.data_list, process_data.label_list 245 | data_list += this_data 246 | label_list += this_label 247 | with open(f'processed_data/{mode}_label.csv', 'w', newline='') as write_file: 248 | writer = csv.writer(write_file) 249 | writer.writerows(label_list) 250 | with open(f'processed_data/{mode}_data.csv', 'w', newline='') as write_file: 251 | writer = csv.writer(write_file) 252 | writer.writerows(data_list) 253 | -------------------------------------------------------------------------------- /rf/process_kddcup_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from os.path import join 4 | import csv 5 | import random 6 | from lxml import etree 7 | from fuzzywuzzy import fuzz 8 | import re 9 | from collections import defaultdict as dd 10 | from tqdm import tqdm 11 | 12 | import utils 13 | import settings 14 | 15 | import logging 16 | 17 | logger = logging.getLogger(__name__) 18 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') # include timestamp 19 | 20 | 21 | random.seed(1) 22 | 23 | 24 | class Process_data(object): 25 | def __init__(self, paper_dic, mode, paper_info_more): 26 | self.paper_dic = paper_dic 27 | paper_id = paper_dic['_id'] 28 | if mode == "train": 29 | paper_positive_id = [item['_id'] for item in paper_dic['refs_trace']] 30 | self.authors_set = set([item.get('name') for item in paper_dic.get('authors', {})]) 31 | # 通过xml获取tree和listBibl 32 | try: 33 | path = f'data/PST/paper-xml/{paper_id}.xml' 34 | self.tree = etree.parse(path) 35 | root = self.tree.getroot() 36 | listBibl = root.xpath("//*[local-name()='listBibl']")[0] 37 | self.biblStruct = listBibl.getchildren() 38 | self.num_ref = len(self.biblStruct) 39 | except OSError: 40 | self.tree = None 41 | self.num_ref = 0 42 | print('not exits xml ' + paper_id) 43 | # 获取论文引用数 44 | self.reference_num = self.get_reciprocal_of_reference_num() 45 | if mode == 'test': # train和valid需要随机选择部分论文构造正例,而test则可以直接把所有例子都放入query_list 46 | query_list = paper_dic.get('references', []) 47 | else: 48 | references = paper_dic.get('references', []) 49 | for item in paper_positive_id: 50 | try: 51 | references.remove(item) 52 | except ValueError: 53 | continue 54 | query_list = random.sample(references, min(max(len(paper_positive_id), 1), 55 | len(references))) # 最少选一个负例,最多references个(存在正例数多于reference的情况!) 56 | query_list += paper_positive_id 57 | self.data_list = [] 58 | self.label_list = [] 59 | for i, item in enumerate(query_list): 60 | this_data = [] 61 | self.query_result = paper_info_more.get(item, {}) 62 | reference_place_list = self.get_referenced_place_num(item) 63 | if len(reference_place_list) == 0: 64 | self.data_list.append([]) 65 | continue # 如果返回长度为0说明文章标题在xml的reference中没有找到自己的序号,这里暂时先不管它。 66 | this_data.append(self.get_referenced_num()) 67 | this_data.append(self.get_common_authors(item)) 68 | this_data.append(self.reference_num) 69 | this_data.append(self.key_words()) 70 | this_data += reference_place_list 71 | self.data_list.append(this_data) 72 | if mode == "train": 73 | self.label_list.append([1] if item in paper_positive_id else [0]) 74 | 75 | # ONE 被引用次数 76 | def get_referenced_num(self): 77 | return self.query_result.get('n_citation', 0) 78 | 79 | # TWO,SIX,EIGHT 引用位置, 是否出现在图表中, 引用次数/引用总数 80 | # 0 abstract 81 | # 1 introduction 82 | # 2 related work 83 | # 3 method 84 | # 4 graph and figure 85 | # 5 result 86 | # 6 others 87 | def get_referenced_place_num(self, paper_id): 88 | title = self.query_result.get('title', '') 89 | # 从xml中检索到序号 90 | if self.tree is None: 91 | return [0 * 8] 92 | 93 | paper_number = -1 94 | for i, item in enumerate(self.biblStruct): 95 | this_test = item.xpath('.//*[local-name()="title"]') 96 | this_text = this_test[0].text 97 | if this_text is None: 98 | try: 99 | this_text = this_test[1].text 100 | except IndexError: 101 | this_text = '' 102 | try: 103 | score = fuzz.partial_ratio(title, this_text) 104 | except ValueError: 105 | score = 0 106 | if score >= 80: 107 | paper_number = i + 1 108 | break 109 | place_num = [0 for i in range(8)] 110 | self.paper_number = paper_number 111 | if paper_number == -1: 112 | return place_num 113 | # 使用序号,在xml文件中检索位置 114 | nodes = self.tree.xpath(f"//*[contains(text(), '[{paper_number}]')]") 115 | reference_times = len(nodes) 116 | 117 | for item in nodes: 118 | found_text = '' 119 | this_node = item 120 | while found_text == '': 121 | this_node = this_node.getparent() 122 | if this_node is None: 123 | break 124 | if this_node.xpath("local-name()") == 'figure': 125 | place_num[4] = 1 126 | it_children = this_node.iterchildren() 127 | for jtem in it_children: 128 | node = this_node 129 | if jtem.xpath("local-name()") == 'head': 130 | found_text = node.text 131 | n_num = jtem.attrib.get('n') 132 | node = this_node 133 | if n_num is None: 134 | break 135 | while not n_num.isdigit(): 136 | node = node.getprevious() 137 | if node is None: 138 | break 139 | node_children = node.iterchildren() 140 | for ktem in node_children: 141 | if ktem.xpath("local-name()") == 'head': 142 | n = ktem.attrib.get('n') 143 | if n is not None and n.isdigit(): 144 | n_num = ktem.attrib.get('n') 145 | found_text = ktem.text 146 | break 147 | break 148 | 149 | if this_node is None or found_text == '': 150 | place_num[6] = 1 151 | continue 152 | if found_text is not None: 153 | found_text = found_text.lower() 154 | if fuzz.partial_ratio('abstract', found_text) >= 60: 155 | place_num[0] = 1 156 | elif fuzz.partial_ratio('introduction', found_text) >= 60: 157 | place_num[1] = 1 158 | elif fuzz.partial_ratio('related work', found_text) >= 60: 159 | place_num[2] = 1 160 | elif fuzz.partial_ratio('method', found_text) >= 60: 161 | place_num[3] = 1 162 | elif fuzz.partial_ratio('result', found_text) >= 60 or fuzz.partial_ratio('experiment', found_text) >= 60: 163 | place_num[5] = 1 164 | else: 165 | place_num[6] = 1 166 | pattern = re.compile(r'[\d+]') 167 | nodes = self.tree.xpath("//*[re:match(text(), $pattern)]", 168 | namespaces={"re": "http://exslt.org/regular-expressions"}, 169 | pattern=pattern.pattern) 170 | total_ref_num = len(nodes) 171 | if not total_ref_num == 0: 172 | place_num[7] = reference_times / total_ref_num 173 | return place_num 174 | 175 | # FOUR 重叠作者 176 | def get_common_authors(self, paper_id): 177 | # ref_authors_set = set([item.get('name') for item in self.query_result.get('authors', {})]) 178 | ref_authors_set = set(self.query_result.get('authors', [])) 179 | if not len(self.authors_set & ref_authors_set) == 0: 180 | return 1 181 | else: 182 | return 0 183 | 184 | # FIVE 关键词 185 | def key_words(self): 186 | if self.paper_number == -1: 187 | return 0 188 | pattern = re.compile(r'[\d+]') 189 | nodes = self.tree.xpath("//*[re:match(text(), $pattern)]", 190 | namespaces={"re": "http://exslt.org/regular-expressions"}, 191 | pattern=pattern.pattern) 192 | key_words_list = ['motivated by', 'inspired by'] 193 | for item in nodes: 194 | if item.xpath('local-name()') == 'ref': 195 | node_text = item.getparent().text 196 | else: 197 | node_text = item.text 198 | if node_text is None: 199 | return 0 200 | node_text = node_text.lower() 201 | for jtem in key_words_list: 202 | pattern = re.compile(fr"{jtem}") 203 | match = pattern.search(node_text) 204 | if match is not None: 205 | return 1 206 | return 0 207 | 208 | # SEVEN 209 | 210 | def get_reciprocal_of_reference_num(self): 211 | if self.num_ref == 0: 212 | return 0 213 | else: 214 | return 1 / self.num_ref 215 | 216 | 217 | def extract_paper_info_from_dblp(): 218 | data_dir = join(settings.DATA_TRACE_DIR, "PST") 219 | papers_train = utils.load_json(data_dir, "paper_source_trace_train_ans.json") 220 | papers_valid = utils.load_json(data_dir, "paper_source_trace_valid_wo_ans.json") 221 | 222 | paper_dict_open = {} 223 | dblp_fname = "DBLP-Citation-network-V15.1.json" 224 | with open(join(data_dir, dblp_fname), "r", encoding="utf-8") as myFile: 225 | for i, line in enumerate(myFile): 226 | if len(line) <= 2: 227 | continue 228 | if i % 10000 == 0: 229 | logger.info("reading papers %d", i) 230 | paper_tmp = json.loads(line.strip()) 231 | paper_dict_open[paper_tmp["id"]] = paper_tmp 232 | 233 | paper_dict_hit = dd(dict) 234 | for paper in tqdm(papers_train + papers_valid): 235 | cur_pid = paper["_id"] 236 | ref_ids = paper.get("references", []) 237 | pids = [cur_pid] + ref_ids 238 | for pid in pids: 239 | if pid not in paper_dict_open: 240 | continue 241 | cur_paper_info = paper_dict_open[pid] 242 | cur_authors = [a.get("name", "") for a in cur_paper_info.get("authors", [])] 243 | n_citation = cur_paper_info.get("n_citation", 0) 244 | title = cur_paper_info.get("title", "") 245 | paper_dict_hit[pid] = {"authors": cur_authors, "n_citation": n_citation, "title": title} 246 | 247 | print("number of papers after filtering", len(paper_dict_hit)) 248 | utils.dump_json(paper_dict_hit, data_dir, "paper_info_hit_from_dblp.json") 249 | 250 | 251 | def extract_train_features(): 252 | data_dir = join(settings.DATA_TRACE_DIR, "PST") 253 | with open(join(data_dir, "paper_source_trace_train_ans.json"), 'r', encoding='utf-8') as read_file: 254 | data_dic = json.load(read_file) 255 | all_id = [item['_id'] for item in data_dic] 256 | data_list = [] 257 | label_list = [] 258 | 259 | paper_info_more = utils.load_json(data_dir, "paper_info_hit_from_dblp.json") 260 | 261 | for i, item in tqdm(enumerate(all_id), total=len(all_id)): 262 | process_data = Process_data(data_dic[i], "train", paper_info_more) 263 | this_data, this_label = process_data.data_list, process_data.label_list 264 | data_list += this_data 265 | label_list += this_label 266 | 267 | out_dir = join(settings.OUT_DIR, "kddcup", "rf") 268 | os.makedirs(out_dir, exist_ok=True) 269 | 270 | with open(join(out_dir, "train_label.csv"), 'w', newline='') as f: 271 | writer = csv.writer(f) 272 | writer.writerows(label_list) 273 | 274 | with open(join(out_dir, "train_data.csv"), 'w', newline='') as write_file: 275 | writer = csv.writer(write_file) 276 | writer.writerows(data_list) 277 | 278 | 279 | def extract_valid_features(): 280 | data_dir = join(settings.DATA_TRACE_DIR, "PST") 281 | with open(join(data_dir, "paper_source_trace_valid_wo_ans.json"), 'r', encoding='utf-8') as read_file: 282 | data_dic = json.load(read_file) 283 | all_id = [item['_id'] for item in data_dic] 284 | paper_info_more = utils.load_json(data_dir, "paper_info_hit_from_dblp.json") 285 | 286 | total_data_dic = {} 287 | for i, item in tqdm(enumerate(all_id), total=len(all_id)): 288 | process_data = Process_data(data_dic[i], "test", paper_info_more) 289 | n_refs = len(data_dic[i].get('references', [])) 290 | this_data, this_label = process_data.data_list, process_data.label_list 291 | total_data_dic[item] = this_data 292 | assert len(this_data) == n_refs 293 | 294 | out_dir = join(settings.OUT_DIR, "kddcup", "rf") 295 | os.makedirs(out_dir, exist_ok=True) 296 | utils.dump_json(total_data_dic, out_dir, "valid_data.json") 297 | 298 | 299 | if __name__ == "__main__": 300 | extract_paper_info_from_dblp() 301 | extract_train_features() 302 | extract_valid_features() 303 | -------------------------------------------------------------------------------- /rf/set_param.py: -------------------------------------------------------------------------------- 1 | class Args(object): 2 | def __init__(self, model_type): 3 | if model_type == "SVM": 4 | self.C = 1 5 | self.kernel = "sigmoid" 6 | self.verbose = 2 7 | self.max_iter = 100 8 | self.tol = 0.1 9 | elif model_type == "RandomForest": 10 | self.n_estimators = 100 11 | elif model_type == "LR": 12 | self.solver = 'saga' 13 | self.multi_class = 'ovr' -------------------------------------------------------------------------------- /rule.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from os.path import join 4 | from tqdm import tqdm 5 | from collections import defaultdict as dd 6 | from bs4 import BeautifulSoup 7 | from fuzzywuzzy import fuzz 8 | import numpy as np 9 | from sklearn.metrics import average_precision_score 10 | 11 | import utils 12 | import settings 13 | 14 | 15 | def extract_one_paper_via_rule(xml): 16 | bs = BeautifulSoup(xml, "xml") 17 | ref = [] 18 | importantlist = [] 19 | for item in bs.find_all(type='bibr'): 20 | if "target" not in item.attrs: 21 | continue 22 | item_str = "{}".format(item.attrs["target"], item.get_text()) 23 | try: 24 | refer = item.attrs["target"][1:] 25 | ref.append((item_str, refer)) # 找到上下文 26 | # print(refer) 27 | pass 28 | except IndexError as e: 29 | continue 30 | xml = xml.lower() 31 | s2 = [ii for ii in range(len(xml)) if xml.startswith('motivated by', ii)] 32 | s3 = [ii for ii in range(len(xml)) if xml.startswith('inspired by', ii)] 33 | s = s2 + s3 34 | pos_to_signal = {} 35 | for i in s2: 36 | pos_to_signal[i] = "motivated by" 37 | for i in s3: 38 | pos_to_signal[i] = "inspired by" 39 | 40 | for i in ref: 41 | cur_bibr, idx = i 42 | p_ref = [ii for ii in range(len(xml)) if xml.startswith(cur_bibr, ii)] 43 | # print("p_ref", p_ref) 44 | for j in p_ref: 45 | for k in s: 46 | if abs(j-k) < 100: 47 | importantlist.append(idx) 48 | # print("hit***************************", j, k, i, pos_to_signal[k]) 49 | break 50 | return importantlist 51 | 52 | 53 | def find_paper_source_by_rule(year=2023): 54 | data_year_dir = join(settings.DATA_TRACE_DIR, str(year)) 55 | truths = utils.load_json(data_year_dir, "paper_source_trace_test.json") 56 | pid_to_source_titles = dd(list) 57 | for paper in tqdm(truths): 58 | pid = paper["_id"] 59 | for ref in paper["refs_trace"]: 60 | pid_to_source_titles[pid].append(ref["title"].lower()) 61 | xml_dir = join(settings.DATA_TRACE_DIR, "paper-xml") 62 | metrics = [] 63 | p_idx = 0 64 | 65 | for paper in tqdm(truths): 66 | cur_pid = paper["_id"] 67 | file = join(xml_dir, cur_pid + ".tei.xml") 68 | f = open(file, encoding='utf-8') 69 | xml = f.read() 70 | bs = BeautifulSoup(xml, "xml") 71 | f.close() 72 | 73 | references = bs.find_all("biblStruct") 74 | bid_to_title = {} 75 | n_refs = 0 76 | for ref in references: 77 | if "xml:id" not in ref.attrs: 78 | continue 79 | bid = ref.attrs["xml:id"] 80 | if ref.analytic is None: 81 | continue 82 | if ref.analytic.title is None: 83 | continue 84 | bid_to_title[bid] = ref.analytic.title.text.lower() 85 | b_idx = int(bid[1:]) + 1 86 | if b_idx > n_refs: 87 | n_refs = b_idx 88 | 89 | bib_to_contexts = utils.find_bib_context(xml) 90 | bib_sorted = sorted(bib_to_contexts.keys()) 91 | 92 | for bib in bib_sorted: 93 | cur_bib_idx = int(bib[1:]) 94 | if cur_bib_idx + 1 > n_refs: 95 | n_refs = cur_bib_idx + 1 96 | 97 | y_true = [] 98 | y_score = [] 99 | try: 100 | source_titles = pid_to_source_titles[cur_pid] 101 | if len(source_titles) == 0: 102 | print("hit1") 103 | raise 104 | continue 105 | pred_sources = extract_one_paper_via_rule(xml) 106 | y_true = [0]* n_refs 107 | y_score = [0]* n_refs 108 | 109 | for bid in bid_to_title: 110 | cur_ref_title = bid_to_title[bid] 111 | for label_title in source_titles: 112 | if fuzz.ratio(cur_ref_title, label_title) >= 80: 113 | b_idx = int(bid[1:]) 114 | y_true[b_idx] = 1 115 | break 116 | for ii in pred_sources: 117 | y_score[int(ii[1:])] = 1 118 | if sum(y_true) == 0: 119 | metrics.append(0) 120 | continue 121 | cur_map = average_precision_score(y_true, y_score) 122 | # print("cur_map", cur_map) 123 | metrics.append(cur_map) 124 | except IndexError as e: 125 | metrics.append(0) 126 | continue 127 | p_idx += 1 128 | if p_idx % 20 == 0: 129 | print("map until now", np.mean(metrics), len(metrics), cur_map) 130 | 131 | print("average map", np.mean(metrics), len(metrics)) 132 | 133 | 134 | if __name__ == '__main__': 135 | find_paper_source_by_rule() 136 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import abspath, dirname, join 3 | 4 | 5 | PROJ_DIR = join(abspath(dirname(__file__))) 6 | DATA_DIR = join(PROJ_DIR, "data") 7 | OUT_DIR = join(PROJ_DIR, "out") 8 | 9 | os.makedirs(DATA_DIR, exist_ok=True) 10 | os.makedirs(OUT_DIR, exist_ok=True) 11 | DATA_TRACE_DIR = DATA_DIR 12 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import json 3 | import numpy as np 4 | import pickle 5 | from collections import defaultdict as dd 6 | from bs4 import BeautifulSoup 7 | from datetime import datetime 8 | 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') # include timestamp 13 | 14 | 15 | def load_json(rfdir, rfname): 16 | logger.info('loading %s ...', rfname) 17 | with open(join(rfdir, rfname), 'r', encoding='utf-8') as rf: 18 | data = json.load(rf) 19 | logger.info('%s loaded', rfname) 20 | return data 21 | 22 | 23 | def dump_json(obj, wfdir, wfname): 24 | logger.info('dumping %s ...', wfname) 25 | with open(join(wfdir, wfname), 'w', encoding='utf-8') as wf: 26 | json.dump(obj, wf, indent=4, ensure_ascii=False) 27 | logger.info('%s dumped.', wfname) 28 | 29 | 30 | def serialize_embedding(embedding): 31 | return pickle.dumps(embedding) 32 | 33 | 34 | def deserialize_embedding(s): 35 | return pickle.loads(s) 36 | 37 | 38 | def find_bib_context(xml, dist=100): 39 | bs = BeautifulSoup(xml, "xml") 40 | bib_to_context = dd(list) 41 | bibr_strs_to_bid_id = {} 42 | for item in bs.find_all(type='bibr'): 43 | if "target" not in item.attrs: 44 | continue 45 | bib_id = item.attrs["target"][1:] 46 | item_str = "{}".format(item.attrs["target"], item.get_text()) 47 | bibr_strs_to_bid_id[item_str] = bib_id 48 | 49 | for item_str in bibr_strs_to_bid_id: 50 | bib_id = bibr_strs_to_bid_id[item_str] 51 | cur_bib_context_pos_start = [ii for ii in range(len(xml)) if xml.startswith(item_str, ii)] 52 | for pos in cur_bib_context_pos_start: 53 | bib_to_context[bib_id].append(xml[pos - dist: pos + dist].replace("\n", " ").replace("\r", " ").strip()) 54 | return bib_to_context 55 | 56 | 57 | def sigmoid(x): 58 | return 1 / (1 + np.exp(-x)) 59 | 60 | 61 | class Log: 62 | def __init__(self, file_path): 63 | self.file_path = file_path 64 | self.f = open(file_path, 'w+') 65 | 66 | def log(self, s): 67 | self.f.write(str(datetime.now()) + "\t" + s + '\n') 68 | self.f.flush() --------------------------------------------------------------------------------