├── src ├── lsp_model_rl │ ├── util │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── configuration_roberta.py │ │ ├── configuration_bert.py │ │ ├── file_utils.py │ │ └── configuration_utils.py │ ├── __init__.py │ ├── automated_metrics.py │ ├── empathy_classifier_bi_encoder_attention.py │ ├── rewards.py │ ├── coherence_classifier2.py │ ├── optim.py │ └── modeling_gpt2.py ├── env.py ├── variables_ext.py ├── gpt2_training │ ├── eval_utils.py │ ├── distributed.py │ └── train_utils.py ├── process_data.py ├── data_loader.py └── train_model.py ├── asset └── social_media.jpg ├── dataset ├── README.md ├── csv2tsv.ipynb ├── sample_data.tsv └── sample_data.csv ├── .gitignore ├── README.md ├── environment.yml └── requirements.txt /src/lsp_model_rl/util/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.8.0" -------------------------------------------------------------------------------- /asset/social_media.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/claws-lab/MisinfoCorrect/HEAD/asset/social_media.jpg -------------------------------------------------------------------------------- /src/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import os 4 | 5 | 6 | END_OF_TURN_TOKEN = '<|endofturn|>' 7 | END_OF_TEXT_TOKEN = '<|endoftext|>' 8 | PROJECT_FOLDER = os.path.dirname(__file__) -------------------------------------------------------------------------------- /src/lsp_model_rl/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | from transformers import GPT2Tokenizer 3 | from transformers import PYTORCH_PRETRAINED_BERT_CACHE, cached_path 4 | from transformers import GPT2Config, GPT2Model, GPT2Config 5 | from transformers import GPT2Tokenizer 6 | 7 | from .modeling_gpt2 import GPT2LMHeadModel 8 | from .modeling_gpt2 import GPT2LMHeadModel_v2 9 | from .optim import Adam 10 | 11 | -------------------------------------------------------------------------------- /src/variables_ext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # ==== gpu server ==== 4 | if torch.cuda.is_available(): 5 | device = "cuda:0" 6 | n_gpu = 1 7 | 8 | # the paths for three reward functions 9 | clf_main_fp = r'./../../MisinfoCorrect_support_models' 10 | 11 | politeness_clf_fp = f"{clf_main_fp}/politeness_clf.pt" 12 | 13 | refutation_clf_fp = f"{clf_main_fp}/refutation_clf.pt" 14 | 15 | evidence_clf_fp = f"{clf_main_fp}/evidence_clf.pt" 16 | 17 | 18 | # general text reward 19 | if_perplexity = True 20 | if_relevance = True # i.e., the mentioned coherence reward in the paper 21 | # counter response reward 22 | # False True 23 | if_politeness = True 24 | if_refutation = True 25 | if_evidence = True 26 | 27 | # initially 5, 3 28 | every_k_epoch_save_model = 3 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | ## Dataset 2 | 3 | 1. In-the-wild social media data containing 754 annotated (misinformation tweet, counter-misinformation reply) pairs. Below is the data statistics: 4 | 5 |
6 | 7 |
8 | 9 | 2. Crowdsourced data containing 591 (misinformation tweet, human-written counter-misinformation reply) pairs. Note that for these 591 human-written replies, compared to social media data, they are refuting misinformation, polite, providing evidence per the requirement in the paper. 10 | 3. Our dataset can be found [here](https://www.dropbox.com/sh/5u2mdo53tgh3vrh/AADfYHqhQbt0A2gUciT583E0a?dl=0). 11 | 4. We notice the change of Twitter API. If you have problems regarding the access to the whole dataset or the code, please contact Bing He (bhe46@gatech.edu). -------------------------------------------------------------------------------- /dataset/csv2tsv.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "\n", 11 | "df = pd.read_csv('./sample_data.csv')" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "df.to_csv('./sample_data.tsv', index=False, header=False, sep='\\t')" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [] 29 | } 30 | ], 31 | "metadata": { 32 | "kernelspec": { 33 | "display_name": "base", 34 | "language": "python", 35 | "name": "python3" 36 | }, 37 | "language_info": { 38 | "codemirror_mode": { 39 | "name": "ipython", 40 | "version": 3 41 | }, 42 | "file_extension": ".py", 43 | "mimetype": "text/x-python", 44 | "name": "python", 45 | "nbconvert_exporter": "python", 46 | "pygments_lexer": "ipython3", 47 | "version": "3.7.6" 48 | }, 49 | "orig_nbformat": 4, 50 | "vscode": { 51 | "interpreter": { 52 | "hash": "442c277bbd2a2bd37b27aa6e5dadfa8a95da05ea7996590fcf9f52713a791d05" 53 | } 54 | } 55 | }, 56 | "nbformat": 4, 57 | "nbformat_minor": 2 58 | } 59 | -------------------------------------------------------------------------------- /dataset/sample_data.tsv: -------------------------------------------------------------------------------- 1 | It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research. Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again. 2 | It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research. Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again. 3 | It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research. Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again. 4 | It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research. Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again. 5 | -------------------------------------------------------------------------------- /dataset/sample_data.csv: -------------------------------------------------------------------------------- 1 | "It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research.","Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again." 2 | "It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research.","Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again." 3 | "It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research.","Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again." 4 | "It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research.","Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again." 5 | "It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research.","Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again." 6 | -------------------------------------------------------------------------------- /src/lsp_model_rl/util/activations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def swish(x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | def _gelu_python(x): 16 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created. 17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 19 | This is now written in C in torch.nn.functional 20 | Also see https://arxiv.org/abs/1606.08415 21 | """ 22 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 23 | 24 | 25 | def gelu_new(x): 26 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). 27 | Also see https://arxiv.org/abs/1606.08415 28 | """ 29 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 30 | 31 | 32 | if torch.__version__ < "1.4.0": 33 | gelu = _gelu_python 34 | else: 35 | gelu = F.gelu 36 | try: 37 | import torch_xla # noqa F401 38 | 39 | logger.warning( 40 | "The torch_xla package was detected in the python environment. PyTorch/XLA and JIT is untested," 41 | " no activation function will be traced with JIT." 42 | ) 43 | except ImportError: 44 | gelu_new = torch.jit.script(gelu_new) 45 | 46 | ACT2FN = { 47 | "relu": F.relu, 48 | "swish": swish, 49 | "gelu": gelu, 50 | "tanh": torch.tanh, 51 | "gelu_new": gelu_new, 52 | } 53 | 54 | 55 | def get_activation(activation_string): 56 | if activation_string in ACT2FN: 57 | return ACT2FN[activation_string] 58 | else: 59 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # MisinfoCorrect details 2 | 3 | .DS_Store 4 | 5 | command.sh 6 | 7 | *.pt 8 | *.pth 9 | *.db 10 | models/medium/* 11 | output/* 12 | logs/* 13 | 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | -------------------------------------------------------------------------------- /src/gpt2_training/eval_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import torch 4 | import logging 5 | 6 | import numpy as np 7 | 8 | from pycocoevalcap.bleu.bleu import Bleu 9 | from collections import defaultdict 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | EOS_ID = 50256 14 | 15 | 16 | def cal_BLEU_4(generated, reference, is_corpus=False): 17 | BLEUscore = [0.0, 0.0, 0.0, 0.0] 18 | for idx, g in enumerate(generated): 19 | if is_corpus: 20 | score, scores = Bleu(4).compute_score(reference, {0: [g]}) 21 | else: 22 | score, scores = Bleu(4).compute_score({0: [reference[0][idx]]}, 23 | {0: [g]}) 24 | for i, s in zip([0, 1, 2, 3], score): 25 | BLEUscore[i] += s 26 | BLEUscore[0] = BLEUscore[0]/len(generated) 27 | BLEUscore[1] = BLEUscore[1]/len(generated) 28 | BLEUscore[2] = BLEUscore[2]/len(generated) 29 | BLEUscore[3] = BLEUscore[3]/len(generated) 30 | return BLEUscore 31 | 32 | 33 | def cal_entropy(generated): 34 | etp_score = [0.0, 0.0, 0.0, 0.0] 35 | div_score = [0.0, 0.0, 0.0, 0.0] 36 | counter = [defaultdict(int), defaultdict(int), 37 | defaultdict(int), defaultdict(int)] 38 | for gg in generated: 39 | g = gg.rstrip().split() 40 | for n in range(4): 41 | for idx in range(len(g)-n): 42 | ngram = ' '.join(g[idx:idx+n+1]) 43 | counter[n][ngram] += 1 44 | for n in range(4): 45 | total = sum(counter[n].values()) + 1e-10 46 | for v in counter[n].values(): 47 | etp_score[n] += - (v+0.0) / total * (np.log(v+0.0) - np.log(total)) 48 | div_score[n] = (len(counter[n].values())+0.0) / total 49 | return etp_score, div_score 50 | 51 | 52 | def eval_model_loss(model, tokenizer, eval_dataloader, epoch_id, args): 53 | # use the same signature with eval_model_generation 54 | logger.info('compute eval model loss, using eval mode, ' 55 | 'please change it back to train after calling this function') 56 | model.eval() 57 | tot_loss = [] 58 | tot_ppl = [] 59 | tot_sample = [] 60 | with torch.no_grad(): 61 | for step, batch in enumerate(eval_dataloader): 62 | batch = tuple(t.to(args.device) for t in batch) 63 | input_ids, position_ids, token_ids, label_ids, src_len, _ = batch 64 | if args.no_token_id: 65 | token_ids = None 66 | n_sample = input_ids.shape[0] 67 | loss, ppl = model(input_ids, position_ids, token_ids, label_ids) 68 | tot_loss.append(loss.mean().item() * n_sample) 69 | tot_ppl.append(ppl.mean().item() * n_sample) 70 | tot_sample.append(n_sample) 71 | print(f"\n Epoch {epoch_id}: Val loss {np.sum(tot_loss) / np.sum(tot_sample)} Val ppl {np.sum(tot_ppl) / np.sum(tot_sample)} ") 72 | return np.sum(tot_loss) / np.sum(tot_sample), np.sum(tot_ppl) / np.sum(tot_sample) 73 | -------------------------------------------------------------------------------- /src/lsp_model_rl/util/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_bert import BertConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", 28 | "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", 29 | "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", 30 | "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json", 31 | "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json", 32 | "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json", 33 | } 34 | 35 | 36 | class RobertaConfig(BertConfig): 37 | r""" 38 | This is the configuration class to store the configuration of an :class:`~transformers.RobertaModel`. 39 | It is used to instantiate an RoBERTa model according to the specified arguments, defining the model 40 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 41 | the BERT `bert-base-uncased `__ architecture. 42 | 43 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 44 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 45 | for more information. 46 | 47 | The :class:`~transformers.RobertaConfig` class directly inherits :class:`~transformers.BertConfig`. 48 | It reuses the same defaults. Please check the parent class for more information. 49 | 50 | Example:: 51 | 52 | from transformers import RobertaConfig, RobertaModel 53 | 54 | # Initializing a RoBERTa configuration 55 | configuration = RobertaConfig() 56 | 57 | # Initializing a model from the configuration 58 | model = RobertaModel(configuration) 59 | 60 | # Accessing the model configuration 61 | configuration = model.config 62 | 63 | Attributes: 64 | pretrained_config_archive_map (Dict[str, str]): 65 | A dictionary containing all the available pre-trained checkpoints. 66 | """ 67 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 68 | model_type = "roberta" 69 | 70 | def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs): 71 | """Constructs FlaubertConfig. 72 | """ 73 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) -------------------------------------------------------------------------------- /src/gpt2_training/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ Pytorch Distributed utils 4 | # NOTE: copied from OpenNMT-py 5 | This piece of code was heavily inspired by the equivalent of Fairseq-py 6 | https://github.com/pytorch/fairseq 7 | """ 8 | 9 | 10 | from __future__ import print_function 11 | 12 | import math 13 | import pickle 14 | import torch.distributed 15 | 16 | 17 | def is_master(opt, device_id): 18 | return opt.gpu_ranks[device_id] == 0 19 | 20 | 21 | def multi_init(opt, device_id, logger=None): 22 | dist_init_method = 'tcp://{master_ip}:{master_port}'.format( 23 | master_ip=opt.master_ip, 24 | master_port=opt.master_port) 25 | dist_world_size = opt.world_size 26 | torch.distributed.init_process_group( 27 | backend=opt.gpu_backend, init_method=dist_init_method, 28 | world_size=dist_world_size, rank=opt.gpu_ranks[device_id]) 29 | gpu_rank = torch.distributed.get_rank() 30 | if not is_master(opt, device_id) and logger is not None: 31 | logger.disabled = True 32 | 33 | return gpu_rank 34 | 35 | 36 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 37 | buffer_size=10485760): 38 | """All-reduce and rescale tensors in chunks of the specified size. 39 | 40 | Args: 41 | tensors: list of Tensors to all-reduce 42 | rescale_denom: denominator for rescaling summed Tensors 43 | buffer_size: all-reduce chunk size in bytes 44 | """ 45 | # buffer size in bytes, determine equiv. # of elements based on data type 46 | buffer_t = tensors[0].new( 47 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 48 | buffer = [] 49 | 50 | def all_reduce_buffer(): 51 | # copy tensors into buffer_t 52 | offset = 0 53 | for t in buffer: 54 | numel = t.numel() 55 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 56 | offset += numel 57 | 58 | # all-reduce and rescale 59 | torch.distributed.all_reduce(buffer_t[:offset]) 60 | buffer_t.div_(rescale_denom) 61 | 62 | # copy all-reduced buffer back into tensors 63 | offset = 0 64 | for t in buffer: 65 | numel = t.numel() 66 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 67 | offset += numel 68 | 69 | filled = 0 70 | for t in tensors: 71 | sz = t.numel() * t.element_size() 72 | if sz > buffer_size: 73 | # tensor is bigger than buffer, all-reduce and rescale directly 74 | torch.distributed.all_reduce(t) 75 | t.div_(rescale_denom) 76 | elif filled + sz > buffer_size: 77 | # buffer is full, all-reduce and replace buffer with grad 78 | all_reduce_buffer() 79 | buffer = [t] 80 | filled = sz 81 | else: 82 | # add tensor to buffer 83 | buffer.append(t) 84 | filled += sz 85 | 86 | if len(buffer) > 0: 87 | all_reduce_buffer() 88 | 89 | 90 | def all_gather_list(data, max_size=4096): 91 | """Gathers arbitrary data from all nodes into a list.""" 92 | world_size = torch.distributed.get_world_size() 93 | if not hasattr(all_gather_list, '_in_buffer') or \ 94 | max_size != all_gather_list._in_buffer.size(): 95 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 96 | all_gather_list._out_buffers = [ 97 | torch.cuda.ByteTensor(max_size) 98 | for i in range(world_size) 99 | ] 100 | in_buffer = all_gather_list._in_buffer 101 | out_buffers = all_gather_list._out_buffers 102 | 103 | enc = pickle.dumps(data) 104 | enc_size = len(enc) 105 | if enc_size + 2 > max_size: 106 | raise ValueError( 107 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 108 | assert max_size < 255*256 109 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 110 | in_buffer[1] = enc_size % 255 111 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 112 | 113 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 114 | 115 | results = [] 116 | for i in range(world_size): 117 | out_buffer = out_buffers[i] 118 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 119 | 120 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 121 | result = pickle.loads(bytes_list) 122 | results.append(result) 123 | return results 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning-based Counter-Misinformation Response Generation: A Case Study of COVID-19 Vaccine Misinformation 2 | This repository contains code and data for our ACM WWW 2023 publication on counter-misinformation response generation. 3 | 4 | You may access the PDF here: [PDF](https://faculty.cc.gatech.edu/~srijan/pubs/he-www23-misinfocorrect.pdf) 5 | 6 | If our code or data helps you in your research, please cite: 7 | 8 | ``` 9 | @inproceedings{he2023reinforcement, 10 | title={Reinforcement Learning-based Counter-Misinformation Response Generation: A Case Study of COVID-19 Vaccine Misinformation}, 11 | author={He, Bing and Ahamad, Mustaque and Kumar, Srijan}, 12 | booktitle={Proceedings of the ACM Web Conference 2023}, 13 | year={2023} 14 | } 15 | ``` 16 | 17 | ## Introduction 18 | 19 | The COVID-19 vaccine misinformation on social media reduces vaccine uptake and threatens public health. While fact-checkers debunk these false claims, they do not interact with misinformation spreaders. Ordinary users, who make up 96% of counter-misinformation responses, often lack evidence and are rude in their responses. This study aims to create a counter-misinformation response generation model to empower users to effectively correct misinformation. We first create two novel datasets of misinformation and counter-misinformation responses from social media and crowdsourcing. A reinforcement-learning-based text generation model is then proposed to reward the generator to increase the politeness, refutation attitude, and factuality while retaining text fluency and relevancy. Through extensive experiments, the model outperforms baselines in generating high-quality responses, demonstrating the potential for generative text models for social good. 20 | 21 | 22 | ## Quickstart 23 | 24 | ### 1. Setup and Installation 25 | 26 | Our framework can be compiled on Python 3 environments. The modules used in our code can be installed using: 27 | ``` 28 | $ pip install -r requirements.txt 29 | ``` 30 | 31 | or 32 | 33 | ``` 34 | $ conda env create -f environment.yml 35 | ``` 36 | 37 | 38 | ### 2. Prepare dataset 39 | 40 | A sample raw input data file is available in [dataset/sample_data.tsv](dataset/sample_data.tsv). Each line in the file has a tweet and a corresponding counter-response (tab-separated). This input file can be converted into a format that is recognized by the model using with following command: 41 | ``` 42 | python src/process_data.py --corpus dataset/sample_data.tsv --if_input_src_only 43 | ``` 44 | 45 | Running this command will generate a folder named `sample_data.128len.db`. 46 | 47 | ### 3. Training the model 48 | For training our model on the sample input data, run the following command: 49 | 50 | ``` 51 | python src/train_model.py \ 52 | --model_name_or_path models/medium/ \ 53 | --train_input_file dataset/sample_data.128len.if_input_src_only.db \ 54 | --output_dir output/ \ 55 | --log_dir output/ \ 56 | --train_batch_size 4 \ 57 | --num_optim_steps 100 58 | ``` 59 | Before running this code, 60 | 1. You will need a DialoGPT-like transformer model for initialization (`model_name_or_path`, ideally finetuned on your dataset, check the warm-up start in the paper); 61 | 2. You will need to separately train multiple reward functions for the reinforcement learning framework. Here, we have three reward functions: politeness, refutation, and evidence reward classifiers. The locations of classifiers are specified in politeness_clf_fp, refutation_clf_fp, evidence_clf_fp in variables_ext.py. 62 | 3. To configure the reward parameters, please refer to variables_ext.py 63 | 4. For sanity check when releasing the codebase, we only use one GPU in the current version. Please revise n_gpu in variables_ext.py when there are multiple GPUs. For the paper, we run our experiments on NVIDIA DGX-1 consisting of 8 V100 GPUs. 64 | 5. The repository is based on [DialoGPT](https://github.com/microsoft/DialoGPT) and [Partner](https://github.com/behavioral-data/PARTNER) and uses a similar code structure and environment. 65 | 66 | ## Dataset 67 | 68 | 1. In-the-wild social media data containing 754 annotated (misinformation tweet, counter-misinformation reply) pairs. Below is the data statistics: 69 | 70 |
71 | 72 |
73 | 74 | 2. Crowdsourced data containing 591 (misinformation tweet, human-written counter-misinformation reply) pairs. Note that for these 591 human-written replies, compared to social media data, they are refuting misinformation, polite, providing evidence per the requirement in the paper. 75 | 3. Our dataset can be found [here](https://www.dropbox.com/sh/5u2mdo53tgh3vrh/AADfYHqhQbt0A2gUciT583E0a?dl=0). 76 | 4. We notice the change of Twitter API. If you have problems regarding the access to the whole dataset or the code, please contact Bing He (bhe46@gatech.edu). 77 | 78 | If you have any questions, please contact the author Bing He (bhe46@gatech.edu) 79 | -------------------------------------------------------------------------------- /src/lsp_model_rl/automated_metrics.py: -------------------------------------------------------------------------------- 1 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 2 | from nltk import ngrams 3 | import numpy as np 4 | from numpy import dot 5 | from numpy.linalg import norm 6 | 7 | # before Oct 2022 8 | #from bert_serving.client import BertClient 9 | #bc = BertClient() 10 | # after Oct 2022 11 | from transformers import BertTokenizer, BertModel 12 | import torch 13 | 14 | bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 15 | bert_model = BertModel.from_pretrained("bert-base-uncased") 16 | 17 | 18 | 19 | import torch 20 | import math 21 | from transformers import GPT2LMHeadModel, GPT2Tokenizer 22 | 23 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 24 | 25 | 26 | # ==== : device control ==== 27 | import sys, os 28 | sys.path.append("./") 29 | sys.path.append("../") 30 | sys.path.append(".../") 31 | sys.path.append("..../") 32 | from MisinfoCorrect.src.variables_ext import device 33 | # ==== : device control ==== 34 | 35 | GPT_model = GPT2LMHeadModel.from_pretrained('microsoft/DialoGPT-medium').to(device) 36 | GPT_tokenizer = GPT2Tokenizer.from_pretrained('microsoft/DialoGPT-medium') 37 | 38 | GPT_tokenizer.pad_token = GPT_tokenizer.eos_token 39 | 40 | GPT_model.eval() 41 | 42 | MAX_LEN = 64 43 | 44 | # BLEU 45 | def bleu(predicted, target): 46 | 47 | all_bleus = [] 48 | 49 | smoothie = SmoothingFunction().method4 50 | 51 | for idx, elem in enumerate(predicted): 52 | try: 53 | curr_bleu = sentence_bleu([str(target[idx])], str(predicted[idx]), smoothing_function=smoothie) 54 | all_bleus.append(curr_bleu) 55 | except ZeroDivisionError: 56 | continue 57 | 58 | return np.mean(all_bleus) 59 | 60 | 61 | # Perplexity 62 | def perplexity(predicted): 63 | 64 | BATCH_SIZE = 1 65 | 66 | tokenized_input = GPT_tokenizer.batch_encode_plus(predicted, max_length=MAX_LEN, pad_to_max_length=True, truncation=True) 67 | 68 | input_ids = tokenized_input['input_ids'] 69 | attention_masks = tokenized_input['attention_mask'] 70 | 71 | input_ids = torch.tensor(input_ids) 72 | attention_masks = torch.tensor(attention_masks) 73 | 74 | data = TensorDataset(input_ids, attention_masks) 75 | 76 | sampler = SequentialSampler(data) 77 | dataloader = DataLoader(data, sampler=sampler, batch_size = BATCH_SIZE) 78 | 79 | all_loss = [] 80 | 81 | with torch.no_grad(): 82 | 83 | for batch in dataloader: 84 | b_input = batch[0].to(device) 85 | b_attn = batch[1].to(device) 86 | 87 | outputs = GPT_model(b_input, attention_mask=b_attn, labels=b_input) 88 | 89 | loss, logits = outputs[:2] 90 | all_loss.append(loss.item()) 91 | 92 | return math.exp(np.mean(all_loss)) 93 | 94 | 95 | # Diversity 96 | 97 | # Distinct-1, Distinct-2 (# of unigrams and bigrams divided by the total number of words) 98 | 99 | def distinct(predicted): 100 | 101 | UNIGRAMS = set() 102 | BIGRAMS = set() 103 | 104 | NUM_WORDS = 0 105 | 106 | for idx, elem in enumerate(predicted): 107 | curr_unigrams = ngrams(str(elem).split(), 1) 108 | curr_bigrams = ngrams(str(elem).split(), 2) 109 | 110 | NUM_WORDS += len(str(elem).split()) 111 | 112 | for unigram in curr_unigrams: 113 | UNIGRAMS.add(' '.join(unigram).strip()) 114 | 115 | for bigram in curr_bigrams: 116 | BIGRAMS.add(' '.join(bigram).strip()) 117 | 118 | 119 | 120 | DISTINCT_1 = len(UNIGRAMS) / NUM_WORDS 121 | DISTINCT_2 = len(BIGRAMS) / NUM_WORDS 122 | 123 | return DISTINCT_1, DISTINCT_2 124 | 125 | 126 | # Entropy 127 | # def entropy(): 128 | 129 | 130 | # Specificity 131 | # similar to relevance 132 | def specificity(seeker_post, predicted): 133 | 134 | # Get embeddings 135 | seeker_post_embeddings = bc.encode(seeker_post) 136 | predicted_response_embeddings = bc.encode(predicted) 137 | 138 | # Compute cosine similarity 139 | 140 | all_cos_sim = [] 141 | 142 | for idx, elem in enumerate(seeker_post_embeddings): 143 | a = seeker_post_embeddings[idx] 144 | b = predicted_response_embeddings[idx] 145 | 146 | cos_sim = dot(a, b)/(norm(a)*norm(b)) 147 | all_cos_sim.append(cos_sim) 148 | 149 | return np.mean(all_cos_sim) 150 | 151 | # ext: relevance reward - coherence reward 152 | def relevance(seeker_post, predicted): 153 | from numpy import dot 154 | from numpy.linalg import norm 155 | import numpy as np 156 | def encode(post): 157 | 158 | inputs = bert_tokenizer.encode_plus(post, return_tensors="pt", add_special_tokens=True) 159 | outputs = bert_model(**inputs) 160 | 161 | # [0] hidden state of all tokens in the sentence, [1] the final last only 1 hidden state output 162 | last_hidden_states = outputs[1] 163 | return last_hidden_states 164 | 165 | # Get embeddings 166 | seeker_post_embeddings = encode(seeker_post) 167 | predicted_response_embeddings = encode(predicted) 168 | 169 | # Compute cosine similarity 170 | 171 | all_cos_sim = [] 172 | 173 | for idx, elem in enumerate(seeker_post_embeddings): 174 | a = seeker_post_embeddings[idx] 175 | b = predicted_response_embeddings[idx] 176 | 177 | a = a.detach().numpy() 178 | b = b.detach().numpy() 179 | 180 | cos_sim = dot(a, b)/(norm(a)*norm(b)) 181 | all_cos_sim.append(cos_sim) 182 | 183 | return np.mean(all_cos_sim) 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /src/lsp_model_rl/empathy_classifier_bi_encoder_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import codecs 3 | import numpy as np 4 | 5 | 6 | import pandas as pd 7 | import re 8 | import csv 9 | import numpy as np 10 | import sys 11 | 12 | import time 13 | 14 | from sklearn.metrics import f1_score 15 | 16 | from transformers import RobertaTokenizer 17 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 18 | from torch.utils.data import TensorDataset, random_split 19 | 20 | 21 | from .util.models import BiEncoderAttentionWithRationaleClassification 22 | from transformers import AdamW, RobertaConfig 23 | 24 | import datetime 25 | 26 | 27 | class EmpathyClassifier(): 28 | 29 | def __init__(self, 30 | device, 31 | ER_model_path = './', 32 | IP_model_path = './', 33 | EX_model_path = './', 34 | batch_size=2): 35 | 36 | self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True) 37 | self.batch_size = batch_size 38 | self.device = device 39 | 40 | self.model_ER = BiEncoderAttentionWithRationaleClassification() 41 | self.model_IP = BiEncoderAttentionWithRationaleClassification() 42 | self.model_EX = BiEncoderAttentionWithRationaleClassification() 43 | 44 | ER_weights = torch.load(ER_model_path, map_location=self.device) 45 | self.model_ER.load_state_dict(ER_weights) 46 | 47 | IP_weights = torch.load(IP_model_path, map_location=self.device) 48 | self.model_IP.load_state_dict(IP_weights) 49 | 50 | EX_weights = torch.load(EX_model_path, map_location=self.device) 51 | self.model_EX.load_state_dict(EX_weights) 52 | 53 | self.model_ER.to(self.device) 54 | self.model_IP.to(self.device) 55 | self.model_EX.to(self.device) 56 | 57 | self.model_ER.eval() 58 | self.model_IP.eval() 59 | self.model_EX.eval() 60 | 61 | 62 | def predict_empathy(self, seeker_posts, response_posts): 63 | 64 | input_ids_SP = [] 65 | attention_masks_SP = [] 66 | 67 | for sent in seeker_posts: 68 | 69 | encoded_dict = self.tokenizer.encode_plus( 70 | sent, # Sentence to encode. 71 | add_special_tokens = True, # Add '[CLS]' and '[SEP]' 72 | max_length = 64, # Pad & truncate all sentences. 73 | truncation=True, 74 | pad_to_max_length = True, 75 | return_attention_mask = True, # Construct attn. masks. 76 | return_tensors = 'pt', # Return pytorch tensors. 77 | ) 78 | 79 | input_ids_SP.append(encoded_dict['input_ids']) 80 | attention_masks_SP.append(encoded_dict['attention_mask']) 81 | 82 | 83 | input_ids_RP = [] 84 | attention_masks_RP = [] 85 | 86 | for sent in response_posts: 87 | encoded_dict = self.tokenizer.encode_plus( 88 | sent, # Sentence to encode. 89 | add_special_tokens = True, # Add '[CLS]' and '[SEP]' 90 | max_length = 64, # Pad & truncate all sentences. 91 | truncation=True, 92 | pad_to_max_length = True, 93 | return_attention_mask = True, # Construct attn. masks. 94 | return_tensors = 'pt', # Return pytorch tensors. 95 | ) 96 | 97 | input_ids_RP.append(encoded_dict['input_ids']) 98 | attention_masks_RP.append(encoded_dict['attention_mask']) 99 | 100 | input_ids_SP = torch.cat(input_ids_SP, dim=0) 101 | attention_masks_SP = torch.cat(attention_masks_SP, dim=0) 102 | 103 | input_ids_RP = torch.cat(input_ids_RP, dim=0) 104 | attention_masks_RP = torch.cat(attention_masks_RP, dim=0) 105 | 106 | dataset = TensorDataset(input_ids_SP, attention_masks_SP, input_ids_RP, attention_masks_RP) 107 | 108 | dataloader = DataLoader( 109 | dataset, # The test samples. 110 | sampler = SequentialSampler(dataset), # Pull out batches sequentially. 111 | batch_size = self.batch_size # Evaluate with this batch size. 112 | ) 113 | 114 | self.model_ER.eval() 115 | self.model_IP.eval() 116 | self.model_EX.eval() 117 | 118 | for batch in dataloader: 119 | b_input_ids_SP = batch[0].to(self.device) 120 | b_input_mask_SP = batch[1].to(self.device) 121 | b_input_ids_RP = batch[2].to(self.device) 122 | b_input_mask_RP = batch[3].to(self.device) 123 | 124 | with torch.no_grad(): 125 | (logits_empathy_ER, logits_rationale_ER,) = self.model_ER(input_ids_SP = b_input_ids_SP, 126 | input_ids_RP = b_input_ids_RP, 127 | token_type_ids_SP=None, 128 | token_type_ids_RP=None, 129 | attention_mask_SP=b_input_mask_SP, 130 | attention_mask_RP=b_input_mask_RP) 131 | 132 | (logits_empathy_IP, logits_rationale_IP,) = self.model_IP(input_ids_SP = b_input_ids_SP, 133 | input_ids_RP = b_input_ids_RP, 134 | token_type_ids_SP=None, 135 | token_type_ids_RP=None, 136 | attention_mask_SP=b_input_mask_SP, 137 | attention_mask_RP=b_input_mask_RP) 138 | 139 | (logits_empathy_EX, logits_rationale_EX,) = self.model_EX(input_ids_SP = b_input_ids_SP, 140 | input_ids_RP = b_input_ids_RP, 141 | token_type_ids_SP=None, 142 | token_type_ids_RP=None, 143 | attention_mask_SP=b_input_mask_SP, 144 | attention_mask_RP=b_input_mask_RP) 145 | 146 | 147 | logits_empathy_ER = logits_empathy_ER.detach().cpu().numpy().tolist() 148 | predictions_ER = np.argmax(logits_empathy_ER, axis=1).flatten() 149 | 150 | logits_empathy_IP = logits_empathy_IP.detach().cpu().numpy().tolist() 151 | predictions_IP = np.argmax(logits_empathy_IP, axis=1).flatten() 152 | 153 | logits_empathy_EX = logits_empathy_EX.detach().cpu().numpy().tolist() 154 | predictions_EX = np.argmax(logits_empathy_EX, axis=1).flatten() 155 | 156 | return (logits_empathy_ER, predictions_ER, logits_empathy_IP, predictions_IP, logits_empathy_EX, predictions_EX) 157 | 158 | 159 | ''' 160 | Example: 161 | ''' 162 | 163 | import sys, os 164 | sys.path.append("../") 165 | sys.path.append(".../") 166 | from MisinfoCorrect.src.variables_ext import device 167 | 168 | 169 | seeker_posts = ['I need help', 'I want to talk to someone','I do not have any friends.'] 170 | response_posts = ['why do you feel this way?', 'I understand how you feel','do you want to talk about it?'] 171 | 172 | empathy_classifier = EmpathyClassifier(device) 173 | 174 | (logits_empathy_ER, predictions_ER, logits_empathy_IP, predictions_IP, logits_empathy_EX, predictions_EX) = empathy_classifier.predict_empathy(seeker_posts, response_posts) 175 | 176 | print(predictions_ER, predictions_IP, predictions_EX) 177 | -------------------------------------------------------------------------------- /src/gpt2_training/train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import logging 6 | import torch 7 | from collections import defaultdict 8 | 9 | from env import END_OF_TEXT_TOKEN 10 | from lsp_model_rl.optim import warmup_linear, noam_decay, noamwd_decay 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | SEQ_LENGTH_SHRINK_PROP = 0.9 16 | 17 | # ==== understand by ext: it seems like device setup ==== 18 | def load_model(model, checkpoint, args, verbose=False): 19 | n_gpu = args.n_gpu 20 | device = args.device 21 | if checkpoint is None or checkpoint == "None": 22 | if verbose: 23 | logger.info('no checkpoint provided for %s!' % model._get_name()) 24 | else: 25 | if not os.path.exists(checkpoint): 26 | raise ValueError('checkpoint %s not exist' % checkpoint) 27 | if verbose: 28 | logger.info('loading finetuned model from %s' % checkpoint) 29 | model_state_dict = torch.load(checkpoint) 30 | 31 | model_state_dict = fix_state_dict_namespace(model_state_dict) 32 | 33 | start_model = model 34 | if (hasattr(model, "transformer") 35 | and all(not s.startswith('transformer.') 36 | for s in model_state_dict.keys())): 37 | logger.info('loading transfomer only') 38 | start_model = model.transformer 39 | start_model.load_state_dict(model_state_dict) 40 | 41 | if args.fp16: 42 | logger.info('in fp16, model.half() activated') 43 | model.half() 44 | model.to(device) 45 | if n_gpu > 1: 46 | logging.info('data parallel because more than one gpu') 47 | model = torch.nn.DataParallel(model) 48 | return model 49 | 50 | 51 | def fix_state_dict_namespace(model_state_dict): 52 | old_keys = [] 53 | new_keys = [] 54 | for t in model_state_dict: 55 | new_key = t 56 | if t.startswith('module.'): 57 | new_key = t.replace('module.', '') 58 | old_keys.append(t) 59 | new_keys.append(new_key) 60 | 61 | for old_key, new_key in zip(old_keys, new_keys): 62 | model_state_dict[new_key] = model_state_dict.pop(old_key) 63 | 64 | return model_state_dict 65 | 66 | 67 | 68 | class InputFeatures(object): 69 | def __init__(self, conv_id, input_ids, position_ids, token_type_ids, 70 | seeker_post, response_post, input_len=None): 71 | self.conv_id = conv_id 72 | self.input_ids = input_ids 73 | self.position_ids = position_ids 74 | self.token_type_ids = token_type_ids 75 | 76 | self.seeker_post = seeker_post 77 | self.response_post = response_post 78 | 79 | if input_len is None: 80 | self.input_len = len(input_ids) 81 | else: 82 | self.input_len = input_len 83 | 84 | 85 | 86 | 87 | class RedditExample(object): 88 | def __init__(self, conv_id, context, response): 89 | self.conv_id = conv_id 90 | self.context = context 91 | self.response = response 92 | 93 | def __repr__(self): 94 | return 'conv_id = {}\ncontext = {}\nresponse = {}'.format( 95 | self.conv_id, self.context, self.response) 96 | 97 | def __str__(self): 98 | return self.__repr__() 99 | 100 | 101 | def boolean_string(s): 102 | if s.lower() not in {'false', 'true'}: 103 | raise ValueError('Not a valid boolean string') 104 | return s.lower() == 'true' 105 | 106 | 107 | def get_eval_list_same_length(input_file, tokenizer, max_batch_size, 108 | norm=True): 109 | examples = [] 110 | with open(input_file, 'r', encoding="utf-8") as f: 111 | content = [l.split('\t') for l in f.read().splitlines()] 112 | 113 | context, response = [c[0] for c in content], [c[1:] for c in content] 114 | i = 0 115 | for src, tgt_all in zip(context, response): 116 | for tgt in tgt_all: 117 | if norm: 118 | src_line = ' '.join(src.strip().split()) 119 | tgt_line = ' '.join(tgt.strip().split()) 120 | else: 121 | src_line = src.strip() 122 | tgt_line = tgt.strip() 123 | examples.append(RedditExample(i, src_line, tgt_line)) 124 | i += 1 125 | 126 | def featurize(example): 127 | conv_id = example.conv_id 128 | context_id = tokenizer.encode(example.context) 129 | end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN] 130 | 131 | response_id = tokenizer.encode(example.response) 132 | input_ids = context_id + [end_of_text_id] 133 | lm_labels = response_id 134 | 135 | position_ids = list(range(len(input_ids))) 136 | 137 | token_type_id = [0] * len(input_ids) 138 | 139 | return InputFeatures(conv_id, input_ids, position_ids, token_type_id, 140 | lm_labels, len(context_id), len(response_id)) 141 | 142 | def batch_feature_same_len(features): 143 | input_ids = torch.stack([torch.tensor(f.choices_features['input_ids'], 144 | dtype=torch.long) 145 | for f in features]) 146 | position_ids = torch.stack( 147 | [torch.tensor(f.choices_features['position_ids'], dtype=torch.long) 148 | for f in features]) 149 | token_type_ids = torch.stack( 150 | [torch.tensor(f.choices_features['token_type_ids'], 151 | dtype=torch.long) 152 | for f in features]) 153 | labels = torch.nn.utils.rnn.pad_sequence( 154 | [torch.tensor(f.lm_labels, dtype=torch.long) for f in features], 155 | batch_first=True, padding_value=-1) 156 | 157 | context_len = torch.tensor([f.context_len for f in features], 158 | dtype=torch.long) 159 | response_len = torch.tensor([f.response_len for f in features], 160 | dtype=torch.long) 161 | return (input_ids, position_ids, token_type_ids, labels, 162 | context_len, response_len) 163 | 164 | features = [featurize(e) for e in examples] 165 | dataloader_pre = defaultdict(list) 166 | for f in features: 167 | dataloader_pre[f.context_len].append(f) 168 | 169 | dataloader = [] 170 | for l in sorted(dataloader_pre): 171 | f = batch_feature_same_len(dataloader_pre[l]) 172 | if len(f[0]) <= max_batch_size: 173 | dataloader.append(f) 174 | else: 175 | start_index = 0 176 | while True: 177 | dataloader.append([ff[start_index:start_index + max_batch_size] 178 | for ff in f]) 179 | start_index += max_batch_size 180 | if start_index >= len(f[0]): 181 | break 182 | return dataloader 183 | 184 | 185 | def set_lr(optimizer, step, schedule, lr, 186 | warmup_steps, warmup_proportion, n_embd, tot_steps): 187 | if schedule == 'None': 188 | lr_this_step = lr 189 | elif schedule == 'noam': # transformer like 190 | lr_this_step = lr * 1e4 * noam_decay(step+1, warmup_steps, n_embd) 191 | elif schedule == 'noamwd': # transformer like 192 | lr_this_step = lr * 1e4 * noamwd_decay(step+1, warmup_steps, n_embd) 193 | else: 194 | lr_this_step = lr * warmup_linear(step / tot_steps, 195 | warmup_proportion) 196 | for param_group in optimizer.param_groups: 197 | param_group['lr'] = lr_this_step 198 | -------------------------------------------------------------------------------- /src/lsp_model_rl/rewards.py: -------------------------------------------------------------------------------- 1 | from .automated_metrics import * 2 | import numpy as np 3 | import nltk 4 | # input: edited sentence, initial sentence, weights 5 | # from .empathy_classifier_bi_encoder_attention import empathy_classifier 6 | from .coherence_classifier2 import coherence_classifier 7 | 8 | from .coherence_classifier2 import politeness_classifier 9 | from .coherence_classifier2 import refutation_classifier 10 | from .coherence_classifier2 import evidence_classifier 11 | 12 | orig_sent_list = ["It might help to re-install Python if possible","The version might behind the bug."] 13 | new_sent_list = ["The version might be the reason for the bug.","The version might be the reason behind the bug."] 14 | 15 | # w: weight for the importance of the rewards 16 | w = {'edit':1,'bleu':1,'dist1':10,'dist2':1,'pp':10,'spec':1,'empathy':1, 17 | 'politeness':1, 'refutation':1, 'evidence':1, 'coherence':0.1, 'relevance':0.1} 18 | 19 | def calc_rewards(seeker_posts, original_responses, generated_responses, candidates = None, 20 | _edit=False, 21 | _bleu=False, 22 | _distinct=False, 23 | _perplexity=False, 24 | _specificity=False, 25 | _empathy=False, 26 | _empathy_change=False, 27 | _coherence=False, 28 | _add_noise=True, 29 | _pick_categorical='', 30 | _empathy_adaptive=False, 31 | _politeness=False, 32 | _refutation=False, 33 | _evidence=False, 34 | _relevance=False, 35 | NOISE=0.00001): 36 | 37 | total_score = 0 38 | 39 | if _edit: 40 | edit = edit_level_jaccard(original_responses, generated_responses) 41 | total_score += edit*w['edit'] 42 | 43 | if _bleu: 44 | bleu_score = bleu(generated_responses, original_responses) 45 | total_score += bleu_score*w['bleu'] 46 | 47 | if _distinct: 48 | distinct_1, distinct_2 = distinct(generated_responses) 49 | total_score += distinct_1*w['dist1']+distinct_2*w['dist2'] 50 | 51 | if _perplexity: 52 | perplexity_score = perplexity(generated_responses) 53 | total_score += perplexity_score*w['pp'] 54 | 55 | if _specificity: 56 | specificity_score = 0 # specificity(seeker_posts, generated_responses) 57 | total_score += specificity_score*w['spec'] 58 | 59 | if _empathy: 60 | empathy_score = calc_empathy_score(seeker_posts, generated_responses) 61 | total_score += empathy_score*w['empathy'] 62 | 63 | if _empathy_change: 64 | prev_empathy_score = calc_empathy_score(seeker_posts, original_responses) 65 | curr_empathy_score = calc_empathy_score(seeker_posts, generated_responses) 66 | 67 | total_score += curr_empathy_score - prev_empathy_score 68 | 69 | 70 | if _coherence: 71 | reward_val = relevance(seeker_posts, generated_responses) 72 | total_score += reward_val*w['coherence'] 73 | if _politeness: 74 | reward_val = calc_politeness_score(generated_responses, "") 75 | total_score += reward_val*w['politeness'] 76 | if _refutation: 77 | reward_val = calc_refutation_score(seeker_posts, generated_responses) 78 | total_score += reward_val*w['refutation'] 79 | if _evidence: 80 | reward_val = calc_evidence_score(seeker_posts, generated_responses) 81 | total_score += reward_val*w['evidence'] 82 | if _relevance: 83 | reward_val = relevance(seeker_posts, generated_responses) 84 | total_score += reward_val*w['relevance'] 85 | 86 | 87 | if _add_noise: 88 | total_score -= NOISE 89 | 90 | if _empathy_adaptive: 91 | _,_,_,_,ER_score, IP_score, EX_score = calc_empathy_score_3dim(seeker_posts, generated_responses) 92 | total_score += ((2-ER_score)*ER_score+(2-IP_score)*IP_score+(2-EX_score)*EX_score)*w['empathy']*0.5 93 | 94 | return total_score 95 | 96 | 97 | def edit_level_jaccard(orig_sent_list, new_sent_list): 98 | total_score = 0 99 | for i, orig_sent in enumerate(orig_sent_list): 100 | total_score += (nltk.jaccard_distance(set(orig_sent), set(new_sent_list[i]))) 101 | return total_score/len(orig_sent_list) 102 | 103 | edit_level_jaccard(orig_sent_list, new_sent_list) 104 | 105 | def calc_empathy_score(seeker_posts, generated_responses): 106 | batch_score = 0 107 | for i in range(len(seeker_posts)): 108 | (logits_empathy_ER, predictions_ER, logits_empathy_IP, predictions_IP, logits_empathy_EX, predictions_EX) = empathy_classifier.predict_empathy([seeker_posts[i]], [generated_responses[i]]) 109 | batch_score += ((predictions_ER[0]+predictions_IP[0]+predictions_EX[0])*0.5) 110 | 111 | return batch_score/len(seeker_posts) 112 | 113 | def calc_empathy_score_3dim(seeker_posts, generated_responses): 114 | batch_score = 0 115 | ER_score_list =[] 116 | IP_score_list =[] 117 | EX_score_list =[] 118 | 119 | for i in range(len(seeker_posts)): 120 | try: 121 | (logits_empathy_ER, predictions_ER, logits_empathy_IP, predictions_IP, logits_empathy_EX, predictions_EX) = empathy_classifier.predict_empathy([seeker_posts[i]], [generated_responses[i]]) 122 | batch_score += ((predictions_ER[0]+predictions_IP[0]+predictions_EX[0])) 123 | ER_score_list.append(predictions_ER[0]) 124 | IP_score_list.append(predictions_IP[0]) 125 | EX_score_list.append(predictions_EX[0]) 126 | except: 127 | print('Error:', seeker_posts[i], generated_responses[i]) 128 | 129 | ER_score = np.sum(ER_score_list)/len(seeker_posts) 130 | IP_score = np.sum(IP_score_list)/len(seeker_posts) 131 | EX_score = np.sum(EX_score_list)/len(seeker_posts) 132 | 133 | return batch_score/len(seeker_posts),ER_score_list,IP_score_list,EX_score_list, ER_score, IP_score, EX_score 134 | 135 | def log2prob(logs): 136 | probs = np.divide(np.exp(logs), (1+np.exp(logs))) 137 | return probs 138 | 139 | def calc_coherence_score(original_responses, candidate): # original_response: list of strings, candidate: string 140 | (logits, predictions,) = coherence_classifier.predict_empathy(original_responses, candidate) 141 | logs_1 = [log[1] for log in logits] 142 | score = np.mean(log2prob(logs_1)) 143 | return score 144 | 145 | ## for politeness, we only use the original response: 1 input 146 | def calc_politeness_score(original_responses, candidate): # original_response: list of strings, candidate: string 147 | (logits, predictions,) = politeness_classifier.predict_empathy(original_responses, candidate) 148 | # here logits: 0:dim: impolite; 1:polite(normal and polite) 149 | logs_1 = [log[1] for log in logits] 150 | score = np.mean(log2prob(logs_1)) 151 | return score 152 | 153 | ## for refutation scores, we should use two inputs: 154 | def calc_refutation_score(original_responses, candidate): # original_response: list of strings, candidate: string 155 | (logits, predictions,) = refutation_classifier.predict_empathy(original_responses, candidate) 156 | # here logits: 0:dim: impolite; 1:polite(normal and polite) 157 | logs_1 = [log[1] for log in logits] 158 | score = np.mean(log2prob(logs_1)) 159 | return score 160 | 161 | def calc_evidence_score(original_responses, candidate): # original_response: list of strings, candidate: string 162 | (logits, predictions,) = evidence_classifier.predict_empathy(original_responses, candidate) 163 | # here logits: 0:dim: impolite; 1:polite(normal and polite) 164 | logs_1 = [log[1] for log in logits] 165 | score = np.mean(log2prob(logs_1)) 166 | return score 167 | 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /src/process_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | preprocess input data into feature and stores binary as python shelve DB 3 | each chunk is gzipped JSON string 4 | """ 5 | import argparse 6 | import gzip 7 | import json 8 | import subprocess as sp 9 | import shelve 10 | import os 11 | from os.path import dirname, exists, join 12 | 13 | import torch 14 | from lsp_model_rl import GPT2Tokenizer 15 | from tqdm import tqdm 16 | 17 | from env import END_OF_TEXT_TOKEN 18 | from gpt2_training.train_utils import InputFeatures as InputFeatures 19 | 20 | 21 | def _get_file_len(corpus): 22 | n_line = int(sp.check_output(f"wc -l {corpus}".split(), 23 | universal_newlines=True).split()[0]) 24 | return n_line 25 | 26 | 27 | def _norm_text(text): 28 | w, *toks = text.strip().split() 29 | try: 30 | w = float(w) 31 | except Exception: 32 | toks = [w] + toks 33 | w = 1.0 34 | return w, ' '.join(toks) 35 | 36 | 37 | def _get_inputs_from_text(text, tokenizer, if_input_src_only=False): 38 | src_seeker_post = text.strip().split('\t')[0].strip() # srcs.split('')[0].strip()[4:] 39 | src_response_post = text.strip().split('\t')[1].strip() # srcs.split('')[1].strip() 40 | 41 | if if_input_src_only: 42 | # ==== condition 1 ====: but, see so many split tokens in the text generation 43 | # srcs = src_seeker_post + ' ' + ' ' 44 | # ==== condition 2 ====: like Dialog-GPT setup 45 | srcs = src_seeker_post + ' <|endoftext|> ' 46 | else: 47 | srcs = src_seeker_post + ' ' + src_response_post + ' ' 48 | 49 | inputs = [] 50 | 51 | for src in srcs.split(' EOS '): 52 | context_id = tokenizer.encode(src) 53 | inputs.append(context_id) 54 | 55 | # pos = [pos,] 56 | return inputs, src_seeker_post, src_response_post 57 | 58 | 59 | def _make_features(id_, inputs, src_seeker_post, src_response_post, tokenizer, max_len): 60 | end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN] 61 | features = [] 62 | sents = [] 63 | ws = [] 64 | ps = [] 65 | len_ = 0 66 | i = 0 67 | # ids: list of list: [[1,2,3]] 68 | for ids in inputs: 69 | if len(ids) > max_len: 70 | ids = ids[:max_len] 71 | 72 | len_ += (len(ids) + 1) 73 | sents.append(ids) 74 | # : it seems sents is just a concated 75 | if len(sents) >= 1: 76 | feat = _make_feature(id_ + i, sents, src_seeker_post, src_response_post, end_of_text_id) 77 | if feat is not None: 78 | features.append(feat) 79 | 80 | return features 81 | 82 | 83 | def _make_feature(id_, sents, src_seeker_post, src_response_post, eos): 84 | # : change from list of list to a single list 85 | input_ids = [i for s in sents for i in s+[eos]][:-1] 86 | 87 | weights = [] 88 | token_type_ids = [] # this becomes round ids 89 | 90 | # : in the input index list: we have split in the middle 91 | # like: source reply 92 | split_id = toker.encode("")[0] 93 | 94 | curr_id = 0 95 | 96 | for i, s in enumerate(input_ids): 97 | 98 | if s == split_id: 99 | curr_id = 1 100 | 101 | token_type_ids.append(curr_id) 102 | 103 | 104 | # TODO: handle trailing -1's 105 | # input_ids = input_ids[:i+1] 106 | # token_type_ids = token_type_ids[:i+1] 107 | 108 | # pad to multiples of 8 109 | while len(input_ids) % 8 != 0: 110 | input_ids.append(0) 111 | token_type_ids.append(0) 112 | 113 | position_ids = list(range(len(input_ids))) 114 | assert (len(input_ids) == len(position_ids) == len(token_type_ids)) 115 | assert len(input_ids) % 8 == 0 116 | 117 | if len(input_ids) == 0: 118 | import pdb 119 | pdb.set_trace() 120 | 121 | # : here, we have three important ids: 122 | # 1) input_ids: sequential; 2) position_ids: provide the position information 123 | # 3) token_type: for the related split/EOS tagging 124 | feature = InputFeatures(id_, input_ids, position_ids, token_type_ids, 125 | src_seeker_post, src_response_post) 126 | 127 | return feature 128 | 129 | 130 | toker = GPT2Tokenizer.from_pretrained('gpt2-medium') 131 | toker.add_tokens(['', '', '']) 132 | 133 | def main(args): 134 | 135 | attrs = [] 136 | if args.if_input_src_only: 137 | print(f"we only keep source text in the input") 138 | attrs.append("if_input_src_only") 139 | if args.reverse: 140 | attrs.append('reverse') 141 | if args.two_turn: 142 | attrs.append('2turn') 143 | if attrs: 144 | db_path = (f'{args.corpus[:-4]}.{args.max_seq_len}len.' 145 | f'{".".join(attrs)}.db/db') 146 | else: 147 | db_path = f'{args.corpus[:-4]}.{args.max_seq_len}len.db/db' 148 | print(f"the db path is: {db_path}") 149 | if exists(dirname(db_path)): 150 | raise ValueError('Found existing DB, please backup') 151 | else: 152 | os.makedirs(dirname(db_path)) 153 | with open(args.corpus, "r", encoding="utf-8") as reader, \ 154 | shelve.open(db_path, 'n') as db: 155 | chunk = [] 156 | n_chunk = 0 157 | n_example = 0 158 | for line in tqdm(reader, total=_get_file_len(args.corpus)): 159 | # print('line:', line) 160 | # print('n_chunk:', len(chunk)) 161 | 162 | try: 163 | if len(chunk) >= args.chunk_size: 164 | # save and renew chunk 165 | db[f'chunk_{n_chunk}'] = gzip.compress( 166 | json.dumps(chunk[:args.chunk_size]).encode('utf-8')) 167 | chunk = chunk[args.chunk_size:] 168 | n_chunk += 1 169 | 170 | # : inputs will have the whole seeker post+response post after tokenization and encoding 171 | inputs, src_seeker_post, src_response_post = _get_inputs_from_text(line, toker, if_input_src_only=args.if_input_src_only) 172 | if args.reverse: 173 | inputs = list(reversed(inputs)) 174 | if args.two_turn: 175 | inputs = inputs[:2] 176 | # : when running the code, it is one line with only one seeker post and response post 177 | # print(f"inputs is: {inputs}\n{type(inputs)}") 178 | # inputs: list of list: [[22087, 13, 10478]] 179 | features = _make_features(n_example, inputs, src_seeker_post, src_response_post, 180 | toker, args.max_seq_len) 181 | 182 | # print('features:', features) # : no need to see it: 183 | # features: [] 184 | 185 | 186 | for feature in features: 187 | chunk.append(vars(feature)) 188 | n_example += 1 189 | except Exception as e: 190 | print('!!! prepro exception !!!', e) 191 | continue 192 | # save last chunk 193 | db[f'chunk_{n_chunk}'] = gzip.compress( 194 | json.dumps(chunk).encode('utf-8')) 195 | # save relevant information to reproduce 196 | meta = {'n_example': n_example, 197 | 'chunk_size': args.chunk_size, 198 | 'max_seq_len': args.max_seq_len, 199 | 'reverse': args.reverse, 200 | 'two_turn': args.two_turn} 201 | with open(join(dirname(db_path), 'meta.json'), 'w') as writer: 202 | json.dump(meta, writer, indent=4) 203 | torch.save(toker, join(dirname(db_path), 'tokenizer.pt')) 204 | 205 | 206 | if __name__ == '__main__': 207 | parser = argparse.ArgumentParser() 208 | parser.add_argument('--corpus', required=True, 209 | help='file name of training corpus (should be .tsv)') 210 | parser.add_argument('--chunk_size', type=int, default=65536, 211 | help='num of data examples in a storing chunk') 212 | parser.add_argument('--max_seq_len', type=int, default=128, 213 | help='discard data longer than this') 214 | parser.add_argument('--reverse', action='store_true', 215 | help='reverse the src tgt') 216 | parser.add_argument('--two_turn', action='store_true', 217 | help='take only the first 2 turns: for toy examples? guessed by ') 218 | parser.add_argument('--if_input_src_only', action='store_true', default=False, 219 | help='by : input with src only') 220 | 221 | args = parser.parse_args() 222 | 223 | main(args) 224 | -------------------------------------------------------------------------------- /src/lsp_model_rl/util/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 28 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 29 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 30 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 31 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 32 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 33 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 34 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 35 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 36 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 39 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 40 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 41 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 42 | "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json", 43 | "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json", 44 | "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json", 45 | "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json", 46 | "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", 47 | "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", 48 | "bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json", 49 | } 50 | 51 | 52 | class BertConfig(PretrainedConfig): 53 | r""" 54 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel`. 55 | It is used to instantiate an BERT model according to the specified arguments, defining the model 56 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 57 | the BERT `bert-base-uncased `__ architecture. 58 | 59 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 60 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 61 | for more information. 62 | 63 | 64 | Args: 65 | vocab_size (:obj:`int`, optional, defaults to 30522): 66 | Vocabulary size of the BERT model. Defines the different tokens that 67 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`. 68 | hidden_size (:obj:`int`, optional, defaults to 768): 69 | Dimensionality of the encoder layers and the pooler layer. 70 | num_hidden_layers (:obj:`int`, optional, defaults to 12): 71 | Number of hidden layers in the Transformer encoder. 72 | num_attention_heads (:obj:`int`, optional, defaults to 12): 73 | Number of attention heads for each attention layer in the Transformer encoder. 74 | intermediate_size (:obj:`int`, optional, defaults to 3072): 75 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 76 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): 77 | The non-linear activation function (function or string) in the encoder and pooler. 78 | If string, "gelu", "relu", "swish" and "gelu_new" are supported. 79 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): 80 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. 81 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): 82 | The dropout ratio for the attention probabilities. 83 | max_position_embeddings (:obj:`int`, optional, defaults to 512): 84 | The maximum sequence length that this model might ever be used with. 85 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 86 | type_vocab_size (:obj:`int`, optional, defaults to 2): 87 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`. 88 | initializer_range (:obj:`float`, optional, defaults to 0.02): 89 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 90 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): 91 | The epsilon used by the layer normalization layers. 92 | 93 | Example:: 94 | 95 | from transformers import BertModel, BertConfig 96 | 97 | # Initializing a BERT bert-base-uncased style configuration 98 | configuration = BertConfig() 99 | 100 | # Initializing a model from the bert-base-uncased style configuration 101 | model = BertModel(configuration) 102 | 103 | # Accessing the model configuration 104 | configuration = model.config 105 | 106 | Attributes: 107 | pretrained_config_archive_map (Dict[str, str]): 108 | A dictionary containing all the available pre-trained checkpoints. 109 | """ 110 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 111 | model_type = "bert" 112 | 113 | def __init__( 114 | self, 115 | vocab_size=30522, 116 | hidden_size=768, 117 | num_hidden_layers=12, 118 | num_attention_heads=12, 119 | intermediate_size=3072, 120 | hidden_act="gelu", 121 | hidden_dropout_prob=0.1, 122 | attention_probs_dropout_prob=0.1, 123 | max_position_embeddings=512, 124 | type_vocab_size=2, 125 | initializer_range=0.02, 126 | layer_norm_eps=1e-12, 127 | pad_token_id=0, 128 | **kwargs 129 | ): 130 | super().__init__(pad_token_id=pad_token_id, **kwargs) 131 | 132 | self.vocab_size = vocab_size 133 | self.hidden_size = hidden_size 134 | self.num_hidden_layers = num_hidden_layers 135 | self.num_attention_heads = num_attention_heads 136 | self.hidden_act = hidden_act 137 | self.intermediate_size = intermediate_size 138 | self.hidden_dropout_prob = hidden_dropout_prob 139 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 140 | self.max_position_embeddings = max_position_embeddings 141 | self.type_vocab_size = type_vocab_size 142 | self.initializer_range = initializer_range 143 | self.layer_norm_eps = layer_norm_eps -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import gzip 4 | import json 5 | import math 6 | import random 7 | import shelve 8 | import torch 9 | 10 | import subprocess as sp 11 | 12 | from math import ceil 13 | from torch.utils.data import DataLoader, Sampler, Dataset 14 | from torch.nn.utils.rnn import pad_sequence 15 | 16 | from env import END_OF_TEXT_TOKEN 17 | from gpt2_training.train_utils import (InputFeatures, 18 | RedditExample) 19 | 20 | 21 | class BucketSampler(Sampler): 22 | """ 23 | this sampler will sort data by sequence length 24 | """ 25 | def __init__(self, lens, bucket_size, batch_size, 26 | droplast=False, shuffle=True): 27 | self._lens = lens 28 | self._batch_size = batch_size 29 | self._bucket_size = bucket_size 30 | self._droplast = droplast 31 | self._shuf = shuffle 32 | 33 | def __iter__(self): 34 | ids = list(range(len(self._lens))) 35 | if self._shuf: 36 | random.shuffle(ids) 37 | buckets = [sorted(ids[i:i+self._bucket_size], 38 | key=lambda i: self._lens[i], reverse=True) 39 | for i in range(0, len(ids), self._bucket_size)] 40 | batches = [bucket[i:i+self._batch_size] 41 | for bucket in buckets 42 | for i in range(0, len(bucket), self._batch_size)] 43 | if self._droplast: 44 | batches = [batch for batch in batches 45 | if len(batch) == self._batch_size] 46 | if self._shuf: 47 | random.shuffle(batches) 48 | return iter(batches) 49 | 50 | def __len__(self): 51 | bucket_sizes = ([self._bucket_size] 52 | * (len(self._lens) // self._bucket_size) 53 | + [len(self._lens) % self._bucket_size]) 54 | if self._droplast: 55 | return sum(s//self._batch_size for s in bucket_sizes) 56 | else: 57 | return sum(math.ceil(s/self._batch_size) for s in bucket_sizes) 58 | 59 | 60 | class GPT2FeatureDataset(Dataset): 61 | """ pytorch dataset for GPT2 training """ 62 | def __init__(self, features, max_len=None): 63 | self.features = features 64 | self.max_len = max_len # this max_len do truncate 65 | 66 | def __getitem__(self, i): 67 | feat_dict = self.features[i] 68 | if self.max_len is not None and feat_dict['input_len'] > self.max_len: 69 | # tuncate on the left side (context) 70 | feat_dict['input_ids'] = feat_dict['input_ids'][-self.max_len:] 71 | feat_dict['position_ids'] = feat_dict['position_ids'][ 72 | -self.max_len:] 73 | feat_dict['token_type_ids'] = feat_dict['token_type_ids'][ 74 | -self.max_len:] 75 | # feat_dict['lm_labels'] = feat_dict['lm_labels'][-self.max_len:] 76 | # feat_dict['pos_labels'] = feat_dict['pos_labels'] 77 | try: 78 | for s in ['context_len', 'response_len']: 79 | if s in feat_dict.keys(): 80 | print("db file missing "+s) 81 | del feat_dict[s] 82 | except Exception: 83 | import pdb 84 | pdb.set_trace() 85 | 86 | feat = InputFeatures(**feat_dict) 87 | return feat 88 | 89 | def __len__(self): 90 | return len(self.features) 91 | 92 | @staticmethod 93 | def collate(features): 94 | input_ids = pad_sequence([torch.tensor(f.input_ids, dtype=torch.long) 95 | for f in features], 96 | batch_first=True, padding_value=0) 97 | position_ids = pad_sequence([torch.tensor(f.position_ids, 98 | dtype=torch.long) 99 | for f in features], 100 | batch_first=True, padding_value=0) 101 | token_type_ids = pad_sequence([torch.tensor(f.token_type_ids, 102 | dtype=torch.long) 103 | for f in features], 104 | batch_first=True, padding_value=0) 105 | # labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long) 106 | # for f in features], 107 | # batch_first=True, padding_value=-1) 108 | # pos_labels = torch.tensor([torch.tensor(f.pos_label, dtype=torch.long) for f in features]) 109 | seeker_post = [f.seeker_post for f in features] 110 | response_post = [f.response_post for f in features] 111 | 112 | return (input_ids, position_ids, token_type_ids, seeker_post, response_post) 113 | 114 | 115 | class BucketingDataLoader(object): 116 | """ this loads shelve db chunks and then convert to mini-batch loader""" 117 | def __init__(self, db_name, batch_size, max_seq_len, 118 | bucket=100, shuffle=True): 119 | print(f'*** {db_name}/db****') 120 | self.db = shelve.open(f'{db_name}/db', 'r') 121 | self.batch_size = batch_size 122 | self.max_len = max_seq_len 123 | self.bucket_size = bucket * batch_size 124 | self.shuffle = shuffle 125 | 126 | def _get_keys(self): 127 | keys = list(self.db.keys()) 128 | return keys 129 | 130 | def __iter__(self): 131 | keys = self._get_keys() 132 | if self.shuffle: 133 | random.shuffle(keys) 134 | for key in keys: 135 | chunk = json.loads(gzip.decompress(self.db[key]).decode('utf-8')) 136 | # discard long examples 137 | trunc_chunk = [] 138 | lens = [] 139 | for feat in chunk: 140 | if feat['input_len'] > self.max_len: 141 | continue 142 | trunc_chunk.append(feat) 143 | lens.append(feat['input_len']) 144 | 145 | # print('trunc_chunk:', trunc_chunk) 146 | 147 | dataset = GPT2FeatureDataset(trunc_chunk, self.max_len) 148 | sampler = BucketSampler(lens, self.bucket_size, self.batch_size, 149 | droplast=True, shuffle=self.shuffle) 150 | loader = DataLoader(dataset, batch_sampler=sampler, 151 | num_workers=0, # can test multi-worker 152 | collate_fn=GPT2FeatureDataset.collate) 153 | yield from loader 154 | 155 | def __len__(self): 156 | raise NotImplementedError() 157 | 158 | def __del__(self): 159 | self.db.close() # ext: not close it? how it goes. 160 | # pass 161 | 162 | 163 | class DistributedBucketingDataLoader(BucketingDataLoader): 164 | """ distributed version """ 165 | def __init__(self, rank, num_replica, *args, **kwargs): 166 | super().__init__(*args, **kwargs) 167 | self.rank = rank 168 | self.num_replica = num_replica 169 | 170 | def _get_keys(self): 171 | keys = list(self.db.keys())[self.rank::self.num_replica] 172 | return keys 173 | 174 | 175 | def convert_examples_to_features_dynamic(examples, tokenizer, 176 | max_seq_length=512): 177 | """ 178 | do not pad 179 | """ 180 | def featurize(example): 181 | conv_id = example.conv_id 182 | context_id = tokenizer.encode(example.context) 183 | end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN] 184 | 185 | # response is provided in example 186 | response_id = tokenizer.encode(example.response) 187 | 188 | input_ids_len = len(context_id) + len(response_id) + 2 189 | if input_ids_len > max_seq_length: 190 | if len(context_id) > input_ids_len - max_seq_length: 191 | # cut context from beginning if length of context + response is too long 192 | # and len of context is long enough to cut 193 | context_id = context_id[input_ids_len - max_seq_length:] 194 | else: 195 | # cut response from end if length of context + response is too long 196 | # and len of response is long enough to cut 197 | # if no response is available, discard the data 198 | if max_seq_length-len(context_id)-2 < 0: 199 | return None 200 | response_id = response_id[:max_seq_length-len(context_id)-2] 201 | 202 | input_ids = context_id + [end_of_text_id] + response_id + [end_of_text_id] 203 | 204 | # label simplely is next token in sequences. MASK all context_id tokens except for the last one 205 | # lm_labels = [-1] * len(context_id) + response_id + [end_of_text_id] + [-1] 206 | 207 | position_ids = list(range(len(input_ids))) 208 | 209 | token_type_id = [0] * len(input_ids) 210 | 211 | return InputFeatures(conv_id, input_ids, position_ids, token_type_id, 212 | len(context_id), len(response_id)) 213 | 214 | # discard None feature 215 | features = [f for f in [featurize(ex) for ex in examples] if f is not None] 216 | return features 217 | 218 | 219 | class DynamicBatchingLoader(object): 220 | """ this loader takes raw text file, used for validate perplexity """ 221 | def __init__(self, corpus_file, tokenizer, normalize_data, 222 | batch_size, max_seq_length): 223 | self.corpus = corpus_file 224 | self.toker = tokenizer 225 | self.norm = normalize_data 226 | self.bs = batch_size 227 | self.max_seq_length = max_seq_length 228 | self.num_examples = self.get_len(corpus_file) 229 | 230 | def __iter__(self, epoch=1): 231 | if epoch > 0: 232 | for epoch in range(epoch): 233 | yield from self._iter_epoch() 234 | else: 235 | while True: 236 | yield from self._iter_epoch() 237 | 238 | def __len__(self): 239 | return ceil(self.num_examples/self.bs) 240 | 241 | def _iter_epoch(self): 242 | try: 243 | with open(self.corpus, 'r', encoding="utf-8") as corpus: 244 | i = 0 245 | while True: 246 | examples = [] 247 | cur_bs = 0 248 | while True: 249 | line = next(corpus).encode('utf-8').decode('utf-8') 250 | contents = line.split('\t') 251 | src, tgt_all = contents[0], contents[1:] 252 | for tgt in tgt_all: 253 | if self.norm: 254 | src_line = ' '.join(src.strip().split()) 255 | tgt_line = ' '.join(tgt.strip().split()) 256 | else: 257 | src_line = src.strip() 258 | tgt_line = tgt.strip() 259 | examples.append( 260 | RedditExample(i, src_line, tgt_line), 261 | ) 262 | i += 1 263 | cur_bs += 1 264 | if cur_bs >= self.bs: 265 | break 266 | features = convert_examples_to_features_dynamic( 267 | examples, self.toker, self.max_seq_length) 268 | batch = self._batch_feature(features) 269 | yield batch 270 | except StopIteration: 271 | pass 272 | 273 | def _batch_feature(self, features): 274 | input_ids = pad_sequence([torch.tensor(f.choices_features['input_ids'], 275 | dtype=torch.long) 276 | for f in features], 277 | batch_first=True, padding_value=0) 278 | position_ids = pad_sequence( 279 | [torch.tensor(f.choices_features['position_ids'], dtype=torch.long) 280 | for f in features], 281 | batch_first=True, padding_value=0) 282 | token_type_ids = pad_sequence( 283 | [torch.tensor(f.choices_features['token_type_ids'], 284 | dtype=torch.long) 285 | for f in features], 286 | batch_first=True, padding_value=0) 287 | # labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long) 288 | # for f in features], 289 | # batch_first=True, padding_value=-1) 290 | context_len = torch.tensor([f.context_len for f in features], 291 | dtype=torch.long) 292 | response_len = torch.tensor([f.response_len for f in features], 293 | dtype=torch.long) 294 | return (input_ids, position_ids, token_type_ids, 295 | context_len, response_len) 296 | 297 | def get_len(self, corpus): 298 | n_line = int(sp.check_output(f"wc -l {corpus}".split(), 299 | universal_newlines=True).split()[0]) 300 | return n_line 301 | -------------------------------------------------------------------------------- /src/lsp_model_rl/coherence_classifier2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import codecs 3 | import numpy as np 4 | 5 | import tqdm 6 | from tqdm import tqdm 7 | 8 | import pandas as pd 9 | import re 10 | import csv 11 | import numpy as np 12 | 13 | import time 14 | 15 | from sklearn.metrics import f1_score 16 | 17 | from transformers import RobertaTokenizer 18 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 19 | from torch.utils.data import TensorDataset, random_split 20 | 21 | from transformers import AdamW, RobertaConfig, RobertaForSequenceClassification 22 | 23 | import datetime 24 | # 25 | import sys, os 26 | sys.path.append("../") 27 | sys.path.append(".../") 28 | from MisinfoCorrect.src.variables_ext import politeness_clf_fp, refutation_clf_fp, evidence_clf_fp 29 | 30 | class CoherenceClassifier(): 31 | def __init__(self, 32 | device, 33 | model_path, 34 | batch_size = 2): 35 | 36 | self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True) 37 | self.batch_size = batch_size 38 | self.device = device 39 | 40 | self.model = RobertaForSequenceClassification.from_pretrained( 41 | "roberta-base", # Use the 12-layer BERT model, with an uncased vocab. 42 | num_labels = 2, # The number of output labels--2 for binary classification. 43 | # You can increase this for multi-class tasks. 44 | output_attentions = False, # Whether the model returns attentions weights. 45 | output_hidden_states = False, # Whether the model returns all hidden-states. 46 | ) 47 | 48 | # comment data parallel by ext, due to the missing key errors: the possible reason is 49 | # that the keys are allocated in different gpus? --- be careful of it? TODO 50 | # OR: another solution, we put nn.DataParallel()after the loading: 51 | # the only way it can work is that the saving is through nn.DataParallel: interesting 52 | # self.model = torch.nn.DataParallel(self.model) # TODO: check the details?? by ext 53 | print(f"the model path is: {model_path}") 54 | weights = torch.load(model_path, map_location=self.device) # commented by ext 55 | self.model.load_state_dict(weights) 56 | 57 | self.model.to(self.device) 58 | 59 | 60 | def predict_empathy(self, original_responses, candidate): 61 | 62 | input_ids = [] 63 | attention_masks = [] 64 | 65 | for idx, elem in enumerate(original_responses): 66 | 67 | response_sentence = original_responses[idx] + ' ' + candidate 68 | 69 | encoded_dict = self.tokenizer.encode_plus( 70 | response_sentence, # Sentence to encode. 71 | add_special_tokens = True, # Add '[CLS]' and '[SEP]' 72 | max_length = 64, # Pad & truncate all sentences. 73 | pad_to_max_length = True, 74 | return_attention_mask = True, # Construct attn. masks. 75 | return_tensors = 'pt', # Return pytorch tensors. 76 | ) 77 | 78 | input_ids.append(encoded_dict['input_ids']) 79 | attention_masks.append(encoded_dict['attention_mask']) 80 | 81 | input_ids = torch.cat(input_ids, dim=0) 82 | attention_masks = torch.cat(attention_masks, dim=0) 83 | 84 | dataset = TensorDataset(input_ids, attention_masks) 85 | 86 | dataloader = DataLoader( 87 | dataset, # The test samples. 88 | sampler = SequentialSampler(dataset), # Pull out batches sequentially. 89 | batch_size = self.batch_size # Evaluate with this batch size. 90 | ) 91 | 92 | self.model.eval() 93 | 94 | for batch in dataloader: 95 | b_input_ids = batch[0].to(self.device) 96 | b_input_mask = batch[1].to(self.device) 97 | 98 | with torch.no_grad(): 99 | (logits, ) = self.model(input_ids = b_input_ids, 100 | token_type_ids=None, 101 | attention_mask=b_input_mask,) 102 | # res = self.model(input_ids = b_input_ids, 103 | # token_type_ids=None, 104 | # attention_mask=b_input_mask,) 105 | # logits = res.logits 106 | 107 | 108 | logits = logits.detach().cpu().numpy().tolist() 109 | predictions = np.argmax(logits, axis=1).flatten() 110 | 111 | return (logits, predictions) 112 | 113 | 114 | class PolitenessClassifier(CoherenceClassifier): 115 | pass 116 | # def __init__(self): 117 | # super().__init__() 118 | 119 | class RefutationClassifier(CoherenceClassifier): 120 | def __init__(self, device, model_path, batch_size=2): 121 | super().__init__(device, model_path, batch_size) 122 | 123 | def predict_empathy(self, original_responses, candidate): 124 | 125 | 126 | test_encode = self.tokenizer.batch_encode_plus(list(zip(original_responses, candidate)), 127 | padding='max_length', truncation=True, max_length=128, return_tensors='pt', pad_to_max_length=True) 128 | 129 | test_seq = torch.tensor(test_encode['input_ids']) 130 | test_mask = torch.tensor(test_encode['attention_mask']) 131 | # test_token = torch.tensor(test_encode['token_type_ids']) 132 | 133 | 134 | test_data = TensorDataset(test_seq, test_mask) #, test_token 135 | 136 | dataloader = DataLoader( 137 | test_data, # The test samples. 138 | sampler = SequentialSampler(test_data), # Pull out batches sequentially. 139 | batch_size = self.batch_size # Evaluate with this batch size. 140 | ) 141 | 142 | def model_eval_and_infer(model, prediction_dataloader, device, if_infer=False, if_have_token_types=True): 143 | 144 | model.eval() 145 | if not if_infer: 146 | predictions , true_labels = [], [] 147 | 148 | for batch in tqdm(prediction_dataloader): 149 | 150 | if if_have_token_types: 151 | b_input_ids, b_input_mask, b_token_type, b_labels = batch 152 | else: 153 | b_input_ids, b_input_mask, b_labels = batch 154 | with torch.no_grad(): 155 | 156 | if if_have_token_types: 157 | 158 | outputs = model(b_input_ids.to(device), token_type_ids=b_token_type.to(device), 159 | attention_mask=b_input_mask.to(device)) 160 | else: 161 | outputs = model(b_input_ids.to(device), attention_mask=b_input_mask.to(device)) 162 | 163 | b_proba = outputs[0] 164 | 165 | proba = b_proba.detach().cpu().numpy() 166 | label_ids = b_labels.numpy() 167 | 168 | predictions.append(proba) 169 | true_labels.append(label_ids) 170 | 171 | print(' DONE.') 172 | 173 | flat_predictions = np.concatenate(predictions, axis=0) 174 | y_pred = np.argmax(flat_predictions, axis=1).flatten() 175 | y_true = np.concatenate(true_labels, axis=0) 176 | 177 | 178 | return y_pred, y_true 179 | if if_infer: 180 | predictions = [] 181 | 182 | for batch in prediction_dataloader: 183 | 184 | if if_have_token_types: 185 | b_input_ids, b_input_mask, b_token_type = batch 186 | else: 187 | b_input_ids, b_input_mask = batch 188 | with torch.no_grad(): 189 | if if_have_token_types: 190 | outputs = model(b_input_ids.to(device), token_type_ids=b_token_type.to(device), 191 | attention_mask=b_input_mask.to(device)) 192 | else: 193 | outputs = model(b_input_ids.to(device), attention_mask=b_input_mask.to(device)) 194 | b_proba = outputs[0] 195 | 196 | proba = b_proba.detach().cpu().numpy() 197 | # label_ids = b_labels.numpy() 198 | 199 | predictions.append(proba) 200 | # true_labels.append(label_ids) 201 | 202 | 203 | 204 | flat_predictions = np.concatenate(predictions, axis=0) 205 | y_pred = np.argmax(flat_predictions, axis=1).flatten() 206 | # y_true = np.concatenate(true_labels, axis=0) 207 | 208 | return y_pred, flat_predictions 209 | 210 | predictions, logits = model_eval_and_infer(self.model, dataloader, device=self.device, if_infer=True, if_have_token_types=False) 211 | return (logits, predictions) 212 | 213 | # ==== before Oct 2022 ==== in the past 214 | # input_ids = [] 215 | # attention_masks = [] 216 | 217 | # for idx, elem in enumerate(original_responses): 218 | 219 | # response_sentence = original_responses[idx] + ' ' + candidate 220 | 221 | # encoded_dict = self.tokenizer.encode_plus( 222 | # response_sentence, # Sentence to encode. 223 | # add_special_tokens = True, # Add '[CLS]' and '[SEP]' 224 | # max_length = 64, # Pad & truncate all sentences. 225 | # pad_to_max_length = True, 226 | # return_attention_mask = True, # Construct attn. masks. 227 | # return_tensors = 'pt', # Return pytorch tensors. 228 | # ) 229 | 230 | # input_ids.append(encoded_dict['input_ids']) 231 | # attention_masks.append(encoded_dict['attention_mask']) 232 | 233 | # input_ids = torch.cat(input_ids, dim=0) 234 | # attention_masks = torch.cat(attention_masks, dim=0) 235 | 236 | # dataset = TensorDataset(input_ids, attention_masks) 237 | 238 | # dataloader = DataLoader( 239 | # dataset, # The test samples. 240 | # sampler = SequentialSampler(dataset), # Pull out batches sequentially. 241 | # batch_size = self.batch_size # Evaluate with this batch size. 242 | # ) 243 | 244 | # self.model.eval() 245 | 246 | # for batch in dataloader: 247 | # b_input_ids = batch[0].to(self.device) 248 | # b_input_mask = batch[1].to(self.device) 249 | 250 | # with torch.no_grad(): 251 | # (logits, ) = self.model(input_ids = b_input_ids, 252 | # token_type_ids=None, 253 | # attention_mask=b_input_mask,) 254 | # # res = self.model(input_ids = b_input_ids, 255 | # # token_type_ids=None, 256 | # # attention_mask=b_input_mask,) 257 | # # logits = res.logits 258 | 259 | 260 | # logits = logits.detach().cpu().numpy().tolist() 261 | # predictions = np.argmax(logits, axis=1).flatten() 262 | 263 | # return (logits, predictions) 264 | 265 | 266 | # inherited from the refutation classifier when considering both the tweet and reply 267 | class EvidenceClassifier(RefutationClassifier): 268 | pass 269 | # def __init__(self): 270 | # super().__init__() 271 | 272 | 273 | ''' 274 | Example: 275 | ''' 276 | 277 | import sys, os 278 | sys.path.append("./") 279 | sys.path.append("../") 280 | sys.path.append(".../") 281 | sys.path.append("..../") 282 | from MisinfoCorrect.src.variables_ext import device 283 | 284 | 285 | original_responses = [ 'I am so sorry that she is not getting it.','so she can get a better idea of what the condition entails?'] 286 | #sentences = ['why do you feel this way?', 'Let me know if you want to talk.'] 287 | candidate = 'Have you thought of directing her to sites like NAMI and Mental Health First Aid' 288 | candidate2 = ' I have been on and off medication for the majority of my life ' 289 | 290 | cadidates = [candidate, candidate2] 291 | 292 | print(f'here we use device: {device}') 293 | coherence_classifier = CoherenceClassifier(device, model_path=politeness_clf_fp) 294 | 295 | # #### #### 296 | # (logits, predictions,) = coherence_classifier.predict_empathy(original_responses, candidate) 297 | 298 | # print(logits, predictions) 299 | 300 | # (logits, predictions,) = coherence_classifier.predict_empathy(original_responses, candidate2) 301 | # print(logits, predictions) 302 | 303 | 304 | politeness_classifier = PolitenessClassifier(device, model_path=politeness_clf_fp) # sanity check 305 | refutation_classifier = RefutationClassifier(device, model_path=refutation_clf_fp) # sanity check 306 | evidence_classifier = EvidenceClassifier(device, model_path=evidence_clf_fp) # sanity check 307 | 308 | (logits, predictions,) = politeness_classifier.predict_empathy(original_responses, candidate2) 309 | print(logits, predictions) 310 | 311 | 312 | # ==== for refutation classifier ==== 313 | print('start sanity check for refutation') 314 | (logits, predictions,) = refutation_classifier.predict_empathy(original_responses, cadidates) 315 | print('the sanity check of refutation classifier version 2') 316 | print(logits, predictions) 317 | 318 | 319 | print('start sanity check for refutation') 320 | (logits, predictions,) = evidence_classifier.predict_empathy(original_responses, cadidates) 321 | print('the sanity check of evidence classifier version 2') 322 | print(logits, predictions) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _anaconda_depends=2020.07=py37_0 5 | _ipyw_jlab_nb_ext_conf=0.1.0=py37_0 6 | _libgcc_mutex=0.1=main 7 | absl-py=1.0.0=pypi_0 8 | aiohttp=3.8.1=pypi_0 9 | aiosignal=1.2.0=pypi_0 10 | alabaster=0.7.12=py37_0 11 | altair=4.1.0=pypi_0 12 | anaconda=custom=py37_1 13 | anaconda-client=1.7.2=py37_0 14 | anaconda-navigator=1.9.12=py37_0 15 | anaconda-project=0.8.4=py_0 16 | appdirs=1.4.4=pypi_0 17 | argh=0.26.2=py37_0 18 | argon2-cffi=20.1.0=py37h7b6447c_1 19 | ase=3.21.1=pypi_0 20 | asn1crypto=1.4.0=py_0 21 | astor=0.8.1=pypi_0 22 | astroid=2.4.2=py37_0 23 | astropy=4.2=py37h27cfd23_0 24 | async-timeout=4.0.1=pypi_0 25 | async_generator=1.10=py37h28b3542_0 26 | asynctest=0.13.0=pypi_0 27 | atomicwrites=1.4.0=py_0 28 | attrs=20.3.0=pyhd3eb1b0_0 29 | autocorrect=1.0.0=pypi_0 30 | autopep8=1.5.4=py_0 31 | babel=2.9.0=pyhd3eb1b0_0 32 | backcall=0.2.0=py_0 33 | backports=1.0=pyhd3eb1b0_2 34 | backports-zoneinfo=0.2.1=pypi_0 35 | backports.functools_lru_cache=1.6.1=pyhd3eb1b0_0 36 | backports.shutil_get_terminal_size=1.0.0=py37_2 37 | backports.tempfile=1.0=py_1 38 | backports.weakref=1.0.post1=py_1 39 | base58=2.1.1=pypi_0 40 | beautifulsoup4=4.9.3=pyhb0f4dca_0 41 | bert-serving=0.0.1=pypi_0 42 | bert-serving-client=1.10.0=pypi_0 43 | bert-serving-server=1.10.0=pypi_0 44 | bitarray=1.6.1=py37h27cfd23_0 45 | bkcharts=0.2=py37_0 46 | blas=1.0=mkl 47 | bleach=3.2.1=py_0 48 | blessings=1.7=pypi_0 49 | blinker=1.4=pypi_0 50 | blis=0.4.1=pypi_0 51 | blosc=1.20.1=hd408876_0 52 | bokeh=2.2.3=py37_0 53 | boto=2.49.0=py37_0 54 | boto3=1.12.34=pypi_0 55 | botocore=1.15.34=pypi_0 56 | bottleneck=1.3.2=py37heb32a55_1 57 | brotli=1.0.9=he6710b0_2 58 | brotlipy=0.7.0=py37h27cfd23_1003 59 | bzip2=1.0.8=h7b6447c_0 60 | ca-certificates=2020.12.5=ha878542_0 61 | cachetools=4.2.4=pypi_0 62 | cairo=1.14.12=h8948797_3 63 | catalogue=1.0.0=pypi_0 64 | certifi=2020.12.5=py37h89c1867_1 65 | cffi=1.14.4=py37h261ae71_0 66 | chardet=3.0.4=py37h06a4308_1003 67 | charls=2.1.0=he6710b0_2 68 | charset-normalizer=2.0.9=pypi_0 69 | clean-text=0.2.1=pypi_0 70 | click=7.1.2=py_0 71 | click-config-file=0.6.0=pypi_0 72 | click-plugins=1.1.1=pypi_0 73 | cloudpickle=1.6.0=py_0 74 | clyent=1.2.2=py37_1 75 | colorama=0.4.4=pyhd3eb1b0_0 76 | conda=4.9.2=py37h89c1867_0 77 | conda-build=3.18.11=py37_0 78 | conda-env=2.6.0=1 79 | conda-package-handling=1.7.2=py37h03888b9_0 80 | conda-verify=3.4.2=py_1 81 | configobj=5.0.6=pypi_0 82 | configparser=5.2.0=pypi_0 83 | contextlib2=0.6.0.post1=py_0 84 | convokit=2.3.2.5=pypi_0 85 | cryptography=3.3.1=py37h3c74f83_0 86 | cudatoolkit=11.0.221=h6bb024c_0 87 | curl=7.71.1=hbc83047_1 88 | cycler=0.10.0=py37_0 89 | cymem=2.0.3=pypi_0 90 | cython=0.29.21=py37h2531618_0 91 | cytoolz=0.11.0=py37h7b6447c_0 92 | dask=2020.12.0=pyhd3eb1b0_0 93 | dask-core=2020.12.0=pyhd3eb1b0_0 94 | datasets=1.16.1=pypi_0 95 | dbus=1.13.18=hb2f20db_0 96 | decorator=4.4.2=py_0 97 | defusedxml=0.6.0=py_0 98 | diff-match-patch=20200713=py_0 99 | dill=0.3.4=pypi_0 100 | distributed=2020.12.0=py37h06a4308_0 101 | docker-pycreds=0.4.0=pypi_0 102 | docutils=0.15.2=pypi_0 103 | emoji=1.2.0=pypi_0 104 | en-core-web-sm=2.3.1=pypi_0 105 | entrypoints=0.3=py37_0 106 | et_xmlfile=1.0.1=py_1001 107 | expat=2.2.10=he6710b0_2 108 | fastai=2.6.3=pypi_0 109 | fastcache=1.1.0=py37h7b6447c_0 110 | fastcore=1.4.2=pypi_0 111 | fastdownload=0.0.5=pypi_0 112 | fastprogress=0.2.4=pypi_0 113 | filelock=3.0.12=py_0 114 | fire=0.4.0=pypi_0 115 | flake8=3.8.4=py_0 116 | flask=1.1.2=py_0 117 | fontconfig=2.13.0=h9420a91_0 118 | freetype=2.10.4=h5ab3b9f_0 119 | fribidi=1.0.10=h7b6447c_0 120 | frozenlist=1.2.0=pypi_0 121 | fsspec=2021.11.1=pypi_0 122 | ftfy=5.8=pypi_0 123 | future=0.18.2=py37_1 124 | get_terminal_size=1.0.0=haa9412d_0 125 | gevent=20.9.0=py37h7b6447c_0 126 | giflib=5.1.4=h14c3975_1 127 | gitdb=4.0.9=pypi_0 128 | gitpython=3.1.24=pypi_0 129 | glib=2.66.1=h92f7085_0 130 | glob2=0.7=py_0 131 | gmp=6.1.2=h6c8ec71_1 132 | gmpy2=2.0.8=py37h10f8cd9_2 133 | google-auth=2.3.3=pypi_0 134 | google-auth-oauthlib=0.4.6=pypi_0 135 | googledrivedownloader=0.4=pypi_0 136 | gpustat=0.6.0=pypi_0 137 | gputil=1.4.0=pypi_0 138 | graphite2=1.3.14=h23475e2_0 139 | greenlet=0.4.17=py37h7b6447c_0 140 | grpcio=1.42.0=pypi_0 141 | gst-plugins-base=1.14.0=hbbd80ab_1 142 | gstreamer=1.14.0=hb31296c_0 143 | h5py=2.10.0=py37h7918eee_0 144 | harfbuzz=2.4.0=hca77d97_1 145 | hdf5=1.10.4=hb1b8bf9_0 146 | heapdict=1.0.1=py_0 147 | html5lib=1.1=py_0 148 | huggingface-hub=0.2.1=pypi_0 149 | humanize=3.12.0=pypi_0 150 | icu=58.2=he6710b0_3 151 | idna=2.8=pypi_0 152 | imagecodecs=2020.5.30=py37hfa7d478_2 153 | imageio=2.9.0=py_0 154 | imagesize=1.2.0=py_0 155 | importlib-metadata=4.8.2=pypi_0 156 | importlib_metadata=2.0.0=1 157 | iniconfig=1.1.1=py_0 158 | intel-openmp=2020.2=254 159 | intervaltree=3.1.0=py_0 160 | ipykernel=5.3.4=py37h5ca1d4c_0 161 | ipython=7.19.0=py37hb070fc8_0 162 | ipython_genutils=0.2.0=pyhd3eb1b0_1 163 | ipywidgets=7.5.1=py_1 164 | isodate=0.6.0=pypi_0 165 | isort=5.6.4=py_0 166 | itsdangerous=1.1.0=py37_0 167 | jbig=2.1=hdba287a_0 168 | jdcal=1.4.1=py_0 169 | jedi=0.14.1=py37_0 170 | jeepney=0.6.0=pyhd3eb1b0_0 171 | jinja2=2.11.2=py_0 172 | jmespath=0.9.5=pypi_0 173 | joblib=0.17.0=py_0 174 | jpeg=9b=h024ee3a_2 175 | json5=0.9.5=py_0 176 | jsonschema=3.2.0=py_2 177 | jupyter=1.0.0=py37_7 178 | jupyter-contrib-core=0.3.3=pypi_0 179 | jupyter_client=6.1.7=py_0 180 | jupyter_console=6.2.0=py_0 181 | jupyter_contrib_core=0.3.3=py_2 182 | jupyter_contrib_nbextensions=0.5.1=pyhd8ed1ab_2 183 | jupyter_core=4.7.0=py37h06a4308_0 184 | jupyter_highlight_selected_word=0.2.0=py37h89c1867_1002 185 | jupyter_latex_envs=1.4.6=pyhd8ed1ab_1002 186 | jupyter_nbextensions_configurator=0.4.1=py37h89c1867_2 187 | jupyterlab=2.2.6=py_0 188 | jupyterlab_pygments=0.1.2=py_0 189 | jupyterlab_server=1.2.0=py_0 190 | jxrlib=1.1=h7b6447c_2 191 | keyring=21.4.0=py37_1 192 | kiwisolver=1.3.0=py37h2531618_0 193 | krb5=1.18.2=h173b8e3_0 194 | lazy-object-proxy=1.4.3=py37h27cfd23_2 195 | lcms2=2.11=h396b838_0 196 | ld_impl_linux-64=2.33.1=h53a641e_7 197 | libaec=1.0.4=he6710b0_1 198 | libarchive=3.4.2=h62408e4_0 199 | libcurl=7.71.1=h20c2e04_1 200 | libedit=3.1.20191231=h14c3975_1 201 | libffi=3.3=he6710b0_2 202 | libgcc-ng=9.1.0=hdf63c60_0 203 | libgfortran-ng=7.3.0=hdf63c60_0 204 | liblief=0.10.1=he6710b0_0 205 | libllvm10=10.0.1=hbcb73fb_5 206 | libllvm9=9.0.1=h4a3c616_1 207 | libpng=1.6.37=hbc83047_0 208 | libsodium=1.0.18=h7b6447c_0 209 | libspatialindex=1.9.3=he6710b0_0 210 | libssh2=1.9.0=h1ba5d50_1 211 | libstdcxx-ng=9.1.0=hdf63c60_0 212 | libtiff=4.1.0=h2733197_1 213 | libtool=2.4.6=h7b6447c_1005 214 | libuuid=1.0.3=h1bed415_2 215 | libuv=1.40.0=h7b6447c_0 216 | libwebp=1.0.1=h8e7db2f_0 217 | libxcb=1.14=h7b6447c_0 218 | libxml2=2.9.10=hb55368b_3 219 | libxslt=1.1.34=hc22bd24_0 220 | libzopfli=1.0.3=he6710b0_0 221 | liwc=0.5.0=pypi_0 222 | liwc-analysis=1.2.4=pypi_0 223 | llvmlite=0.34.0=py37h269e1b5_4 224 | locket=0.2.0=py37_1 225 | lxml=4.6.2=py37h9120a33_0 226 | lz4-c=1.9.2=heb0550a_3 227 | lzo=2.10=h7b6447c_2 228 | markdown=3.3.6=pypi_0 229 | markupsafe=1.1.1=py37h14c3975_1 230 | matplotlib=3.3.2=0 231 | matplotlib-base=3.3.2=py37h817c723_0 232 | mccabe=0.6.1=py37_1 233 | mistune=0.8.4=py37h14c3975_1001 234 | mkl=2020.2=256 235 | mkl-service=2.3.0=py37he8ac12f_0 236 | mkl_fft=1.2.0=py37h23d657b_0 237 | mkl_random=1.1.1=py37h0573a6f_0 238 | mock=4.0.3=pyhd3eb1b0_0 239 | more-itertools=8.6.0=pyhd3eb1b0_0 240 | mpc=1.1.0=h10f8cd9_1 241 | mpfr=4.0.2=hb69a4c5_1 242 | mpmath=1.1.0=py37_0 243 | msgpack-numpy=0.4.6.1=pypi_0 244 | msgpack-python=1.0.0=py37hfd86e86_1 245 | multidict=5.2.0=pypi_0 246 | multipledispatch=0.6.0=py37_0 247 | multiprocess=0.70.12.2=pypi_0 248 | murmurhash=1.0.2=pypi_0 249 | mwparserfromhell=0.5.4=pypi_0 250 | navigator-updater=0.2.1=py37_0 251 | nbclient=0.5.1=py_0 252 | nbconvert=6.0.7=py37_0 253 | nbformat=5.0.7=py_0 254 | ncurses=6.2=he6710b0_1 255 | nest-asyncio=1.4.3=pyhd3eb1b0_0 256 | networkx=2.5=py_0 257 | ninja=1.10.2=py37hff7bd54_0 258 | nltk=3.5=py_0 259 | nose=1.3.7=pyhd3eb1b0_1006 260 | notebook=6.1.4=py37_0 261 | numba=0.51.2=py37h04863e7_1 262 | numexpr=2.7.1=py37h63df603_0 263 | numpy=1.19.2=py37h54aff64_0 264 | numpy-base=1.19.2=py37hfa32c7d_0 265 | numpydoc=1.1.0=pyhd3eb1b0_1 266 | nvidia-ml-py3=7.352.0=pypi_0 267 | oauthlib=3.2.2=pypi_0 268 | olefile=0.46=py37_0 269 | openjpeg=2.3.0=h05c96fa_1 270 | openpyxl=3.0.5=py_0 271 | openssl=1.1.1k=h27cfd23_0 272 | packaging=21.3=pypi_0 273 | pandas=1.1.3=py37he6710b0_0 274 | pandoc=2.11=hb0f4dca_0 275 | pandocfilters=1.4.3=py37h06a4308_1 276 | pango=1.45.3=hd140c19_0 277 | parso=0.5.2=py_0 278 | partd=1.1.0=py_0 279 | patchelf=0.12=h2531618_1 280 | path=15.0.1=py37h06a4308_0 281 | path.py=12.5.0=0 282 | pathlib2=2.3.5=py37h06a4308_2 283 | pathtools=0.1.2=py_1 284 | patsy=0.5.1=py37_0 285 | pcre=8.44=he6710b0_0 286 | pep8=1.7.1=py37_0 287 | pexpect=4.8.0=pyhd3eb1b0_3 288 | pickleshare=0.7.5=pyhd3eb1b0_1003 289 | pillow=8.0.1=py37he98fc37_0 290 | pip=20.3.1=py37h06a4308_0 291 | pip-autoremove=0.9.1=pypi_0 292 | pixman=0.40.0=h7b6447c_0 293 | pkginfo=1.6.1=py37h06a4308_0 294 | plac=1.1.3=pypi_0 295 | pluggy=0.13.1=py37_0 296 | ply=3.11=py37_0 297 | powerlaw=1.4.6=pypi_0 298 | preshed=3.0.2=pypi_0 299 | prometheus_client=0.9.0=pyhd3eb1b0_0 300 | promise=2.3=pypi_0 301 | prompt-toolkit=3.0.8=py_0 302 | prompt_toolkit=3.0.8=0 303 | protobuf=3.19.1=pypi_0 304 | psutil=5.7.2=py37h7b6447c_0 305 | ptyprocess=0.6.0=pyhd3eb1b0_2 306 | py=1.9.0=py_0 307 | py-lief=0.10.1=py37h403a769_0 308 | pyarrow=6.0.1=pypi_0 309 | pyasn1=0.4.8=pypi_0 310 | pyasn1-modules=0.2.8=pypi_0 311 | pycocoevalcap=1.2=pypi_0 312 | pycocotools=2.0.4=pypi_0 313 | pycodestyle=2.6.0=py_0 314 | pycosat=0.6.3=py37h27cfd23_0 315 | pycparser=2.20=py_2 316 | pycrypto=2.6.1=py37h7b6447c_10 317 | pycurl=7.43.0.6=py37h1ba5d50_0 318 | pydeck=0.7.1=pypi_0 319 | pydocstyle=5.1.1=py_0 320 | pyerfa=1.7.1.1=py37h27cfd23_1 321 | pyflakes=2.2.0=py_0 322 | pygments=2.7.3=pyhd3eb1b0_0 323 | pylint=2.6.0=py37_0 324 | pympler=0.9=pypi_0 325 | pyodbc=4.0.30=py37he6710b0_0 326 | pyopenssl=20.0.0=pyhd3eb1b0_1 327 | pyparsing=2.4.7=py_0 328 | pyphen=0.10.0=pypi_0 329 | pyqt=5.9.2=py37h05f1152_2 330 | pyrsistent=0.17.3=py37h7b6447c_0 331 | pyshorteners=1.0.1=pypi_0 332 | pysocks=1.7.1=py37_1 333 | pytables=3.6.1=py37h71ec239_0 334 | pytest=6.1.2=py37h06a4308_0 335 | python=3.7.6=h0371630_2 336 | python-dateutil=2.8.1=py_0 337 | python-jsonrpc-server=0.4.0=py_0 338 | python-language-server=0.31.7=py37_0 339 | python-levenshtein=0.12.0=pypi_0 340 | python-libarchive-c=2.9=py_0 341 | python-louvain=0.15=pypi_0 342 | python_abi=3.7=1_cp37m 343 | pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0 344 | pytorch-pretrained-bert=0.6.2=pypi_0 345 | pytz=2020.4=pyhd3eb1b0_0 346 | pytz-deprecation-shim=0.1.0.post0=pypi_0 347 | pywavelets=1.1.1=py37h7b6447c_2 348 | pyxdg=0.27=pyhd3eb1b0_0 349 | pyyaml=5.3.1=py37h7b6447c_1 350 | pyzmq=20.0.0=py37h2531618_1 351 | qdarkstyle=2.8.1=py_0 352 | qt=5.9.7=h5867ecd_1 353 | qtawesome=1.0.1=py_0 354 | qtconsole=4.7.7=py_0 355 | qtpy=1.9.0=py_0 356 | rdflib=5.0.0=pypi_0 357 | readline=7.0=h7b6447c_5 358 | regex=2017.4.5=pypi_0 359 | requests=2.28.2=pypi_0 360 | requests-oauthlib=1.3.0=pypi_0 361 | ripgrep=12.1.1=0 362 | rope=0.18.0=py_0 363 | rsa=4.8=pypi_0 364 | rtree=0.9.4=py37_1 365 | ruamel_yaml=0.15.87=py37h7b6447c_1 366 | s3transfer=0.3.3=pypi_0 367 | sacremoses=0.0.41=pypi_0 368 | scikit-image=0.17.2=py37hdf5156a_0 369 | scikit-learn=0.23.2=py37h0573a6f_0 370 | scipy=1.5.2=py37h0b6359f_0 371 | seaborn=0.11.0=py_0 372 | secretstorage=3.3.0=py37h06a4308_0 373 | send2trash=1.5.0=pyhd3eb1b0_1 374 | sentencepiece=0.1.85=pypi_0 375 | sentry-sdk=1.5.1=pypi_0 376 | seqeval=1.2.2=pypi_0 377 | setuptools=51.0.0=py37h06a4308_2 378 | shortuuid=1.0.8=pypi_0 379 | simplediff=1.0=pypi_0 380 | simplegeneric=0.8.1=py37_2 381 | simpletransformers=0.63.3=pypi_0 382 | singledispatch=3.4.0.3=py_1001 383 | sip=4.19.8=py37hf484d3e_0 384 | six=1.15.0=py37h06a4308_0 385 | sklearn=0.0=pypi_0 386 | smmap=5.0.0=pypi_0 387 | snappy=1.1.8=he6710b0_0 388 | snowballstemmer=2.0.0=py_0 389 | sortedcollections=1.2.1=py_0 390 | sortedcontainers=2.3.0=pyhd3eb1b0_0 391 | soupsieve=2.0.1=py_0 392 | spacy=2.3.2=pypi_0 393 | sphinx=3.2.1=py_0 394 | sphinxcontrib=1.0=py37_1 395 | sphinxcontrib-applehelp=1.0.2=py_0 396 | sphinxcontrib-devhelp=1.0.2=py_0 397 | sphinxcontrib-htmlhelp=1.0.3=py_0 398 | sphinxcontrib-jsmath=1.0.1=py_0 399 | sphinxcontrib-qthelp=1.0.3=py_0 400 | sphinxcontrib-serializinghtml=1.1.4=py_0 401 | sphinxcontrib-websupport=1.2.4=py_0 402 | spyder=4.0.1=py37_0 403 | spyder-kernels=1.8.1=py37_0 404 | sqlalchemy=1.3.20=py37he8ac12f_0 405 | sqlite=3.33.0=h62c20be_0 406 | srsly=1.0.2=pypi_0 407 | statsmodels=0.12.1=py37h27cfd23_0 408 | streamlit=1.2.0=pypi_0 409 | subprocess32=3.5.4=pypi_0 410 | sympy=1.6.2=py37h06a4308_1 411 | tbb=2020.3=hfd86e86_0 412 | tblib=1.7.0=py_0 413 | tensorboard=2.7.0=pypi_0 414 | tensorboard-data-server=0.6.1=pypi_0 415 | tensorboard-plugin-wit=1.8.0=pypi_0 416 | tensorboardx=2.0=pypi_0 417 | termcolor=1.1.0=pypi_0 418 | terminado=0.9.1=py37_0 419 | testpath=0.4.4=py_0 420 | textstat=0.7.0=pypi_0 421 | thinc=7.4.1=pypi_0 422 | threadpoolctl=2.1.0=pyh5ca1d4c_0 423 | tifffile=2020.12.4=pyhd3eb1b0_0 424 | tk=8.6.10=hbc83047_0 425 | tokenizers=0.5.2=pypi_0 426 | toml=0.10.1=py_0 427 | toolz=0.11.1=py_0 428 | toposort=1.5=pypi_0 429 | torch=1.4.0=pypi_0 430 | torch-geometric=1.7.0=pypi_0 431 | torch-scatter=2.0.6=pypi_0 432 | torch-sparse=0.6.9=pypi_0 433 | torchaudio=0.7.2=py37 434 | torchvision=0.5.0=pypi_0 435 | tornado=6.1=py37h27cfd23_0 436 | tqdm=4.62.3=pypi_0 437 | traitlets=5.0.5=py_0 438 | transformers=2.7.0=pypi_0 439 | twarc=2.8.1=pypi_0 440 | tweepy=4.12.1=pypi_0 441 | twython=3.8.2=pypi_0 442 | typed-ast=1.4.1=py37h7b6447c_0 443 | typing_extensions=3.7.4.3=py_0 444 | tzdata=2021.5=pypi_0 445 | tzlocal=4.1=pypi_0 446 | ujson=4.0.1=py37he6710b0_0 447 | unicodecsv=0.14.1=py37_0 448 | unidecode=1.1.1=pypi_0 449 | unixodbc=2.3.9=h7b6447c_0 450 | unshortenit=0.4.0=pypi_0 451 | uritools=3.0.0=pypi_0 452 | urlextract=1.0.0=pypi_0 453 | urllib3=1.24.3=pypi_0 454 | vadersentiment=3.3.2=pypi_0 455 | validators=0.18.2=pypi_0 456 | wandb=0.12.7=pypi_0 457 | wasabi=0.7.1=pypi_0 458 | watchdog=0.10.4=py37h06a4308_0 459 | wcwidth=0.2.5=py_0 460 | webencodings=0.5.1=py37_1 461 | werkzeug=1.0.1=py_0 462 | wheel=0.36.1=pyhd3eb1b0_0 463 | widgetsnbextension=3.5.1=py37_0 464 | wikipedia-api=0.5.4=pypi_0 465 | wrapt=1.11.2=py37h7b6447c_0 466 | wurlitzer=2.0.1=py37_0 467 | xlrd=1.2.0=py37_0 468 | xlsxwriter=1.3.7=py_0 469 | xlwt=1.3.0=py37_0 470 | xmltodict=0.12.0=py_0 471 | xxhash=2.0.2=pypi_0 472 | xz=5.2.5=h7b6447c_0 473 | yaml=0.2.5=h7b6447c_0 474 | yapf=0.30.0=py_0 475 | yarl=1.7.2=pypi_0 476 | yaspin=2.1.0=pypi_0 477 | zeromq=4.3.3=he6710b0_3 478 | zict=2.0.0=py_0 479 | zipp=3.4.0=pyhd3eb1b0_0 480 | zlib=1.2.11=h7b6447c_3 481 | zope=1.0=py37_1 482 | zope.event=4.5.0=py37_0 483 | zope.interface=5.2.0=py37h27cfd23_0 484 | zstd=1.4.5=h9ceee32_0 485 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | aiohttp==3.8.1 3 | aiosignal==1.2.0 4 | alabaster==0.7.12 5 | altair==4.1.0 6 | anaconda-client==1.7.2 7 | anaconda-navigator==1.9.12 8 | anaconda-project==0.8.3 9 | appdirs==1.4.4 10 | argh==0.26.2 11 | argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1596828452693/work 12 | ase==3.21.1 13 | asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work 14 | astor==0.8.1 15 | astroid @ file:///tmp/build/80754af9/astroid_1592495881661/work 16 | astropy @ file:///tmp/build/80754af9/astropy_1606922938720/work 17 | async-generator==1.10 18 | async-timeout==4.0.1 19 | asynctest==0.13.0 20 | atomicwrites==1.4.0 21 | attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work 22 | autocorrect==1.0.0 23 | autopep8 @ file:///tmp/build/80754af9/autopep8_1596578164842/work 24 | Babel @ file:///tmp/build/80754af9/babel_1607110387436/work 25 | backcall==0.2.0 26 | backports.functools-lru-cache @ file:///tmp/build/80754af9/backports.functools_lru_cache_1605305165209/work 27 | backports.shutil-get-terminal-size==1.0.0 28 | backports.tempfile==1.0 29 | backports.weakref==1.0.post1 30 | backports.zoneinfo==0.2.1 31 | base58==2.1.1 32 | beautifulsoup4 @ file:///tmp/build/80754af9/beautifulsoup4_1601924105527/work 33 | bert-serving==0.0.1 34 | bert-serving-client==1.10.0 35 | bert-serving-server==1.10.0 36 | bitarray @ file:///tmp/build/80754af9/bitarray_1605065136653/work 37 | bkcharts==0.2 38 | bleach @ file:///tmp/build/80754af9/bleach_1600439572647/work 39 | blessings==1.7 40 | blinker==1.4 41 | blis==0.4.1 42 | bokeh @ file:///tmp/build/80754af9/bokeh_1603297847301/work 43 | boto==2.49.0 44 | boto3==1.12.34 45 | botocore==1.15.34 46 | Bottleneck==1.3.2 47 | brotlipy==0.7.0 48 | cachetools==4.2.4 49 | catalogue==1.0.0 50 | certifi==2020.12.5 51 | cffi @ file:///tmp/build/80754af9/cffi_1606255099073/work 52 | chardet @ file:///tmp/build/80754af9/chardet_1605303159953/work 53 | charset-normalizer==2.0.9 54 | clean-text==0.2.1 55 | click==7.1.2 56 | click-config-file==0.6.0 57 | click-plugins==1.1.1 58 | cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work 59 | clyent==1.2.2 60 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work 61 | conda==4.9.2 62 | conda-build==3.18.11 63 | conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1603018138503/work 64 | conda-verify==3.4.2 65 | configobj==5.0.6 66 | configparser==5.2.0 67 | contextlib2==0.6.0.post1 68 | convokit==2.3.2.5 69 | cryptography @ file:///tmp/build/80754af9/cryptography_1607635305226/work 70 | cycler==0.10.0 71 | cymem==2.0.3 72 | Cython @ file:///tmp/build/80754af9/cython_1605457613176/work 73 | cytoolz==0.11.0 74 | dask @ file:///tmp/build/80754af9/dask-core_1607706933335/work 75 | datasets==1.16.1 76 | decorator==4.4.2 77 | defusedxml==0.6.0 78 | diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work 79 | dill==0.3.4 80 | distributed @ file:///tmp/build/80754af9/distributed_1607714018210/work 81 | docker-pycreds==0.4.0 82 | docutils==0.15.2 83 | emoji==1.2.0 84 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz 85 | entrypoints==0.3 86 | et-xmlfile==1.0.1 87 | fastai==2.6.3 88 | fastcache==1.1.0 89 | fastcore==1.4.2 90 | fastdownload==0.0.5 91 | fastprogress==0.2.4 92 | filelock==3.0.12 93 | fire==0.4.0 94 | flake8 @ file:///tmp/build/80754af9/flake8_1601911421857/work 95 | Flask==1.1.2 96 | frozenlist==1.2.0 97 | fsspec==2021.11.1 98 | ftfy==5.8 99 | future==0.18.2 100 | gevent @ file:///tmp/build/80754af9/gevent_1601397565838/work 101 | gitdb==4.0.9 102 | GitPython==3.1.24 103 | glob2==0.7 104 | gmpy2==2.0.8 105 | google-auth==2.3.3 106 | google-auth-oauthlib==0.4.6 107 | googledrivedownloader==0.4 108 | gpustat==0.6.0 109 | GPUtil==1.4.0 110 | greenlet @ file:///tmp/build/80754af9/greenlet_1600873995270/work 111 | grpcio==1.42.0 112 | h5py==2.10.0 113 | HeapDict==1.0.1 114 | html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work 115 | huggingface-hub==0.2.1 116 | humanize==3.12.0 117 | idna==2.8 118 | imagecodecs @ file:///tmp/build/80754af9/imagecodecs_1603270454756/work 119 | imageio @ file:///tmp/build/80754af9/imageio_1594161405741/work 120 | imagesize==1.2.0 121 | importlib-metadata==4.8.2 122 | iniconfig @ file:///tmp/build/80754af9/iniconfig_1602780191262/work 123 | intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work 124 | ipykernel @ file:///tmp/build/80754af9/ipykernel_1596206598566/work/dist/ipykernel-5.3.4-py3-none-any.whl 125 | ipython @ file:///tmp/build/80754af9/ipython_1604101195213/work 126 | ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work 127 | ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1601490159889/work 128 | isodate==0.6.0 129 | isort @ file:///tmp/build/80754af9/isort_1602603989581/work 130 | itsdangerous==1.1.0 131 | jdcal==1.4.1 132 | jedi==0.14.1 133 | jeepney @ file:///tmp/build/80754af9/jeepney_1606148855031/work 134 | Jinja2==2.11.2 135 | jmespath==0.9.5 136 | joblib @ file:///tmp/build/80754af9/joblib_1601912903842/work 137 | json5==0.9.5 138 | jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work 139 | jupyter==1.0.0 140 | jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work 141 | jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1598884538475/work 142 | jupyter-contrib-core==0.3.3 143 | jupyter-contrib-nbextensions @ file:///home/conda/feedstock_root/build_artifacts/jupyter_contrib_nbextensions_1614931162960/work 144 | jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1606148959479/work 145 | jupyter-highlight-selected-word @ file:///home/conda/feedstock_root/build_artifacts/jupyter_highlight_selected_word_1611341004115/work 146 | jupyter-latex-envs @ file:///home/conda/feedstock_root/build_artifacts/jupyter_latex_envs_1614852190293/work 147 | jupyter-nbextensions-configurator @ file:///home/conda/feedstock_root/build_artifacts/jupyter_nbextensions_configurator_1611341112910/work 148 | jupyterlab==2.2.6 149 | jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work 150 | jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1594164409481/work 151 | keyring @ file:///tmp/build/80754af9/keyring_1601490840626/work 152 | kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1604014532738/work 153 | lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1607707315973/work 154 | libarchive-c==2.9 155 | liwc==0.5.0 156 | liwc-analysis==1.2.4 157 | llvmlite==0.34.0 158 | locket==0.2.0 159 | lxml @ file:///tmp/build/80754af9/lxml_1606516847630/work 160 | Markdown==3.3.6 161 | MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1594371495811/work 162 | matplotlib @ file:///tmp/build/80754af9/matplotlib-base_1603376012865/work 163 | mccabe==0.6.1 164 | mistune @ file:///tmp/build/80754af9/mistune_1594373098390/work 165 | mkl-fft==1.2.0 166 | mkl-random==1.1.1 167 | mkl-service==2.3.0 168 | mock @ file:///tmp/build/80754af9/mock_1607622725907/work 169 | more-itertools @ file:///tmp/build/80754af9/more-itertools_1605111547926/work 170 | mpmath==1.1.0 171 | msgpack==1.0.0 172 | msgpack-numpy==0.4.6.1 173 | multidict==5.2.0 174 | multipledispatch==0.6.0 175 | multiprocess==0.70.12.2 176 | murmurhash==1.0.2 177 | mwparserfromhell==0.5.4 178 | navigator-updater==0.2.1 179 | nbclient @ file:///tmp/build/80754af9/nbclient_1602783176460/work 180 | nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914821128/work 181 | nbformat==5.0.7 182 | nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1606153767164/work 183 | networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work 184 | nltk @ file:///tmp/build/80754af9/nltk_1592496090529/work 185 | nose @ file:///tmp/build/80754af9/nose_1606773131901/work 186 | notebook @ file:///tmp/build/80754af9/notebook_1601501580008/work 187 | numba @ file:///tmp/build/80754af9/numba_1600102479638/work 188 | numexpr @ file:///tmp/build/80754af9/numexpr_1607693749420/work 189 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1603479632437/work 190 | numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work 191 | nvidia-ml-py3==7.352.0 192 | oauthlib==3.2.2 193 | olefile==0.46 194 | openpyxl @ file:///tmp/build/80754af9/openpyxl_1598113097404/work 195 | packaging==21.3 196 | pandas @ file:///tmp/build/80754af9/pandas_1602088128026/work 197 | pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120451932/work 198 | parso==0.5.2 199 | partd==1.1.0 200 | path @ file:///tmp/build/80754af9/path_1607537225003/work 201 | pathlib2 @ file:///tmp/build/80754af9/pathlib2_1607024979554/work 202 | pathtools==0.1.2 203 | patsy==0.5.1 204 | pep8==1.7.1 205 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work 206 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 207 | Pillow @ file:///tmp/build/80754af9/pillow_1603822253009/work 208 | pip-autoremove==0.9.1 209 | pkginfo==1.6.1 210 | plac==1.1.3 211 | pluggy==0.13.1 212 | ply==3.11 213 | powerlaw==1.4.6 214 | preshed==3.0.2 215 | prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1606344362066/work 216 | promise==2.3 217 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work 218 | protobuf==3.19.1 219 | psutil @ file:///tmp/build/80754af9/psutil_1598370249042/work 220 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1605560620615/work/dist/ptyprocess-0.6.0-py2.py3-none-any.whl 221 | py @ file:///tmp/build/80754af9/py_1593446248552/work 222 | pyarrow==6.0.1 223 | pyasn1==0.4.8 224 | pyasn1-modules==0.2.8 225 | pycocoevalcap==1.2 226 | pycocotools==2.0.4 227 | pycodestyle==2.6.0 228 | pycosat==0.6.3 229 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work 230 | pycrypto==2.6.1 231 | pycurl==7.43.0.6 232 | pydeck==0.7.1 233 | pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1598885001695/work 234 | pyerfa @ file:///tmp/build/80754af9/pyerfa_1606860189168/work 235 | pyflakes==2.2.0 236 | Pygments @ file:///tmp/build/80754af9/pygments_1607368905949/work 237 | pylint @ file:///tmp/build/80754af9/pylint_1598624038450/work 238 | Pympler==0.9 239 | pyodbc===4.0.0-unsupported 240 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1606517880428/work 241 | pyparsing==2.4.7 242 | Pyphen==0.10.0 243 | pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141707582/work 244 | pyshorteners==1.0.1 245 | PySocks @ file:///tmp/build/80754af9/pysocks_1594394576006/work 246 | pytest==0.0.0 247 | python-dateutil==2.8.1 248 | python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work 249 | python-language-server==0.31.7 250 | python-Levenshtein==0.12.0 251 | python-louvain==0.15 252 | pytorch-pretrained-bert==0.6.2 253 | pytz @ file:///tmp/build/80754af9/pytz_1606604771399/work 254 | pytz-deprecation-shim==0.1.0.post0 255 | PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658308664/work 256 | pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work 257 | PyYAML==5.3.1 258 | pyzmq==20.0.0 259 | QDarkStyle==2.8.1 260 | QtAwesome @ file:///tmp/build/80754af9/qtawesome_1602272867890/work 261 | qtconsole @ file:///tmp/build/80754af9/qtconsole_1600870028330/work 262 | QtPy==1.9.0 263 | rdflib==5.0.0 264 | regex==2017.4.5 265 | requests==2.28.2 266 | requests-oauthlib==1.3.0 267 | rope @ file:///tmp/build/80754af9/rope_1602264064449/work 268 | rsa==4.8 269 | Rtree==0.9.4 270 | ruamel-yaml==0.15.87 271 | s3transfer==0.3.3 272 | sacremoses==0.0.41 273 | scikit-image==0.17.2 274 | scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1598376882706/work 275 | scipy @ file:///tmp/build/80754af9/scipy_1597686620742/work 276 | seaborn @ file:///tmp/build/80754af9/seaborn_1600553570093/work 277 | SecretStorage @ file:///tmp/build/80754af9/secretstorage_1606864755624/work 278 | Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work 279 | sentencepiece==0.1.85 280 | sentry-sdk==1.5.1 281 | seqeval==1.2.2 282 | shortuuid==1.0.8 283 | simplediff==1.0 284 | simplegeneric==0.8.1 285 | simpletransformers==0.63.3 286 | singledispatch @ file:///tmp/build/80754af9/singledispatch_1602523705405/work 287 | six @ file:///tmp/build/80754af9/six_1605205313296/work 288 | sklearn==0.0 289 | smmap==5.0.0 290 | snowballstemmer==2.0.0 291 | sortedcollections==1.2.1 292 | sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1606865132123/work 293 | soupsieve==2.0.1 294 | spacy==2.3.2 295 | Sphinx @ file:///tmp/build/80754af9/sphinx_1597428793432/work 296 | sphinxcontrib-applehelp==1.0.2 297 | sphinxcontrib-devhelp==1.0.2 298 | sphinxcontrib-htmlhelp==1.0.3 299 | sphinxcontrib-jsmath==1.0.1 300 | sphinxcontrib-qthelp==1.0.3 301 | sphinxcontrib-serializinghtml==1.1.4 302 | sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work 303 | spyder==4.0.1 304 | spyder-kernels==1.8.1 305 | SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1607563700310/work 306 | srsly==1.0.2 307 | statsmodels @ file:///tmp/build/80754af9/statsmodels_1606925351355/work 308 | streamlit==1.2.0 309 | subprocess32==3.5.4 310 | sympy @ file:///tmp/build/80754af9/sympy_1605119531870/work 311 | tables==3.6.1 312 | tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work 313 | tensorboard==2.7.0 314 | tensorboard-data-server==0.6.1 315 | tensorboard-plugin-wit==1.8.0 316 | tensorboardX==2.0 317 | termcolor==1.1.0 318 | terminado==0.9.1 319 | testpath==0.4.4 320 | textstat==0.7.0 321 | thinc==7.4.1 322 | threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl 323 | tifffile @ file:///tmp/build/80754af9/tifffile_1607369857969/work 324 | tokenizers==0.5.2 325 | toml @ file:///tmp/build/80754af9/toml_1592853716807/work 326 | toolz @ file:///tmp/build/80754af9/toolz_1601054250827/work 327 | toposort==1.5 328 | torch==1.7.1 329 | torch-geometric==1.7.0 330 | torch-scatter==2.0.6 331 | torch-sparse==0.6.9 332 | torchaudio==0.7.0a0+a853dff 333 | torchvision==0.8.2 334 | tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work 335 | tqdm==4.62.3 336 | traitlets @ file:///tmp/build/80754af9/traitlets_1602787416690/work 337 | transformers==2.7.0 338 | twarc==2.8.1 339 | tweepy==4.12.1 340 | twython==3.8.2 341 | typed-ast==1.4.1 342 | typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1598376058250/work 343 | tzdata==2021.5 344 | tzlocal==4.1 345 | ujson @ file:///tmp/build/80754af9/ujson_1602523313803/work 346 | unicodecsv==0.14.1 347 | Unidecode==1.1.1 348 | unshortenit==0.4.0 349 | uritools==3.0.0 350 | urlextract==1.0.0 351 | urllib3==1.24.3 352 | vaderSentiment==3.3.2 353 | validators==0.18.2 354 | wandb==0.12.7 355 | wasabi==0.7.1 356 | watchdog @ file:///tmp/build/80754af9/watchdog_1606939061947/work 357 | wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work 358 | webencodings==0.5.1 359 | Werkzeug==1.0.1 360 | widgetsnbextension==3.5.1 361 | Wikipedia-API==0.5.4 362 | wrapt==1.11.2 363 | wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1594751868473/work 364 | xlrd==1.2.0 365 | XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1602692860603/work 366 | xlwt==1.3.0 367 | xmltodict==0.12.0 368 | xxhash==2.0.2 369 | yapf @ file:///tmp/build/80754af9/yapf_1593528177422/work 370 | yarl==1.7.2 371 | yaspin==2.1.0 372 | zict==2.0.0 373 | zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work 374 | zope.event==4.5.0 375 | zope.interface @ file:///tmp/build/80754af9/zope.interface_1606940237377/work 376 | -------------------------------------------------------------------------------- /src/lsp_model_rl/optim.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch optimization """ 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.nn.utils import clip_grad_norm_ 21 | 22 | 23 | def warmup_cosine(x, warmup=0.002): 24 | if x < warmup: 25 | return x/warmup 26 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 27 | 28 | 29 | def warmup_constant(x, warmup=0.002): 30 | if x < warmup: 31 | return x/warmup 32 | return 1.0 33 | 34 | 35 | def warmup_linear(x, warmup=0.002): 36 | if x < warmup: 37 | return x/warmup 38 | return (1.0 - x)/(1.0 - warmup) 39 | 40 | 41 | def noam_decay(step, warmup_steps, model_size): 42 | """Learning rate schedule described in 43 | https://arxiv.org/pdf/1706.03762.pdf. 44 | """ 45 | return ( 46 | model_size ** (-0.5) * 47 | min(step ** (-0.5), step * warmup_steps**(-1.5))) 48 | 49 | 50 | def noamwd_decay(step, warmup_steps, 51 | model_size, rate=0.5, decay_steps=1000, start_step=500): 52 | """Learning rate schedule optimized for huge batches 53 | """ 54 | return ( 55 | model_size ** (-0.5) * 56 | min(step ** (-0.5), step * warmup_steps**(-1.5)) * 57 | rate ** (max(step - start_step + decay_steps, 0) // decay_steps)) 58 | 59 | 60 | def exponential_decay(step, rate, decay_steps, start_step=0): 61 | """A standard exponential decay, scaling the learning rate by :obj:`rate` 62 | every :obj:`decay_steps` steps. 63 | """ 64 | return rate ** (max(step - start_step + decay_steps, 0) // decay_steps) 65 | 66 | 67 | def rsqrt_decay(step, warmup_steps): 68 | """Decay based on the reciprocal of the step square root.""" 69 | return 1.0 / math.sqrt(max(step, warmup_steps)) 70 | 71 | 72 | SCHEDULES = { 73 | 'warmup_cosine': warmup_cosine, 74 | 'warmup_constant': warmup_constant, 75 | 'warmup_linear': warmup_linear, 76 | } 77 | 78 | 79 | class Adam(Optimizer): 80 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 81 | Params: 82 | lr: learning rate 83 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 84 | t_total: total number of training steps for the learning 85 | rate schedule, -1 means constant learning rate. Default: -1 86 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 87 | b1: Adams b1. Default: 0.9 88 | b2: Adams b2. Default: 0.999 89 | e: Adams epsilon. Default: 1e-6 90 | weight_decay_rate: Weight decay. Default: 0.01 91 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 92 | """ 93 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 94 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 95 | max_grad_norm=1.0): 96 | if not lr >= 0.0: 97 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 98 | if schedule not in SCHEDULES: 99 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 100 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 101 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 102 | if not 0.0 <= b1 < 1.0: 103 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 104 | if not 0.0 <= b2 < 1.0: 105 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 106 | if not e >= 0.0: 107 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 108 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 109 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 110 | max_grad_norm=max_grad_norm) 111 | super(Adam, self).__init__(params, defaults) 112 | 113 | def get_lr(self): 114 | lr = [] 115 | for group in self.param_groups: 116 | for p in group['params']: 117 | state = self.state[p] 118 | if len(state) == 0: 119 | return [0] 120 | if group['t_total'] != -1: 121 | schedule_fct = SCHEDULES[group['schedule']] 122 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 123 | else: 124 | lr_scheduled = group['lr'] 125 | lr.append(lr_scheduled) 126 | return lr 127 | 128 | def to(self, device): 129 | """ Move the optimizer state to a specified device""" 130 | for state in self.state.values(): 131 | state['exp_avg'].to(device) 132 | state['exp_avg_sq'].to(device) 133 | 134 | def initialize_step(self, initial_step): 135 | """Initialize state with a defined step (but we don't have stored averaged). 136 | Arguments: 137 | initial_step (int): Initial step number. 138 | """ 139 | for group in self.param_groups: 140 | for p in group['params']: 141 | state = self.state[p] 142 | # State initialization 143 | state['step'] = initial_step 144 | # Exponential moving average of gradient values 145 | state['exp_avg'] = torch.zeros_like(p.data) 146 | # Exponential moving average of squared gradient values 147 | state['exp_avg_sq'] = torch.zeros_like(p.data) 148 | 149 | def step(self, closure=None): 150 | """Performs a single optimization step. 151 | 152 | Arguments: 153 | closure (callable, optional): A closure that reevaluates the model 154 | and returns the loss. 155 | """ 156 | loss = None 157 | if closure is not None: 158 | loss = closure() 159 | 160 | for group in self.param_groups: 161 | for p in group['params']: 162 | if p.grad is None: 163 | continue 164 | grad = p.grad.data 165 | if grad.is_sparse: 166 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 167 | 168 | state = self.state[p] 169 | 170 | # State initialization 171 | if len(state) == 0: 172 | state['step'] = 0 173 | # Exponential moving average of gradient values 174 | state['next_m'] = torch.zeros_like(p.data) 175 | # Exponential moving average of squared gradient values 176 | state['next_v'] = torch.zeros_like(p.data) 177 | 178 | next_m, next_v = state['next_m'], state['next_v'] 179 | beta1, beta2 = group['b1'], group['b2'] 180 | 181 | # Add grad clipping 182 | if group['max_grad_norm'] > 0: 183 | clip_grad_norm_(p, group['max_grad_norm']) 184 | 185 | # Decay the first and second moment running average coefficient 186 | # In-place operations to update the averages at the same time 187 | next_m.mul_(beta1).add_(1 - beta1, grad) 188 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 189 | update = next_m / (next_v.sqrt() + group['e']) 190 | 191 | # Just adding the square of the weights to the loss function is *not* 192 | # the correct way of using L2 regularization/weight decay with Adam, 193 | # since that will interact with the m and v parameters in strange ways. 194 | # 195 | # Instead we want ot decay the weights in a manner that doesn't interact 196 | # with the m/v parameters. This is equivalent to adding the square 197 | # of the weights to the loss with plain (non-momentum) SGD. 198 | if group['weight_decay_rate'] > 0.0: 199 | update += group['weight_decay_rate'] * p.data 200 | 201 | if group['t_total'] != -1: 202 | schedule_fct = SCHEDULES[group['schedule']] 203 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 204 | else: 205 | lr_scheduled = group['lr'] 206 | 207 | update_with_lr = lr_scheduled * update 208 | p.data.add_(-update_with_lr) 209 | 210 | state['step'] += 1 211 | 212 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 213 | # bias_correction1 = 1 - beta1 ** state['step'] 214 | # bias_correction2 = 1 - beta2 ** state['step'] 215 | 216 | return loss 217 | 218 | 219 | class Adamax(Optimizer): 220 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 221 | Params: 222 | lr: learning rate 223 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 224 | t_total: total number of training steps for the learning 225 | rate schedule, -1 means constant learning rate. Default: -1 226 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 227 | b1: Adams b1. Default: 0.9 228 | b2: Adams b2. Default: 0.999 229 | e: Adams epsilon. Default: 1e-6 230 | weight_decay_rate: Weight decay. Default: 0.01 231 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 232 | """ 233 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 234 | betas=(0.9, 0.999), eps=1e-6, weight_decay_rate=0.01, 235 | max_grad_norm=1.0): 236 | if not lr >= 0.0: 237 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 238 | if schedule not in SCHEDULES: 239 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 240 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 241 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 242 | if not 0.0 <= eps: 243 | raise ValueError("Invalid epsilon value: {}".format(eps)) 244 | if not 0.0 <= betas[0] < 1.0: 245 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 246 | if not 0.0 <= betas[1] < 1.0: 247 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 248 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 249 | betas=betas, eps=eps, weight_decay_rate=weight_decay_rate, 250 | max_grad_norm=max_grad_norm) 251 | super(Adamax, self).__init__(params, defaults) 252 | 253 | def get_lr(self): 254 | lr = [] 255 | for group in self.param_groups: 256 | for p in group['params']: 257 | state = self.state[p] 258 | if len(state) == 0: 259 | return [0] 260 | if group['t_total'] != -1: 261 | schedule_fct = SCHEDULES[group['schedule']] 262 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 263 | else: 264 | lr_scheduled = group['lr'] 265 | lr.append(lr_scheduled) 266 | return lr 267 | 268 | def to(self, device): 269 | """ Move the optimizer state to a specified device""" 270 | for state in self.state.values(): 271 | state['exp_avg'].to(device) 272 | state['exp_avg_sq'].to(device) 273 | 274 | def initialize_step(self, initial_step): 275 | """Initialize state with a defined step (but we don't have stored averaged). 276 | Arguments: 277 | initial_step (int): Initial step number. 278 | """ 279 | for group in self.param_groups: 280 | for p in group['params']: 281 | state = self.state[p] 282 | # State initialization 283 | state['step'] = initial_step 284 | # Exponential moving average of gradient values 285 | state['exp_avg'] = torch.zeros_like(p.data) 286 | # Exponential moving average of squared gradient values 287 | state['exp_avg_sq'] = torch.zeros_like(p.data) 288 | 289 | def step(self, closure=None): 290 | """Performs a single optimization step. 291 | 292 | Arguments: 293 | closure (callable, optional): A closure that reevaluates the model 294 | and returns the loss. 295 | """ 296 | loss = None 297 | if closure is not None: 298 | loss = closure() 299 | 300 | for group in self.param_groups: 301 | for p in group['params']: 302 | if p.grad is None: 303 | continue 304 | grad = p.grad.data 305 | if grad.is_sparse: 306 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 307 | 308 | state = self.state[p] 309 | 310 | # State initialization 311 | if len(state) == 0: 312 | state['step'] = 0 313 | # Exponential moving average of gradient values 314 | state['exp_avg'] = torch.zeros_like(p.data) 315 | state['exp_inf'] = torch.zeros_like(p.data) 316 | 317 | exp_avg, exp_inf = state['exp_avg'], state['exp_inf'] 318 | beta1, beta2 = group['betas'] 319 | eps = group['eps'] 320 | # Add grad clipping 321 | if group['max_grad_norm'] > 0: 322 | clip_grad_norm_(p, group['max_grad_norm']) 323 | 324 | # Update biased first moment estimate. 325 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 326 | # Update the exponentially weighted infinity norm. 327 | norm_buf = torch.cat([ 328 | exp_inf.mul_(beta2).unsqueeze(0), 329 | grad.abs().add_(eps).unsqueeze_(0) 330 | ], 0) 331 | torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long())) 332 | update = exp_avg / (exp_inf + eps) 333 | 334 | if group['weight_decay_rate'] > 0.0: 335 | update += group['weight_decay_rate'] * p.data 336 | 337 | if group['t_total'] != -1: 338 | schedule_fct = SCHEDULES[group['schedule']] 339 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 340 | else: 341 | lr_scheduled = group['lr'] 342 | 343 | 344 | # Just adding the square of the weights to the loss function is *not* 345 | # the correct way of using L2 regularization/weight decay with Adam, 346 | # since that will interact with the m and v parameters in strange ways. 347 | # 348 | # Instead we want ot decay the weights in a manner that doesn't interact 349 | # with the m/v parameters. This is equivalent to adding the square 350 | # of the weights to the loss with plain (non-momentum) SGD. 351 | if group['weight_decay_rate'] > 0.0: 352 | update += group['weight_decay_rate'] * p.data 353 | 354 | if group['t_total'] != -1: 355 | schedule_fct = SCHEDULES[group['schedule']] 356 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 357 | else: 358 | lr_scheduled = group['lr'] 359 | 360 | update_with_lr = lr_scheduled * update 361 | p.data.add_(-update_with_lr) 362 | 363 | state['step'] += 1 364 | 365 | return loss 366 | 367 | -------------------------------------------------------------------------------- /src/lsp_model_rl/util/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import fnmatch 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import sys 13 | import tarfile 14 | import tempfile 15 | from contextlib import contextmanager 16 | from functools import partial, wraps 17 | from hashlib import sha256 18 | from typing import Optional 19 | from urllib.parse import urlparse 20 | from zipfile import ZipFile, is_zipfile 21 | 22 | import requests 23 | from filelock import FileLock 24 | from tqdm.auto import tqdm 25 | 26 | # from . import __version__ 27 | __version__ = "2.8.0" 28 | 29 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 30 | 31 | try: 32 | USE_TF = os.environ.get("USE_TF", "AUTO").upper() 33 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() 34 | if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): 35 | import torch 36 | 37 | _torch_available = True # pylint: disable=invalid-name 38 | logger.info("PyTorch version {} available.".format(torch.__version__)) 39 | else: 40 | logger.info("Disabling PyTorch because USE_TF is set") 41 | _torch_available = False 42 | except ImportError: 43 | _torch_available = False # pylint: disable=invalid-name 44 | 45 | try: 46 | USE_TF = os.environ.get("USE_TF", "AUTO").upper() 47 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() 48 | 49 | if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): 50 | import tensorflow as tf 51 | 52 | assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 53 | _tf_available = True # pylint: disable=invalid-name 54 | logger.info("TensorFlow version {} available.".format(tf.__version__)) 55 | else: 56 | logger.info("Disabling Tensorflow because USE_TORCH is set") 57 | _tf_available = False 58 | except (ImportError, AssertionError): 59 | _tf_available = False # pylint: disable=invalid-name 60 | 61 | try: 62 | from torch.hub import _get_torch_home 63 | 64 | torch_cache_home = _get_torch_home() 65 | except ImportError: 66 | torch_cache_home = os.path.expanduser( 67 | os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) 68 | ) 69 | default_cache_path = os.path.join(torch_cache_home, "transformers") 70 | 71 | try: 72 | from pathlib import Path 73 | 74 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 75 | os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)) 76 | ) 77 | except (AttributeError, ImportError): 78 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( 79 | "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) 80 | ) 81 | 82 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 83 | TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 84 | 85 | WEIGHTS_NAME = "pytorch_model.bin" 86 | TF2_WEIGHTS_NAME = "tf_model.h5" 87 | TF_WEIGHTS_NAME = "model.ckpt" 88 | CONFIG_NAME = "config.json" 89 | MODEL_CARD_NAME = "modelcard.json" 90 | 91 | 92 | MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]] 93 | DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] 94 | DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] 95 | 96 | S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" 97 | CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net" 98 | 99 | 100 | def is_torch_available(): 101 | return _torch_available 102 | 103 | 104 | def is_tf_available(): 105 | return _tf_available 106 | 107 | 108 | def add_start_docstrings(*docstr): 109 | def docstring_decorator(fn): 110 | fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") 111 | return fn 112 | 113 | return docstring_decorator 114 | 115 | 116 | def add_start_docstrings_to_callable(*docstr): 117 | def docstring_decorator(fn): 118 | class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0]) 119 | intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name) 120 | note = r""" 121 | 122 | .. note:: 123 | Although the recipe for forward pass needs to be defined within 124 | this function, one should call the :class:`Module` instance afterwards 125 | instead of this since the former takes care of running the 126 | pre and post processing steps while the latter silently ignores them. 127 | """ 128 | fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") 129 | return fn 130 | 131 | return docstring_decorator 132 | 133 | 134 | def add_end_docstrings(*docstr): 135 | def docstring_decorator(fn): 136 | fn.__doc__ = fn.__doc__ + "".join(docstr) 137 | return fn 138 | 139 | return docstring_decorator 140 | 141 | 142 | def is_remote_url(url_or_filename): 143 | parsed = urlparse(url_or_filename) 144 | return parsed.scheme in ("http", "https") 145 | 146 | 147 | def hf_bucket_url(identifier, postfix=None, cdn=False) -> str: 148 | endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX 149 | if postfix is None: 150 | return "/".join((endpoint, identifier)) 151 | else: 152 | return "/".join((endpoint, identifier, postfix)) 153 | 154 | 155 | def url_to_filename(url, etag=None): 156 | """ 157 | Convert `url` into a hashed filename in a repeatable way. 158 | If `etag` is specified, append its hash to the url's, delimited 159 | by a period. 160 | If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name 161 | so that TF 2.0 can identify it as a HDF5 file 162 | (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) 163 | """ 164 | url_bytes = url.encode("utf-8") 165 | url_hash = sha256(url_bytes) 166 | filename = url_hash.hexdigest() 167 | 168 | if etag: 169 | etag_bytes = etag.encode("utf-8") 170 | etag_hash = sha256(etag_bytes) 171 | filename += "." + etag_hash.hexdigest() 172 | 173 | if url.endswith(".h5"): 174 | filename += ".h5" 175 | 176 | return filename 177 | 178 | 179 | def filename_to_url(filename, cache_dir=None): 180 | """ 181 | Return the url and etag (which may be ``None``) stored for `filename`. 182 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 183 | """ 184 | if cache_dir is None: 185 | cache_dir = TRANSFORMERS_CACHE 186 | if isinstance(cache_dir, Path): 187 | cache_dir = str(cache_dir) 188 | 189 | cache_path = os.path.join(cache_dir, filename) 190 | if not os.path.exists(cache_path): 191 | raise EnvironmentError("file {} not found".format(cache_path)) 192 | 193 | meta_path = cache_path + ".json" 194 | if not os.path.exists(meta_path): 195 | raise EnvironmentError("file {} not found".format(meta_path)) 196 | 197 | with open(meta_path, encoding="utf-8") as meta_file: 198 | metadata = json.load(meta_file) 199 | url = metadata["url"] 200 | etag = metadata["etag"] 201 | 202 | return url, etag 203 | 204 | 205 | def cached_path( 206 | url_or_filename, 207 | cache_dir=None, 208 | force_download=False, 209 | proxies=None, 210 | resume_download=False, 211 | user_agent=None, 212 | extract_compressed_file=False, 213 | force_extract=False, 214 | local_files_only=False, 215 | ) -> Optional[str]: 216 | """ 217 | Given something that might be a URL (or might be a local path), 218 | determine which. If it's a URL, download the file and cache it, and 219 | return the path to the cached file. If it's already a local path, 220 | make sure the file exists and then return the path. 221 | Args: 222 | cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). 223 | force_download: if True, re-dowload the file even if it's already cached in the cache dir. 224 | resume_download: if True, resume the download if incompletly recieved file is found. 225 | user_agent: Optional string or dict that will be appended to the user-agent on remote requests. 226 | extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed 227 | file in a folder along the archive. 228 | force_extract: if True when extract_compressed_file is True and the archive was already extracted, 229 | re-extract the archive and overide the folder where it was extracted. 230 | 231 | Return: 232 | None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). 233 | Local path (string) otherwise 234 | """ 235 | if cache_dir is None: 236 | cache_dir = TRANSFORMERS_CACHE 237 | if isinstance(url_or_filename, Path): 238 | url_or_filename = str(url_or_filename) 239 | if isinstance(cache_dir, Path): 240 | cache_dir = str(cache_dir) 241 | 242 | if is_remote_url(url_or_filename): 243 | # URL, so get it from the cache (downloading if necessary) 244 | output_path = get_from_cache( 245 | url_or_filename, 246 | cache_dir=cache_dir, 247 | force_download=force_download, 248 | proxies=proxies, 249 | resume_download=resume_download, 250 | user_agent=user_agent, 251 | local_files_only=local_files_only, 252 | ) 253 | elif os.path.exists(url_or_filename): 254 | # File, and it exists. 255 | output_path = url_or_filename 256 | elif urlparse(url_or_filename).scheme == "": 257 | # File, but it doesn't exist. 258 | raise EnvironmentError("file {} not found".format(url_or_filename)) 259 | else: 260 | # Something unknown 261 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 262 | 263 | if extract_compressed_file: 264 | if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): 265 | return output_path 266 | 267 | # Path where we extract compressed archives 268 | # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" 269 | output_dir, output_file = os.path.split(output_path) 270 | output_extract_dir_name = output_file.replace(".", "-") + "-extracted" 271 | output_path_extracted = os.path.join(output_dir, output_extract_dir_name) 272 | 273 | if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: 274 | return output_path_extracted 275 | 276 | # Prevent parallel extractions 277 | lock_path = output_path + ".lock" 278 | with FileLock(lock_path): 279 | shutil.rmtree(output_path_extracted, ignore_errors=True) 280 | os.makedirs(output_path_extracted) 281 | if is_zipfile(output_path): 282 | with ZipFile(output_path, "r") as zip_file: 283 | zip_file.extractall(output_path_extracted) 284 | zip_file.close() 285 | elif tarfile.is_tarfile(output_path): 286 | tar_file = tarfile.open(output_path) 287 | tar_file.extractall(output_path_extracted) 288 | tar_file.close() 289 | else: 290 | raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) 291 | 292 | return output_path_extracted 293 | 294 | return output_path 295 | 296 | 297 | def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): 298 | ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) 299 | if is_torch_available(): 300 | ua += "; torch/{}".format(torch.__version__) 301 | if is_tf_available(): 302 | ua += "; tensorflow/{}".format(tf.__version__) 303 | if isinstance(user_agent, dict): 304 | ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) 305 | elif isinstance(user_agent, str): 306 | ua += "; " + user_agent 307 | headers = {"user-agent": ua} 308 | if resume_size > 0: 309 | headers["Range"] = "bytes=%d-" % (resume_size,) 310 | response = requests.get(url, stream=True, proxies=proxies, headers=headers) 311 | if response.status_code == 416: # Range not satisfiable 312 | return 313 | content_length = response.headers.get("Content-Length") 314 | total = resume_size + int(content_length) if content_length is not None else None 315 | progress = tqdm( 316 | unit="B", 317 | unit_scale=True, 318 | total=total, 319 | initial=resume_size, 320 | desc="Downloading", 321 | disable=bool(logger.getEffectiveLevel() == logging.NOTSET), 322 | ) 323 | for chunk in response.iter_content(chunk_size=1024): 324 | if chunk: # filter out keep-alive new chunks 325 | progress.update(len(chunk)) 326 | temp_file.write(chunk) 327 | progress.close() 328 | 329 | 330 | def get_from_cache( 331 | url, 332 | cache_dir=None, 333 | force_download=False, 334 | proxies=None, 335 | etag_timeout=10, 336 | resume_download=False, 337 | user_agent=None, 338 | local_files_only=False, 339 | ) -> Optional[str]: 340 | """ 341 | Given a URL, look for the corresponding file in the local cache. 342 | If it's not there, download it. Then return the path to the cached file. 343 | 344 | Return: 345 | None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). 346 | Local path (string) otherwise 347 | """ 348 | if cache_dir is None: 349 | cache_dir = TRANSFORMERS_CACHE 350 | if isinstance(cache_dir, Path): 351 | cache_dir = str(cache_dir) 352 | 353 | os.makedirs(cache_dir, exist_ok=True) 354 | 355 | etag = None 356 | if not local_files_only: 357 | try: 358 | response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) 359 | if response.status_code == 200: 360 | etag = response.headers.get("ETag") 361 | except (EnvironmentError, requests.exceptions.Timeout): 362 | # etag is already None 363 | pass 364 | 365 | filename = url_to_filename(url, etag) 366 | 367 | # get cache path to put the file 368 | cache_path = os.path.join(cache_dir, filename) 369 | 370 | # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. 371 | # try to get the last downloaded one 372 | if etag is None: 373 | if os.path.exists(cache_path): 374 | return cache_path 375 | else: 376 | matching_files = [ 377 | file 378 | for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") 379 | if not file.endswith(".json") and not file.endswith(".lock") 380 | ] 381 | if len(matching_files) > 0: 382 | return os.path.join(cache_dir, matching_files[-1]) 383 | else: 384 | # If files cannot be found and local_files_only=True, 385 | # the models might've been found if local_files_only=False 386 | # Notify the user about that 387 | if local_files_only: 388 | raise ValueError( 389 | "Cannot find the requested files in the cached path and outgoing traffic has been" 390 | " disabled. To enable model look-ups and downloads online, set 'local_files_only'" 391 | " to False." 392 | ) 393 | return None 394 | 395 | # From now on, etag is not None. 396 | if os.path.exists(cache_path) and not force_download: 397 | return cache_path 398 | 399 | # Prevent parallel downloads of the same file with a lock. 400 | lock_path = cache_path + ".lock" 401 | with FileLock(lock_path): 402 | 403 | # If the download just completed while the lock was activated. 404 | if os.path.exists(cache_path) and not force_download: 405 | # Even if returning early like here, the lock will be released. 406 | return cache_path 407 | 408 | if resume_download: 409 | incomplete_path = cache_path + ".incomplete" 410 | 411 | @contextmanager 412 | def _resumable_file_manager(): 413 | with open(incomplete_path, "a+b") as f: 414 | yield f 415 | 416 | temp_file_manager = _resumable_file_manager 417 | if os.path.exists(incomplete_path): 418 | resume_size = os.stat(incomplete_path).st_size 419 | else: 420 | resume_size = 0 421 | else: 422 | temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) 423 | resume_size = 0 424 | 425 | # Download to temporary file, then copy to cache dir once finished. 426 | # Otherwise you get corrupt cache entries if the download gets interrupted. 427 | with temp_file_manager() as temp_file: 428 | logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) 429 | 430 | http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) 431 | 432 | logger.info("storing %s in cache at %s", url, cache_path) 433 | os.replace(temp_file.name, cache_path) 434 | 435 | logger.info("creating metadata file for %s", cache_path) 436 | meta = {"url": url, "etag": etag} 437 | meta_path = cache_path + ".json" 438 | with open(meta_path, "w") as meta_file: 439 | json.dump(meta, meta_file) 440 | 441 | return cache_path 442 | 443 | 444 | class cached_property(property): 445 | """ 446 | Descriptor that mimics @property but caches output in member variable. 447 | 448 | From tensorflow_datasets 449 | 450 | Built-in in functools from Python 3.8. 451 | """ 452 | 453 | def __get__(self, obj, objtype=None): 454 | # See docs.python.org/3/howto/descriptor.html#properties 455 | if obj is None: 456 | return self 457 | if self.fget is None: 458 | raise AttributeError("unreadable attribute") 459 | attr = "__cached_" + self.fget.__name__ 460 | cached = getattr(obj, attr, None) 461 | if cached is None: 462 | cached = self.fget(obj) 463 | setattr(obj, attr, cached) 464 | return cached 465 | 466 | 467 | def torch_required(func): 468 | # Chose a different decorator name than in tests so it's clear they are not the same. 469 | @wraps(func) 470 | def wrapper(*args, **kwargs): 471 | if is_torch_available(): 472 | return func(*args, **kwargs) 473 | else: 474 | raise ImportError(f"Method `{func.__name__}` requires PyTorch.") 475 | 476 | return wrapper 477 | 478 | 479 | def tf_required(func): 480 | # Chose a different decorator name than in tests so it's clear they are not the same. 481 | @wraps(func) 482 | def wrapper(*args, **kwargs): 483 | if is_tf_available(): 484 | return func(*args, **kwargs) 485 | else: 486 | raise ImportError(f"Method `{func.__name__}` requires TF.") 487 | 488 | return wrapper -------------------------------------------------------------------------------- /src/train_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | ''' 4 | * @Desc: train GPT2 from scratch/ fine tuning. 5 | Modified based on Huggingface GPT-2 implementation 6 | ''' 7 | import json 8 | import os 9 | import sys 10 | import argparse 11 | import logging 12 | import time 13 | import tqdm 14 | import datetime 15 | import torch 16 | 17 | import numpy as np 18 | 19 | from os.path import join 20 | from torch.distributed import get_rank, get_world_size 21 | 22 | from lsp_model_rl import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, Adam 23 | from lsp_model_rl import GPT2LMHeadModel_v2 24 | from gpt2_training.train_utils import load_model, boolean_string, set_lr, get_eval_list_same_length 25 | from gpt2_training.eval_utils import eval_model_loss 26 | 27 | from data_loader import BucketingDataLoader, DynamicBatchingLoader, DistributedBucketingDataLoader 28 | 29 | 30 | from gpt2_training.distributed import all_reduce_and_rescale_tensors, all_gather_list 31 | 32 | import sys, os 33 | sys.path.append("../") 34 | sys.path.append(".../") 35 | import MisinfoCorrect.src.variables_ext as cfg 36 | 37 | import time 38 | 39 | 40 | 41 | ############################ 42 | 43 | logging.basicConfig( 44 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 45 | datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 46 | logger = logging.getLogger(__name__) 47 | 48 | INF = 100000000 49 | CACHE_EMPTY_STEP = 1000 # previously, 1000, let's use 50 | EVAL_STEP = 100000 51 | 52 | ######################################################################### 53 | # Prepare Parser 54 | ########################################################################## 55 | 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--model_name_or_path', type=str, 58 | help='pretrained model name or path to local checkpoint') 59 | parser.add_argument("--seed", type=int, default=42) 60 | parser.add_argument("--max_seq_length", type=int, default=128) 61 | 62 | parser.add_argument("--skip_eval", action='store_true', 63 | help='If true, skip evaluation.') 64 | parser.add_argument("--use_baseline", action='store_true', 65 | help='If true, use baseline for RL.') 66 | parser.add_argument("--init_checkpoint", type=str) 67 | parser.add_argument("--train_input_file", type=str) 68 | parser.add_argument("--eval_input_file", type=str) 69 | parser.add_argument("--continue_from", type=int, default=0) 70 | 71 | parser.add_argument("--train_batch_size", type=int, default=4, 72 | help="batch size now means per GPU per step") 73 | parser.add_argument("--gradient_accumulation_steps", type=int, default=2, 74 | help="to increase effective batch size " 75 | "and reduce synchronization") 76 | parser.add_argument("--eval_batch_size", type=int, default=4) 77 | parser.add_argument("--learning_rate", type=float, default=1e-5) 78 | parser.add_argument("--num_optim_steps", type=int, default=1000000, 79 | help="new API specifies num update steps") 80 | parser.add_argument("--valid_step", type=int, default=1000, 81 | help="how many optim steps between validations") 82 | parser.add_argument("--warmup_proportion", type=float, default=0.1) 83 | parser.add_argument("--warmup_steps", type=int, default=16000) 84 | 85 | parser.add_argument("--normalize_data", type=boolean_string, default=True) 86 | parser.add_argument("--fp16", type=boolean_string, default=False) 87 | parser.add_argument("--lr_schedule", type=str, 88 | choices=['noam', 'noamwd', 'BERT', 'None'], default='noam') 89 | parser.add_argument("--loss_scale", type=float, default=0) 90 | parser.add_argument("--no_token_id", type=boolean_string, default=True) 91 | 92 | parser.add_argument("--output_dir", type=str) 93 | parser.add_argument("--log_dir", type=str) 94 | parser.add_argument('--pbar', type=boolean_string, default=True, help='turn on progress bar') 95 | 96 | # distributed 97 | parser.add_argument('--local_rank', type=int, default=-1, 98 | help='for torch.distributed') 99 | parser.add_argument('--config', help='JSON config file') 100 | 101 | 102 | # do normal parsing 103 | args = parser.parse_args() 104 | 105 | if args.config is not None: 106 | # override argparse defaults by config JSON 107 | opts = json.load(open(args.config)) 108 | for k, v in opts.items(): 109 | if isinstance(v, str): 110 | # PHILLY ENV special cases 111 | if 'PHILLY_JOB_DIRECTORY' in v: 112 | v = v.replace('PHILLY_JOB_DIRECTORY', 113 | os.environ['PHILLY_JOB_DIRECTORY']) 114 | elif 'PHILLY_LOG_DIRECTORY' in v: 115 | v = v.replace('PHILLY_LOG_DIRECTORY', 116 | os.environ['PHILLY_LOG_DIRECTORY']) 117 | setattr(args, k, v) 118 | 119 | # command line should override config JSON 120 | argv = sys.argv[1:] 121 | overrides, _ = parser.parse_known_args(argv) 122 | for k, v in vars(overrides).items(): 123 | if f'--{k}' in argv: 124 | setattr(args, k, v) 125 | setattr(args, 'local_rank', overrides.local_rank) 126 | 127 | 128 | assert args.train_batch_size % args.gradient_accumulation_steps == 0, \ 129 | 'batch size % gradient accumulation steps != 0!' 130 | args.train_batch_size = (args.train_batch_size 131 | // args.gradient_accumulation_steps) 132 | logger.info('train batch size = {}, ' 133 | 'new train batch size (after gradient accumulation) = {}'.format( 134 | args.train_batch_size*args.gradient_accumulation_steps, 135 | args.train_batch_size)) 136 | 137 | 138 | if args.local_rank == -1: 139 | logger.info('CUDA available? {}'.format(str(torch.cuda.is_available()))) 140 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 141 | # n_gpu = torch.cuda.device_count() 142 | # args.device, args.n_gpu = device, n_gpu 143 | # n_gpu = 1 144 | # device = torch.device("cuda:7") 145 | 146 | import sys, os 147 | sys.path.append("./") 148 | sys.path.append("../") 149 | sys.path.append(".../") 150 | sys.path.append("..../") 151 | from MisinfoCorrect.src.variables_ext import device, n_gpu 152 | 153 | args.device, args.n_gpu = device, n_gpu 154 | else: 155 | # distributed training 156 | print('args.local_rank:', args.local_rank) 157 | torch.cuda.set_device(args.local_rank) 158 | device = torch.device("cuda", args.local_rank) 159 | # Initializes the distributed backend which will take care of 160 | # sychronizing nodes/GPUs 161 | torch.distributed.init_process_group(backend='nccl') 162 | n_gpu = torch.distributed.get_world_size() 163 | args.device, args.n_gpu = device, 1 164 | logger.info("device: {} n_gpu: {}, distributed training: {}, " 165 | "16-bits training: {}".format( 166 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 167 | 168 | np.random.seed(args.seed) 169 | torch.random.manual_seed(args.seed) 170 | torch.cuda.manual_seed(args.seed) 171 | if n_gpu > 0: 172 | torch.cuda.manual_seed_all(args.seed) 173 | 174 | timestamp = datetime.datetime.now().strftime('%Y-%m-%d%H%M%S') 175 | # ==== TODO ==== by ext: use this output_dir 176 | output_dir = join(args.output_dir, 177 | 'GPT2.{}.{}.{}gpu.{}'.format(args.learning_rate, 178 | args.train_batch_size, n_gpu, 179 | timestamp)) 180 | log_dir = args.log_dir if args.log_dir is not None and len(args.log_dir) > 0 else output_dir 181 | if args.local_rank == -1 or get_rank() == 0: 182 | os.makedirs(output_dir, exist_ok=True) 183 | 184 | logger.info('Input Argument Information') 185 | args_dict = vars(args) 186 | for a in args_dict: 187 | logger.info('%-28s %s' % (a, args_dict[a])) 188 | 189 | 190 | ######################################################################### 191 | # Prepare Data Set 192 | ########################################################################## 193 | enc = GPT2Tokenizer.from_pretrained('gpt2-medium') 194 | enc.add_tokens(['', '', '']) 195 | eos = enc.encoder["<|endoftext|>"] 196 | 197 | config = GPT2Config.from_json_file( 198 | join(args.model_name_or_path, 'config.json')) 199 | 200 | # ext: single GPU or CPU 201 | if args.local_rank == -1: 202 | train_dataloader = BucketingDataLoader(args.train_input_file, 203 | args.train_batch_size, 204 | args.max_seq_length) 205 | # ext: multiple GPUs 206 | else: 207 | train_dataloader = DistributedBucketingDataLoader( 208 | get_rank(), get_world_size(), 209 | args.train_input_file, args.train_batch_size, 210 | args.max_seq_length) 211 | 212 | # eval_dataloader_loss = DynamicBatchingLoader( 213 | # args.eval_input_file, enc, args.normalize_data, 214 | # args.eval_batch_size, args.max_seq_length) 215 | 216 | # eval_dataloader_gen = get_eval_list_same_length( 217 | # args.eval_input_file, enc, args.eval_batch_size, True) 218 | 219 | 220 | ######################################################################### 221 | # Prepare Model and Optimizer 222 | ########################################################################## 223 | 224 | gpt2_model = GPT2LMHeadModel_v2.from_pretrained(args.model_name_or_path) # ext 225 | print("we use version 2 of GPT model from ext") 226 | 227 | gpt2_model.resize_token_embeddings(len(enc)) 228 | 229 | # ==== understand by ext: it seems like device setup ==== 230 | model = load_model(gpt2_model, args.init_checkpoint, 231 | args, verbose=True) 232 | if args.local_rank != -1: 233 | # when from scratch make sure initial models are the same 234 | params = [p.data for p in model.parameters()] 235 | all_reduce_and_rescale_tensors( 236 | params, float(torch.distributed.get_world_size())) 237 | 238 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 239 | total_params = sum([np.prod(p.size()) for p in model_parameters]) 240 | logger.info('Number of parameter = {}'.format(total_params)) 241 | 242 | param_optimizer = list(model.named_parameters()) 243 | no_decay = ['bias', 'ln'] # no decay for bias and LayerNorm (ln) 244 | optimizer_grouped_parameters = [ 245 | {'params': [p for n, p in param_optimizer 246 | if not any(nd in n for nd in no_decay)], 247 | 'weight_decay': 0.01}, 248 | {'params': [p for n, p in param_optimizer 249 | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 250 | ] 251 | 252 | if args.fp16: 253 | logger.info('in fp16, using FusedAdam') 254 | try: 255 | from apex.optimizers import FP16_Optimizer 256 | from apex.optimizers import FusedAdam 257 | except ImportError: 258 | raise ImportError( 259 | "Please install apex from https://www.github.com/nvidia/apex " 260 | "to use distributed and fp16 training.") 261 | 262 | optimizer = FusedAdam(optimizer_grouped_parameters, 263 | lr=args.learning_rate, 264 | bias_correction=False, 265 | max_grad_norm=1.0) 266 | if args.loss_scale == 0: 267 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True, 268 | verbose=False) 269 | else: 270 | optimizer = FP16_Optimizer(optimizer, 271 | static_loss_scale=args.loss_scale, 272 | verbose=False) 273 | else: 274 | optimizer = Adam(optimizer_grouped_parameters, args.learning_rate, 275 | max_grad_norm=1.0) 276 | 277 | ######################################################################### 278 | # Training ! 279 | ########################################################################## 280 | 281 | if args.local_rank == -1 or get_rank() == 0: 282 | train_logger = open(join(log_dir, 'train_log.txt'), 'a+', buffering=1) 283 | eval_logger = open(join(log_dir, 'eval_log.txt'), 'a+', buffering=1) 284 | print('epoch,global_step,step,mean_loss,n_token_real,' 285 | 'n_token_total,epoch_time', file=train_logger) 286 | print('epoch,global_step,step,eval_loss', file=eval_logger) 287 | 288 | global_step = 0 # multiple batches one global_step where we update the gradient 289 | step = 0 # one batch, one step 290 | epoch = 0 # the default definition: all training examples one epoch 291 | 292 | if args.continue_from: 293 | global_step = args.continue_from 294 | step = global_step*2 - 1 295 | 296 | 297 | if args.local_rank != -1: 298 | n_gpu = 1 299 | if args.local_rank == -1 or get_rank() == 0: 300 | if args.pbar: 301 | pbar = tqdm.tqdm(total=args.num_optim_steps, desc=f"training") 302 | else: 303 | pbar = None 304 | 305 | # pbar = None 306 | 307 | while True: 308 | 309 | if args.use_baseline: 310 | moving_avg_cnt = 20 311 | cumsum = [0.0] 312 | moving_avg_idx = 0 313 | 314 | model.train() 315 | (tr_loss, nb_tr_examples, nb_tr_steps) = 0.0, 0.0, 0.0 316 | n_token_real, n_token_total = 0, 0 317 | train_start_time_epoch = time.time() 318 | tr_reward = 0.0 319 | 320 | # print('iteration started') 321 | 322 | for batch in train_dataloader: 323 | # torch.cuda.empty_cache() 324 | # activate new training mode 325 | seq_len = batch[0].shape[1] 326 | batch = tuple(t for t in batch) 327 | 328 | input_ids, position_ids, token_ids, seeker_post, response_post, = batch 329 | 330 | input_ids = input_ids.to(device) # ext: here, input_ids, we have seeker_post+response_post 331 | position_ids = position_ids.to(device) 332 | token_ids = token_ids.to(device) 333 | 334 | if args.no_token_id: 335 | token_ids = None 336 | 337 | forward_pass_start_time = time.time() 338 | 339 | # ext: in the released code: we do not use the baseline 340 | # which means, we use the relative distance for the downstream task. 341 | if args.use_baseline: 342 | if len(cumsum) >= moving_avg_cnt: 343 | baseline_val = (cumsum[moving_avg_idx] - cumsum[moving_avg_idx-moving_avg_cnt])/moving_avg_cnt 344 | else: 345 | baseline_val = cumsum[moving_avg_idx] 346 | 347 | loss, reward = model(input_ids, position_ids=position_ids, token_type_ids=token_ids, seeker_post=seeker_post, response_post=response_post, eos=eos, tokenizer=enc, baseline_val=baseline_val) 348 | 349 | cumsum.append(cumsum[moving_avg_idx-1] + reward) 350 | moving_avg_idx+=1 351 | # ext: in the release code, we use this one. 352 | else: 353 | loss, reward = model(input_ids, position_ids=position_ids, token_type_ids=token_ids, 354 | seeker_post=seeker_post, response_post=response_post, eos=eos, tokenizer=enc) 355 | 356 | forward_pass_end_time = time.time() 357 | 358 | backward_pass_start_time = time.time() 359 | 360 | # print(f"the loss is: {loss}") # the loss is a negative value 361 | 362 | if n_gpu > 1: 363 | loss = loss.mean() 364 | loss = loss / (args.train_batch_size / input_ids.shape[0]) 365 | if args.fp16: 366 | optimizer.backward(loss) 367 | else: 368 | loss.backward() 369 | # ==== add by ext to see the reward change: 370 | if n_gpu > 1: 371 | reward = reward.mean() 372 | reward = reward / (args.train_batch_size / input_ids.shape[0]) 373 | 374 | 375 | backward_pass_end_time = time.time() 376 | 377 | tr_loss += float(loss.item()) * (args.train_batch_size / input_ids.shape[0]) 378 | tr_reward += float(reward.item()) * (args.train_batch_size / input_ids.shape[0]) 379 | 380 | nb_tr_examples += input_ids.size(0) 381 | nb_tr_steps += 1 382 | 383 | mean_loss = tr_loss / nb_tr_steps 384 | mean_reward = tr_reward / nb_tr_steps 385 | 386 | n_token_total += input_ids.shape[0] * input_ids.shape[1] 387 | n_token_real += (input_ids != 0).sum().item() 388 | 389 | # gradient update 390 | step += 1 391 | print(f'the step is: {step}, the time is: {time.time()}') # added by ext to monitor the running time 392 | # ext: it seems that only update gradient after multiple batches 393 | # gradient_accumulation_steps by default is 2. 394 | if step % args.gradient_accumulation_steps == 0: 395 | # ext: it seems to optimize the learning rate 396 | set_lr(optimizer, global_step, 397 | args.lr_schedule, args.learning_rate, 398 | args.warmup_steps, args.warmup_proportion, 399 | config.n_embd, args.num_optim_steps) 400 | 401 | if args.local_rank != -1: 402 | grads = [p.grad.data for p in model.parameters() 403 | if p.requires_grad and p.grad is not None] 404 | all_reduce_and_rescale_tensors(grads, float(1)) 405 | 406 | optimizer.step() 407 | optimizer.zero_grad() 408 | global_step += 1 409 | 410 | # Print log info to file 411 | if args.local_rank != -1: 412 | mean_loss = sum(all_gather_list(mean_loss)) / get_world_size() 413 | n_token_real_all_proc = sum(all_gather_list(n_token_real)) 414 | n_token_total_all_proc = sum(all_gather_list(n_token_total)) 415 | else: 416 | n_token_real_all_proc = n_token_real 417 | n_token_total_all_proc = n_token_total 418 | 419 | if args.local_rank == -1 or get_rank() == 0: 420 | epoch_time = time.time() - train_start_time_epoch 421 | 422 | # print('step:', global_step, 'time:', forward_pass_end_time - forward_pass_start_time, backward_pass_end_time - backward_pass_start_time) 423 | 424 | if pbar is not None: 425 | pbar.set_postfix_str( 426 | f"tok/s: {n_token_real_all_proc//epoch_time//1000}k " 427 | f"epoch: {epoch}") 428 | pbar.update(1) 429 | # commented by ext: append the result to train_log.txt 430 | print('{},{},{},{},{},{},{}'.format( 431 | epoch+1, global_step+1, step+1, mean_loss, 432 | n_token_real_all_proc, n_token_total_all_proc, epoch_time), 433 | file=train_logger) 434 | 435 | if global_step % args.valid_step == 0: 436 | if args.local_rank == -1 or get_rank() == 0: 437 | # by ext: TODO: check how to save by these lines 438 | # only rank 0 process evaluate 439 | torch.save( 440 | {k: (v.cpu() if v is not None else None) # save to cpu tensors 441 | for k, v in model.state_dict().items()}, 442 | join(output_dir, 443 | f'GP2-pretrain-step-{global_step}.pkl')) 444 | 445 | # eval_loss, eval_ppl = eval_model_loss( 446 | # model, enc, eval_dataloader_loss, epoch, args) 447 | # enable generation step evaluation for now 448 | # gen_response = eval_model_generation( 449 | # model, enc, eval_dataloader_gen, epoch, args) 450 | ''' 451 | # probably use beam search only for test set 452 | if False: 453 | gen_response_beam = eval_model_generation( 454 | model, enc, eval_dataloader_gen, epoch, args, 455 | use_beam_search=True, beam_width=3) 456 | ''' 457 | # print('{},{},{},{},{}'.format( 458 | # epoch+1, global_step+1, step+1, eval_loss, eval_ppl), 459 | # file=eval_logger) 460 | logger.info('current learning rate: ' 461 | + str(optimizer.param_groups[0]['lr'])) 462 | model.train() 463 | if global_step >= args.num_optim_steps: 464 | break 465 | # ==== commented by ext ==== 466 | # if (step+1) % CACHE_EMPTY_STEP == 0: 467 | # torch.cuda.empty_cache() 468 | # ==== by ext ====: after one batch, we print some results 469 | print(f'epoch: {epoch + 1}, mean loss:{mean_loss}, mean reward:{mean_reward}', file=train_logger) 470 | # # ==== by ext ====: add here 471 | # if (step+1) % CACHE_EMPTY_STEP == 0: 472 | # torch.cuda.empty_cache() 473 | # ==== or after one epoch, we clear something, since the % may not hold 474 | # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 475 | with torch.cuda.device(cfg.device): 476 | torch.cuda.empty_cache() 477 | if global_step >= args.num_optim_steps: 478 | break 479 | epoch += 1 480 | # ==== by ext ==== save the model some checkpoints 481 | if epoch % cfg.every_k_epoch_save_model == 0: 482 | save_dir = f"{output_dir}/epoch_{epoch}" 483 | os.makedirs(save_dir, exist_ok=True) 484 | model.save_pretrained(save_dir) 485 | # ==== end ==== 486 | 487 | 488 | if args.local_rank == -1 or get_rank() == 0: 489 | if pbar is not None: 490 | pbar.close() 491 | train_logger.close() 492 | eval_logger.close() 493 | 494 | # ==== save model by ext ==== 495 | model.save_pretrained(output_dir) 496 | print(f"yes, we run the program and save the model at: {output_dir}") 497 | 498 | 499 | -------------------------------------------------------------------------------- /src/lsp_model_rl/modeling_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """PyTorch OpenAI GPT-2 model.""" 3 | 4 | from __future__ import absolute_import, division, print_function, unicode_literals 5 | 6 | import logging 7 | import copy 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import CrossEntropyLoss 12 | from typing import Iterable, Optional, Tuple 13 | from torch import Tensor 14 | import time 15 | import nltk 16 | import numpy as np 17 | 18 | from transformers import GPT2PreTrainedModel, GPT2Model 19 | 20 | from pytorch_pretrained_bert.modeling_gpt2 import GPT2LMHead, Attention, Block, \ 21 | LayerNorm, MLP 22 | 23 | # from generate import top_filtering 24 | 25 | from .rewards import calc_rewards 26 | 27 | # from .generation_utils import GenerationMixin 28 | 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | import sys, os 33 | sys.path.append("../") 34 | sys.path.append(".../") 35 | import MisinfoCorrect.src.variables_ext as cfg 36 | 37 | class AttentionFP16(Attention): 38 | def __init__(self, nx, n_ctx, config, scale=False): 39 | super(AttentionFP16, self).__init__(nx, n_ctx, config, scale) 40 | 41 | def _attn(self, q, k, v): 42 | w = torch.matmul(q, k) 43 | if self.scale: 44 | w = w / math.sqrt(v.size(-1)) 45 | nd, ns = w.size(-2), w.size(-1) 46 | b = self.bias[:, :, ns-nd:ns, :ns] 47 | w = w * b - 1e4 * (1 - b) # point out by Yen-Chun, FP16 overflow 48 | 49 | w = nn.Softmax(dim=-1)(w) 50 | return torch.matmul(w, v) 51 | 52 | 53 | class BlockFP16(Block): 54 | def __init__(self, n_ctx, config, scale=False): 55 | super(BlockFP16, self).__init__(n_ctx, config, scale) 56 | nx = config.n_embd 57 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) 58 | self.attn = AttentionFP16(nx, n_ctx, config, scale) 59 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) 60 | self.mlp = MLP(4 * nx, config) 61 | 62 | 63 | class GPT2ModelFP16(GPT2Model): 64 | def __init__(self, config): 65 | # super(GPT2ModelFP16, self).__init__(config) 66 | super().__init__(config) 67 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 68 | self.wpe = nn.Embedding(config.n_positions, config.n_embd) 69 | block = BlockFP16(config.n_ctx, config, scale=True) 70 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) 71 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 72 | 73 | self.init_weights() 74 | 75 | class GPT2LMHeadModel(GPT2PreTrainedModel): 76 | def __init__(self, config): 77 | super(GPT2LMHeadModel, self).__init__(config) 78 | self.transformer = GPT2Model(config) 79 | # lm_head generated hidden state to lm_logits -> vocabulary distribution 80 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # GPT2LMHead(self.transformer.wte.weight, config) 81 | self.position_num_labels = 2 # insert/replace: only two 82 | self.lambda_position = 0.1 83 | # position classifier as mentioned in the paper 84 | # actually, the head is a linear layer for the classification 85 | self.position_classifier = GPT2ClassificationHead(num_labels = self.position_num_labels) #GPT2LMHead(self.transformer.wte.weight, config) 86 | self.init_weights() 87 | 88 | def set_tied(self): 89 | """ Make sure we are sharing the embeddings 90 | """ 91 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight) 92 | 93 | def get_output_embeddings(self): 94 | return self.lm_head 95 | 96 | def padding_tensor_3D(self, sequences, max_len): 97 | """ 98 | :param sequences: list of tensors 99 | :return: 100 | """ 101 | num = len(sequences) 102 | out_dims = (num, max_len, *sequences[0].shape[1:]) 103 | out_tensor = sequences[0].data.new(*out_dims).fill_(0) 104 | 105 | # print('out_tensor:', out_tensor.shape) 106 | 107 | mask = sequences[0].data.new(*out_dims).fill_(0) 108 | for i, tensor in enumerate(sequences): 109 | length = tensor.size(0) 110 | # print('length:', length) 111 | out_tensor[i, :length] = tensor 112 | mask[i, :length] = 1 113 | return out_tensor, mask 114 | 115 | def padding_tensor_2D(self, sequences, max_len): 116 | """ 117 | :param sequences: list of tensors 118 | :return: 119 | """ 120 | num = len(sequences) 121 | out_dims = (num, max_len) 122 | out_tensor = sequences[0].data.new(*out_dims).fill_(0) 123 | 124 | # print('out_tensor:', out_tensor.shape) 125 | 126 | mask = sequences[0].data.new(*out_dims).fill_(0) 127 | for i, tensor in enumerate(sequences): 128 | length = min(tensor.size(0), max_len) 129 | # print('length:', length) 130 | out_tensor[i, :length] = tensor[:length] 131 | mask[i, :length] = 1 132 | return out_tensor, mask 133 | 134 | def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, position_labels=None, past=None, seeker_post=None, response_post=None, top_k=60, top_p=0.92, temperature=0.9, eos=None, tokenizer=None, baseline_val=0): 135 | 136 | transformer_start_time = time.time() 137 | 138 | # Forward Transformer Pass 139 | # self.transformer is a GPT model: no generation at this moment 140 | hidden_states, presents = self.transformer(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, past=past) 141 | # res = self.transformer(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, past_key_values=past) # past_key_values from past # by ext 142 | # hidden_states, presents = res.last_hidden_state, None # by ext last_hidden_state 143 | 144 | transformer_end_time = time.time() 145 | 146 | # Get LM and position logits 147 | lm_logits = self.lm_head(hidden_states) # 148 | 149 | if tokenizer is None: 150 | return lm_logits, presents 151 | 152 | # ext: i think X2 to expand it to 2k+1: even if we just k positions, we double it. 153 | # -1: we select the final dimension data: --- dimension-driven understanding 154 | # like a.shape = 2,3,4 -> :,-1,:->2,4 155 | position_logits = self.position_classifier(hidden_states[:, -1, :]) # X2: shape 156 | 157 | # A1 (Selecting a position) 158 | probs_position = torch.softmax(position_logits.view(-1, self.position_num_labels), -1) # (batch_size, num_position) 159 | all_positions = torch.argmax(probs_position, 1) 160 | all_positions = all_positions.squeeze() 161 | 162 | all_positions = all_positions.cpu().numpy().tolist() 163 | 164 | # A2 (Candidate Sentence): Sample from DialoGPT 165 | all_outputs = [] 166 | all_output_ids = [] 167 | all_output_logits = [] 168 | sample_dialogpt_start_time = time.time() 169 | 170 | for ii, _ in enumerate(input_ids): 171 | curr_seeker = tokenizer.encode(seeker_post[ii] + tokenizer.eos_token) 172 | curr_seeker = torch.tensor([curr_seeker,]) 173 | 174 | # ==== ext: device control ==== 175 | # ==== ext: device control ==== 176 | curr_seeker = curr_seeker # .to(device) # by ext from cuda for .to(device), use the version2 177 | # self.generate: is a method of class from transformers.PreTrainedModel 178 | # here parameters, like top_p, top_k, temperature 179 | generated_output = self.generate(input_ids = curr_seeker, 180 | max_length=1000, 181 | pad_token_id=tokenizer.eos_token_id, 182 | top_p=0.92, 183 | top_k=60, 184 | temperature=1, 185 | num_return_sequences=1) 186 | 187 | curr_output = tokenizer.decode(generated_output[:, curr_seeker.shape[-1]:][0], skip_special_tokens=True) 188 | 189 | curr_output_ids = generated_output[:, curr_seeker.shape[-1]:][0] 190 | curr_output_ids = curr_output_ids[:hidden_states.shape[1]] 191 | curr_position_ids = torch.tensor(range(len(curr_output_ids)), dtype=torch.long) # ext: .to(device) for .to(device), use the version2 192 | 193 | curr_output_logits = lm_logits[ii, range(curr_output_ids.shape[0]), curr_output_ids] 194 | 195 | all_outputs.append(curr_output) 196 | all_output_ids.append(curr_output_ids) 197 | all_output_logits.append(curr_output_logits) 198 | 199 | log_softmax = nn.LogSoftmax(1) 200 | 201 | all_output_logits, _ = self.padding_tensor_2D(all_output_logits, hidden_states.shape[1]) 202 | all_output_logits = log_softmax(all_output_logits) 203 | 204 | sample_dialogpt_end_time = time.time() 205 | 206 | 207 | # Calculate Reward 208 | 209 | rewritten_response = [] 210 | 211 | for idx, _ in enumerate(all_outputs): 212 | curr_seeker_post = seeker_post[idx] 213 | curr_response = response_post[idx] 214 | curr_output = all_outputs[idx] 215 | curr_position = all_positions[idx] 216 | 217 | curr_response_li = nltk.sent_tokenize(curr_response) 218 | 219 | if curr_position == 0: 220 | curr_rewritten_response = curr_response 221 | 222 | else: 223 | curr_rewritten_response_li = curr_response_li[:curr_position] + [curr_output] + curr_response_li[curr_position:] 224 | curr_rewritten_response = '. '.join(curr_rewritten_response_li) 225 | 226 | rewritten_response.append(curr_rewritten_response) 227 | 228 | reward_start_time = time.time() 229 | # TODO: ext: we rewrite this part 230 | # ==== partner's ==== 231 | # we change the reward in v2, the follow child class 232 | # reward = calc_rewards(seeker_post, response_post, rewritten_response, _empathy_change=True, _perplexity=True) 233 | reward = calc_rewards(seeker_post, response_post, rewritten_response, 234 | _politeness=True, 235 | _coherence=False, _perplexity=True) 236 | reward_end_time = time.time() 237 | 238 | batches = np.arange(input_ids.shape[0]).tolist() 239 | 240 | rl_loss = - (reward - baseline_val) * (-torch.mean(all_output_logits[batches,]) + torch.mean(torch.log(probs_position[batches, all_positions]) )) 241 | 242 | return rl_loss, reward 243 | 244 | # ext: by search on github repo: this function is not used 245 | def forward_pointwise(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): 246 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) 247 | # import pdb; pdb.set_trace() 248 | lm_logits = self.lm_head(hidden_states) 249 | if lm_labels is not None: 250 | # loss_fct = CrossEntropyLoss(ignore_index=-1) 251 | # loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)) 252 | loss_fct1 = CrossEntropyLoss(ignore_index=-1, reduction='none') 253 | loss1 = loss_fct1(lm_logits.view(-1, lm_logits.size(-1)), 254 | lm_labels.view(-1)) 255 | loss1 = loss1.view(lm_labels.size(0), lm_labels.size(1)) 256 | label_size = torch.sum(lm_labels != -1, dim=1).type(loss1.type()) 257 | loss1 = torch.sum(loss1, dim=1)/label_size 258 | ppl1 = torch.exp(loss1) 259 | 260 | return loss1, ppl1 261 | return lm_logits, presents 262 | 263 | def prepare_inputs_for_generation(self, input_ids, **kwargs): 264 | return {"input_ids": input_ids} 265 | 266 | # created by ext for our purpose 267 | class GPT2LMHeadModel_v2(GPT2LMHeadModel): 268 | def __init__(self, config): 269 | super(GPT2LMHeadModel_v2, self).__init__(config) 270 | self.transformer = GPT2Model(config) 271 | # lm_head generated hidden state to lm_logits -> vocabulary distribution 272 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # GPT2LMHead(self.transformer.wte.weight, config) 273 | # self.position_num_labels = 2 # insert/replace: only two 274 | # self.lambda_position = 0.1 # commented for inheritation: ext 275 | # TODO: maybe, when we load config, we do not find the weight and it is missing 276 | # position classifier as mentioned in the paper 277 | # actually, the head is a linear layer for the classification 278 | # self.position_classifier = GPT2ClassificationHead(num_labels = self.position_num_labels) #GPT2LMHead(self.transformer.wte.weight, config) 279 | self.init_weights() 280 | 281 | def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, position_labels=None, 282 | past=None, seeker_post=None, response_post=None, top_k=60, top_p=0.92, temperature=0.9, eos=None, 283 | tokenizer=None, baseline_val=0): 284 | # print("==== use the forward function in version 2 ====") 285 | 286 | transformer_start_time = time.time() 287 | 288 | # Forward Transformer Pass 289 | # self.transformer is a GPT model: no generation at this moment 290 | # from the code setup: when we have batches: input_ids can be n-batch-size*length-of-a sentence/input-token-id 291 | hidden_states, presents = self.transformer(input_ids=input_ids, position_ids=position_ids, 292 | token_type_ids=token_type_ids, past=past) 293 | # res = self.transformer(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, past_key_values=past) # past_key_values from past # by ext 294 | # hidden_states, presents = res.last_hidden_state, None # by ext last_hidden_state 295 | 296 | transformer_end_time = time.time() 297 | 298 | # Get LM and position logits 299 | lm_logits = self.lm_head(hidden_states) # 300 | 301 | if tokenizer is None: 302 | return lm_logits, presents 303 | 304 | # ext: i think X2 to expand it to 2k+1: even if we just k positions, we double it. 305 | # -1: we select the final dimension data: --- dimension-driven understanding 306 | # like a.shape = 2,3,4 -> :,-1,:->2,4 307 | # position_logits = self.position_classifier(hidden_states[:, -1, :]) # X2: shape 308 | 309 | # A1 (Selecting a position) 310 | # probs_position = torch.softmax(position_logits.view(-1, self.position_num_labels),-1) # (batch_size, num_position) 311 | # all_positions = torch.argmax(probs_position, 1) 312 | # all_positions = all_positions.squeeze() 313 | 314 | # all_positions = all_positions.cpu().numpy().tolist() 315 | 316 | # A2 (Candidate Sentence): Sample from DialoGPT 317 | all_outputs = [] 318 | all_output_ids = [] 319 | all_output_logits = [] 320 | sample_dialogpt_start_time = time.time() 321 | 322 | for ii, _ in enumerate(input_ids): 323 | curr_seeker = tokenizer.encode(seeker_post[ii] + tokenizer.eos_token) 324 | curr_seeker = torch.tensor([curr_seeker, ]) 325 | 326 | # ==== ext: device control ==== 327 | import sys, os 328 | sys.path.append("../") 329 | sys.path.append(".../") 330 | from MisinfoCorrect.src.variables_ext import device 331 | # ==== ext: device control ==== 332 | curr_seeker = curr_seeker.to(device) # by ext from cuda 333 | # self.generate: is a method of class from transformers.PreTrainedModel 334 | # here parameters, like top_p, top_k, temperature 335 | generated_output = self.generate(input_ids=curr_seeker, 336 | max_length=140, # by ext from 1000 to 140 337 | pad_token_id=tokenizer.eos_token_id, 338 | top_p=0.92, 339 | top_k=60, 340 | temperature=1, 341 | num_return_sequences=1) 342 | 343 | curr_output = tokenizer.decode(generated_output[:, curr_seeker.shape[-1]:][0], skip_special_tokens=True) 344 | 345 | curr_output_ids = generated_output[:, curr_seeker.shape[-1]:][0] 346 | curr_output_ids = curr_output_ids[:hidden_states.shape[1]] # ext: TODO: expanded usage? 347 | # TODO: or nonsense due to the large number of hidden state 348 | curr_position_ids = torch.tensor(range(len(curr_output_ids)), dtype=torch.long).to( 349 | device) # by ext 350 | 351 | # ext: here, we propability of each vob in the vocabulary, like p1*p2*p3 352 | curr_output_logits = lm_logits[ii, range(curr_output_ids.shape[0]), curr_output_ids] 353 | 354 | all_outputs.append(curr_output) 355 | all_output_ids.append(curr_output_ids) 356 | all_output_logits.append(curr_output_logits) 357 | 358 | log_softmax = nn.LogSoftmax(1) 359 | 360 | all_output_logits, _ = self.padding_tensor_2D(all_output_logits, hidden_states.shape[1]) 361 | all_output_logits = log_softmax(all_output_logits) 362 | 363 | sample_dialogpt_end_time = time.time() 364 | 365 | # Calculate Reward 366 | 367 | rewritten_response = [] 368 | 369 | for idx, _ in enumerate(all_outputs): 370 | curr_seeker_post = seeker_post[idx] 371 | curr_response = response_post[idx] 372 | curr_output = all_outputs[idx] 373 | # curr_position = all_positions[idx] 374 | 375 | curr_response_li = nltk.sent_tokenize(curr_response) 376 | 377 | # ==== previous partner ==== 378 | # if curr_position == 0: 379 | # curr_rewritten_response = curr_response 380 | # else: 381 | # curr_rewritten_response_li = curr_response_li[:curr_position] + [curr_output] + curr_response_li[ 382 | # curr_position:] 383 | # curr_rewritten_response = '. '.join(curr_rewritten_response_li) 384 | # ==== version 2 ==== 385 | 386 | curr_rewritten_response_li = [curr_output] 387 | curr_rewritten_response = '. '.join(curr_rewritten_response_li) 388 | 389 | rewritten_response.append(curr_rewritten_response) 390 | 391 | reward_start_time = time.time() 392 | # TODO: ext: we rewrite this part 393 | # ==== partner's ==== 394 | # reward = calc_rewards(seeker_post, response_post, rewritten_response, _empathy_change=True, _perplexity=True) 395 | # ==== by ext ==== 396 | # ======== debug =======: seeker_post: list, ['i like it', 'i forget it'] 397 | # print(f'the current seeker post is: {seeker_post}') 398 | # print(f'the current rewritten_response is: {rewritten_response}') 399 | # exit(0) 400 | 401 | if_cut_responses_with_more_characters = True 402 | if if_cut_responses_with_more_characters: 403 | rewritten_response_new = [] 404 | for one_response in rewritten_response: 405 | if len(one_response) > 280: 406 | rewritten_response_new.append(one_response[:280]) 407 | else: 408 | rewritten_response_new.append(one_response) 409 | 410 | rewritten_response = rewritten_response_new 411 | 412 | reward = calc_rewards(seeker_post, None, rewritten_response, 413 | _politeness=cfg.if_politeness, 414 | _refutation=cfg.if_refutation, 415 | _evidence=cfg.if_evidence, 416 | _perplexity=cfg.if_perplexity, 417 | _relevance=cfg.if_relevance) 418 | reward_end_time = time.time() 419 | 420 | batches = np.arange(input_ids.shape[0]).tolist() 421 | 422 | # rl_loss = - (reward - baseline_val) * (-torch.mean(all_output_logits[batches,]) + torch.mean( 423 | # torch.log(probs_position[batches, all_positions]))) 424 | # ==== version 2 ==== 425 | # log addition -> 0 for probability 1 426 | # in the release code, baseline_val = 0 427 | # reward is positive: like cross-entropy setup 428 | # here, -1 to change the negative value to the positive value 429 | # some data examples: reward: 0.50, rl_loss is -95, then prob after log: ~200 430 | # please note: the baseline_val is 0 if we do not initiate it 431 | rl_loss = - (reward - baseline_val) * (-torch.mean(all_output_logits[batches,])) 432 | # print(f"the reward, baseline_val, all_output_logits[batches,], rl_loss is:" 433 | # f"\n {reward}\n{baseline_val},\n{all_output_logits[batches,]},\n{rl_loss}") 434 | # exit 435 | return rl_loss, reward 436 | 437 | class GPT2ClassificationHead(nn.Module): 438 | """Head for sentence-level classification tasks.""" 439 | 440 | def __init__(self, hidden_dropout_prob=0.1, hidden_size=1024, num_labels=2): 441 | super().__init__() 442 | 443 | # self.dense = nn.Linear(hidden_size, hidden_size) # previous by the 444 | self.dense = nn.Linear(768, hidden_size) # by ext from 768*2 to 1024, 445 | self.dropout = nn.Dropout(hidden_dropout_prob) 446 | self.out_proj = nn.Linear(hidden_size, num_labels) 447 | 448 | def forward(self, features, **kwargs): 449 | print(f"***{features.shape}***") 450 | x = features[:, :] 451 | x = self.dropout(x) 452 | print(f"***{x.shape}***") 453 | x = self.dense(x) 454 | x = torch.relu(x) 455 | x = self.dropout(x) 456 | x = self.out_proj(x) 457 | return x 458 | -------------------------------------------------------------------------------- /src/lsp_model_rl/util/configuration_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Configuration base class and utilities.""" 17 | 18 | 19 | import copy 20 | import json 21 | import logging 22 | import os 23 | from typing import Dict, Optional, Tuple 24 | 25 | from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class PretrainedConfig(object): 32 | r""" Base class for all configuration classes. 33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. 34 | 35 | Note: 36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. 37 | It only affects the model's configuration. 38 | 39 | Class attributes (overridden by derived classes): 40 | - ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values. 41 | - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`. 42 | 43 | Args: 44 | finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`): 45 | Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. 46 | num_labels (:obj:`int`, `optional`, defaults to `2`): 47 | Number of classes to use when the model is a classification model (sequences/tokens) 48 | output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`): 49 | Should the model returns attentions weights. 50 | output_hidden_states (:obj:`string`, `optional`, defaults to :obj:`False`): 51 | Should the model returns all hidden-states. 52 | torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`): 53 | Is the model used with Torchscript (for PyTorch models). 54 | """ 55 | pretrained_config_archive_map: Dict[str, str] = {} 56 | model_type: str = "" 57 | 58 | def __init__(self, **kwargs): 59 | # Attributes with defaults 60 | self.output_attentions = kwargs.pop("output_attentions", False) 61 | self.output_hidden_states = kwargs.pop("output_hidden_states", False) 62 | self.use_cache = kwargs.pop("use_cache", True) # Not used by all models 63 | self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models 64 | self.use_bfloat16 = kwargs.pop("use_bfloat16", False) 65 | self.pruned_heads = kwargs.pop("pruned_heads", {}) 66 | 67 | # Is decoder is used in encoder-decoder models to differentiate encoder from decoder 68 | self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) 69 | self.is_decoder = kwargs.pop("is_decoder", False) 70 | 71 | # Parameters for sequence generation 72 | self.max_length = kwargs.pop("max_length", 20) 73 | self.min_length = kwargs.pop("min_length", 0) 74 | self.do_sample = kwargs.pop("do_sample", False) 75 | self.early_stopping = kwargs.pop("early_stopping", False) 76 | self.num_beams = kwargs.pop("num_beams", 1) 77 | self.temperature = kwargs.pop("temperature", 1.0) 78 | self.top_k = kwargs.pop("top_k", 50) 79 | self.top_p = kwargs.pop("top_p", 1.0) 80 | self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) 81 | self.length_penalty = kwargs.pop("length_penalty", 1.0) 82 | self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) 83 | self.bad_words_ids = kwargs.pop("bad_words_ids", None) 84 | self.num_return_sequences = kwargs.pop("num_return_sequences", 1) 85 | 86 | # Fine-tuning task arguments 87 | self.architectures = kwargs.pop("architectures", None) 88 | self.finetuning_task = kwargs.pop("finetuning_task", None) 89 | self.num_labels = kwargs.pop("num_labels", 2) 90 | self.empathy_num_labels = kwargs.pop("empathy_num_labels", 2) 91 | self.rationale_num_labels = kwargs.pop("rationale_num_labels", 2) 92 | self.id2label = kwargs.pop("id2label", {i: f"LABEL_{i}" for i in range(self.num_labels)}) 93 | self.id2label = dict((int(key), value) for key, value in self.id2label.items()) 94 | self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys()))) 95 | self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) 96 | 97 | # Tokenizer arguments TODO: eventually tokenizer and models should share the same config 98 | self.prefix = kwargs.pop("prefix", None) 99 | self.bos_token_id = kwargs.pop("bos_token_id", None) 100 | self.pad_token_id = kwargs.pop("pad_token_id", None) 101 | self.eos_token_id = kwargs.pop("eos_token_id", None) 102 | self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) 103 | 104 | # task specific arguments 105 | self.task_specific_params = kwargs.pop("task_specific_params", None) 106 | 107 | # TPU arguments 108 | self.xla_device = kwargs.pop("xla_device", None) 109 | 110 | # Additional attributes without default values 111 | for key, value in kwargs.items(): 112 | try: 113 | setattr(self, key, value) 114 | except AttributeError as err: 115 | logger.error("Can't set {} with value {} for {}".format(key, value, self)) 116 | raise err 117 | 118 | @property 119 | def num_labels(self): 120 | return self._num_labels 121 | 122 | @num_labels.setter 123 | def num_labels(self, num_labels): 124 | self._num_labels = num_labels 125 | self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)} 126 | self.id2label = dict((int(key), value) for key, value in self.id2label.items()) 127 | self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) 128 | self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) 129 | 130 | def save_pretrained(self, save_directory): 131 | """ 132 | Save a configuration object to the directory `save_directory`, so that it 133 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. 134 | 135 | Args: 136 | save_directory (:obj:`string`): 137 | Directory where the configuration JSON file will be saved. 138 | """ 139 | assert os.path.isdir( 140 | save_directory 141 | ), "Saving path should be a directory where the model and configuration can be saved" 142 | 143 | # If we save using the predefined names, we can load using `from_pretrained` 144 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 145 | 146 | self.to_json_file(output_config_file, use_diff=True) 147 | logger.info("Configuration saved in {}".format(output_config_file)) 148 | 149 | @classmethod 150 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig": 151 | r""" 152 | 153 | Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. 154 | 155 | Args: 156 | pretrained_model_name_or_path (:obj:`string`): 157 | either: 158 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or 159 | download, e.g.: ``bert-base-uncased``. 160 | - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to 161 | our S3, e.g.: ``dbmdz/bert-base-german-cased``. 162 | - a path to a `directory` containing a configuration file saved using the 163 | :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 164 | - a path or url to a saved configuration JSON `file`, e.g.: 165 | ``./my_model_directory/configuration.json``. 166 | cache_dir (:obj:`string`, `optional`): 167 | Path to a directory in which a downloaded pre-trained model 168 | configuration should be cached if the standard cache should not be used. 169 | kwargs (:obj:`Dict[str, any]`, `optional`): 170 | The values in kwargs of any keys which are configuration attributes will be used to override the loaded 171 | values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is 172 | controlled by the `return_unused_kwargs` keyword parameter. 173 | force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 174 | Force to (re-)download the model weights and configuration files and override the cached versions if they exist. 175 | resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 176 | Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. 177 | proxies (:obj:`Dict`, `optional`): 178 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: 179 | :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` 180 | The proxies are used on each request. 181 | return_unused_kwargs: (`optional`) bool: 182 | If False, then this function returns just the final configuration object. 183 | If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a 184 | dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part 185 | of kwargs which has not been used to update `config` and is otherwise ignored. 186 | 187 | Returns: 188 | :class:`PretrainedConfig`: An instance of a configuration object 189 | 190 | Examples:: 191 | 192 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a 193 | # derived class: BertConfig 194 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 195 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 196 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 197 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 198 | assert config.output_attention == True 199 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 200 | foo=False, return_unused_kwargs=True) 201 | assert config.output_attention == True 202 | assert unused_kwargs == {'foo': False} 203 | 204 | """ 205 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 206 | return cls.from_dict(config_dict, **kwargs) 207 | 208 | @classmethod 209 | def get_config_dict( 210 | cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs 211 | ) -> Tuple[Dict, Dict]: 212 | """ 213 | From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used 214 | for instantiating a Config using `from_dict`. 215 | 216 | Parameters: 217 | pretrained_model_name_or_path (:obj:`string`): 218 | The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. 219 | pretrained_config_archive_map: (:obj:`Dict[str, str]`, `optional`) Dict: 220 | A map of `shortcut names` to `url`. By default, will use the current class attribute. 221 | 222 | Returns: 223 | :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. 224 | 225 | """ 226 | cache_dir = kwargs.pop("cache_dir", None) 227 | force_download = kwargs.pop("force_download", False) 228 | resume_download = kwargs.pop("resume_download", False) 229 | proxies = kwargs.pop("proxies", None) 230 | local_files_only = kwargs.pop("local_files_only", False) 231 | 232 | if pretrained_config_archive_map is None: 233 | pretrained_config_archive_map = cls.pretrained_config_archive_map 234 | 235 | if pretrained_model_name_or_path in pretrained_config_archive_map: 236 | config_file = pretrained_config_archive_map[pretrained_model_name_or_path] 237 | elif os.path.isdir(pretrained_model_name_or_path): 238 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 239 | elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): 240 | config_file = pretrained_model_name_or_path 241 | else: 242 | config_file = hf_bucket_url(pretrained_model_name_or_path, postfix=CONFIG_NAME) 243 | 244 | try: 245 | # Load from URL or cache if already cached 246 | resolved_config_file = cached_path( 247 | config_file, 248 | cache_dir=cache_dir, 249 | force_download=force_download, 250 | proxies=proxies, 251 | resume_download=resume_download, 252 | local_files_only=local_files_only, 253 | ) 254 | # Load config dict 255 | if resolved_config_file is None: 256 | raise EnvironmentError 257 | config_dict = cls._dict_from_json_file(resolved_config_file) 258 | 259 | except EnvironmentError: 260 | if pretrained_model_name_or_path in pretrained_config_archive_map: 261 | msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( 262 | config_file 263 | ) 264 | else: 265 | msg = ( 266 | "Can't load '{}'. Make sure that:\n\n" 267 | "- '{}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" 268 | "- or '{}' is the correct path to a directory containing a '{}' file\n\n".format( 269 | pretrained_model_name_or_path, 270 | pretrained_model_name_or_path, 271 | pretrained_model_name_or_path, 272 | CONFIG_NAME, 273 | ) 274 | ) 275 | raise EnvironmentError(msg) 276 | 277 | except json.JSONDecodeError: 278 | msg = ( 279 | "Couldn't reach server at '{}' to download configuration file or " 280 | "configuration file is not a valid JSON file. " 281 | "Please check network or file content here: {}.".format(config_file, resolved_config_file) 282 | ) 283 | raise EnvironmentError(msg) 284 | 285 | if resolved_config_file == config_file: 286 | logger.info("loading configuration file {}".format(config_file)) 287 | else: 288 | logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file)) 289 | 290 | return config_dict, kwargs 291 | 292 | @classmethod 293 | def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig": 294 | """ 295 | Constructs a `Config` from a Python dictionary of parameters. 296 | 297 | Args: 298 | config_dict (:obj:`Dict[str, any]`): 299 | Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved 300 | from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict` 301 | method. 302 | kwargs (:obj:`Dict[str, any]`): 303 | Additional parameters from which to initialize the configuration object. 304 | 305 | Returns: 306 | :class:`PretrainedConfig`: An instance of a configuration object 307 | """ 308 | return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) 309 | 310 | config = cls(**config_dict) 311 | 312 | if hasattr(config, "pruned_heads"): 313 | config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) 314 | 315 | # Update config with kwargs if needed 316 | to_remove = [] 317 | for key, value in kwargs.items(): 318 | if hasattr(config, key): 319 | setattr(config, key, value) 320 | to_remove.append(key) 321 | for key in to_remove: 322 | kwargs.pop(key, None) 323 | 324 | logger.info("Model config %s", str(config)) 325 | if return_unused_kwargs: 326 | return config, kwargs 327 | else: 328 | return config 329 | 330 | @classmethod 331 | def from_json_file(cls, json_file: str) -> "PretrainedConfig": 332 | """ 333 | Constructs a `Config` from the path to a json file of parameters. 334 | 335 | Args: 336 | json_file (:obj:`string`): 337 | Path to the JSON file containing the parameters. 338 | 339 | Returns: 340 | :class:`PretrainedConfig`: An instance of a configuration object 341 | 342 | """ 343 | config_dict = cls._dict_from_json_file(json_file) 344 | return cls(**config_dict) 345 | 346 | @classmethod 347 | def _dict_from_json_file(cls, json_file: str): 348 | with open(json_file, "r", encoding="utf-8") as reader: 349 | text = reader.read() 350 | return json.loads(text) 351 | 352 | def __eq__(self, other): 353 | return self.__dict__ == other.__dict__ 354 | 355 | def __repr__(self): 356 | return "{} {}".format(self.__class__.__name__, self.to_json_string()) 357 | 358 | def to_diff_dict(self): 359 | """ 360 | Removes all attributes from config which correspond to the default 361 | config attributes for better readability and serializes to a Python 362 | dictionary. 363 | 364 | Returns: 365 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 366 | """ 367 | config_dict = self.to_dict() 368 | 369 | # get the default config dict 370 | default_config_dict = PretrainedConfig().to_dict() 371 | 372 | serializable_config_dict = {} 373 | 374 | # only serialize values that differ from the default config 375 | for key, value in config_dict.items(): 376 | if key not in default_config_dict or value != default_config_dict[key]: 377 | serializable_config_dict[key] = value 378 | 379 | return serializable_config_dict 380 | 381 | def to_dict(self): 382 | """ 383 | Serializes this instance to a Python dictionary. 384 | 385 | Returns: 386 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 387 | """ 388 | output = copy.deepcopy(self.__dict__) 389 | if hasattr(self.__class__, "model_type"): 390 | output["model_type"] = self.__class__.model_type 391 | return output 392 | 393 | def to_json_string(self, use_diff=True): 394 | """ 395 | Serializes this instance to a JSON string. 396 | 397 | Args: 398 | use_diff (:obj:`bool`): 399 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string. 400 | 401 | Returns: 402 | :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format. 403 | """ 404 | if use_diff is True: 405 | config_dict = self.to_diff_dict() 406 | else: 407 | config_dict = self.to_dict() 408 | return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" 409 | 410 | def to_json_file(self, json_file_path, use_diff=True): 411 | """ 412 | Save this instance to a json file. 413 | 414 | Args: 415 | json_file_path (:obj:`string`): 416 | Path to the JSON file in which this configuration instance's parameters will be saved. 417 | use_diff (:obj:`bool`): 418 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file. 419 | """ 420 | with open(json_file_path, "w", encoding="utf-8") as writer: 421 | writer.write(self.to_json_string(use_diff=use_diff)) 422 | 423 | def update(self, config_dict: Dict): 424 | """ 425 | Updates attributes of this class 426 | with attributes from `config_dict`. 427 | 428 | Args: 429 | :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class. 430 | """ 431 | for key, value in config_dict.items(): 432 | setattr(self, key, value) --------------------------------------------------------------------------------