├── .gitignore ├── PLMConfig └── LFM_config.json ├── ReadMe.md ├── config ├── Lawformer.config └── default.config ├── config_parser ├── __init__.py └── parser.py ├── convert_roberta_lfm.py ├── dataset ├── FullTokenDataset.py ├── IndexedDataset.py └── __init__.py ├── formatter ├── LawformerFormatter.py └── __init__.py ├── model ├── Lawformer.py ├── __init__.py ├── loss.py ├── metric.py └── optimizer.py ├── reader ├── __init__.py └── reader.py ├── requirements.txt ├── test.py ├── tools ├── __init__.py ├── accuracy_init.py ├── accuracy_tool.py ├── dataset_tool.py ├── eval_tool.py ├── init_tool.py ├── output_init.py ├── output_tool.py ├── test_tool.py └── train_tool.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *swp 3 | *swo 4 | .idea 5 | __pycache__ 6 | *un~ 7 | data/ 8 | 9 | PLMConfig/roberta-converted-lfm 10 | 11 | config/default_local.config 12 | 13 | temp 14 | notebook 15 | result 16 | -------------------------------------------------------------------------------- /PLMConfig/LFM_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_mode": "longformer", 3 | "attention_probs_dropout_prob": 0.1, 4 | "attention_window": [ 5 | 512, 6 | 512, 7 | 512, 8 | 512, 9 | 512, 10 | 512, 11 | 512, 12 | 512, 13 | 512, 14 | 512, 15 | 512, 16 | 512 17 | ], 18 | "bos_token_id": 0, 19 | "eos_token_id": 2, 20 | "gradient_checkpointing": false, 21 | "hidden_act": "gelu", 22 | "hidden_dropout_prob": 0.1, 23 | "hidden_size": 768, 24 | "ignore_attention_mask": false, 25 | "initializer_range": 0.02, 26 | "intermediate_size": 3072, 27 | "layer_norm_eps": 1e-05, 28 | "max_position_embeddings": 4098, 29 | "model_type": "longformer", 30 | "num_attention_heads": 12, 31 | "num_hidden_layers": 12, 32 | "pad_token_id": 0, 33 | "position_embedding_type": "absolute", 34 | "sep_token_id": 2, 35 | "transformers_version": "4.3.3", 36 | "type_vocab_size": 1, 37 | "use_cache": true, 38 | "vocab_size": 21128 39 | } 40 | 41 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | ## Lawformer 2 | 3 | ### Introduction 4 | This repository provides the source code and checkpoints of the paper "Lawformer: A Pre-trained Language Model for Chinese Legal Long Documents". You can download the checkpoint of Lawformer from the [huggingface model hub](https://huggingface.co/xcjthu/Lawformer) or from [here](https://data.thunlp.org/legal/Lawformer.zip). Besides, the checkpoint of our baseline model, Legal RoBERTa, can be downloaded from [here](https://data.thunlp.org/legal/LegalRoBERTa.zip). 5 | 6 | The new judgement prediction dataset, CAIL-Long, can be downloaded from [here](https://data.thunlp.org/legal/CAIL-Long.tar.gz). 7 | 8 | ### Installation 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ### Easy Start 14 | We have uploaded our model to the huggingface model hub. Make sure you have installed transformers. 15 | ```python 16 | >>> from transformers import AutoModel, AutoTokenizer 17 | >>> tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") 18 | >>> model = AutoModel.from_pretrained("thunlp/Lawformer") 19 | >>> inputs = tokenizer("任某提起诉讼,请求判令解除婚姻关系并对夫妻共同财产进行分割。", return_tensors="pt") 20 | >>> outputs = model(**inputs) 21 | ``` 22 | 23 | ### Pre-training 24 | We pre-train Lawformer continuously from `hfl/chinese-roberta-wwm-ext`. Therefore, we first convert the RoBERTa model to the Longformer by running the following command: 25 | ``` 26 | python3 convert_roberta_lfm.py 27 | ``` 28 | Then run the following command to pre-train the model: 29 | ``` 30 | python3 -m torch.distributed.launch --master_port 10086 --nproc_per_node 8 train.py -c config/Lawformer.config -g 0,1,2,3,4,5,6,7 31 | ``` 32 | 33 | ### Cite 34 | If you use the pre-trained models, please cite this paper: 35 | ``` 36 | @article{xiao2021lawformer, 37 | title={Lawformer: A Pre-trained Language Model forChinese Legal Long Documents}, 38 | author={Xiao, Chaojun and Hu, Xueyu and Liu, Zhiyuan and Tu, Cunchao and Sun, Maosong}, 39 | year={2021} 40 | } 41 | ``` 42 | 43 | -------------------------------------------------------------------------------- /config/Lawformer.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 16 3 | batch_size = 4 4 | 5 | shuffle = True 6 | 7 | reader_num = 8 8 | 9 | optimizer = AdamW 10 | learning_rate = 5e-5 11 | weight_decay = 0.01 12 | step_size = 1 13 | lr_multiplier = 1 14 | 15 | max_len=4096 16 | mlm_prob=0.15 17 | 18 | warmup_steps=3000 19 | training_steps=200000 20 | max_grad_norm=1.0 21 | fp16=True 22 | 23 | valid_mode = step 24 | step_epoch = 3000 25 | 26 | [eval] #eval parameters 27 | batch_size = 12 28 | 29 | shuffle = False 30 | 31 | reader_num = 4 32 | 33 | [distributed] 34 | use = True 35 | backend = nccl 36 | 37 | [data] #data parameters 38 | train_dataset_type = MultiDocDataset 39 | train_formatter_type = Lawformer 40 | train_data = /mnt/datadisk0/xcj/LegalBert/data/tokens 41 | train_files = ms_data_law_train_SS_document,xs_data_law_train_SS_document 42 | 43 | valid_dataset_type = MultiDocDataset 44 | valid_formatter_type = Lawformer 45 | valid_data = /mnt/datadisk0/xcj/LegalBert/data/tokens 46 | valid_files = ms_data_law_valid_SS_document,xs_data_law_valid_SS_document 47 | 48 | [model] #model parameters 49 | model_name = Lawformer 50 | 51 | [output] #output parameters 52 | output_time = 1 53 | test_time = 1 54 | 55 | model_path = checkpoint 56 | model_name = Lawformer 57 | 58 | output_function = Null 59 | -------------------------------------------------------------------------------- /config/default.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 16 3 | batch_size = 128 4 | 5 | shuffle = True 6 | 7 | reader_num = 8 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | weight_decay = 0 12 | step_size = 1 13 | lr_multiplier = 1 14 | max_grad_norm = 0 15 | fp16=False 16 | 17 | valid_mode=batch 18 | 19 | grad_accumulate = 1 20 | 21 | [eval] #eval parameters 22 | batch_size = 128 23 | 24 | shuffle = False 25 | 26 | reader_num = 4 27 | 28 | [distributed] 29 | use = False 30 | backend = nccl 31 | 32 | [data] #data parameters 33 | train_dataset_type = FilenameOnly 34 | train_formatter_type = Basic 35 | train_data_path = data 36 | train_file_list = train.json 37 | 38 | valid_dataset_type = FilenameOnly 39 | valid_formatter_type = Basic 40 | valid_data_path = data 41 | valid_file_list = valid.json 42 | 43 | test_dataset_type = FilenameOnly 44 | test_formatter_type = Basic 45 | test_data_path = data 46 | test_file_list = test.json 47 | 48 | load_into_mem = True 49 | 50 | [model] #model parameters 51 | model_name = BasicBert 52 | 53 | [output] #output parameters 54 | output_time = 1 55 | test_time = 1 56 | 57 | model_path = model 58 | model_name = name 59 | 60 | accuracy_method = SingleLabelTop1 61 | output_function = Basic 62 | output_value = micro_precision,macro_precision,macro_recall,macro_f1 63 | -------------------------------------------------------------------------------- /config_parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import create_config 2 | -------------------------------------------------------------------------------- /config_parser/parser.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import functools 4 | 5 | 6 | class ConfigParser: 7 | def __init__(self, *args, **params): 8 | self.default_config = configparser.RawConfigParser(*args, **params) 9 | self.local_config = configparser.RawConfigParser(*args, **params) 10 | self.config = configparser.RawConfigParser(*args, **params) 11 | 12 | def read(self, filenames, encoding=None): 13 | if os.path.exists("config/default_local.config"): 14 | self.local_config.read("config/default_local.config", encoding=encoding) 15 | else: 16 | self.local_config.read("config/default.config", encoding=encoding) 17 | 18 | self.default_config.read("config/default.config", encoding=encoding) 19 | self.config.read(filenames, encoding=encoding) 20 | 21 | 22 | def _build_func(func_name): 23 | @functools.wraps(getattr(configparser.RawConfigParser, func_name)) 24 | def func(self, *args, **kwargs): 25 | try: 26 | return getattr(self.config, func_name)(*args, **kwargs) 27 | except Exception as e: 28 | try: 29 | return getattr(self.local_config, func_name)(*args, **kwargs) 30 | except Exception as e: 31 | return getattr(self.default_config, func_name)(*args, **kwargs) 32 | 33 | return func 34 | 35 | 36 | def create_config(path): 37 | for func_name in dir(configparser.RawConfigParser): 38 | if not func_name.startswith('_') and func_name != "read": 39 | setattr(ConfigParser, func_name, _build_func(func_name)) 40 | 41 | config = ConfigParser() 42 | config.read(path) 43 | 44 | return config 45 | -------------------------------------------------------------------------------- /convert_roberta_lfm.py: -------------------------------------------------------------------------------- 1 | from transformers import LongformerForMaskedLM,RobertaForMaskedLM,AutoModelForMaskedLM,AutoTokenizer 2 | import copy 3 | import torch 4 | 5 | max_pos = 4096 6 | attention_window = 512 7 | 8 | roberta = AutoModelForMaskedLM.from_pretrained("hfl/chinese-roberta-wwm-ext") 9 | tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext", model_max_length=max_pos) 10 | 11 | # extend position embedding 12 | config = roberta.config 13 | tokenizer.model_max_length = max_pos 14 | tokenizer.init_kwargs['model_max_length'] = max_pos 15 | current_max_pos, embed_size = roberta.bert.embeddings.position_embeddings.weight.shape 16 | max_pos += 2 17 | config.max_position_embeddings = max_pos 18 | assert max_pos > current_max_pos 19 | 20 | new_pos_embed = roberta.bert.embeddings.position_embeddings.weight.new_empty(max_pos, embed_size) 21 | # copy position embeddings over and over to initialize the new position embeddings 22 | k = 2 23 | step = current_max_pos - 2 24 | while k < max_pos - 1: 25 | if k + step >= max_pos: 26 | new_pos_embed[k:] = roberta.bert.embeddings.position_embeddings.weight[2:(max_pos + 2 - k)] 27 | else: 28 | new_pos_embed[k:(k + step)] = roberta.bert.embeddings.position_embeddings.weight[2:] 29 | k += step 30 | roberta.bert.embeddings.position_embeddings.weight.data = new_pos_embed 31 | roberta.bert.embeddings.position_ids.data = torch.tensor([i for i in range(max_pos)]).reshape(1, max_pos) 32 | 33 | # add global attention 34 | config.attention_window = [attention_window] * config.num_hidden_layers 35 | for i in range(len(roberta.bert.encoder.layer)): 36 | roberta.bert.encoder.layer[i].attention.self.query_global = copy.deepcopy(roberta.bert.encoder.layer[i].attention.self.query) 37 | roberta.bert.encoder.layer[i].attention.self.key_global = copy.deepcopy(roberta.bert.encoder.layer[i].attention.self.key) 38 | roberta.bert.encoder.layer[i].attention.self.value_global = copy.deepcopy(roberta.bert.encoder.layer[i].attention.self.value) 39 | 40 | lfm = LongformerForMaskedLM(config) 41 | lfm.longformer.load_state_dict(roberta.bert.state_dict()) 42 | lfm.lm_head.dense.load_state_dict(roberta.cls.predictions.transform.dense.state_dict()) 43 | lfm.lm_head.layer_norm.load_state_dict(roberta.cls.predictions.transform.LayerNorm.state_dict()) 44 | lfm.lm_head.decoder.load_state_dict(roberta.cls.predictions.decoder.state_dict()) 45 | lfm.lm_head.bias = copy.deepcopy(roberta.cls.predictions.bias) 46 | 47 | lfm.save_pretrained('PLMConfig/roberta-converted-lfm') 48 | tokenizer.save_pretrained('PLMConfig/roberta-converted-lfm') -------------------------------------------------------------------------------- /dataset/FullTokenDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | from .IndexedDataset import make_dataset,MMapIndexedDataset 5 | import random 6 | import numpy as np 7 | 8 | class MultiDocDataset(Dataset): 9 | def __init__(self, config, mode, encoding="utf8", *args, **params): 10 | self.config = config 11 | self.mode = mode 12 | self.max_len = config.getint('train', 'max_len') 13 | path = config.get('data', '%s_data' % mode) 14 | flist = config.get('data', '%s_files' % mode).split(',') 15 | self.datasets = [MMapIndexedDataset(os.path.join(path, f), False) for f in flist] 16 | self.lens = [len(d) for d in self.datasets] 17 | self.length = sum(self.lens) 18 | self.idlist = np.arange(0, self.length) 19 | np.random.shuffle(self.idlist) 20 | 21 | def get_index_i(self, idx): 22 | ridx = int(self.idlist[idx]) 23 | sent = None 24 | for i in range(len(self.lens)): 25 | if ridx >= self.lens[i]: 26 | ridx -= self.lens[i] 27 | else: 28 | sent = self.datasets[i][ridx] 29 | if sent is None: 30 | raise ValueError('Index is larger than the number of data') 31 | for i in range(max(sent.shape[0] - 52, 0), sent.shape[0]): 32 | if sent[i] == 102: 33 | break 34 | return sent[:i+1] 35 | 36 | def __getitem__(self, idx): 37 | sent = self.get_index_i(idx) 38 | while sent.shape[0] < self.max_len - 50: 39 | ridx = random.randint(0, self.length - 1) 40 | rsent = self.get_index_i(ridx) 41 | sent = np.concatenate([sent, rsent])[:self.max_len] 42 | return sent 43 | 44 | def __len__(self): 45 | return self.length -------------------------------------------------------------------------------- /dataset/IndexedDataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # copied from fairseq/fairseq/data/indexed_dataset.py 8 | # Removed IndexedRawTextDataset since it relied on Fairseq dictionary 9 | # other slight modifications to remove fairseq dependencies 10 | # Added document index to index file and made it accessible. 11 | # An empty sentence no longer separates documents. 12 | 13 | from functools import lru_cache 14 | import os 15 | import shutil 16 | import struct 17 | from itertools import accumulate 18 | 19 | import numpy as np 20 | import torch 21 | 22 | def print_rank_0(message): 23 | """If distributed is initialized, print only on rank 0.""" 24 | if torch.distributed.is_initialized(): 25 | if torch.distributed.get_rank() == 0: 26 | print(message, flush=True) 27 | else: 28 | print(message, flush=True) 29 | 30 | 31 | def __best_fitting_dtype(vocab_size=None): 32 | if vocab_size is not None and vocab_size < 65500: 33 | return np.uint16 34 | else: 35 | return np.int32 36 | 37 | 38 | def get_available_dataset_impl(): 39 | return ['lazy', 'cached', 'mmap'] 40 | 41 | 42 | def infer_dataset_impl(path): 43 | if IndexedDataset.exists(path): 44 | with open(index_file_path(path), 'rb') as f: 45 | magic = f.read(8) 46 | if magic == IndexedDataset._HDR_MAGIC: 47 | return 'cached' 48 | elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: 49 | return 'mmap' 50 | else: 51 | return None 52 | else: 53 | print(f"Dataset does not exist: {path}") 54 | print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") 55 | return None 56 | 57 | 58 | def make_builder(out_file, impl, vocab_size=None): 59 | if impl == 'mmap': 60 | return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) 61 | else: 62 | return IndexedDatasetBuilder(out_file) 63 | 64 | def make_dataset(config, mode, encoding="utf8", *args, **params): 65 | # def make_dataset(path, impl, skip_warmup=False): 66 | path = config.get('data', '%s_data' % mode) 67 | impl = "mmap" 68 | skip_warmup = False 69 | if not IndexedDataset.exists(path): 70 | print(f"Dataset does not exist: {path}") 71 | print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") 72 | return None 73 | if impl == 'infer': 74 | impl = infer_dataset_impl(path) 75 | if impl == 'lazy' and IndexedDataset.exists(path): 76 | return IndexedDataset(path) 77 | elif impl == 'cached' and IndexedDataset.exists(path): 78 | return IndexedCachedDataset(path) 79 | elif impl == 'mmap' and MMapIndexedDataset.exists(path): 80 | return MMapIndexedDataset(path, skip_warmup) 81 | print(f"Unknown dataset implementation: {impl}") 82 | return None 83 | 84 | 85 | def dataset_exists(path, impl): 86 | if impl == 'mmap': 87 | return MMapIndexedDataset.exists(path) 88 | else: 89 | return IndexedDataset.exists(path) 90 | 91 | 92 | def read_longs(f, n): 93 | a = np.empty(n, dtype=np.int64) 94 | f.readinto(a) 95 | return a 96 | 97 | 98 | def write_longs(f, a): 99 | f.write(np.array(a, dtype=np.int64)) 100 | 101 | 102 | dtypes = { 103 | 1: np.uint8, 104 | 2: np.int8, 105 | 3: np.int16, 106 | 4: np.int32, 107 | 5: np.int64, 108 | 6: np.float, 109 | 7: np.double, 110 | 8: np.uint16 111 | } 112 | 113 | 114 | def code(dtype): 115 | for k in dtypes.keys(): 116 | if dtypes[k] == dtype: 117 | return k 118 | raise ValueError(dtype) 119 | 120 | 121 | def index_file_path(prefix_path): 122 | return prefix_path + '.idx' 123 | 124 | 125 | def data_file_path(prefix_path): 126 | return prefix_path + '.bin' 127 | 128 | 129 | def create_doc_idx(sizes): 130 | doc_idx = [0] 131 | for i, s in enumerate(sizes): 132 | if s == 0: 133 | doc_idx.append(i + 1) 134 | return doc_idx 135 | 136 | 137 | class IndexedDataset(torch.utils.data.Dataset): 138 | """Loader for IndexedDataset""" 139 | _HDR_MAGIC = b'TNTIDX\x00\x00' 140 | 141 | def __init__(self, path): 142 | super().__init__() 143 | self.path = path 144 | self.data_file = None 145 | self.read_index(path) 146 | 147 | def read_index(self, path): 148 | with open(index_file_path(path), 'rb') as f: 149 | magic = f.read(8) 150 | assert magic == self._HDR_MAGIC, ( 151 | 'Index file doesn\'t match expected format. ' 152 | 'Make sure that --dataset-impl is configured properly.' 153 | ) 154 | version = f.read(8) 155 | assert struct.unpack('= self._len: 170 | raise IndexError('index out of range') 171 | 172 | def __del__(self): 173 | if self.data_file: 174 | self.data_file.close() 175 | 176 | # @lru_cache(maxsize=8) 177 | def __getitem__(self, idx): 178 | if not self.data_file: 179 | self.read_data(self.path) 180 | if isinstance(idx, int): 181 | i = idx 182 | self.check_index(i) 183 | tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] 184 | a = np.empty(tensor_size, dtype=self.dtype) 185 | self.data_file.seek(self.data_offsets[i] * self.element_size) 186 | self.data_file.readinto(a) 187 | return a 188 | elif isinstance(idx, slice): 189 | start, stop, step = idx.indices(len(self)) 190 | if step != 1: 191 | raise ValueError("Slices into indexed_dataset must be contiguous") 192 | sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] 193 | size = sum(sizes) 194 | a = np.empty(size, dtype=self.dtype) 195 | self.data_file.seek(self.data_offsets[start] * self.element_size) 196 | self.data_file.readinto(a) 197 | offsets = list(accumulate(sizes)) 198 | sents = np.split(a, offsets[:-1]) 199 | return sents 200 | 201 | def __len__(self): 202 | return self._len 203 | 204 | def num_tokens(self, index): 205 | return self.sizes[index] 206 | 207 | def size(self, index): 208 | return self.sizes[index] 209 | 210 | @staticmethod 211 | def exists(path): 212 | return ( 213 | os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) 214 | ) 215 | 216 | @property 217 | def supports_prefetch(self): 218 | return False # avoid prefetching to save memory 219 | 220 | 221 | class IndexedCachedDataset(IndexedDataset): 222 | 223 | def __init__(self, path): 224 | super().__init__(path) 225 | self.cache = None 226 | self.cache_index = {} 227 | 228 | @property 229 | def supports_prefetch(self): 230 | return True 231 | 232 | def prefetch(self, indices): 233 | if all(i in self.cache_index for i in indices): 234 | return 235 | if not self.data_file: 236 | self.read_data(self.path) 237 | indices = sorted(set(indices)) 238 | total_size = 0 239 | for i in indices: 240 | total_size += self.data_offsets[i + 1] - self.data_offsets[i] 241 | self.cache = np.empty(total_size, dtype=self.dtype) 242 | ptx = 0 243 | self.cache_index.clear() 244 | for i in indices: 245 | self.cache_index[i] = ptx 246 | size = self.data_offsets[i + 1] - self.data_offsets[i] 247 | a = self.cache[ptx: ptx + size] 248 | self.data_file.seek(self.data_offsets[i] * self.element_size) 249 | self.data_file.readinto(a) 250 | ptx += size 251 | if self.data_file: 252 | # close and delete data file after prefetch so we can pickle 253 | self.data_file.close() 254 | self.data_file = None 255 | 256 | # @lru_cache(maxsize=8) 257 | def __getitem__(self, idx): 258 | if isinstance(idx, int): 259 | i = idx 260 | self.check_index(i) 261 | tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] 262 | a = np.empty(tensor_size, dtype=self.dtype) 263 | ptx = self.cache_index[i] 264 | np.copyto(a, self.cache[ptx: ptx + a.size]) 265 | return a 266 | elif isinstance(idx, slice): 267 | # Hack just to make this work, can optimizer later if necessary 268 | sents = [] 269 | for i in range(*idx.indices(len(self))): 270 | sents.append(self[i]) 271 | return sents 272 | 273 | 274 | class IndexedDatasetBuilder(object): 275 | element_sizes = { 276 | np.uint8: 1, 277 | np.int8: 1, 278 | np.int16: 2, 279 | np.int32: 4, 280 | np.int64: 8, 281 | np.float: 4, 282 | np.double: 8 283 | } 284 | 285 | def __init__(self, out_file, dtype=np.int32): 286 | self.out_file = open(out_file, 'wb') 287 | self.dtype = dtype 288 | self.data_offsets = [0] 289 | self.dim_offsets = [0] 290 | self.sizes = [] 291 | self.element_size = self.element_sizes[self.dtype] 292 | self.doc_idx = [0] 293 | 294 | def add_item(self, tensor): 295 | bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) 296 | self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) 297 | for s in tensor.size(): 298 | self.sizes.append(s) 299 | self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) 300 | 301 | def end_document(self): 302 | self.doc_idx.append(len(self.sizes)) 303 | 304 | def merge_file_(self, another_file): 305 | index = IndexedDataset(another_file) 306 | assert index.dtype == self.dtype 307 | 308 | begin = self.data_offsets[-1] 309 | for offset in index.data_offsets[1:]: 310 | self.data_offsets.append(begin + offset) 311 | self.sizes.extend(index.sizes) 312 | begin = self.dim_offsets[-1] 313 | for dim_offset in index.dim_offsets[1:]: 314 | self.dim_offsets.append(begin + dim_offset) 315 | 316 | with open(data_file_path(another_file), 'rb') as f: 317 | while True: 318 | data = f.read(1024) 319 | if data: 320 | self.out_file.write(data) 321 | else: 322 | break 323 | 324 | def finalize(self, index_file): 325 | self.out_file.close() 326 | index = open(index_file, 'wb') 327 | index.write(b'TNTIDX\x00\x00') 328 | index.write(struct.pack(' 2: 48 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 49 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 50 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 51 | target = target.view(-1, 1) 52 | 53 | logpt = F.log_softmax(input) 54 | logpt = logpt.gather(1, target) 55 | logpt = logpt.view(-1) 56 | pt = Variable(logpt.data.exp()) 57 | 58 | if self.alpha is not None: 59 | if self.alpha.type() != input.data.type(): 60 | self.alpha = self.alpha.type_as(input.data) 61 | at = self.alpha.gather(0, target.data.view(-1)) 62 | logpt = logpt * Variable(at) 63 | 64 | loss = -1 * (1 - pt) ** self.gamma * logpt 65 | if self.size_average: 66 | return loss.mean() 67 | else: 68 | return loss.sum() 69 | -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import json 5 | 6 | def softmax_acc(score, label, acc_result): 7 | if acc_result is None: 8 | acc_result = {'total': 0, 'right': 0} 9 | predict = torch.max(score, dim = 1)[1] 10 | acc_result['total'] += int(label.shape[0]) 11 | acc_result['right'] += int((predict == label).int().sum()) 12 | return acc_result -------------------------------------------------------------------------------- /model/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from transformers import AdamW 3 | 4 | 5 | def init_optimizer(model, config, *args, **params): 6 | optimizer_type = config.get("train", "optimizer") 7 | learning_rate = config.getfloat("train", "learning_rate") 8 | if optimizer_type == "adam": 9 | optimizer = optim.Adam(model.parameters(), lr=learning_rate, 10 | weight_decay=config.getfloat("train", "weight_decay")) 11 | elif optimizer_type == "sgd": 12 | optimizer = optim.SGD(model.parameters(), lr=learning_rate, 13 | weight_decay=config.getfloat("train", "weight_decay")) 14 | elif optimizer_type == "AdamW": 15 | optimizer = AdamW(model.parameters(), lr=learning_rate, 16 | weight_decay=config.getfloat("train", "weight_decay")) 17 | else: 18 | raise NotImplementedError 19 | 20 | return optimizer 21 | -------------------------------------------------------------------------------- /reader/__init__.py: -------------------------------------------------------------------------------- 1 | from .reader import init_dataset, init_test_dataset 2 | -------------------------------------------------------------------------------- /reader/reader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import logging 3 | 4 | import formatter as form 5 | from dataset import dataset_list 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | collate_fn = {} 11 | formatter = {} 12 | 13 | 14 | def init_formatter(config, task_list, *args, **params): 15 | for task in task_list: 16 | formatter[task] = form.init_formatter(config, task, *args, **params) 17 | 18 | def train_collate_fn(data): 19 | return formatter["train"].process(data, config, "train") 20 | 21 | def valid_collate_fn(data): 22 | return formatter["valid"].process(data, config, "valid") 23 | 24 | def test_collate_fn(data): 25 | return formatter["test"].process(data, config, "test") 26 | 27 | if task == "train": 28 | collate_fn[task] = train_collate_fn 29 | elif task == "valid": 30 | collate_fn[task] = valid_collate_fn 31 | else: 32 | collate_fn[task] = test_collate_fn 33 | 34 | 35 | def init_one_dataset(config, mode, *args, **params): 36 | temp_mode = mode 37 | if mode != "train": 38 | try: 39 | config.get("data", "%s_dataset_type" % temp_mode) 40 | except Exception as e: 41 | logger.warning( 42 | "[reader] %s_dataset_type has not been defined in config file, use [dataset] train_dataset_type instead." % temp_mode) 43 | temp_mode = "train" 44 | which = config.get("data", "%s_dataset_type" % temp_mode) 45 | 46 | if which in dataset_list: 47 | dataset = dataset_list[which](config, mode, *args, **params) 48 | batch_size = config.getint("train", "batch_size") 49 | shuffle = config.getboolean("train", "shuffle") 50 | reader_num = config.getint("train", "reader_num") 51 | drop_last = True 52 | if mode in ["valid", "test"]: 53 | if mode == "test": 54 | drop_last = False 55 | 56 | try: 57 | batch_size = config.getint("eval", "batch_size") 58 | except Exception as e: 59 | logger.warning("[eval] batch size has not been defined in config file, use [train] batch_size instead.") 60 | 61 | try: 62 | shuffle = config.getboolean("eval", "shuffle") 63 | except Exception as e: 64 | shuffle = False 65 | logger.warning("[eval] shuffle has not been defined in config file, use false as default.") 66 | try: 67 | reader_num = config.getint("eval", "reader_num") 68 | except Exception as e: 69 | logger.warning("[eval] reader num has not been defined in config file, use [train] reader num instead.") 70 | 71 | if config.getboolean('distributed', 'use'): 72 | sampler = DistributedSampler(dataset) 73 | else: 74 | sampler = RandomSampler(dataset) 75 | 76 | dataloader = DataLoader(dataset=dataset, 77 | batch_size=batch_size, 78 | #shuffle=shuffle, 79 | num_workers=reader_num, 80 | collate_fn=collate_fn[mode], 81 | drop_last=drop_last, 82 | sampler=sampler) 83 | 84 | return dataloader 85 | else: 86 | logger.error("There is no dataset called %s, check your config." % which) 87 | raise NotImplementedError 88 | 89 | 90 | def init_test_dataset(config, *args, **params): 91 | init_formatter(config, ["test"], *args, **params) 92 | test_dataset = init_one_dataset(config, "test", *args, **params) 93 | 94 | return test_dataset 95 | 96 | 97 | def init_dataset(config, *args, **params): 98 | init_formatter(config, ["train", "valid"], *args, **params) 99 | train_dataset = init_one_dataset(config, "train", *args, **params) 100 | valid_dataset = init_one_dataset(config, "valid", *args, **params) 101 | 102 | return train_dataset, valid_dataset 103 | 104 | 105 | if __name__ == "__main__": 106 | pass 107 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.4.2 2 | torch==1.6.0 3 | numpy==1.20.2 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import logging 5 | import json 6 | 7 | from tools.init_tool import init_all 8 | from config_parser import create_config 9 | from tools.test_tool import test 10 | 11 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 12 | datefmt='%m/%d/%Y %H:%M:%S', 13 | level=logging.INFO) 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--config', '-c', help="specific config file", required=True) 20 | parser.add_argument('--gpu', '-g', help="gpu id list") 21 | parser.add_argument('--checkpoint', help="checkpoint file path", required=True) 22 | parser.add_argument('--result', help="result file path", required=True) 23 | args = parser.parse_args() 24 | 25 | configFilePath = args.config 26 | 27 | use_gpu = True 28 | gpu_list = [] 29 | if args.gpu is None: 30 | use_gpu = False 31 | else: 32 | use_gpu = True 33 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 34 | 35 | device_list = args.gpu.split(",") 36 | for a in range(0, len(device_list)): 37 | gpu_list.append(int(a)) 38 | 39 | os.system("clear") 40 | 41 | config = create_config(configFilePath) 42 | 43 | cuda = torch.cuda.is_available() 44 | logger.info("CUDA available: %s" % str(cuda)) 45 | if not cuda and len(gpu_list) > 0: 46 | logger.error("CUDA is not available but specific gpu id") 47 | raise NotImplementedError 48 | 49 | parameters = init_all(config, gpu_list, args.checkpoint, "test") 50 | 51 | json.dump(test(parameters, config, gpu_list), open(args.result, "w", encoding="utf8"), ensure_ascii=False, 52 | sort_keys=True, indent=2) 53 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/LegalPLMs/35746bf3d2e751f453ae41c4129f1fa1233cf9b0/tools/__init__.py -------------------------------------------------------------------------------- /tools/accuracy_init.py: -------------------------------------------------------------------------------- 1 | from .accuracy_tool import single_label_top1_accuracy, single_label_top2_accuracy, multi_label_accuracy, \ 2 | null_accuracy_function 3 | 4 | accuracy_function_dic = { 5 | "SingleLabelTop1": single_label_top1_accuracy, 6 | "MultiLabel": multi_label_accuracy, 7 | "Null": null_accuracy_function 8 | } 9 | 10 | 11 | def init_accuracy_function(config, *args, **params): 12 | name = config.get("output", "accuracy_method") 13 | if name in accuracy_function_dic: 14 | return accuracy_function_dic[name] 15 | else: 16 | raise NotImplementedError 17 | -------------------------------------------------------------------------------- /tools/accuracy_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | logger = logging.Logger(__name__) 5 | 6 | 7 | def get_prf(res): 8 | # According to https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure 9 | if res["TP"] == 0: 10 | if res["FP"] == 0 and res["FN"] == 0: 11 | precision = 1.0 12 | recall = 1.0 13 | f1 = 1.0 14 | else: 15 | precision = 0.0 16 | recall = 0.0 17 | f1 = 0.0 18 | else: 19 | precision = 1.0 * res["TP"] / (res["TP"] + res["FP"]) 20 | recall = 1.0 * res["TP"] / (res["TP"] + res["FN"]) 21 | f1 = 2 * precision * recall / (precision + recall) 22 | 23 | return precision, recall, f1 24 | 25 | 26 | def gen_micro_macro_result(res): 27 | precision = [] 28 | recall = [] 29 | f1 = [] 30 | total = {"TP": 0, "FP": 0, "FN": 0, "TN": 0} 31 | for a in range(0, len(res)): 32 | total["TP"] += res[a]["TP"] 33 | total["FP"] += res[a]["FP"] 34 | total["FN"] += res[a]["FN"] 35 | total["TN"] += res[a]["TN"] 36 | 37 | p, r, f = get_prf(res[a]) 38 | precision.append(p) 39 | recall.append(r) 40 | f1.append(f) 41 | 42 | micro_precision, micro_recall, micro_f1 = get_prf(total) 43 | 44 | macro_precision = 0 45 | macro_recall = 0 46 | macro_f1 = 0 47 | for a in range(0, len(f1)): 48 | macro_precision += precision[a] 49 | macro_recall += recall[a] 50 | macro_f1 += f1[a] 51 | 52 | macro_precision /= len(f1) 53 | macro_recall /= len(f1) 54 | macro_f1 /= len(f1) 55 | 56 | return { 57 | "micro_precision": round(micro_precision, 3), 58 | "micro_recall": round(micro_recall, 3), 59 | "micro_f1": round(micro_f1, 3), 60 | "macro_precision": round(macro_precision, 3), 61 | "macro_recall": round(macro_recall, 3), 62 | "macro_f1": round(macro_f1, 3) 63 | } 64 | 65 | 66 | def null_accuracy_function(outputs, label, config, result=None): 67 | return None 68 | 69 | 70 | def single_label_top1_accuracy(outputs, label, config, result=None): 71 | if result is None: 72 | result = [] 73 | id1 = torch.max(outputs, dim=1)[1] 74 | # id2 = torch.max(label, dim=1)[1] 75 | id2 = label 76 | nr_classes = outputs.size(1) 77 | while len(result) < nr_classes: 78 | result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 79 | for a in range(0, len(id1)): 80 | # if len(result) < a: 81 | # result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 82 | 83 | it_is = int(id1[a]) 84 | should_be = int(id2[a]) 85 | if it_is == should_be: 86 | result[it_is]["TP"] += 1 87 | else: 88 | result[it_is]["FP"] += 1 89 | result[should_be]["FN"] += 1 90 | 91 | return result 92 | 93 | 94 | def multi_label_accuracy(outputs, label, config, result=None): 95 | if len(label[0]) != len(outputs[0]): 96 | raise ValueError('Input dimensions of labels and outputs must match.') 97 | 98 | if len(outputs.size()) > 2: 99 | outputs = outputs.view(outputs.size()[0], -1, 2) 100 | outputs = torch.nn.Softmax(dim=2)(outputs) 101 | outputs = outputs[:, :, 1] 102 | 103 | outputs = outputs.data 104 | labels = label.data 105 | 106 | if result is None: 107 | result = [] 108 | 109 | total = 0 110 | nr_classes = outputs.size(1) 111 | 112 | while len(result) < nr_classes: 113 | result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 114 | 115 | for i in range(nr_classes): 116 | outputs1 = (outputs[:, i] >= 0.5).long() 117 | labels1 = (labels[:, i].float() >= 0.5).long() 118 | total += int((labels1 * outputs1).sum()) 119 | total += int(((1 - labels1) * (1 - outputs1)).sum()) 120 | 121 | if result is None: 122 | continue 123 | 124 | # if len(result) < i: 125 | # result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 126 | 127 | result[i]["TP"] += int((labels1 * outputs1).sum()) 128 | result[i]["FN"] += int((labels1 * (1 - outputs1)).sum()) 129 | result[i]["FP"] += int(((1 - labels1) * outputs1).sum()) 130 | result[i]["TN"] += int(((1 - labels1) * (1 - outputs1)).sum()) 131 | 132 | return result 133 | 134 | def single_label_top2_accuracy(outputs, label, config, result=None): 135 | raise NotImplementedError 136 | # still bug here 137 | 138 | if result is None: 139 | result = [] 140 | # print(label) 141 | 142 | id1 = torch.max(outputs, dim=1)[1] 143 | # id2 = torch.max(label, dim=1)[1] 144 | id2 = label 145 | nr_classes = outputs.size(1) 146 | while len(result) < nr_classes: 147 | result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 148 | for a in range(0, len(id1)): 149 | # if len(result) < a: 150 | # result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 151 | 152 | it_is = int(id1[a]) 153 | should_be = int(id2[a]) 154 | if it_is == should_be: 155 | result[it_is]["TP"] += 1 156 | else: 157 | result[it_is]["FP"] += 1 158 | result[should_be]["FN"] += 1 159 | 160 | _, prediction = torch.topk(outputs, 2, 1, largest=True) 161 | prediction1 = prediction[:, 0:1] 162 | prediction2 = prediction[:, 1:] 163 | 164 | prediction1 = prediction1.view(-1) 165 | prediction2 = prediction2.view(-1) 166 | 167 | return result 168 | -------------------------------------------------------------------------------- /tools/dataset_tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def dfs_search(path, recursive): 5 | if os.path.isfile(path): 6 | return [path] 7 | file_list = [] 8 | name_list = os.listdir(path) 9 | name_list.sort() 10 | for filename in name_list: 11 | real_path = os.path.join(path, filename) 12 | 13 | if os.path.isdir(real_path): 14 | if recursive: 15 | file_list = file_list + dfs_search(real_path, recursive) 16 | else: 17 | file_list.append(real_path) 18 | 19 | return file_list 20 | -------------------------------------------------------------------------------- /tools/eval_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.optim import lr_scheduler 6 | from timeit import default_timer as timer 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def gen_time_str(t): 12 | t = int(t) 13 | minute = t // 60 14 | second = t % 60 15 | return '%2d:%02d' % (minute, second) 16 | 17 | 18 | def output_value(epoch, mode, step, time, loss, info, end, config): 19 | try: 20 | delimiter = config.get("output", "delimiter") 21 | except Exception as e: 22 | delimiter = " " 23 | s = "" 24 | s = s + str(epoch) + " " 25 | while len(s) < 10: 26 | s += " " 27 | s = s + str(mode) + " " 28 | while len(s) < 18: 29 | s += " " 30 | s = s + str(step) + " " 31 | while len(s) < 30: 32 | s += " " 33 | s += str(time) 34 | while len(s) < 45: 35 | s += " " 36 | s += str(loss) 37 | while len(s) < 53: 38 | s += " " 39 | s += str(info) 40 | s = s.replace(" ", delimiter) 41 | if not (end is None): 42 | print(s, end=end) 43 | else: 44 | print(s) 45 | 46 | 47 | def valid(model, dataset, epoch, config, gpu_list, output_function, mode="valid"): 48 | model.eval() 49 | local_rank = config.getint('distributed', 'local_rank') 50 | 51 | acc_result = None 52 | total_loss = 0 53 | cnt = 0 54 | total_len = len(dataset) 55 | start_time = timer() 56 | output_info = "" 57 | 58 | output_time = config.getint("output", "output_time") 59 | step = -1 60 | more = "" 61 | if total_len < 10000: 62 | more = "\t" 63 | 64 | for step, data in enumerate(dataset): 65 | for key in data.keys(): 66 | if isinstance(data[key], torch.Tensor): 67 | if len(gpu_list) > 0: 68 | data[key] = Variable(data[key].cuda()) 69 | else: 70 | data[key] = Variable(data[key]) 71 | 72 | results = model(data, config, gpu_list, acc_result, "valid") 73 | 74 | loss, acc_result = results["loss"], results["acc_result"] 75 | total_loss += float(loss) 76 | cnt += 1 77 | 78 | if step % output_time == 0 and local_rank <= 0: 79 | delta_t = timer() - start_time 80 | 81 | output_value(epoch, mode, "%d/%d" % (step + 1, total_len), "%s/%s" % ( 82 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 83 | "%.3lf" % (total_loss / (step + 1)), output_info, '\r', config) 84 | 85 | if step == -1: 86 | logger.error("There is no data given to the model in this epoch, check your data.") 87 | raise NotImplementedError 88 | 89 | if config.getboolean("distributed", "use"): 90 | shape = len(acc_result) + 1 91 | mytensor = torch.FloatTensor([total_loss] + [acc_result[key] for key in acc_result]).to(gpu_list[local_rank]) 92 | mylist = [torch.FloatTensor(shape).to(gpu_list[local_rank]) for i in range(config.getint('distributed', 'gpu_num'))] 93 | torch.distributed.all_gather(mylist, mytensor)#, 0) 94 | if local_rank == 0: 95 | mytensor = sum(mylist) 96 | total_loss = float(mytensor[0]) / config.getint('distributed', 'gpu_num') 97 | index = 1 98 | for key in acc_result: 99 | acc_result[key] = int(mytensor[index]) 100 | index += 1 101 | if local_rank <= 0: 102 | delta_t = timer() - start_time 103 | output_info = output_function(acc_result, config) 104 | output_value(epoch, mode, "%d/%d" % (step + 1, total_len), "%s/%s" % ( 105 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 106 | "%.3lf" % (total_loss / (step + 1)), output_info, None, config) 107 | 108 | model.train() 109 | -------------------------------------------------------------------------------- /tools/init_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from reader.reader import init_dataset, init_formatter, init_test_dataset 5 | from model import get_model 6 | from model.optimizer import init_optimizer 7 | from .output_init import init_output_function 8 | from torch import nn 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def init_all(config, gpu_list, checkpoint, mode, *args, **params): 14 | result = {} 15 | 16 | logger.info("Begin to initialize dataset and formatter...") 17 | if mode == "train": 18 | # init_formatter(config, ["train", "valid"], *args, **params) 19 | result["train_dataset"], result["valid_dataset"] = init_dataset(config, *args, **params) 20 | else: 21 | # init_formatter(config, ["test"], *args, **params) 22 | result["test_dataset"] = init_test_dataset(config, *args, **params) 23 | 24 | logger.info("Begin to initialize models...") 25 | 26 | model = get_model(config.get("model", "model_name"))(config, gpu_list, *args, **params) 27 | optimizer = init_optimizer(model, config, *args, **params) 28 | trained_epoch = 0 29 | global_step = 0 30 | 31 | if len(gpu_list) > 0: 32 | if params['local_rank'] < 0: 33 | model = model.cuda() 34 | else: 35 | model = model.to(gpu_list[params['local_rank']]) 36 | try: 37 | model = nn.parallel.DistributedDataParallel(model, device_ids = [params['local_rank']], find_unused_parameters = True) 38 | except Exception as e: 39 | logger.warning("No init_multi_gpu implemented in the model, use single gpu instead.") 40 | 41 | try: 42 | parameters = torch.load(checkpoint, map_location=lambda storage, loc: storage) 43 | if hasattr(model, 'module'): 44 | model.module.load_state_dict(parameters["model"]) 45 | else: 46 | model.load_state_dict(parameters["model"]) 47 | if mode == "train": 48 | trained_epoch = parameters["trained_epoch"] 49 | if config.get("train", "optimizer") == parameters["optimizer_name"]: 50 | optimizer.load_state_dict(parameters["optimizer"]) 51 | else: 52 | logger.warning("Optimizer changed, do not load parameters of optimizer.") 53 | 54 | if "global_step" in parameters: 55 | global_step = parameters["global_step"] 56 | if "lr_scheduler" in parameters: 57 | result["lr_scheduler"] = parameters["lr_scheduler"] 58 | except Exception as e: 59 | information = "Cannot load checkpoint file with error %s" % str(e) 60 | if mode == "test": 61 | logger.error(information) 62 | raise e 63 | else: 64 | logger.warning(information) 65 | 66 | result["model"] = model 67 | if mode == "train": 68 | result["optimizer"] = optimizer 69 | result["trained_epoch"] = trained_epoch 70 | result["output_function"] = init_output_function(config) 71 | result["global_step"] = global_step 72 | 73 | logger.info("Initialize done.") 74 | 75 | return result 76 | -------------------------------------------------------------------------------- /tools/output_init.py: -------------------------------------------------------------------------------- 1 | from .output_tool import basic_output_function, null_output_function, output_function1, binary_output_function 2 | 3 | output_function_dic = { 4 | "Basic": basic_output_function, 5 | "Null": null_output_function, 6 | "out1": output_function1, 7 | "binary": binary_output_function, 8 | } 9 | 10 | 11 | def init_output_function(config, *args, **params): 12 | name = config.get("output", "output_function") 13 | 14 | if name in output_function_dic: 15 | return output_function_dic[name] 16 | else: 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /tools/output_tool.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .accuracy_tool import gen_micro_macro_result 4 | 5 | 6 | def null_output_function(data, config, *args, **params): 7 | return "" 8 | 9 | 10 | def basic_output_function(data, config, *args, **params): 11 | which = config.get("output", "output_value").replace(" ", "").split(",") 12 | temp = gen_micro_macro_result(data) 13 | result = {} 14 | for name in which: 15 | result[name] = temp[name] 16 | 17 | return json.dumps(result, sort_keys=True) 18 | 19 | def output_function1(data, config, *args, **params): 20 | if data['pre_num'] != 0 and data['actual_num'] != 0: 21 | pre = data['right'] / data['pre_num'] 22 | recall = data['right'] / data['actual_num'] 23 | if pre + recall == 0: 24 | f1 = 0 25 | else: 26 | f1 = 2 * pre * recall / (pre + recall) 27 | else: 28 | pre = 0 29 | recall = 0 30 | f1 = 0 31 | 32 | metric = { 33 | 'precision': round(pre, 4), 34 | 'recall': round(recall, 4), 35 | 'f1': round(f1, 4), 36 | } 37 | if 'labelset' in data and 'doc_num' in data and data['doc_num'] != 0: 38 | metric['ave_len'] = data['labelset'] / data['doc_num'] 39 | return json.dumps(metric) 40 | 41 | def binary_output_function(data, config, *args, **params): 42 | if data['total'] == 0: 43 | metric = {'acc': 0} 44 | else: 45 | metric = {'acc': round(data['right'] / data['total'], 4)} 46 | return json.dumps(metric) 47 | 48 | -------------------------------------------------------------------------------- /tools/test_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from timeit import default_timer as timer 6 | 7 | from tools.eval_tool import gen_time_str, output_value 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def test(parameters, config, gpu_list): 13 | model = parameters["model"] 14 | dataset = parameters["test_dataset"] 15 | model.eval() 16 | 17 | acc_result = None 18 | total_loss = 0 19 | cnt = 0 20 | total_len = len(dataset) 21 | start_time = timer() 22 | output_info = "testing" 23 | 24 | output_time = config.getint("output", "output_time") 25 | step = -1 26 | result = [] 27 | 28 | for step, data in enumerate(dataset): 29 | for key in data.keys(): 30 | if isinstance(data[key], torch.Tensor): 31 | if len(gpu_list) > 0: 32 | data[key] = Variable(data[key].cuda()) 33 | else: 34 | data[key] = Variable(data[key]) 35 | 36 | results = model(data, config, gpu_list, acc_result, "test") 37 | result = result + results["output"] 38 | cnt += 1 39 | 40 | if step % output_time == 0: 41 | delta_t = timer() - start_time 42 | 43 | output_value(0, "test", "%d/%d" % (step + 1, total_len), "%s/%s" % ( 44 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 45 | "%.3lf" % (total_loss / (step + 1)), output_info, '\r', config) 46 | 47 | if step == -1: 48 | logger.error("There is no data given to the model in this epoch, check your data.") 49 | raise NotImplementedError 50 | 51 | delta_t = timer() - start_time 52 | output_info = "testing" 53 | output_value(0, "test", "%d/%d" % (step + 1, total_len), "%s/%s" % ( 54 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 55 | "%.3lf" % (total_loss / (step + 1)), output_info, None, config) 56 | 57 | return result 58 | -------------------------------------------------------------------------------- /tools/train_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.optim import lr_scheduler as lrs 6 | import shutil 7 | from timeit import default_timer as timer 8 | 9 | from tools.eval_tool import valid, gen_time_str, output_value 10 | from tools.init_tool import init_test_dataset, init_formatter 11 | from transformers import get_linear_schedule_with_warmup 12 | from torch.cuda.amp import autocast 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def checkpoint(filename, model, optimizer, trained_epoch, config, global_step, lr_scheduler): 18 | model_to_save = model.module if hasattr(model, 'module') else model 19 | save_params = { 20 | "model": model_to_save.state_dict(), 21 | "optimizer_name": config.get("train", "optimizer"), 22 | "optimizer": optimizer.state_dict(), 23 | "trained_epoch": trained_epoch, 24 | "global_step": global_step, 25 | "lr_scheduler": lr_scheduler.state_dict(), 26 | } 27 | 28 | try: 29 | torch.save(save_params, filename) 30 | except Exception as e: 31 | logger.warning("Cannot save models with error %s, continue anyway" % str(e)) 32 | 33 | 34 | def train(parameters, config, gpu_list, do_test=False, local_rank=-1): 35 | epoch = config.getint("train", "epoch") 36 | batch_size = config.getint("train", "batch_size") 37 | 38 | output_time = config.getint("output", "output_time") 39 | test_time = config.getint("output", "test_time") 40 | 41 | output_path = os.path.join(config.get("output", "model_path"), config.get("output", "model_name")) 42 | if os.path.exists(output_path): 43 | logger.warning("Output path exists, check whether need to change a name of model") 44 | os.makedirs(output_path, exist_ok=True) 45 | 46 | trained_epoch = parameters["trained_epoch"] + 1 47 | model = parameters["model"] 48 | optimizer = parameters["optimizer"] 49 | dataset = parameters["train_dataset"] 50 | global_step = parameters["global_step"] 51 | output_function = parameters["output_function"] 52 | 53 | if do_test: 54 | init_formatter(config, ["test"]) 55 | test_dataset = init_test_dataset(config) 56 | 57 | step_size = config.getint("train", "step_size") 58 | gamma = config.getfloat("train", "lr_multiplier") 59 | grad_accumulate = config.getint("train", "grad_accumulate") 60 | 61 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.getint('train', 'warmup_steps'), num_training_steps=config.getint('train', 'training_steps')) 62 | #if "lr_scheduler" in parameters: 63 | #lr_scheduler.load_state_dict(parameters["lr_scheduler"]) 64 | 65 | # exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 66 | # exp_lr_scheduler.step(trained_epoch) 67 | 68 | fp16 = config.getboolean('train', 'fp16') 69 | if fp16: 70 | scaler = torch.cuda.amp.GradScaler() 71 | max_grad_norm = config.getfloat('train', 'max_grad_norm') 72 | valid_mode = config.get('train', 'valid_mode') 73 | if valid_mode != 'step' and valid_mode != 'batch': 74 | raise ValueError('The value of valid_mode is invalid.') 75 | print('valid_mode', valid_mode) 76 | if valid_mode == 'step': 77 | step_epoch = config.getint('train', 'step_epoch') 78 | print('step_epoch', step_epoch) 79 | logger.info("Training start....") 80 | 81 | print("Epoch Stage Iterations Time Usage Loss Output Information") 82 | 83 | total_len = len(dataset) 84 | more = "" 85 | if total_len < 10000: 86 | more = "\t" 87 | for epoch_num in range(trained_epoch, epoch): 88 | start_time = timer() 89 | current_epoch = epoch_num 90 | 91 | # exp_lr_scheduler.step(current_epoch) 92 | 93 | acc_result = None 94 | total_loss = 0 95 | 96 | output_info = "" 97 | step = -1 98 | for step, data in enumerate(dataset): 99 | for key in data.keys(): 100 | if isinstance(data[key], torch.Tensor): 101 | if len(gpu_list) > 0: 102 | data[key] = Variable(data[key].cuda()) 103 | else: 104 | data[key] = Variable(data[key]) 105 | 106 | if fp16: 107 | with autocast(): 108 | results = model(data, config, gpu_list, acc_result, "train") 109 | else: 110 | results = model(data, config, gpu_list, acc_result, "train") 111 | 112 | loss, acc_result = results["loss"], results["acc_result"] 113 | total_loss += float(loss) 114 | 115 | loss = loss / grad_accumulate 116 | if fp16: 117 | scaler.scale(loss).backward() 118 | else: 119 | loss.backward() 120 | 121 | if (step + 1) % grad_accumulate == 0: 122 | 123 | if max_grad_norm is not None and max_grad_norm > 0: 124 | if fp16: 125 | scaler.unscale_(optimizer) 126 | if hasattr(optimizer, "clip_grad_norm"): 127 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 128 | optimizer.clip_grad_norm(max_grad_norm) 129 | elif hasattr(model, "clip_grad_norm_"): 130 | # Some models (like FullyShardedDDP) have a specific way to do gradient clipping 131 | model.clip_grad_norm_(max_grad_norm) 132 | else: 133 | # Revert to normal clipping otherwise, handling Apex or full precision 134 | torch.nn.utils.clip_grad_norm_( 135 | model.parameters(), 136 | max_grad_norm 137 | ) 138 | 139 | if fp16: 140 | scaler.step(optimizer) 141 | scaler.update() 142 | else: 143 | optimizer.step() 144 | lr_scheduler.step() 145 | optimizer.zero_grad() 146 | 147 | if step % output_time == 0 and local_rank <= 0: 148 | output_info = output_function(acc_result, config) 149 | 150 | delta_t = timer() - start_time 151 | 152 | output_value(current_epoch, "train", "%d/%d" % (step + 1, total_len), "%s/%s" % ( 153 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 154 | "%.3lf" % (total_loss / (step + 1)), output_info, '\r', config) 155 | 156 | global_step += 1 157 | if (step + 1) % grad_accumulate == 0 and valid_mode == 'step' and int((step + 1) / grad_accumulate) % step_epoch == 0: 158 | if local_rank <= 0: 159 | print() 160 | checkpoint(os.path.join(output_path, "%d.pkl" % current_epoch), model, optimizer, current_epoch, config, global_step, lr_scheduler) 161 | path = os.path.join(output_path, 'model_%d_%d' % (current_epoch, (step + 1) // grad_accumulate)) 162 | if local_rank < 0: 163 | model.save_pretrained(path) 164 | else: 165 | model.module.save_pretrained(path) 166 | with torch.no_grad(): 167 | valid(model, parameters["valid_dataset"], current_epoch, config, gpu_list, output_function) 168 | 169 | if step == -1: 170 | logger.error("There is no data given to the model in this epoch, check your data.") 171 | raise NotImplementedError 172 | 173 | 174 | if valid_mode != 'batch': 175 | continue 176 | 177 | if local_rank <= 0: 178 | output_info = output_function(acc_result, config) 179 | delta_t = timer() - start_time 180 | output_value(current_epoch, "train", "%d/%d" % (step + 1, total_len), "%s/%s" % ( 181 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 182 | "%.3lf" % (total_loss / (step + 1)), output_info, None, config) 183 | 184 | # if step == -1: 185 | # logger.error("There is no data given to the model in this epoch, check your data.") 186 | # raise NotImplementedError 187 | 188 | if local_rank <= 0: 189 | checkpoint(os.path.join(output_path, "%d.pkl" % current_epoch), model, optimizer, current_epoch, config, global_step, lr_scheduler) 190 | 191 | if current_epoch % test_time == 0: 192 | with torch.no_grad(): 193 | valid(model, parameters["valid_dataset"], current_epoch, config, gpu_list, output_function) 194 | if do_test: 195 | valid(model, test_dataset, current_epoch, config, gpu_list, output_function, mode="test") 196 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import logging 5 | 6 | from tools.init_tool import init_all 7 | from config_parser import create_config 8 | from tools.train_tool import train 9 | 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | datefmt='%m/%d/%Y %H:%M:%S', 12 | level=logging.INFO) 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--config', '-c', help="specific config file", required=True) 19 | parser.add_argument('--gpu', '-g', help="gpu id list") 20 | parser.add_argument('--checkpoint', help="checkpoint file path") 21 | parser.add_argument('--local_rank', type=int, help='local rank', default=-1) 22 | parser.add_argument('--do_test', help="do test while training or not", action="store_true") 23 | parser.add_argument('--comment', help="checkpoint file path", default=None) 24 | args = parser.parse_args() 25 | 26 | configFilePath = args.config 27 | 28 | config = create_config(configFilePath) 29 | 30 | 31 | use_gpu = True 32 | gpu_list = [] 33 | if args.gpu is None: 34 | use_gpu = False 35 | else: 36 | use_gpu = True 37 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 38 | 39 | device_list = args.gpu.split(",") 40 | for a in range(0, len(device_list)): 41 | gpu_list.append(int(a)) 42 | 43 | os.system("clear") 44 | config.set('distributed', 'local_rank', args.local_rank) 45 | if config.getboolean("distributed", "use"): 46 | torch.cuda.set_device(gpu_list[args.local_rank]) 47 | torch.distributed.init_process_group(backend=config.get("distributed", "backend")) 48 | config.set('distributed', 'gpu_num', len(gpu_list)) 49 | 50 | cuda = torch.cuda.is_available() 51 | logger.info("CUDA available: %s" % str(cuda)) 52 | if not cuda and len(gpu_list) > 0: 53 | logger.error("CUDA is not available but specific gpu id") 54 | raise NotImplementedError 55 | 56 | parameters = init_all(config, gpu_list, args.checkpoint, "train", local_rank = args.local_rank) 57 | do_test = False 58 | if args.do_test: 59 | do_test = True 60 | 61 | print(args.comment) 62 | train(parameters, config, gpu_list, do_test, args.local_rank) 63 | --------------------------------------------------------------------------------