├── .gitignore ├── README.md ├── batch_loader.py ├── bert ├── __init__.py ├── file_utils.py ├── modeling.py ├── optimization.py └── tokenization.py ├── config.py ├── data ├── nlpcc-iccpol-2016.kbqa.testing-data └── nlpcc-iccpol-2016.kbqa.training-data ├── graph.py ├── main.py ├── model.py ├── nega_sampling.py ├── run.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | NLPCC2017-OpenDomainQA 3 | msm 4 | graph 5 | experiments 6 | *.json 7 | __pycache__ 8 | server.py 9 | statistics.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 百科KBQA 2 | 3 | ## 数据 4 | Knowledge Graph: 5 | [NLPCC 2017图谱(提取码:khrv)](https://pan.baidu.com/s/1yO77WW5XQwA_RtkxRHI7Yw) 6 | 7 | QAs: 8 | [NLPCC 2016KBQA](https://github.com/fyubang/Joint-BERT-KBQA/tree/master/data) 9 | 10 | -------------------------------------------------------------------------------- /batch_loader.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import json 4 | import torch 5 | 6 | from random import shuffle 7 | from tqdm import tqdm 8 | from collections import OrderedDict 9 | 10 | from bert import BertTokenizer 11 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 12 | 13 | 14 | WHITESPACE_PLACEHOLDER = ' □ ' 15 | 16 | class BatchLoader(object): 17 | def __init__(self, args): 18 | self.data_dir = args.data_dir 19 | self.nega_num = args.nega_num 20 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_model_dir, do_lower_case=True) 21 | 22 | def load_data(self, file_path): 23 | data = [] 24 | num = 0 25 | with open(os.path.join(self.data_dir, file_path), 'r', encoding='utf-8') as f: 26 | for line in f: 27 | sample = json.loads(line) 28 | data.append(sample) 29 | return data 30 | 31 | def build_data(self, data, is_train): 32 | ''' 33 | build data list, split to train and dev 34 | tokenization 35 | ''' 36 | ner_data = [] 37 | re_data = [] 38 | if is_train: 39 | for sample in data: 40 | # ner 41 | question = sample['question'] 42 | s, p, o = sample['triple'] 43 | ner_sample = self._get_ner_sample(question, s) 44 | ner_data.append(ner_sample) 45 | # re 46 | re_sample = self._get_re_sample(question, p, 1) 47 | re_data.append(re_sample) 48 | nps = sample.get('negative_predicates', []) 49 | shuffle(nps) 50 | for np in nps[:self.nega_num]: 51 | re_sample = self._get_re_sample(question, np, 0) 52 | re_data.append(re_sample) 53 | else: 54 | for sample in data: 55 | question = sample['question'] 56 | s, p, o = sample['triple'] 57 | ner_sample = self._get_ner_sample(question, s) 58 | ner_data.append(ner_sample) 59 | return ner_data, re_data 60 | 61 | def build_ner_data(self, question): 62 | ''' 63 | build ner data for inference 64 | question: 65 | ''' 66 | ner_data = [self._get_ner_sample(question, None)] 67 | return ner_data 68 | 69 | def build_re_data(self, question, predicates): 70 | ''' 71 | build re data for inference 72 | question: 73 | predicates: all predicates relative to the subject 74 | ''' 75 | re_data = [] 76 | for p in predicates: 77 | re_sample = self._get_re_sample(question, p, 0) 78 | re_data.append(re_sample) 79 | return re_data 80 | 81 | 82 | def batch_loader(self, ner_data=None, re_data=None, ner_max_len=32, re_max_len=64, batch_size=32, is_train=True): 83 | if is_train: # input all three datas 84 | ner_dataset = self._build_dataset(ner_data, None, ner_max_len) 85 | re_dataset = self._build_dataset(None, re_data, re_max_len) 86 | datasets = [ner_dataset, re_dataset] 87 | dataloaders = [] 88 | for dataset in datasets: 89 | dataloaders.append(DataLoader(dataset, batch_size, sampler=RandomSampler(dataset), drop_last=True)) 90 | return dataloaders 91 | else: # only input ner_data 92 | if ner_data: 93 | ner_dataset = self._build_dataset(ner_data, None, ner_max_len) 94 | return DataLoader(dataset=ner_dataset, batch_size=batch_size, sampler=SequentialSampler(ner_dataset), drop_last=False) 95 | elif re_data: 96 | re_dataset = self._build_dataset(None, re_data, re_max_len) 97 | return DataLoader(dataset=re_dataset, batch_size=batch_size, sampler=SequentialSampler(re_dataset), drop_last=False) 98 | else: 99 | raise Exception('At least ner or re should not be None') 100 | 101 | def _build_dataset(self, ner_data=None, re_data=None, max_len=32): 102 | ''' 103 | only input one data 104 | ''' 105 | data, dataset = None, None 106 | if ner_data: 107 | data = ner_data 108 | elif re_data: 109 | data = re_data 110 | else: 111 | raise Exception('as least an input, ner or re') 112 | token_ids = torch.tensor(self._padding([item['token_ids'] for item in data], max_len), dtype=torch.long) 113 | token_types = torch.tensor(self._padding([item['token_types'] for item in data], max_len), dtype=torch.long) 114 | if ner_data: 115 | head = torch.tensor([item['head'] for item in data], dtype=torch.long) 116 | tail = torch.tensor([item['tail'] for item in data], dtype=torch.long) 117 | dataset = TensorDataset(token_ids, token_types, head, tail) 118 | elif re_data: 119 | label = torch.tensor([item['label'] for item in data], dtype=torch.float) 120 | dataset = TensorDataset(token_ids, token_types, label) 121 | return dataset 122 | 123 | def _get_token_ids(self, text, add_cls=False, add_sep=False): 124 | new_text = text.replace(' ', WHITESPACE_PLACEHOLDER).replace(' ', WHITESPACE_PLACEHOLDER) 125 | tokens = self.tokenizer.tokenize(new_text, inference=True) 126 | tokens = ['[CLS]'] + tokens if add_cls else tokens 127 | tokens = tokens + ['[SEP]'] if add_sep else tokens 128 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 129 | return tokens, token_ids 130 | 131 | def _get_ner_sample(self, question, subject): 132 | tokens, token_ids = self._get_token_ids(question, True, True) 133 | token_types = [0]*len(tokens) 134 | head = tail = 0 135 | if subject: 136 | span_tokens, _ = self._get_token_ids(subject) 137 | head, tail = self._get_head_tail(tokens, span_tokens) 138 | ner_sample = {'tokens': tokens, 'token_ids': token_ids, 'token_types': token_types, 139 | 'head': head, 'tail': tail, 'subject': subject} 140 | return ner_sample 141 | 142 | def _get_re_sample(self, question, predicate, label): 143 | p_tokens, p_token_ids = self._get_token_ids(predicate, True, True) 144 | tokens, token_ids = self._get_token_ids(question, False, True) 145 | tokens, token_ids, token_types = p_tokens + tokens, p_token_ids + token_ids, [1]*(len(p_tokens)) + [0]*len(tokens) 146 | re_sample = {'tokens': tokens, 'token_ids': token_ids, 'token_types': token_types, 147 | 'label': label} 148 | return re_sample 149 | 150 | def _get_head_tail(self, tokens, span_tokens): 151 | len_span = len(span_tokens) 152 | head, tail = -1, -1 153 | for i in range(len(tokens)-len_span+1): 154 | if tokens[i:i+len_span] == span_tokens: 155 | head, tail = i, i+len_span-1 156 | return head, tail 157 | 158 | def _padding(self, data, max_len): 159 | res = [] 160 | for seq in data: 161 | if len(seq) > max_len: 162 | res.append(seq[:max_len]) 163 | else: 164 | res.append(seq + [0]*(max_len-len(seq))) 165 | return res 166 | 167 | 168 | if __name__ == "__main__": 169 | dl = BatchLoader('/root/pretrain_model_weights/torch/chinese/chinese_wwm_ext_pytorch/') 170 | data = dl.load_data('data/train.json') 171 | ner_data, re_data= dl.build_data(data, True) 172 | bgs = dl.batch_loader(ner_data, re_data, is_train=True, batch_size=1) 173 | task_ids = [0]*10 + [1]*10 174 | shuffle(task_ids) 175 | print(len(bgs[1])) 176 | print(bgs[1]) 177 | iters = [iter(bg) for bg in bgs] 178 | for i in range(len(bgs[0])): 179 | sample = ner_data[i] 180 | print(sample) 181 | print(next(iters[0])) 182 | head, tail = sample['head'], sample['tail'] 183 | print(sample['tokens'][head: tail+1]) 184 | print('****') 185 | input() 186 | 187 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.1" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | 4 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 5 | BertForMaskedLM, BertForNextSentencePrediction, 6 | BertForSequenceClassification, BertForMultipleChoice, 7 | BertForTokenClassification, BertForQuestionAnswering, 8 | load_tf_weights_in_bert, BertPreTrainedModel) 9 | 10 | from .optimization import BertAdam 11 | 12 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME 13 | -------------------------------------------------------------------------------- /bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | import boto3 19 | import requests 20 | from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 31 | Path.home() / '.pytorch_pretrained_bert')) 32 | except (AttributeError, ImportError): 33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 35 | 36 | CONFIG_NAME = "config.json" 37 | WEIGHTS_NAME = "pytorch_model.bin" 38 | 39 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 40 | 41 | 42 | def url_to_filename(url, etag=None): 43 | """ 44 | Convert `url` into a hashed filename in a repeatable way. 45 | If `etag` is specified, append its hash to the url's, delimited 46 | by a period. 47 | """ 48 | url_bytes = url.encode('utf-8') 49 | url_hash = sha256(url_bytes) 50 | filename = url_hash.hexdigest() 51 | 52 | if etag: 53 | etag_bytes = etag.encode('utf-8') 54 | etag_hash = sha256(etag_bytes) 55 | filename += '.' + etag_hash.hexdigest() 56 | 57 | return filename 58 | 59 | 60 | def filename_to_url(filename, cache_dir=None): 61 | """ 62 | Return the url and etag (which may be ``None``) stored for `filename`. 63 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 64 | """ 65 | if cache_dir is None: 66 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 67 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 68 | cache_dir = str(cache_dir) 69 | 70 | cache_path = os.path.join(cache_dir, filename) 71 | if not os.path.exists(cache_path): 72 | raise EnvironmentError("file {} not found".format(cache_path)) 73 | 74 | meta_path = cache_path + '.json' 75 | if not os.path.exists(meta_path): 76 | raise EnvironmentError("file {} not found".format(meta_path)) 77 | 78 | with open(meta_path, encoding="utf-8") as meta_file: 79 | metadata = json.load(meta_file) 80 | url = metadata['url'] 81 | etag = metadata['etag'] 82 | 83 | return url, etag 84 | 85 | 86 | def cached_path(url_or_filename, cache_dir=None): 87 | """ 88 | Given something that might be a URL (or might be a local path), 89 | determine which. If it's a URL, download the file and cache it, and 90 | return the path to the cached file. If it's already a local path, 91 | make sure the file exists and then return the path. 92 | """ 93 | if cache_dir is None: 94 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 95 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 96 | url_or_filename = str(url_or_filename) 97 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 98 | cache_dir = str(cache_dir) 99 | 100 | parsed = urlparse(url_or_filename) 101 | 102 | if parsed.scheme in ('http', 'https', 's3'): 103 | # URL, so get it from the cache (downloading if necessary) 104 | return get_from_cache(url_or_filename, cache_dir) 105 | elif os.path.exists(url_or_filename): 106 | # File, and it exists. 107 | return url_or_filename 108 | elif parsed.scheme == '': 109 | # File, but it doesn't exist. 110 | raise EnvironmentError("file {} not found".format(url_or_filename)) 111 | else: 112 | # Something unknown 113 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 114 | 115 | 116 | def split_s3_path(url): 117 | """Split a full s3 path into the bucket name and path.""" 118 | parsed = urlparse(url) 119 | if not parsed.netloc or not parsed.path: 120 | raise ValueError("bad s3 path {}".format(url)) 121 | bucket_name = parsed.netloc 122 | s3_path = parsed.path 123 | # Remove '/' at beginning of path. 124 | if s3_path.startswith("/"): 125 | s3_path = s3_path[1:] 126 | return bucket_name, s3_path 127 | 128 | 129 | def s3_request(func): 130 | """ 131 | Wrapper function for s3 requests in order to create more helpful error 132 | messages. 133 | """ 134 | 135 | @wraps(func) 136 | def wrapper(url, *args, **kwargs): 137 | try: 138 | return func(url, *args, **kwargs) 139 | except ClientError as exc: 140 | if int(exc.response["Error"]["Code"]) == 404: 141 | raise EnvironmentError("file {} not found".format(url)) 142 | else: 143 | raise 144 | 145 | return wrapper 146 | 147 | 148 | @s3_request 149 | def s3_etag(url): 150 | """Check ETag on S3 object.""" 151 | s3_resource = boto3.resource("s3") 152 | bucket_name, s3_path = split_s3_path(url) 153 | s3_object = s3_resource.Object(bucket_name, s3_path) 154 | return s3_object.e_tag 155 | 156 | 157 | @s3_request 158 | def s3_get(url, temp_file): 159 | """Pull a file directly from S3.""" 160 | s3_resource = boto3.resource("s3") 161 | bucket_name, s3_path = split_s3_path(url) 162 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 163 | 164 | 165 | def http_get(url, temp_file): 166 | req = requests.get(url, stream=True) 167 | content_length = req.headers.get('Content-Length') 168 | total = int(content_length) if content_length is not None else None 169 | progress = tqdm(unit="B", total=total) 170 | for chunk in req.iter_content(chunk_size=1024): 171 | if chunk: # filter out keep-alive new chunks 172 | progress.update(len(chunk)) 173 | temp_file.write(chunk) 174 | progress.close() 175 | 176 | 177 | def get_from_cache(url, cache_dir=None): 178 | """ 179 | Given a URL, look for the corresponding dataset in the local cache. 180 | If it's not there, download it. Then return the path to the cached file. 181 | """ 182 | if cache_dir is None: 183 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 184 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 185 | cache_dir = str(cache_dir) 186 | 187 | if not os.path.exists(cache_dir): 188 | os.makedirs(cache_dir) 189 | 190 | # Get eTag to add to filename, if it exists. 191 | if url.startswith("s3://"): 192 | etag = s3_etag(url) 193 | else: 194 | response = requests.head(url, allow_redirects=True) 195 | if response.status_code != 200: 196 | raise IOError("HEAD request failed for url {} with status code {}" 197 | .format(url, response.status_code)) 198 | etag = response.headers.get("ETag") 199 | 200 | filename = url_to_filename(url, etag) 201 | 202 | # get cache path to put the file 203 | cache_path = os.path.join(cache_dir, filename) 204 | 205 | if not os.path.exists(cache_path): 206 | # Download to temporary file, then copy to cache dir once finished. 207 | # Otherwise you get corrupt cache entries if the download gets interrupted. 208 | with tempfile.NamedTemporaryFile() as temp_file: 209 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 210 | 211 | # GET file object 212 | if url.startswith("s3://"): 213 | s3_get(url, temp_file) 214 | else: 215 | http_get(url, temp_file) 216 | 217 | # we are copying the file before closing it, so flush to avoid truncation 218 | temp_file.flush() 219 | # shutil.copyfileobj() starts at the current position, so go to the start 220 | temp_file.seek(0) 221 | 222 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 223 | with open(cache_path, 'wb') as cache_file: 224 | shutil.copyfileobj(temp_file, cache_file) 225 | 226 | logger.info("creating metadata file for %s", cache_path) 227 | meta = {'url': url, 'etag': etag} 228 | meta_path = cache_path + '.json' 229 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 230 | json.dump(meta, meta_file) 231 | 232 | logger.info("removing temp file %s", temp_file.name) 233 | 234 | return cache_path 235 | 236 | 237 | def read_set_from_file(filename): 238 | ''' 239 | Extract a de-duped collection (set) of text from a file. 240 | Expected file format is one item per line. 241 | ''' 242 | collection = set() 243 | with open(filename, 'r', encoding='utf-8') as file_: 244 | for line in file_: 245 | collection.add(line.rstrip()) 246 | return collection 247 | 248 | 249 | def get_file_extension(path, dot=True, lower=True): 250 | ext = os.path.splitext(path)[1] 251 | ext = ext if dot else ext[1:] 252 | return ext.lower() if lower else ext 253 | -------------------------------------------------------------------------------- /bert/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import copy 21 | import json 22 | import logging 23 | import math 24 | import os 25 | import shutil 26 | import tarfile 27 | import tempfile 28 | import sys 29 | from io import open 30 | 31 | import torch 32 | from torch import nn 33 | from torch.nn import CrossEntropyLoss 34 | 35 | from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | PRETRAINED_MODEL_ARCHIVE_MAP = { 40 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 41 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 42 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 43 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 44 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 45 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 46 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 47 | } 48 | BERT_CONFIG_NAME = 'bert_config.json' 49 | TF_WEIGHTS_NAME = 'model.ckpt' 50 | 51 | def load_tf_weights_in_bert(model, tf_checkpoint_path): 52 | """ Load tf checkpoints in a pytorch model 53 | """ 54 | try: 55 | import re 56 | import numpy as np 57 | import tensorflow as tf 58 | except ImportError: 59 | print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " 60 | "https://www.tensorflow.org/install/ for installation instructions.") 61 | raise 62 | tf_path = os.path.abspath(tf_checkpoint_path) 63 | print("Converting TensorFlow checkpoint from {}".format(tf_path)) 64 | # Load weights from TF model 65 | init_vars = tf.train.list_variables(tf_path) 66 | names = [] 67 | arrays = [] 68 | for name, shape in init_vars: 69 | print("Loading TF weight {} with shape {}".format(name, shape)) 70 | array = tf.train.load_variable(tf_path, name) 71 | names.append(name) 72 | arrays.append(array) 73 | 74 | for name, array in zip(names, arrays): 75 | name = name.split('/') 76 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 77 | # which are not required for using pretrained model 78 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name): 79 | print("Skipping {}".format("/".join(name))) 80 | continue 81 | pointer = model 82 | for m_name in name: 83 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 84 | l = re.split(r'_(\d+)', m_name) 85 | else: 86 | l = [m_name] 87 | if l[0] == 'kernel' or l[0] == 'gamma': 88 | pointer = getattr(pointer, 'weight') 89 | elif l[0] == 'output_bias' or l[0] == 'beta': 90 | pointer = getattr(pointer, 'bias') 91 | elif l[0] == 'output_weights': 92 | pointer = getattr(pointer, 'weight') 93 | elif l[0] == 'squad': 94 | pointer = getattr(pointer, 'classifier') 95 | else: 96 | try: 97 | pointer = getattr(pointer, l[0]) 98 | except AttributeError: 99 | print("Skipping {}".format("/".join(name))) 100 | continue 101 | if len(l) >= 2: 102 | num = int(l[1]) 103 | pointer = pointer[num] 104 | if m_name[-11:] == '_embeddings': 105 | pointer = getattr(pointer, 'weight') 106 | elif m_name == 'kernel': 107 | array = np.transpose(array) 108 | try: 109 | assert pointer.shape == array.shape 110 | except AssertionError as e: 111 | e.args += (pointer.shape, array.shape) 112 | raise 113 | print("Initialize PyTorch weight {}".format(name)) 114 | pointer.data = torch.from_numpy(array) 115 | return model 116 | 117 | 118 | def gelu(x): 119 | """Implementation of the gelu activation function. 120 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 121 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 122 | Also see https://arxiv.org/abs/1606.08415 123 | """ 124 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 125 | 126 | 127 | def swish(x): 128 | return x * torch.sigmoid(x) 129 | 130 | 131 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 132 | 133 | 134 | class BertConfig(object): 135 | """Configuration class to store the configuration of a `BertModel`. 136 | """ 137 | def __init__(self, 138 | vocab_size_or_config_json_file, 139 | hidden_size=768, 140 | num_hidden_layers=12, 141 | num_attention_heads=12, 142 | intermediate_size=3072, 143 | hidden_act="gelu", 144 | hidden_dropout_prob=0.1, 145 | attention_probs_dropout_prob=0.1, 146 | max_position_embeddings=512, 147 | type_vocab_size=2, 148 | initializer_range=0.02): 149 | """Constructs BertConfig. 150 | 151 | Args: 152 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 153 | hidden_size: Size of the encoder layers and the pooler layer. 154 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 155 | num_attention_heads: Number of attention heads for each attention layer in 156 | the Transformer encoder. 157 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 158 | layer in the Transformer encoder. 159 | hidden_act: The non-linear activation function (function or string) in the 160 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 161 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 162 | layers in the embeddings, encoder, and pooler. 163 | attention_probs_dropout_prob: The dropout ratio for the attention 164 | probabilities. 165 | max_position_embeddings: The maximum sequence length that this model might 166 | ever be used with. Typically set this to something large just in case 167 | (e.g., 512 or 1024 or 2048). 168 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 169 | `BertModel`. 170 | initializer_range: The sttdev of the truncated_normal_initializer for 171 | initializing all weight matrices. 172 | """ 173 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 174 | and isinstance(vocab_size_or_config_json_file, unicode)): 175 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 176 | json_config = json.loads(reader.read()) 177 | for key, value in json_config.items(): 178 | self.__dict__[key] = value 179 | elif isinstance(vocab_size_or_config_json_file, int): 180 | self.vocab_size = vocab_size_or_config_json_file 181 | self.hidden_size = hidden_size 182 | self.num_hidden_layers = num_hidden_layers 183 | self.num_attention_heads = num_attention_heads 184 | self.hidden_act = hidden_act 185 | self.intermediate_size = intermediate_size 186 | self.hidden_dropout_prob = hidden_dropout_prob 187 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 188 | self.max_position_embeddings = max_position_embeddings 189 | self.type_vocab_size = type_vocab_size 190 | self.initializer_range = initializer_range 191 | else: 192 | raise ValueError("First argument must be either a vocabulary size (int)" 193 | "or the path to a pretrained model config file (str)") 194 | 195 | @classmethod 196 | def from_dict(cls, json_object): 197 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 198 | config = BertConfig(vocab_size_or_config_json_file=-1) 199 | for key, value in json_object.items(): 200 | config.__dict__[key] = value 201 | return config 202 | 203 | @classmethod 204 | def from_json_file(cls, json_file): 205 | """Constructs a `BertConfig` from a json file of parameters.""" 206 | with open(json_file, "r", encoding='utf-8') as reader: 207 | text = reader.read() 208 | return cls.from_dict(json.loads(text)) 209 | 210 | def __repr__(self): 211 | return str(self.to_json_string()) 212 | 213 | def to_dict(self): 214 | """Serializes this instance to a Python dictionary.""" 215 | output = copy.deepcopy(self.__dict__) 216 | return output 217 | 218 | def to_json_string(self): 219 | """Serializes this instance to a JSON string.""" 220 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 221 | 222 | def to_json_file(self, json_file_path): 223 | """ Save this instance to a json file.""" 224 | with open(json_file_path, "w", encoding='utf-8') as writer: 225 | writer.write(self.to_json_string()) 226 | 227 | # try: 228 | # from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 229 | # except ImportError: 230 | # logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") 231 | class BertLayerNorm(nn.Module): 232 | # eps = 1e-12 original 233 | def __init__(self, hidden_size, eps=1e-12): 234 | """Construct a layernorm module in the TF style (epsilon inside the square root). 235 | """ 236 | super(BertLayerNorm, self).__init__() 237 | self.weight = nn.Parameter(torch.ones(hidden_size)) 238 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 239 | self.variance_epsilon = eps 240 | 241 | def forward(self, x): 242 | u = x.mean(-1, keepdim=True) 243 | s = (x - u).pow(2).mean(-1, keepdim=True) 244 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 245 | return self.weight * x + self.bias 246 | 247 | class BertEmbeddings(nn.Module): 248 | """Construct the embeddings from word, position and token_type embeddings. 249 | """ 250 | def __init__(self, config): 251 | super(BertEmbeddings, self).__init__() 252 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 253 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 254 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 255 | 256 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 257 | # any TensorFlow checkpoint file 258 | self.LayerNorm = BertLayerNorm(config.hidden_size) 259 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 260 | 261 | def forward(self, input_ids, token_type_ids=None): 262 | seq_length = input_ids.size(1) 263 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 264 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 265 | if token_type_ids is None: 266 | token_type_ids = torch.zeros_like(input_ids) 267 | 268 | words_embeddings = self.word_embeddings(input_ids) 269 | position_embeddings = self.position_embeddings(position_ids) 270 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 271 | 272 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 273 | embeddings = self.LayerNorm(embeddings) 274 | embeddings = self.dropout(embeddings) 275 | return embeddings 276 | 277 | 278 | class BertSelfAttention(nn.Module): 279 | def __init__(self, config): 280 | super(BertSelfAttention, self).__init__() 281 | if config.hidden_size % config.num_attention_heads != 0: 282 | raise ValueError( 283 | "The hidden size (%d) is not a multiple of the number of attention " 284 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 285 | self.num_attention_heads = config.num_attention_heads 286 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 287 | self.all_head_size = self.num_attention_heads * self.attention_head_size 288 | 289 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 290 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 291 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 292 | 293 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 294 | 295 | def transpose_for_scores(self, x): 296 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 297 | x = x.view(*new_x_shape) 298 | return x.permute(0, 2, 1, 3) 299 | 300 | def forward(self, hidden_states, attention_mask): 301 | mixed_query_layer = self.query(hidden_states) 302 | mixed_key_layer = self.key(hidden_states) 303 | mixed_value_layer = self.value(hidden_states) 304 | 305 | query_layer = self.transpose_for_scores(mixed_query_layer) 306 | key_layer = self.transpose_for_scores(mixed_key_layer) 307 | value_layer = self.transpose_for_scores(mixed_value_layer) 308 | 309 | # Take the dot product between "query" and "key" to get the raw attention scores. 310 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 311 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 312 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 313 | attention_scores = attention_scores + attention_mask 314 | 315 | # Normalize the attention scores to probabilities. 316 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 317 | 318 | # This is actually dropping out entire tokens to attend to, which might 319 | # seem a bit unusual, but is taken from the original Transformer paper. 320 | attention_probs = self.dropout(attention_probs) 321 | 322 | context_layer = torch.matmul(attention_probs, value_layer) 323 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 324 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 325 | context_layer = context_layer.view(*new_context_layer_shape) 326 | return context_layer 327 | 328 | 329 | class BertSelfOutput(nn.Module): 330 | def __init__(self, config): 331 | super(BertSelfOutput, self).__init__() 332 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 333 | self.LayerNorm = BertLayerNorm(config.hidden_size) 334 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 335 | 336 | def forward(self, hidden_states, input_tensor): 337 | hidden_states = self.dense(hidden_states) 338 | hidden_states = self.dropout(hidden_states) 339 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 340 | return hidden_states 341 | 342 | 343 | class BertAttention(nn.Module): 344 | def __init__(self, config): 345 | super(BertAttention, self).__init__() 346 | self.self = BertSelfAttention(config) 347 | self.output = BertSelfOutput(config) 348 | 349 | def forward(self, input_tensor, attention_mask): 350 | self_output = self.self(input_tensor, attention_mask) 351 | attention_output = self.output(self_output, input_tensor) 352 | return attention_output 353 | 354 | 355 | class BertIntermediate(nn.Module): 356 | def __init__(self, config): 357 | super(BertIntermediate, self).__init__() 358 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 359 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 360 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 361 | else: 362 | self.intermediate_act_fn = config.hidden_act 363 | 364 | def forward(self, hidden_states): 365 | hidden_states = self.dense(hidden_states) 366 | hidden_states = self.intermediate_act_fn(hidden_states) 367 | return hidden_states 368 | 369 | 370 | class BertOutput(nn.Module): 371 | def __init__(self, config): 372 | super(BertOutput, self).__init__() 373 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 374 | self.LayerNorm = BertLayerNorm(config.hidden_size) 375 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 376 | 377 | def forward(self, hidden_states, input_tensor): 378 | hidden_states = self.dense(hidden_states) 379 | hidden_states = self.dropout(hidden_states) 380 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 381 | return hidden_states 382 | 383 | 384 | class BertLayer(nn.Module): 385 | def __init__(self, config): 386 | super(BertLayer, self).__init__() 387 | self.attention = BertAttention(config) 388 | self.intermediate = BertIntermediate(config) 389 | self.output = BertOutput(config) 390 | 391 | def forward(self, hidden_states, attention_mask): 392 | attention_output = self.attention(hidden_states, attention_mask) 393 | intermediate_output = self.intermediate(attention_output) 394 | layer_output = self.output(intermediate_output, attention_output) 395 | return layer_output 396 | 397 | 398 | class BertEncoder(nn.Module): 399 | def __init__(self, config): 400 | super(BertEncoder, self).__init__() 401 | layer = BertLayer(config) 402 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 403 | 404 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): 405 | all_encoder_layers = [] 406 | for layer_module in self.layer: 407 | hidden_states = layer_module(hidden_states, attention_mask) 408 | if output_all_encoded_layers: 409 | all_encoder_layers.append(hidden_states) 410 | if not output_all_encoded_layers: 411 | all_encoder_layers.append(hidden_states) 412 | return all_encoder_layers 413 | 414 | 415 | class BertPooler(nn.Module): 416 | def __init__(self, config): 417 | super(BertPooler, self).__init__() 418 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 419 | self.activation = nn.Tanh() 420 | 421 | def forward(self, hidden_states): 422 | # We "pool" the model by simply taking the hidden state corresponding 423 | # to the first token. 424 | first_token_tensor = hidden_states[:, 0] 425 | pooled_output = self.dense(first_token_tensor) 426 | pooled_output = self.activation(pooled_output) 427 | return pooled_output 428 | 429 | 430 | class BertPredictionHeadTransform(nn.Module): 431 | def __init__(self, config): 432 | super(BertPredictionHeadTransform, self).__init__() 433 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 434 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 435 | self.transform_act_fn = ACT2FN[config.hidden_act] 436 | else: 437 | self.transform_act_fn = config.hidden_act 438 | self.LayerNorm = BertLayerNorm(config.hidden_size) 439 | 440 | def forward(self, hidden_states): 441 | hidden_states = self.dense(hidden_states) 442 | hidden_states = self.transform_act_fn(hidden_states) 443 | hidden_states = self.LayerNorm(hidden_states) 444 | return hidden_states 445 | 446 | 447 | class BertLMPredictionHead(nn.Module): 448 | def __init__(self, config, bert_model_embedding_weights): 449 | super(BertLMPredictionHead, self).__init__() 450 | self.transform = BertPredictionHeadTransform(config) 451 | 452 | # The output weights are the same as the input embeddings, but there is 453 | # an output-only bias for each token. 454 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 455 | bert_model_embedding_weights.size(0), 456 | bias=False) 457 | self.decoder.weight = bert_model_embedding_weights 458 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) 459 | 460 | def forward(self, hidden_states): 461 | hidden_states = self.transform(hidden_states) 462 | hidden_states = self.decoder(hidden_states) + self.bias 463 | return hidden_states 464 | 465 | 466 | class BertOnlyMLMHead(nn.Module): 467 | def __init__(self, config, bert_model_embedding_weights): 468 | super(BertOnlyMLMHead, self).__init__() 469 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 470 | 471 | def forward(self, sequence_output): 472 | prediction_scores = self.predictions(sequence_output) 473 | return prediction_scores 474 | 475 | 476 | class BertOnlyNSPHead(nn.Module): 477 | def __init__(self, config): 478 | super(BertOnlyNSPHead, self).__init__() 479 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 480 | 481 | def forward(self, pooled_output): 482 | seq_relationship_score = self.seq_relationship(pooled_output) 483 | return seq_relationship_score 484 | 485 | 486 | class BertPreTrainingHeads(nn.Module): 487 | def __init__(self, config, bert_model_embedding_weights): 488 | super(BertPreTrainingHeads, self).__init__() 489 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 490 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 491 | 492 | def forward(self, sequence_output, pooled_output): 493 | prediction_scores = self.predictions(sequence_output) 494 | seq_relationship_score = self.seq_relationship(pooled_output) 495 | return prediction_scores, seq_relationship_score 496 | 497 | 498 | class BertPreTrainedModel(nn.Module): 499 | """ An abstract class to handle weights initialization and 500 | a simple interface for dowloading and loading pretrained models. 501 | """ 502 | def __init__(self, config, *inputs, **kwargs): 503 | super(BertPreTrainedModel, self).__init__() 504 | if not isinstance(config, BertConfig): 505 | raise ValueError( 506 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 507 | "To create a model from a Google pretrained model use " 508 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 509 | self.__class__.__name__, self.__class__.__name__ 510 | )) 511 | self.config = config 512 | 513 | def init_bert_weights(self, module): 514 | """ Initialize the weights. 515 | """ 516 | if isinstance(module, (nn.Linear, nn.Embedding)): 517 | # Slightly different from the TF version which uses truncated_normal for initialization 518 | # cf https://github.com/pytorch/pytorch/pull/5617 519 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 520 | elif isinstance(module, BertLayerNorm): 521 | module.bias.data.zero_() 522 | module.weight.data.fill_(1.0) 523 | if isinstance(module, nn.Linear) and module.bias is not None: 524 | module.bias.data.zero_() 525 | 526 | @classmethod 527 | def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, 528 | from_tf=False, *inputs, **kwargs): 529 | """ 530 | Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. 531 | Download and cache the pre-trained model file if needed. 532 | 533 | Params: 534 | pretrained_model_name_or_path: either: 535 | - a str with the name of a pre-trained model to load selected in the list of: 536 | . `bert-base-uncased` 537 | . `bert-large-uncased` 538 | . `bert-base-cased` 539 | . `bert-large-cased` 540 | . `bert-base-multilingual-uncased` 541 | . `bert-base-multilingual-cased` 542 | . `bert-base-chinese` 543 | - a path or url to a pretrained model archive containing: 544 | . `bert_config.json` a configuration file for the model 545 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance 546 | - a path or url to a pretrained model archive containing: 547 | . `bert_config.json` a configuration file for the model 548 | . `model.chkpt` a TensorFlow checkpoint 549 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 550 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 551 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models 552 | *inputs, **kwargs: additional input for the specific Bert class 553 | (ex: num_labels for BertForSequenceClassification) 554 | """ 555 | if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: 556 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] 557 | else: 558 | archive_file = pretrained_model_name_or_path 559 | # redirect to the cache, if necessary 560 | try: 561 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 562 | except EnvironmentError: 563 | logger.error( 564 | "Model name '{}' was not found in model name list ({}). " 565 | "We assumed '{}' was a path or url but couldn't find any file " 566 | "associated to this path or url.".format( 567 | pretrained_model_name_or_path, 568 | ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), 569 | archive_file)) 570 | return None 571 | if resolved_archive_file == archive_file: 572 | logger.info("loading archive file {}".format(archive_file)) 573 | else: 574 | logger.info("loading archive file {} from cache at {}".format( 575 | archive_file, resolved_archive_file)) 576 | tempdir = None 577 | if os.path.isdir(resolved_archive_file) or from_tf: 578 | serialization_dir = resolved_archive_file 579 | else: 580 | # Extract archive to temp dir 581 | tempdir = tempfile.mkdtemp() 582 | logger.info("extracting archive file {} to temp dir {}".format( 583 | resolved_archive_file, tempdir)) 584 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 585 | archive.extractall(tempdir) 586 | serialization_dir = tempdir 587 | # Load config 588 | config_file = os.path.join(serialization_dir, CONFIG_NAME) 589 | if not os.path.exists(config_file): 590 | # Backward compatibility with old naming format 591 | config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME) 592 | config = BertConfig.from_json_file(config_file) 593 | logger.info("Model config {}".format(config)) 594 | # Instantiate model. 595 | model = cls(config, *inputs, **kwargs) 596 | if state_dict is None and not from_tf: 597 | weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) 598 | state_dict = torch.load(weights_path, map_location='cpu') 599 | if tempdir: 600 | # Clean up temp dir 601 | shutil.rmtree(tempdir) 602 | if from_tf: 603 | # Directly load from a TensorFlow checkpoint 604 | weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) 605 | return load_tf_weights_in_bert(model, weights_path) 606 | # Load from a PyTorch state_dict 607 | old_keys = [] 608 | new_keys = [] 609 | for key in state_dict.keys(): 610 | new_key = None 611 | if 'gamma' in key: 612 | new_key = key.replace('gamma', 'weight') 613 | if 'beta' in key: 614 | new_key = key.replace('beta', 'bias') 615 | if new_key: 616 | old_keys.append(key) 617 | new_keys.append(new_key) 618 | for old_key, new_key in zip(old_keys, new_keys): 619 | state_dict[new_key] = state_dict.pop(old_key) 620 | 621 | missing_keys = [] 622 | unexpected_keys = [] 623 | error_msgs = [] 624 | # copy state_dict so _load_from_state_dict can modify it 625 | metadata = getattr(state_dict, '_metadata', None) 626 | state_dict = state_dict.copy() 627 | if metadata is not None: 628 | state_dict._metadata = metadata 629 | 630 | def load(module, prefix=''): 631 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 632 | module._load_from_state_dict( 633 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 634 | for name, child in module._modules.items(): 635 | if child is not None: 636 | load(child, prefix + name + '.') 637 | start_prefix = '' 638 | if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): 639 | start_prefix = 'bert.' 640 | load(model, prefix=start_prefix) 641 | if len(missing_keys) > 0: 642 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 643 | model.__class__.__name__, missing_keys)) 644 | if len(unexpected_keys) > 0: 645 | logger.info("Weights from pretrained model not used in {}: {}".format( 646 | model.__class__.__name__, unexpected_keys)) 647 | if len(error_msgs) > 0: 648 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 649 | model.__class__.__name__, "\n\t".join(error_msgs))) 650 | return model 651 | 652 | 653 | class BertModel(BertPreTrainedModel): 654 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 655 | 656 | Params: 657 | config: a BertConfig class instance with the configuration to build a new model 658 | 659 | Inputs: 660 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 661 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 662 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 663 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 664 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 665 | a `sentence B` token (see BERT paper for more details). 666 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 667 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 668 | input sequence length in the current batch. It's the mask that we typically use for attention when 669 | a batch has varying length sentences. 670 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 671 | 672 | Outputs: Tuple of (encoded_layers, pooled_output) 673 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 674 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 675 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 676 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 677 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 678 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 679 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 680 | classifier pretrained on top of the hidden state associated to the first character of the 681 | input (`CLS`) to train on the Next-Sentence task (see BERT's paper). 682 | 683 | Example usage: 684 | ```python 685 | # Already been converted into WordPiece token ids 686 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 687 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 688 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 689 | 690 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 691 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 692 | 693 | model = modeling.BertModel(config=config) 694 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 695 | ``` 696 | """ 697 | def __init__(self, config): 698 | super(BertModel, self).__init__(config) 699 | self.embeddings = BertEmbeddings(config) 700 | self.encoder = BertEncoder(config) 701 | self.pooler = BertPooler(config) 702 | self.apply(self.init_bert_weights) 703 | 704 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): 705 | if attention_mask is None: 706 | attention_mask = torch.ones_like(input_ids) 707 | if token_type_ids is None: 708 | token_type_ids = torch.zeros_like(input_ids) 709 | 710 | # We create a 3D attention mask from a 2D tensor mask. 711 | # Sizes are [batch_size, 1, 1, to_seq_length] 712 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 713 | # this attention mask is more simple than the triangular masking of causal attention 714 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 715 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 716 | 717 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 718 | # masked positions, this operation will create a tensor which is 0.0 for 719 | # positions we want to attend and -10000.0 for masked positions. 720 | # Since we are adding it to the raw scores before the softmax, this is 721 | # effectively the same as removing these entirely. 722 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 723 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 724 | 725 | embedding_output = self.embeddings(input_ids, token_type_ids) 726 | encoded_layers = self.encoder(embedding_output, 727 | extended_attention_mask, 728 | output_all_encoded_layers=output_all_encoded_layers) 729 | sequence_output = encoded_layers[-1] 730 | pooled_output = self.pooler(sequence_output) 731 | if not output_all_encoded_layers: 732 | encoded_layers = encoded_layers[-1] 733 | return encoded_layers, pooled_output 734 | 735 | 736 | class BertForPreTraining(BertPreTrainedModel): 737 | """BERT model with pre-training heads. 738 | This module comprises the BERT model followed by the two pre-training heads: 739 | - the masked language modeling head, and 740 | - the next sentence classification head. 741 | 742 | Params: 743 | config: a BertConfig class instance with the configuration to build a new model. 744 | 745 | Inputs: 746 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 747 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 748 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 749 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 750 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 751 | a `sentence B` token (see BERT paper for more details). 752 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 753 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 754 | input sequence length in the current batch. It's the mask that we typically use for attention when 755 | a batch has varying length sentences. 756 | `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 757 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 758 | is only computed for the labels set in [0, ..., vocab_size] 759 | `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] 760 | with indices selected in [0, 1]. 761 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 762 | 763 | Outputs: 764 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 765 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 766 | sentence classification loss. 767 | if `masked_lm_labels` or `next_sentence_label` is `None`: 768 | Outputs a tuple comprising 769 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 770 | - the next sentence classification logits of shape [batch_size, 2]. 771 | 772 | Example usage: 773 | ```python 774 | # Already been converted into WordPiece token ids 775 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 776 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 777 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 778 | 779 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 780 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 781 | 782 | model = BertForPreTraining(config) 783 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 784 | ``` 785 | """ 786 | def __init__(self, config): 787 | super(BertForPreTraining, self).__init__(config) 788 | self.bert = BertModel(config) 789 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) 790 | self.apply(self.init_bert_weights) 791 | 792 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): 793 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 794 | output_all_encoded_layers=False) 795 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 796 | 797 | if masked_lm_labels is not None and next_sentence_label is not None: 798 | loss_fct = CrossEntropyLoss(ignore_index=-1) 799 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 800 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 801 | total_loss = masked_lm_loss + next_sentence_loss 802 | return total_loss 803 | else: 804 | return prediction_scores, seq_relationship_score 805 | 806 | 807 | class BertForMaskedLM(BertPreTrainedModel): 808 | """BERT model with the masked language modeling head. 809 | This module comprises the BERT model followed by the masked language modeling head. 810 | 811 | Params: 812 | config: a BertConfig class instance with the configuration to build a new model. 813 | 814 | Inputs: 815 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 816 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 817 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 818 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 819 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 820 | a `sentence B` token (see BERT paper for more details). 821 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 822 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 823 | input sequence length in the current batch. It's the mask that we typically use for attention when 824 | a batch has varying length sentences. 825 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 826 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 827 | is only computed for the labels set in [0, ..., vocab_size] 828 | 829 | Outputs: 830 | if `masked_lm_labels` is not `None`: 831 | Outputs the masked language modeling loss. 832 | if `masked_lm_labels` is `None`: 833 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 834 | 835 | Example usage: 836 | ```python 837 | # Already been converted into WordPiece token ids 838 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 839 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 840 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 841 | 842 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 843 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 844 | 845 | model = BertForMaskedLM(config) 846 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 847 | ``` 848 | """ 849 | def __init__(self, config): 850 | super(BertForMaskedLM, self).__init__(config) 851 | self.bert = BertModel(config) 852 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 853 | self.apply(self.init_bert_weights) 854 | 855 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): 856 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, 857 | output_all_encoded_layers=False) 858 | prediction_scores = self.cls(sequence_output) 859 | 860 | if masked_lm_labels is not None: 861 | loss_fct = CrossEntropyLoss(ignore_index=-1) 862 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 863 | return masked_lm_loss 864 | else: 865 | return prediction_scores 866 | 867 | 868 | class BertForNextSentencePrediction(BertPreTrainedModel): 869 | """BERT model with next sentence prediction head. 870 | This module comprises the BERT model followed by the next sentence classification head. 871 | 872 | Params: 873 | config: a BertConfig class instance with the configuration to build a new model. 874 | 875 | Inputs: 876 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 877 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 878 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 879 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 880 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 881 | a `sentence B` token (see BERT paper for more details). 882 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 883 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 884 | input sequence length in the current batch. It's the mask that we typically use for attention when 885 | a batch has varying length sentences. 886 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 887 | with indices selected in [0, 1]. 888 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 889 | 890 | Outputs: 891 | if `next_sentence_label` is not `None`: 892 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 893 | sentence classification loss. 894 | if `next_sentence_label` is `None`: 895 | Outputs the next sentence classification logits of shape [batch_size, 2]. 896 | 897 | Example usage: 898 | ```python 899 | # Already been converted into WordPiece token ids 900 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 901 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 902 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 903 | 904 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 905 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 906 | 907 | model = BertForNextSentencePrediction(config) 908 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 909 | ``` 910 | """ 911 | def __init__(self, config): 912 | super(BertForNextSentencePrediction, self).__init__(config) 913 | self.bert = BertModel(config) 914 | self.cls = BertOnlyNSPHead(config) 915 | self.apply(self.init_bert_weights) 916 | 917 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): 918 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 919 | output_all_encoded_layers=False) 920 | seq_relationship_score = self.cls( pooled_output) 921 | 922 | if next_sentence_label is not None: 923 | loss_fct = CrossEntropyLoss(ignore_index=-1) 924 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 925 | return next_sentence_loss 926 | else: 927 | return seq_relationship_score 928 | 929 | 930 | class BertForSequenceClassification(BertPreTrainedModel): 931 | """BERT model for classification. 932 | This module is composed of the BERT model with a linear layer on top of 933 | the pooled output. 934 | 935 | Params: 936 | `config`: a BertConfig class instance with the configuration to build a new model. 937 | `num_labels`: the number of classes for the classifier. Default = 2. 938 | 939 | Inputs: 940 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 941 | with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts 942 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 943 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 944 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 945 | a `sentence B` token (see BERT paper for more details). 946 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 947 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 948 | input sequence length in the current batch. It's the mask that we typically use for attention when 949 | a batch has varying length sentences. 950 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 951 | with indices selected in [0, ..., num_labels]. 952 | 953 | Outputs: 954 | if `labels` is not `None`: 955 | Outputs the CrossEntropy classification loss of the output with the labels. 956 | if `labels` is `None`: 957 | Outputs the classification logits of shape [batch_size, num_labels]. 958 | 959 | Example usage: 960 | ```python 961 | # Already been converted into WordPiece token ids 962 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 963 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 964 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 965 | 966 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 967 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 968 | 969 | num_labels = 2 970 | 971 | model = BertForSequenceClassification(config, num_labels) 972 | logits = model(input_ids, token_type_ids, input_mask) 973 | ``` 974 | """ 975 | def __init__(self, config, num_labels): 976 | super(BertForSequenceClassification, self).__init__(config) 977 | self.num_labels = num_labels 978 | self.bert = BertModel(config) 979 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 980 | self.classifier = nn.Linear(config.hidden_size, num_labels) 981 | self.apply(self.init_bert_weights) 982 | 983 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 984 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 985 | pooled_output = self.dropout(pooled_output) 986 | logits = self.classifier(pooled_output) 987 | 988 | if labels is not None: 989 | loss_fct = CrossEntropyLoss() 990 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 991 | return loss 992 | else: 993 | return logits 994 | 995 | 996 | class BertForMultipleChoice(BertPreTrainedModel): 997 | """BERT model for multiple choice tasks. 998 | This module is composed of the BERT model with a linear layer on top of 999 | the pooled output. 1000 | 1001 | Params: 1002 | `config`: a BertConfig class instance with the configuration to build a new model. 1003 | `num_choices`: the number of classes for the classifier. Default = 2. 1004 | 1005 | Inputs: 1006 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1007 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1008 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1009 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1010 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 1011 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 1012 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 1013 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1014 | input sequence length in the current batch. It's the mask that we typically use for attention when 1015 | a batch has varying length sentences. 1016 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 1017 | with indices selected in [0, ..., num_choices]. 1018 | 1019 | Outputs: 1020 | if `labels` is not `None`: 1021 | Outputs the CrossEntropy classification loss of the output with the labels. 1022 | if `labels` is `None`: 1023 | Outputs the classification logits of shape [batch_size, num_labels]. 1024 | 1025 | Example usage: 1026 | ```python 1027 | # Already been converted into WordPiece token ids 1028 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 1029 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 1030 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 1031 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1032 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1033 | 1034 | num_choices = 2 1035 | 1036 | model = BertForMultipleChoice(config, num_choices) 1037 | logits = model(input_ids, token_type_ids, input_mask) 1038 | ``` 1039 | """ 1040 | def __init__(self, config, num_choices): 1041 | super(BertForMultipleChoice, self).__init__(config) 1042 | self.num_choices = num_choices 1043 | self.bert = BertModel(config) 1044 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1045 | self.classifier = nn.Linear(config.hidden_size, 1) 1046 | self.apply(self.init_bert_weights) 1047 | 1048 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 1049 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 1050 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 1051 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) 1052 | _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) 1053 | pooled_output = self.dropout(pooled_output) 1054 | logits = self.classifier(pooled_output) 1055 | reshaped_logits = logits.view(-1, self.num_choices) 1056 | 1057 | if labels is not None: 1058 | loss_fct = CrossEntropyLoss() 1059 | loss = loss_fct(reshaped_logits, labels) 1060 | return loss 1061 | else: 1062 | return reshaped_logits 1063 | 1064 | 1065 | class BertForTokenClassification(BertPreTrainedModel): 1066 | """BERT model for token-level classification. 1067 | This module is composed of the BERT model with a linear layer on top of 1068 | the full hidden state of the last layer. 1069 | 1070 | Params: 1071 | `config`: a BertConfig class instance with the configuration to build a new model. 1072 | `num_labels`: the number of classes for the classifier. Default = 2. 1073 | 1074 | Inputs: 1075 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1076 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1077 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1078 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1079 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1080 | a `sentence B` token (see BERT paper for more details). 1081 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1082 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1083 | input sequence length in the current batch. It's the mask that we typically use for attention when 1084 | a batch has varying length sentences. 1085 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] 1086 | with indices selected in [0, ..., num_labels]. 1087 | 1088 | Outputs: 1089 | if `labels` is not `None`: 1090 | Outputs the CrossEntropy classification loss of the output with the labels. 1091 | if `labels` is `None`: 1092 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. 1093 | 1094 | Example usage: 1095 | ```python 1096 | # Already been converted into WordPiece token ids 1097 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1098 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1099 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1100 | 1101 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1102 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1103 | 1104 | num_labels = 2 1105 | 1106 | model = BertForTokenClassification(config, num_labels) 1107 | logits = model(input_ids, token_type_ids, input_mask) 1108 | ``` 1109 | """ 1110 | def __init__(self, config, num_labels): 1111 | super(BertForTokenClassification, self).__init__(config) 1112 | self.num_labels = num_labels 1113 | self.bert = BertModel(config) 1114 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1115 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1116 | self.apply(self.init_bert_weights) 1117 | 1118 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 1119 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1120 | sequence_output = self.dropout(sequence_output) 1121 | logits = self.classifier(sequence_output) 1122 | 1123 | if labels is not None: 1124 | loss_fct = CrossEntropyLoss() 1125 | # Only keep active parts of the loss 1126 | if attention_mask is not None: 1127 | active_loss = attention_mask.view(-1) == 1 1128 | active_logits = logits.view(-1, self.num_labels)[active_loss] 1129 | active_labels = labels.view(-1)[active_loss] 1130 | loss = loss_fct(active_logits, active_labels) 1131 | else: 1132 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1133 | return loss 1134 | else: 1135 | return logits 1136 | 1137 | 1138 | class BertForQuestionAnswering(BertPreTrainedModel): 1139 | """BERT model for Question Answering (span extraction). 1140 | This module is composed of the BERT model with a linear layer on top of 1141 | the sequence output that computes start_logits and end_logits 1142 | 1143 | Params: 1144 | `config`: a BertConfig class instance with the configuration to build a new model. 1145 | 1146 | Inputs: 1147 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1148 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1149 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1150 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1151 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1152 | a `sentence B` token (see BERT paper for more details). 1153 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1154 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1155 | input sequence length in the current batch. It's the mask that we typically use for attention when 1156 | a batch has varying length sentences. 1157 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 1158 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1159 | into account for computing the loss. 1160 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 1161 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1162 | into account for computing the loss. 1163 | 1164 | Outputs: 1165 | if `start_positions` and `end_positions` are not `None`: 1166 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 1167 | if `start_positions` or `end_positions` is `None`: 1168 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 1169 | position tokens of shape [batch_size, sequence_length]. 1170 | 1171 | Example usage: 1172 | ```python 1173 | # Already been converted into WordPiece token ids 1174 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1175 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1176 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1177 | 1178 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1179 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1180 | 1181 | model = BertForQuestionAnswering(config) 1182 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 1183 | ``` 1184 | """ 1185 | def __init__(self, config): 1186 | super(BertForQuestionAnswering, self).__init__(config) 1187 | self.bert = BertModel(config) 1188 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 1189 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 1190 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 1191 | self.apply(self.init_bert_weights) 1192 | 1193 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): 1194 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1195 | logits = self.qa_outputs(sequence_output) 1196 | start_logits, end_logits = logits.split(1, dim=-1) 1197 | start_logits = start_logits.squeeze(-1) 1198 | end_logits = end_logits.squeeze(-1) 1199 | 1200 | if start_positions is not None and end_positions is not None: 1201 | # If we are on multi-GPU, split add a dimension 1202 | if len(start_positions.size()) > 1: 1203 | start_positions = start_positions.squeeze(-1) 1204 | if len(end_positions.size()) > 1: 1205 | end_positions = end_positions.squeeze(-1) 1206 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1207 | ignored_index = start_logits.size(1) 1208 | start_positions.clamp_(0, ignored_index) 1209 | end_positions.clamp_(0, ignored_index) 1210 | 1211 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1212 | start_loss = loss_fct(start_logits, start_positions) 1213 | end_loss = loss_fct(end_logits, end_positions) 1214 | total_loss = (start_loss + end_loss) / 2 1215 | return total_loss 1216 | else: 1217 | return start_logits, end_logits 1218 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import re 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | def warmup_cosine(x, warmup=0.002): 28 | if x < warmup: 29 | return x/warmup 30 | x_ = (x - warmup) / (1 - warmup) # progress after warmup - 31 | return 0.5 * (1. + math.cos(math.pi * x_)) 32 | 33 | def warmup_constant(x, warmup=0.002): 34 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 35 | Learning rate is 1. afterwards. """ 36 | if x < warmup: 37 | return x/warmup 38 | return 1.0 39 | 40 | def warmup_linear(x, warmup=0.002): 41 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 42 | After `t_total`-th training step, learning rate is zero. """ 43 | if x < warmup: 44 | return x/warmup 45 | return max((x-1.)/(warmup-1.), 0) 46 | 47 | def warmup_linear_constant(x, warmup=0.002): 48 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 49 | After `t_total`-th training step, learning rate is zero. """ 50 | if x < warmup: 51 | return x/warmup 52 | return max((x-1.)/(warmup-1.), 0.2) 53 | 54 | SCHEDULES = { 55 | 'warmup_cosine': warmup_cosine, 56 | 'warmup_constant': warmup_constant, 57 | 'warmup_linear': warmup_linear, 58 | 'warmup_linear_constant': warmup_linear_constant 59 | } 60 | 61 | 62 | class BertAdam(Optimizer): 63 | """Implements BERT version of Adam algorithm with weight decay fix. 64 | Params: 65 | lr: learning rate 66 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 67 | t_total: total number of training steps for the learning 68 | rate schedule, -1 means constant learning rate. Default: -1 69 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 70 | b1: Adams b1. Default: 0.9 71 | b2: Adams b2. Default: 0.999 72 | e: Adams epsilon. Default: 1e-6 73 | weight_decay: Weight decay. Default: 0.01 74 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 75 | """ 76 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 77 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 78 | max_grad_norm=1.0): 79 | if lr is not required and lr < 0.0: 80 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 81 | if schedule not in SCHEDULES: 82 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 83 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 84 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 85 | if not 0.0 <= b1 < 1.0: 86 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 87 | if not 0.0 <= b2 < 1.0: 88 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 89 | if not e >= 0.0: 90 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 91 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 92 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 93 | max_grad_norm=max_grad_norm) 94 | super(BertAdam, self).__init__(params, defaults) 95 | 96 | def get_lr(self): 97 | lr = [] 98 | for group in self.param_groups: 99 | for p in group['params']: 100 | state = self.state[p] 101 | if len(state) == 0: 102 | return [0] 103 | if group['t_total'] != -1: 104 | schedule_fct = SCHEDULES[group['schedule']] 105 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 106 | else: 107 | lr_scheduled = group['lr'] 108 | lr.append(lr_scheduled) 109 | return lr 110 | 111 | def step(self, closure=None): 112 | """Performs a single optimization step. 113 | 114 | Arguments: 115 | closure (callable, optional): A closure that reevaluates the model 116 | and returns the loss. 117 | """ 118 | loss = None 119 | if closure is not None: 120 | loss = closure() 121 | 122 | warned_for_t_total = False 123 | 124 | for group in self.param_groups: 125 | for n, p in zip(group['names'], group['params']): 126 | if p.grad is None: 127 | continue 128 | grad = p.grad.data 129 | if grad.is_sparse: 130 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 131 | 132 | state = self.state[p] 133 | 134 | # State initialization 135 | if len(state) == 0: 136 | state['step'] = 0 137 | # Exponential moving average of gradient values 138 | state['next_m'] = torch.zeros_like(p.data) 139 | # Exponential moving average of squared gradient values 140 | state['next_v'] = torch.zeros_like(p.data) 141 | 142 | next_m, next_v = state['next_m'], state['next_v'] 143 | beta1, beta2 = group['b1'], group['b2'] 144 | 145 | # Add grad clipping 146 | if group['max_grad_norm'] > 0: 147 | clip_grad_norm_(p, group['max_grad_norm']) 148 | 149 | # Decay the first and second moment running average coefficient 150 | # In-place operations to update the averages at the same time 151 | next_m.mul_(beta1).add_(1 - beta1, grad) 152 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 153 | update = next_m / (next_v.sqrt() + group['e']) 154 | 155 | # Just adding the square of the weights to the loss function is *not* 156 | # the correct way of using L2 regularization/weight decay with Adam, 157 | # since that will interact with the m and v parameters in strange ways. 158 | # 159 | # Instead we want to decay the weights in a manner that doesn't interact 160 | # with the m/v parameters. This is equivalent to adding the square 161 | # of the weights to the loss with plain (non-momentum) SGD. 162 | if group['weight_decay'] > 0.0: 163 | update += group['weight_decay'] * p.data 164 | 165 | if group['t_total'] != -1: 166 | schedule_fct = SCHEDULES[group['schedule']] 167 | progress = state['step']/group['t_total'] 168 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 169 | self.lr_scheduled = lr_scheduled 170 | # warning for exceeding t_total (only active with warmup_linear 171 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: 172 | logger.warning( 173 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " 174 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) 175 | warned_for_t_total = True 176 | # end warning 177 | else: 178 | lr_scheduled = group['lr'] 179 | # layer-wise learning rate decay 180 | 181 | rr = re.search("layer\.(\d{1,2})\.", n) 182 | if rr: 183 | lr_scheduled *= 0.9 ** (12-int(rr.group(1))) 184 | if ".embeddings." in n: 185 | # token_in_batch = 1 - torch.eq(grad, 0).all(dim=-1, keepdim=True) 186 | # update *= token_in_batch.float() 187 | lr_scheduled *= 0.9 ** 13 188 | update_with_lr = lr_scheduled * update 189 | p.data.add_(-update_with_lr) 190 | 191 | state['step'] += 1 192 | 193 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 194 | # No bias correction 195 | # bias_correction1 = 1 - beta1 ** state['step'] 196 | # bias_correction2 = 1 - beta2 ** state['step'] 197 | 198 | return loss 199 | -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .file_utils import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 37 | } 38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 39 | 'bert-base-uncased': 512, 40 | 'bert-large-uncased': 512, 41 | 'bert-base-cased': 512, 42 | 'bert-large-cased': 512, 43 | 'bert-base-multilingual-uncased': 512, 44 | 'bert-base-multilingual-cased': 512, 45 | 'bert-base-chinese': 512, 46 | } 47 | VOCAB_NAME = 'vocab.txt' 48 | 49 | 50 | def load_vocab(vocab_file): 51 | """Loads a vocabulary file into a dictionary.""" 52 | vocab = collections.OrderedDict() 53 | index = 0 54 | with open(vocab_file, "r", encoding="utf-8") as reader: 55 | while True: 56 | token = reader.readline() 57 | if not token: 58 | break 59 | token = token.strip() 60 | vocab[token] = index 61 | index += 1 62 | return vocab 63 | 64 | 65 | def whitespace_tokenize(text): 66 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 67 | text = text.strip() 68 | if not text: 69 | return [] 70 | tokens = text.split() 71 | return tokens 72 | 73 | 74 | class BertTokenizer(object): 75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 76 | 77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 79 | """Constructs a BertTokenizer. 80 | 81 | Args: 82 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 83 | do_lower_case: Whether to lower case the input 84 | Only has an effect when do_wordpiece_only=False 85 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 86 | max_len: An artificial maximum length to truncate tokenized sequences to; 87 | Effective maximum length is always the minimum of this 88 | value (if specified) and the underlying BERT model's 89 | sequence length. 90 | never_split: List of tokens which will never be split during tokenization. 91 | Only has an effect when do_wordpiece_only=False 92 | """ 93 | if not os.path.isfile(vocab_file): 94 | raise ValueError( 95 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 96 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 97 | self.vocab = load_vocab(vocab_file) 98 | self.ids_to_tokens = collections.OrderedDict( 99 | [(ids, tok) for tok, ids in self.vocab.items()]) 100 | self.do_basic_tokenize = do_basic_tokenize 101 | if do_basic_tokenize: 102 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 103 | never_split=never_split) 104 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 105 | self.max_len = max_len if max_len is not None else int(1e12) 106 | 107 | def tokenize(self, text, inference=False): 108 | split_tokens = [] 109 | if self.do_basic_tokenize: 110 | for token in self.basic_tokenizer.tokenize(text): 111 | for sub_token in self.wordpiece_tokenizer.tokenize(token, inference=inference): 112 | split_tokens.append(sub_token) 113 | else: 114 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | """Converts a sequence of tokens into ids using the vocab.""" 119 | ids = [] 120 | for token in tokens: 121 | ids.append(self.vocab.get(token, self.vocab['[UNK]'])) 122 | if len(ids) > self.max_len: 123 | logger.warning( 124 | "Token indices sequence length is longer than the specified maximum " 125 | " sequence length for this BERT model ({} > {}). Running this" 126 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 127 | ) 128 | return ids 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 132 | tokens = [] 133 | for i in ids: 134 | tokens.append(self.ids_to_tokens[i]) 135 | return tokens 136 | 137 | def save_vocabulary(self, vocab_path): 138 | """Save the tokenizer vocabulary to a directory or file.""" 139 | index = 0 140 | if os.path.isdir(vocab_path): 141 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 142 | with open(vocab_file, "w", encoding="utf-8") as writer: 143 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 144 | if index != token_index: 145 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 146 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 147 | index = token_index 148 | writer.write(token + u'\n') 149 | index += 1 150 | return vocab_file 151 | 152 | @classmethod 153 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 154 | """ 155 | Instantiate a PreTrainedBertModel from a pre-trained model file. 156 | Download and cache the pre-trained model file if needed. 157 | """ 158 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 159 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 160 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 161 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 162 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 163 | "you may want to check this behavior.") 164 | kwargs['do_lower_case'] = False 165 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 166 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 167 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 168 | "but you may want to check this behavior.") 169 | kwargs['do_lower_case'] = True 170 | else: 171 | vocab_file = pretrained_model_name_or_path 172 | if os.path.isdir(vocab_file): 173 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 174 | # redirect to the cache, if necessary 175 | try: 176 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 177 | except EnvironmentError: 178 | logger.error( 179 | "Model name '{}' was not found in model name list ({}). " 180 | "We assumed '{}' was a path or url but couldn't find any file " 181 | "associated to this path or url.".format( 182 | pretrained_model_name_or_path, 183 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 184 | vocab_file)) 185 | return None 186 | if resolved_vocab_file == vocab_file: 187 | logger.info("loading vocabulary file {}".format(vocab_file)) 188 | else: 189 | logger.info("loading vocabulary file {} from cache at {}".format( 190 | vocab_file, resolved_vocab_file)) 191 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 192 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 193 | # than the number of positional embeddings 194 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 195 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 196 | # Instantiate tokenizer. 197 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 198 | return tokenizer 199 | 200 | 201 | class BasicTokenizer(object): 202 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 203 | 204 | def __init__(self, 205 | do_lower_case=True, 206 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 207 | """Constructs a BasicTokenizer. 208 | 209 | Args: 210 | do_lower_case: Whether to lower case the input. 211 | """ 212 | self.do_lower_case = do_lower_case 213 | self.never_split = never_split 214 | 215 | def tokenize(self, text): 216 | """Tokenizes a piece of text.""" 217 | text = self._clean_text(text) 218 | # This was added on November 1st, 2018 for the multilingual and Chinese 219 | # models. This is also applied to the English models now, but it doesn't 220 | # matter since the English models were not trained on any Chinese data 221 | # and generally don't have any Chinese data in them (there are Chinese 222 | # characters in the vocabulary because Wikipedia does have some Chinese 223 | # words in the English Wikipedia.). 224 | text = self._tokenize_chinese_chars(text) 225 | orig_tokens = whitespace_tokenize(text) 226 | split_tokens = [] 227 | for token in orig_tokens: 228 | if self.do_lower_case and token not in self.never_split: 229 | token = token.lower() 230 | # token = self._run_strip_accents(token) 231 | split_tokens.extend(self._run_split_on_punc(token)) 232 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 233 | return output_tokens 234 | 235 | def _run_strip_accents(self, text): 236 | """Strips accents from a piece of text.""" 237 | text = unicodedata.normalize("NFD", text) 238 | output = [] 239 | for char in text: 240 | cat = unicodedata.category(char) 241 | if cat == "Mn": 242 | continue 243 | output.append(char) 244 | return "".join(output) 245 | 246 | def _run_split_on_punc(self, text): 247 | """Splits punctuation on a piece of text.""" 248 | if text in self.never_split: 249 | return [text] 250 | chars = list(text) 251 | i = 0 252 | start_new_word = True 253 | output = [] 254 | while i < len(chars): 255 | char = chars[i] 256 | if _is_punctuation(char): 257 | output.append([char]) 258 | start_new_word = True 259 | else: 260 | if start_new_word: 261 | output.append([]) 262 | start_new_word = False 263 | output[-1].append(char) 264 | i += 1 265 | 266 | return ["".join(x) for x in output] 267 | 268 | def _tokenize_chinese_chars(self, text): 269 | """Adds whitespace around any CJK character.""" 270 | output = [] 271 | for char in text: 272 | cp = ord(char) 273 | if self._is_chinese_char(cp): 274 | output.append(" ") 275 | output.append(char) 276 | output.append(" ") 277 | else: 278 | output.append(char) 279 | return "".join(output) 280 | 281 | def _is_chinese_char(self, cp): 282 | """Checks whether CP is the codepoint of a CJK character.""" 283 | # This defines a "chinese character" as anything in the CJK Unicode block: 284 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 285 | # 286 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 287 | # despite its name. The modern Korean Hangul alphabet is a different block, 288 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 289 | # space-separated words, so they are not treated specially and handled 290 | # like the all of the other languages. 291 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 292 | (cp >= 0x3400 and cp <= 0x4DBF) or # 293 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 294 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 295 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 296 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 297 | (cp >= 0xF900 and cp <= 0xFAFF) or # 298 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 299 | return True 300 | 301 | return False 302 | 303 | def _clean_text(self, text): 304 | """Performs invalid character removal and whitespace cleanup on text.""" 305 | output = [] 306 | for char in text: 307 | cp = ord(char) 308 | if cp == 0 or cp == 0xfffd or _is_control(char): 309 | continue 310 | if _is_whitespace(char): 311 | output.append(" ") 312 | else: 313 | output.append(char) 314 | return "".join(output) 315 | 316 | 317 | class WordpieceTokenizer(object): 318 | """Runs WordPiece tokenization.""" 319 | 320 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 321 | self.vocab = vocab 322 | self.unk_token = unk_token 323 | self.max_input_chars_per_word = max_input_chars_per_word 324 | 325 | def tokenize(self, text, inference=False): 326 | """Tokenizes a piece of text into its word pieces. 327 | 328 | This uses a greedy longest-match-first algorithm to perform tokenization 329 | using the given vocabulary. 330 | 331 | For example: 332 | input = "unaffable" 333 | output = ["un", "##aff", "##able"] 334 | 335 | Args: 336 | text: A single token or whitespace separated tokens. This should have 337 | already been passed through `BasicTokenizer`. 338 | 339 | Returns: 340 | A list of wordpiece tokens. 341 | """ 342 | 343 | output_tokens = [] 344 | for token in whitespace_tokenize(text): 345 | chars = list(token) 346 | if len(chars) > self.max_input_chars_per_word: 347 | output_tokens.append(self.unk_token) 348 | continue 349 | 350 | is_bad = False 351 | start = 0 352 | sub_tokens = [] 353 | while start < len(chars): 354 | end = len(chars) 355 | cur_substr = None 356 | while start < end: 357 | substr = "".join(chars[start:end]) 358 | if start > 0: 359 | substr = "##" + substr 360 | if substr in self.vocab: 361 | cur_substr = substr 362 | break 363 | end -= 1 364 | if cur_substr is None: 365 | is_bad = True 366 | break 367 | sub_tokens.append(cur_substr) 368 | start = end 369 | if not inference: 370 | if is_bad: 371 | output_tokens.append(self.unk_token) 372 | else: 373 | output_tokens.extend(sub_tokens) 374 | else: 375 | if is_bad: 376 | output_tokens.append(text) 377 | else: 378 | output_tokens.extend(sub_tokens) 379 | return output_tokens 380 | 381 | 382 | def _is_whitespace(char): 383 | """Checks whether `chars` is a whitespace character.""" 384 | # \t, \n, and \r are technically contorl characters but we treat them 385 | # as whitespace since they are generally considered as such. 386 | if char == " " or char == "\t" or char == "\n" or char == "\r": 387 | return True 388 | cat = unicodedata.category(char) 389 | if cat == "Zs": 390 | return True 391 | return False 392 | 393 | 394 | def _is_control(char): 395 | """Checks whether `chars` is a control character.""" 396 | # These are technically control characters but we count them as whitespace 397 | # characters. 398 | if char == "\t" or char == "\n" or char == "\r": 399 | return False 400 | cat = unicodedata.category(char) 401 | if cat.startswith("C"): 402 | return True 403 | return False 404 | 405 | 406 | def _is_punctuation(char): 407 | """Checks whether `chars` is a punctuation character.""" 408 | cp = ord(char) 409 | # We treat all non-letter/number ASCII as punctuation. 410 | # Characters such as "^", "$", and "`" are not in the Unicode 411 | # Punctuation class but we treat them as punctuation anyways, for 412 | # consistency. 413 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 414 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 415 | return True 416 | cat = unicodedata.category(char) 417 | if cat.startswith("P"): 418 | return True 419 | return False 420 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | # warmup_linear_constant 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--data_dir', default='data', help="Directory containing the dataset") 7 | parser.add_argument('--bert_model_dir', default='/root/pretrain_model_weights/torch/chinese/chinese_wwm_ext_pytorch', help="Directory containing the BERT model in PyTorch") 8 | 9 | parser.add_argument('--clip_grad', default=2, type=int, help="") 10 | parser.add_argument('--seed', default=42, type=int, help="random seed for initialization") # 8 11 | parser.add_argument('--schedule', default='warmup_linear', help="schedule for optimizer") 12 | parser.add_argument('--weight_decay', default=0.01, type=float, help="") 13 | parser.add_argument('--warmup', default=0.1, type=float, help="") 14 | 15 | parser.add_argument('--model_dir', default='experiments/baseline', help="model directory") 16 | parser.add_argument('--epoch_num', default=6, type=int, help="num of epoch") 17 | parser.add_argument('--nega_num', default=4, type=int, help="num of negative predicates") 18 | parser.add_argument('--batch_size', default=32, type=int, help="batch size") 19 | parser.add_argument('--ner_max_len', default=32, type=int, help="max sequence length for ner task") 20 | parser.add_argument('--re_max_len', default=64, type=int, help="max sequence length for re task") 21 | parser.add_argument('--learning_rate', default=5e-5, type=float, help="learning rate") 22 | 23 | parser.add_argument('--do_train_and_eval', action='store_true', help="do_train_and_eval") 24 | parser.add_argument('--do_eval', action='store_true', help="do_eval") 25 | parser.add_argument('--do_predict', action='store_true', help="do_predict") 26 | 27 | args = parser.parse_args() 28 | 29 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import pickle 4 | 5 | from collections import defaultdict 6 | 7 | from tqdm import tqdm 8 | 9 | file_path = 'NLPCC2017-OpenDomainQA/knowledge/nlpcc-iccpol-2016.kbqa.kb' 10 | 11 | graph = defaultdict(list) 12 | entity_linking = defaultdict(list) 13 | with open(file_path, 'r', encoding='utf-8') as f: 14 | for line in tqdm(f): 15 | s, p, o = line.strip().split(' ||| ') 16 | s = s.lower() 17 | p = p.replace(' ', '') 18 | o = o.lower() 19 | graph[s].append((s, p, o)) 20 | if '(' in s: 21 | s1 = s.split('(')[0] 22 | entity_linking[s1].append(s) 23 | if s[0] =='《' and s[-1] == '》': 24 | entity_linking[s[1:-1]].append(s) 25 | 26 | print('Dumping gragh...') 27 | with open('graph/graph.pkl', 'wb') as f: 28 | pickle.dump(graph, f) 29 | 30 | print('Dumping entity linking...') 31 | with open('graph/entity_linking.pkl', 'wb') as f: 32 | pickle.dump(entity_linking, f) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import logging 5 | import pickle 6 | 7 | from tqdm import tqdm, trange 8 | from random import shuffle 9 | import numpy as np 10 | 11 | import utils 12 | from model import KBQA 13 | from batch_loader import BatchLoader 14 | from bert.optimization import BertAdam 15 | from config import args 16 | 17 | if not os.path.isdir(args.model_dir): 18 | os.makedirs(args.model_dir) 19 | 20 | 21 | def eval(model, ner_dev_data, dev_data, dev_bl, graph, entity_linking, args): 22 | ## Evaluate 23 | model.eval() 24 | dev_iter = iter(dev_bl) 25 | num = cor_num_s = cor_num_p = cor_num_o = 0 26 | t = trange(len(dev_bl), desc='Eval') 27 | for i in t: 28 | batch_data = next(dev_iter) 29 | batch_samples = dev_data[args.batch_size*i: args.batch_size*i+batch_data[0].size(0)] 30 | batch_ners = ner_dev_data[args.batch_size*i: args.batch_size*i+batch_data[0].size(0)] 31 | batch_data = tuple(tmp.to(args.device) for tmp in batch_data) 32 | head_logits, tail_logits = (tmp.cpu() for tmp in model(batch_data, 0, False)) # (bz, seqlen) 33 | heads = head_logits.argmax(dim=-1).tolist() 34 | for j, head in enumerate(heads): 35 | tail = tail_logits[j][head:].argmax().item()+head 36 | tokens = batch_ners[j]['tokens'] 37 | subject = ''.join(tokens[head: tail+1]).replace('##', '').replace('□', ' ') 38 | if subject[0] == '《': 39 | subject = subject[1:] 40 | if subject[-1] == '》': 41 | subject = subject[:-1] 42 | question = batch_samples[j]['question'] 43 | gold_spo = batch_samples[j]['triple'] 44 | spos = graph.get(subject, []) 45 | ons = entity_linking.get(subject, []) 46 | for on in ons: 47 | spos += graph.get(on, []) 48 | pres = list(set([spo[1] for spo in spos])) 49 | objs = set() 50 | pre = '' 51 | if pres: 52 | sub_re_data = bl.build_re_data(question, pres) 53 | sub_re_bl = bl.batch_loader(None, sub_re_data, args.ner_max_len, args.re_max_len, args.batch_size, is_train=False) 54 | sub_labels = [] 55 | for batch_data in sub_re_bl: 56 | batch_data = tuple(tmp.to(args.device) for tmp in batch_data) 57 | label_logits = model(batch_data, 1, False).cpu().tolist() 58 | sub_labels += label_logits 59 | index_pre = np.argmax(sub_labels) 60 | pre = pres[index_pre].replace(' ', '') 61 | for spo in spos: 62 | s, p, o = spo 63 | if subject in s and p == pre: 64 | objs.add(o) 65 | num += 1 66 | cor_num_s += 1 if subject == gold_spo[0] else 0 67 | cor_num_p += 1 if pre == gold_spo[1] else 0 68 | cor_num_o += 1 if gold_spo[-1] in objs else 0 69 | # if gold_spo[-1] not in objs: 70 | # print('XXXXXXXXXXXX') 71 | # print(question) 72 | # print(gold_spo) 73 | # print(subject, pre) 74 | # print(objs) 75 | # print(pres) 76 | # input() 77 | t.set_postfix(acc_s='{:.2f}'.format(cor_num_s/num*100), 78 | acc_p='{:.2f}'.format(cor_num_p/num*100), 79 | acc_o='{:.2f}'.format(cor_num_o/num*100)) 80 | return cor_num_o/num 81 | 82 | if __name__ == '__main__': 83 | # Use GPUs if available 84 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 85 | # Set the random seed for reproducible experiments 86 | random.seed(args.seed) 87 | torch.manual_seed(args.seed) 88 | # Set the logger 89 | utils.set_logger(os.path.join(args.model_dir, 'train.log')) 90 | logging.info('device: {}'.format(args.device)) 91 | logging.info('Hyper params:%r'%args.__dict__) 92 | 93 | # Create the input data pipeline 94 | logging.info('Loading the datasets...') 95 | bl = BatchLoader(args) 96 | ## Load train and dev data 97 | train_data = bl.load_data('train.json') 98 | dev_data = bl.load_data('dev.json') 99 | ## Train data 100 | ner_train_data, re_train_data = bl.build_data(train_data, is_train=True) 101 | train_bls = bl.batch_loader(ner_train_data, re_train_data, args.ner_max_len, args.re_max_len, args.batch_size, is_train=True) 102 | num_batchs_per_task = [len(train_bl) for train_bl in train_bls] 103 | logging.info('num of batch per task for train: {}'.format(num_batchs_per_task)) 104 | train_task_ids = sum([[i]*num_batchs_per_task[i] for i in range(len(num_batchs_per_task))], []) 105 | shuffle(train_task_ids) 106 | ## Dev data 107 | ner_dev_data, _ = bl.build_data(dev_data, is_train=False) 108 | dev_bl = bl.batch_loader(ner_dev_data, None, args.ner_max_len, args.re_max_len, args.batch_size, is_train=False) 109 | logging.info('num of batch for dev: {}'.format(len(dev_bl))) 110 | 111 | # Model 112 | model = KBQA.from_pretrained(args.bert_model_dir) 113 | model.to(args.device) 114 | 115 | # Optimizer 116 | param_optimizer = list(model.named_parameters()) 117 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 118 | optimizer_grouped_parameters = [ 119 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 120 | 'names': [n for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 121 | 'weight_decay_rate': 0.01}, 122 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 123 | 'names': [n for n, p in param_optimizer if any(nd in n for nd in no_decay)], 124 | 'weight_decay_rate': 0.0} 125 | ] 126 | args.steps_per_epoch = sum(num_batchs_per_task) 127 | args.total_steps = args.steps_per_epoch * args.epoch_num 128 | optimizer = BertAdam(params=optimizer_grouped_parameters, 129 | lr=args.learning_rate, 130 | warmup=args.warmup, 131 | t_total=args.total_steps, 132 | max_grad_norm=args.clip_grad, 133 | schedule=args.schedule) 134 | 135 | logging.info('Loading graph and entity linking...') 136 | graph = pickle.load(open('graph/graph.pkl', 'rb')) 137 | entity_linking = pickle.load(open('graph/entity_linking.pkl', 'rb')) 138 | 139 | if args.do_train_and_eval: 140 | # Train and evaluate 141 | best_acc = 0 142 | for epoch in range(args.epoch_num): 143 | ## Train 144 | model.train() 145 | t = trange(args.steps_per_epoch, desc='Epoch {} -Train'.format(epoch)) 146 | loss_avg = utils.RunningAverage() 147 | train_iters = [iter(tmp) for tmp in train_bls] # to use next and reset the iterator 148 | for i in t: 149 | task_id = train_task_ids[i] 150 | batch_data = next(train_iters[task_id]) 151 | batch_data = tuple(tmp.to(args.device) for tmp in batch_data) 152 | loss = model(batch_data, task_id, True) 153 | loss.backward() 154 | optimizer.step() 155 | optimizer.zero_grad() 156 | loss_avg.update(loss.item()) 157 | t.set_postfix(loss='{:5.4f}'.format(loss.item()), avg_loss='{:5.4f}'.format(loss_avg())) 158 | acc = eval(model, ner_dev_data, dev_data, dev_bl, graph, entity_linking, args) 159 | utils.save_checkpoint({'epoch': epoch + 1, 160 | 'state_dict': model.state_dict(), 161 | 'optim_dict': optimizer.state_dict()}, 162 | is_best=acc>best_acc, 163 | checkpoint=args.model_dir) 164 | best_acc = max(best_acc, acc) 165 | 166 | if args.do_eval: 167 | logging.info('num of batch for dev: {}'.format(len(dev_bl))) 168 | utils.load_checkpoint(os.path.join(args.model_dir, 'best.pth.tar'), model) 169 | eval(model, ner_dev_data, dev_data, dev_bl, graph, entity_linking, args) 170 | 171 | if args.do_predict: 172 | utils.load_checkpoint(os.path.join(args.model_dir, 'best.pth.tar'), model) 173 | model.eval() 174 | logging.info('Loading graph and entity linking...') 175 | graph = pickle.load(open('graph/graph.pkl', 'rb')) 176 | entity_linking = pickle.load(open('graph/entity_linking.pkl', 'rb')) 177 | while True: 178 | try: 179 | logging.info('请输入问题:') 180 | question = input() 181 | ner_data = bl.build_ner_data(question) 182 | ner_bl = bl.batch_loader(ner_data, None, args.ner_max_len, args.re_max_len, args.batch_size, is_train=False) 183 | for batch_data in ner_bl: 184 | batch_data = tuple(tmp.to(args.device) for tmp in batch_data) 185 | head_logits, tail_logits = (tmp.cpu() for tmp in model(batch_data, 0, False)) # (bz, seqlen) 186 | head = head_logits[0].argmax().item() 187 | tail = tail_logits[0][head:].argmax().item()+head 188 | tokens = ner_data[0]['tokens'] 189 | subject = ''.join(tokens[head: tail+1]).replace('##', '').replace('□', ' ') 190 | if subject[0] == '《': 191 | subject = subject[1:] 192 | if subject[-1] == '》': 193 | subject = subject[:-1] 194 | logging.info('抽到的主语为:{}'.format(subject)) 195 | spos = [] 196 | spos += graph.get(subject, []) 197 | ons = entity_linking.get(subject, []) 198 | for on in ons: 199 | spos += graph.get(on, []) 200 | spos = set(spos) 201 | pres = list(set([spo[1] for spo in spos])) 202 | # logging.info('候选关系为:{}'.format(pres)) 203 | if pres: 204 | sub_re_data = bl.build_re_data(question, pres) 205 | sub_re_bl = bl.batch_loader(None, sub_re_data, args.ner_max_len, args.re_max_len, args.batch_size, is_train=False) 206 | sub_labels = [] 207 | for batch_data in sub_re_bl: 208 | batch_data = tuple(tmp.to(args.device) for tmp in batch_data) 209 | label_logits = model(batch_data, 1, False).cpu().tolist() 210 | sub_labels += label_logits 211 | index_pre = np.argmax(sub_labels) 212 | pre = pres[index_pre] 213 | logging.info('最可能的关系为:{}'.format(pre)) 214 | for spo in spos: 215 | s, p, o = spo 216 | if subject in s and p == pre: 217 | logging.info('答案为:{}'.format(spo)) 218 | logging.info('\n') 219 | except: 220 | logging.info('出错了!是否继续?y/n') 221 | cmd = input() 222 | if cmd == 'y': 223 | pass 224 | else: 225 | break 226 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.functional import binary_cross_entropy_with_logits 4 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss 5 | 6 | from bert import BertPreTrainedModel, BertModel 7 | 8 | from batch_loader import WHITESPACE_PLACEHOLDER 9 | 10 | class KBQA(BertPreTrainedModel): 11 | def __init__(self, config): 12 | super(KBQA, self).__init__(config) 13 | self.bert = BertModel(config) 14 | self.ner_layer = nn.Linear(config.hidden_size, 2) # head, tail 15 | self.re_layer = nn.Linear(config.hidden_size, 1) # yes/no 16 | self.apply(self.init_bert_weights) 17 | 18 | def forward(self, batch_data, task_id, is_train): 19 | if task_id == 0: # task is ner 20 | token_ids, token_types, head, tail = batch_data 21 | attention_mask = token_ids.gt(0) 22 | sequence_output, _ = self.bert(token_ids, token_types, attention_mask, output_all_encoded_layers=False) 23 | head_logits, tail_logits = self.ner_layer(sequence_output).split(1, dim=-1) 24 | head_logits = head_logits.squeeze(dim=-1) 25 | tail_logits = tail_logits.squeeze(dim=-1) 26 | logits = (head_logits, tail_logits) 27 | if is_train: 28 | seq_lengths = attention_mask.sum(-1).float() 29 | ignored_index = head_logits.size(1) 30 | head.clamp_(0, ignored_index) 31 | tail.clamp_(0, ignored_index) 32 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 33 | head_loss = loss_fct(head_logits, head) 34 | tail_loss = loss_fct(tail_logits, tail) 35 | loss = (head_loss + tail_loss) / 2 36 | else: # task is re 37 | token_ids, token_types, label = batch_data 38 | attention_mask = token_ids.gt(0) 39 | _, pooled_output = self.bert(token_ids, token_types, attention_mask, output_all_encoded_layers=False) 40 | logits = self.re_layer(pooled_output).squeeze(-1) 41 | if is_train: 42 | loss_fct = BCEWithLogitsLoss() 43 | loss = loss_fct(logits, label) 44 | return loss if is_train else logits -------------------------------------------------------------------------------- /nega_sampling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from collections import defaultdict 5 | from copy import deepcopy 6 | 7 | from tqdm import tqdm 8 | 9 | def load_data(file_path): 10 | data = [] 11 | num = 0 12 | with open(file_path, 'r', encoding='utf-8') as f: 13 | sample = {} 14 | for line in f: 15 | if line.startswith(' 0: 52 | item['negative_predicates'] = ps 53 | else: 54 | for ss in s2p: 55 | if s in ss and p in s2p[ss]: 56 | ps = deepcopy(s2p[ss]) 57 | ps.remove(p) 58 | item['negative_predicates'] = ps 59 | break 60 | 61 | with open('data/train.json', 'w', encoding='utf-8') as f: 62 | for sample in train_data: 63 | f.write(json.dumps(sample, ensure_ascii=False)+'\n') 64 | 65 | test_data = load_data('data/nlpcc-iccpol-2016.kbqa.testing-data') 66 | with open('data/dev.json', 'w', encoding='utf-8') as f: 67 | for sample in test_data: 68 | f.write(json.dumps(sample, ensure_ascii=False)+'\n') 69 | 70 | 71 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export PYTHONIOENCODING=utf-8 2 | export CUDA_VISIBLE_DEVICES=1 3 | python main.py \ 4 | --do_predict \ 5 | --model_dir experiments/debug \ 6 | --nega_num 8 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | class RunningAverage(): 10 | """A simple class that maintains the running average of a quantity 11 | 12 | Example: 13 | ``` 14 | loss_avg = RunningAverage() 15 | loss_avg.update(2) 16 | loss_avg.update(4) 17 | loss_avg() = 3 18 | ``` 19 | """ 20 | 21 | def __init__(self): 22 | self.steps = 0 23 | self.total = 0 24 | 25 | def update(self, val): 26 | self.total += val 27 | self.steps += 1 28 | 29 | def __call__(self): 30 | return self.total / float(self.steps) 31 | 32 | 33 | def set_logger(log_path): 34 | """Set the logger to log info in terminal and file `log_path`. 35 | 36 | In general, it is useful to have a logger so that every output to the terminal is saved 37 | in a permanent file. Here we save it to `model_dir/train.log`. 38 | 39 | Example: 40 | ``` 41 | logging.info("Starting training...") 42 | ``` 43 | 44 | Args: 45 | log_path: (string) where to log 46 | """ 47 | logger = logging.getLogger() 48 | logger.setLevel(logging.INFO) 49 | 50 | if not logger.handlers: 51 | # Logging to a file 52 | file_handler = logging.FileHandler(log_path, encoding='utf-8') 53 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 54 | logger.addHandler(file_handler) 55 | 56 | # Logging to console 57 | stream_handler = logging.StreamHandler() 58 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 59 | logger.addHandler(stream_handler) 60 | 61 | def save_checkpoint(state, is_best, checkpoint): 62 | """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves 63 | checkpoint + 'best.pth.tar' 64 | 65 | Args: 66 | state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict 67 | is_best: (bool) True if it is the best model seen till now 68 | checkpoint: (string) folder where parameters are to be saved 69 | """ 70 | filepath = os.path.join(checkpoint, 'last.pth.tar') 71 | if not os.path.exists(checkpoint): 72 | print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint)) 73 | os.mkdir(checkpoint) 74 | torch.save(state, filepath) 75 | if is_best: 76 | shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar')) 77 | 78 | def load_checkpoint(checkpoint, model, optimizer=None): 79 | """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of 80 | optimizer assuming it is present in checkpoint. 81 | 82 | Args: 83 | checkpoint: (string) filename which needs to be loaded 84 | model: (torch.nn.Module) model for which the parameters are loaded 85 | optimizer: (torch.optim) optional: resume optimizer from checkpoint 86 | """ 87 | if not os.path.exists(checkpoint): 88 | raise ("File doesn't exist {}".format(checkpoint)) 89 | logging.info('Loading the ckpt from {}'.format(checkpoint)) 90 | checkpoint = torch.load(checkpoint) 91 | # model.load_state_dict(checkpoint['state_dict']) 92 | model.load_state_dict(checkpoint['state_dict']) 93 | 94 | if optimizer: 95 | optimizer.load_state_dict(checkpoint['optim_dict']) 96 | 97 | return checkpoint 98 | 99 | 100 | class EMA(): 101 | def __init__(self, model, decay): 102 | self.model = model 103 | self.decay = decay 104 | self.shadow = {} 105 | self.backup = {} 106 | 107 | def register(self): 108 | for name, param in self.model.named_parameters(): 109 | if param.requires_grad: 110 | self.shadow[name] = param.data.clone() 111 | 112 | def update(self): 113 | for name, param in self.model.named_parameters(): 114 | if param.requires_grad: 115 | assert name in self.shadow 116 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] 117 | self.shadow[name] = new_average.clone() 118 | 119 | def apply_shadow(self): 120 | for name, param in self.model.named_parameters(): 121 | if param.requires_grad: 122 | assert name in self.shadow 123 | self.backup[name] = param.data 124 | param.data = self.shadow[name] 125 | 126 | def restore(self): 127 | for name, param in self.model.named_parameters(): 128 | if param.requires_grad: 129 | assert name in self.backup 130 | param.data = self.backup[name] 131 | self.backup = {} --------------------------------------------------------------------------------