├── src
├── lsp_model_rl
│ ├── util
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── configuration_roberta.py
│ │ ├── configuration_bert.py
│ │ ├── file_utils.py
│ │ └── configuration_utils.py
│ ├── __init__.py
│ ├── automated_metrics.py
│ ├── empathy_classifier_bi_encoder_attention.py
│ ├── rewards.py
│ ├── coherence_classifier2.py
│ ├── optim.py
│ └── modeling_gpt2.py
├── env.py
├── variables_ext.py
├── gpt2_training
│ ├── eval_utils.py
│ ├── distributed.py
│ └── train_utils.py
├── process_data.py
├── data_loader.py
└── train_model.py
├── asset
└── social_media.jpg
├── dataset
├── README.md
├── csv2tsv.ipynb
├── sample_data.tsv
└── sample_data.csv
├── .gitignore
├── README.md
├── environment.yml
└── requirements.txt
/src/lsp_model_rl/util/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "2.8.0"
--------------------------------------------------------------------------------
/asset/social_media.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/claws-lab/MisinfoCorrect/HEAD/asset/social_media.jpg
--------------------------------------------------------------------------------
/src/env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 | import os
4 |
5 |
6 | END_OF_TURN_TOKEN = '<|endofturn|>'
7 | END_OF_TEXT_TOKEN = '<|endoftext|>'
8 | PROJECT_FOLDER = os.path.dirname(__file__)
--------------------------------------------------------------------------------
/src/lsp_model_rl/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.1"
2 | from transformers import GPT2Tokenizer
3 | from transformers import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
4 | from transformers import GPT2Config, GPT2Model, GPT2Config
5 | from transformers import GPT2Tokenizer
6 |
7 | from .modeling_gpt2 import GPT2LMHeadModel
8 | from .modeling_gpt2 import GPT2LMHeadModel_v2
9 | from .optim import Adam
10 |
11 |
--------------------------------------------------------------------------------
/src/variables_ext.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # ==== gpu server ====
4 | if torch.cuda.is_available():
5 | device = "cuda:0"
6 | n_gpu = 1
7 |
8 | # the paths for three reward functions
9 | clf_main_fp = r'./../../MisinfoCorrect_support_models'
10 |
11 | politeness_clf_fp = f"{clf_main_fp}/politeness_clf.pt"
12 |
13 | refutation_clf_fp = f"{clf_main_fp}/refutation_clf.pt"
14 |
15 | evidence_clf_fp = f"{clf_main_fp}/evidence_clf.pt"
16 |
17 |
18 | # general text reward
19 | if_perplexity = True
20 | if_relevance = True # i.e., the mentioned coherence reward in the paper
21 | # counter response reward
22 | # False True
23 | if_politeness = True
24 | if_refutation = True
25 | if_evidence = True
26 |
27 | # initially 5, 3
28 | every_k_epoch_save_model = 3
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/dataset/README.md:
--------------------------------------------------------------------------------
1 | ## Dataset
2 |
3 | 1. In-the-wild social media data containing 754 annotated (misinformation tweet, counter-misinformation reply) pairs. Below is the data statistics:
4 |
5 |
6 |
7 |
8 |
9 | 2. Crowdsourced data containing 591 (misinformation tweet, human-written counter-misinformation reply) pairs. Note that for these 591 human-written replies, compared to social media data, they are refuting misinformation, polite, providing evidence per the requirement in the paper.
10 | 3. Our dataset can be found [here](https://www.dropbox.com/sh/5u2mdo53tgh3vrh/AADfYHqhQbt0A2gUciT583E0a?dl=0).
11 | 4. We notice the change of Twitter API. If you have problems regarding the access to the whole dataset or the code, please contact Bing He (bhe46@gatech.edu).
--------------------------------------------------------------------------------
/dataset/csv2tsv.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "\n",
11 | "df = pd.read_csv('./sample_data.csv')"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "df.to_csv('./sample_data.tsv', index=False, header=False, sep='\\t')"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": []
29 | }
30 | ],
31 | "metadata": {
32 | "kernelspec": {
33 | "display_name": "base",
34 | "language": "python",
35 | "name": "python3"
36 | },
37 | "language_info": {
38 | "codemirror_mode": {
39 | "name": "ipython",
40 | "version": 3
41 | },
42 | "file_extension": ".py",
43 | "mimetype": "text/x-python",
44 | "name": "python",
45 | "nbconvert_exporter": "python",
46 | "pygments_lexer": "ipython3",
47 | "version": "3.7.6"
48 | },
49 | "orig_nbformat": 4,
50 | "vscode": {
51 | "interpreter": {
52 | "hash": "442c277bbd2a2bd37b27aa6e5dadfa8a95da05ea7996590fcf9f52713a791d05"
53 | }
54 | }
55 | },
56 | "nbformat": 4,
57 | "nbformat_minor": 2
58 | }
59 |
--------------------------------------------------------------------------------
/dataset/sample_data.tsv:
--------------------------------------------------------------------------------
1 | It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research. Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again.
2 | It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research. Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again.
3 | It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research. Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again.
4 | It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research. Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again.
5 |
--------------------------------------------------------------------------------
/dataset/sample_data.csv:
--------------------------------------------------------------------------------
1 | "It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research.","Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again."
2 | "It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research.","Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again."
3 | "It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research.","Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again."
4 | "It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research.","Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again."
5 | "It’s not a vaccine, it’s gene therapy. Gene therapy is an experimental technique. It’s the same tech- nology used in cloning, DNA editing, and stem cell research.","Sorry to see you think in this way. It is not correct. The vaccine is not gene therapy. Instead, it uses mRNA to generate spike protein to protect people. Please do not say the misinformation again."
6 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/util/activations.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def swish(x):
12 | return x * torch.sigmoid(x)
13 |
14 |
15 | def _gelu_python(x):
16 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
19 | This is now written in C in torch.nn.functional
20 | Also see https://arxiv.org/abs/1606.08415
21 | """
22 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
23 |
24 |
25 | def gelu_new(x):
26 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
27 | Also see https://arxiv.org/abs/1606.08415
28 | """
29 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
30 |
31 |
32 | if torch.__version__ < "1.4.0":
33 | gelu = _gelu_python
34 | else:
35 | gelu = F.gelu
36 | try:
37 | import torch_xla # noqa F401
38 |
39 | logger.warning(
40 | "The torch_xla package was detected in the python environment. PyTorch/XLA and JIT is untested,"
41 | " no activation function will be traced with JIT."
42 | )
43 | except ImportError:
44 | gelu_new = torch.jit.script(gelu_new)
45 |
46 | ACT2FN = {
47 | "relu": F.relu,
48 | "swish": swish,
49 | "gelu": gelu,
50 | "tanh": torch.tanh,
51 | "gelu_new": gelu_new,
52 | }
53 |
54 |
55 | def get_activation(activation_string):
56 | if activation_string in ACT2FN:
57 | return ACT2FN[activation_string]
58 | else:
59 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # MisinfoCorrect details
2 |
3 | .DS_Store
4 |
5 | command.sh
6 |
7 | *.pt
8 | *.pth
9 | *.db
10 | models/medium/*
11 | output/*
12 | logs/*
13 |
14 |
15 | # Byte-compiled / optimized / DLL files
16 | __pycache__/
17 | *.py[cod]
18 | *$py.class
19 |
20 | # C extensions
21 | *.so
22 |
23 | # Distribution / packaging
24 | .Python
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | wheels/
37 | pip-wheel-metadata/
38 | share/python-wheels/
39 | *.egg-info/
40 | .installed.cfg
41 | *.egg
42 | MANIFEST
43 |
44 | # PyInstaller
45 | # Usually these files are written by a python script from a template
46 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
47 | *.manifest
48 | *.spec
49 |
50 | # Installer logs
51 | pip-log.txt
52 | pip-delete-this-directory.txt
53 |
54 | # Unit test / coverage reports
55 | htmlcov/
56 | .tox/
57 | .nox/
58 | .coverage
59 | .coverage.*
60 | .cache
61 | nosetests.xml
62 | coverage.xml
63 | *.cover
64 | *.py,cover
65 | .hypothesis/
66 | .pytest_cache/
67 |
68 | # Translations
69 | *.mo
70 | *.pot
71 |
72 | # Django stuff:
73 | *.log
74 | local_settings.py
75 | db.sqlite3
76 | db.sqlite3-journal
77 |
78 | # Flask stuff:
79 | instance/
80 | .webassets-cache
81 |
82 | # Scrapy stuff:
83 | .scrapy
84 |
85 | # Sphinx documentation
86 | docs/_build/
87 |
88 | # PyBuilder
89 | target/
90 |
91 | # Jupyter Notebook
92 | .ipynb_checkpoints
93 |
94 | # IPython
95 | profile_default/
96 | ipython_config.py
97 |
98 | # pyenv
99 | .python-version
100 |
101 | # pipenv
102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
105 | # install all needed dependencies.
106 | #Pipfile.lock
107 |
108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
109 | __pypackages__/
110 |
111 | # Celery stuff
112 | celerybeat-schedule
113 | celerybeat.pid
114 |
115 | # SageMath parsed files
116 | *.sage.py
117 |
118 | # Environments
119 | .env
120 | .venv
121 | env/
122 | venv/
123 | ENV/
124 | env.bak/
125 | venv.bak/
126 |
127 | # Spyder project settings
128 | .spyderproject
129 | .spyproject
130 |
131 | # Rope project settings
132 | .ropeproject
133 |
134 | # mkdocs documentation
135 | /site
136 |
137 | # mypy
138 | .mypy_cache/
139 | .dmypy.json
140 | dmypy.json
141 |
142 | # Pyre type checker
143 | .pyre/
144 |
145 |
--------------------------------------------------------------------------------
/src/gpt2_training/eval_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 | import torch
4 | import logging
5 |
6 | import numpy as np
7 |
8 | from pycocoevalcap.bleu.bleu import Bleu
9 | from collections import defaultdict
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | EOS_ID = 50256
14 |
15 |
16 | def cal_BLEU_4(generated, reference, is_corpus=False):
17 | BLEUscore = [0.0, 0.0, 0.0, 0.0]
18 | for idx, g in enumerate(generated):
19 | if is_corpus:
20 | score, scores = Bleu(4).compute_score(reference, {0: [g]})
21 | else:
22 | score, scores = Bleu(4).compute_score({0: [reference[0][idx]]},
23 | {0: [g]})
24 | for i, s in zip([0, 1, 2, 3], score):
25 | BLEUscore[i] += s
26 | BLEUscore[0] = BLEUscore[0]/len(generated)
27 | BLEUscore[1] = BLEUscore[1]/len(generated)
28 | BLEUscore[2] = BLEUscore[2]/len(generated)
29 | BLEUscore[3] = BLEUscore[3]/len(generated)
30 | return BLEUscore
31 |
32 |
33 | def cal_entropy(generated):
34 | etp_score = [0.0, 0.0, 0.0, 0.0]
35 | div_score = [0.0, 0.0, 0.0, 0.0]
36 | counter = [defaultdict(int), defaultdict(int),
37 | defaultdict(int), defaultdict(int)]
38 | for gg in generated:
39 | g = gg.rstrip().split()
40 | for n in range(4):
41 | for idx in range(len(g)-n):
42 | ngram = ' '.join(g[idx:idx+n+1])
43 | counter[n][ngram] += 1
44 | for n in range(4):
45 | total = sum(counter[n].values()) + 1e-10
46 | for v in counter[n].values():
47 | etp_score[n] += - (v+0.0) / total * (np.log(v+0.0) - np.log(total))
48 | div_score[n] = (len(counter[n].values())+0.0) / total
49 | return etp_score, div_score
50 |
51 |
52 | def eval_model_loss(model, tokenizer, eval_dataloader, epoch_id, args):
53 | # use the same signature with eval_model_generation
54 | logger.info('compute eval model loss, using eval mode, '
55 | 'please change it back to train after calling this function')
56 | model.eval()
57 | tot_loss = []
58 | tot_ppl = []
59 | tot_sample = []
60 | with torch.no_grad():
61 | for step, batch in enumerate(eval_dataloader):
62 | batch = tuple(t.to(args.device) for t in batch)
63 | input_ids, position_ids, token_ids, label_ids, src_len, _ = batch
64 | if args.no_token_id:
65 | token_ids = None
66 | n_sample = input_ids.shape[0]
67 | loss, ppl = model(input_ids, position_ids, token_ids, label_ids)
68 | tot_loss.append(loss.mean().item() * n_sample)
69 | tot_ppl.append(ppl.mean().item() * n_sample)
70 | tot_sample.append(n_sample)
71 | print(f"\n Epoch {epoch_id}: Val loss {np.sum(tot_loss) / np.sum(tot_sample)} Val ppl {np.sum(tot_ppl) / np.sum(tot_sample)} ")
72 | return np.sum(tot_loss) / np.sum(tot_sample), np.sum(tot_ppl) / np.sum(tot_sample)
73 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/util/configuration_roberta.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ RoBERTa configuration """
17 |
18 |
19 | import logging
20 |
21 | from .configuration_bert import BertConfig
22 |
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27 | "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
28 | "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
29 | "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
30 | "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json",
31 | "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json",
32 | "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json",
33 | }
34 |
35 |
36 | class RobertaConfig(BertConfig):
37 | r"""
38 | This is the configuration class to store the configuration of an :class:`~transformers.RobertaModel`.
39 | It is used to instantiate an RoBERTa model according to the specified arguments, defining the model
40 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
41 | the BERT `bert-base-uncased `__ architecture.
42 |
43 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
44 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
45 | for more information.
46 |
47 | The :class:`~transformers.RobertaConfig` class directly inherits :class:`~transformers.BertConfig`.
48 | It reuses the same defaults. Please check the parent class for more information.
49 |
50 | Example::
51 |
52 | from transformers import RobertaConfig, RobertaModel
53 |
54 | # Initializing a RoBERTa configuration
55 | configuration = RobertaConfig()
56 |
57 | # Initializing a model from the configuration
58 | model = RobertaModel(configuration)
59 |
60 | # Accessing the model configuration
61 | configuration = model.config
62 |
63 | Attributes:
64 | pretrained_config_archive_map (Dict[str, str]):
65 | A dictionary containing all the available pre-trained checkpoints.
66 | """
67 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
68 | model_type = "roberta"
69 |
70 | def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
71 | """Constructs FlaubertConfig.
72 | """
73 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
--------------------------------------------------------------------------------
/src/gpt2_training/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 | """ Pytorch Distributed utils
4 | # NOTE: copied from OpenNMT-py
5 | This piece of code was heavily inspired by the equivalent of Fairseq-py
6 | https://github.com/pytorch/fairseq
7 | """
8 |
9 |
10 | from __future__ import print_function
11 |
12 | import math
13 | import pickle
14 | import torch.distributed
15 |
16 |
17 | def is_master(opt, device_id):
18 | return opt.gpu_ranks[device_id] == 0
19 |
20 |
21 | def multi_init(opt, device_id, logger=None):
22 | dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
23 | master_ip=opt.master_ip,
24 | master_port=opt.master_port)
25 | dist_world_size = opt.world_size
26 | torch.distributed.init_process_group(
27 | backend=opt.gpu_backend, init_method=dist_init_method,
28 | world_size=dist_world_size, rank=opt.gpu_ranks[device_id])
29 | gpu_rank = torch.distributed.get_rank()
30 | if not is_master(opt, device_id) and logger is not None:
31 | logger.disabled = True
32 |
33 | return gpu_rank
34 |
35 |
36 | def all_reduce_and_rescale_tensors(tensors, rescale_denom,
37 | buffer_size=10485760):
38 | """All-reduce and rescale tensors in chunks of the specified size.
39 |
40 | Args:
41 | tensors: list of Tensors to all-reduce
42 | rescale_denom: denominator for rescaling summed Tensors
43 | buffer_size: all-reduce chunk size in bytes
44 | """
45 | # buffer size in bytes, determine equiv. # of elements based on data type
46 | buffer_t = tensors[0].new(
47 | math.ceil(buffer_size / tensors[0].element_size())).zero_()
48 | buffer = []
49 |
50 | def all_reduce_buffer():
51 | # copy tensors into buffer_t
52 | offset = 0
53 | for t in buffer:
54 | numel = t.numel()
55 | buffer_t[offset:offset+numel].copy_(t.view(-1))
56 | offset += numel
57 |
58 | # all-reduce and rescale
59 | torch.distributed.all_reduce(buffer_t[:offset])
60 | buffer_t.div_(rescale_denom)
61 |
62 | # copy all-reduced buffer back into tensors
63 | offset = 0
64 | for t in buffer:
65 | numel = t.numel()
66 | t.view(-1).copy_(buffer_t[offset:offset+numel])
67 | offset += numel
68 |
69 | filled = 0
70 | for t in tensors:
71 | sz = t.numel() * t.element_size()
72 | if sz > buffer_size:
73 | # tensor is bigger than buffer, all-reduce and rescale directly
74 | torch.distributed.all_reduce(t)
75 | t.div_(rescale_denom)
76 | elif filled + sz > buffer_size:
77 | # buffer is full, all-reduce and replace buffer with grad
78 | all_reduce_buffer()
79 | buffer = [t]
80 | filled = sz
81 | else:
82 | # add tensor to buffer
83 | buffer.append(t)
84 | filled += sz
85 |
86 | if len(buffer) > 0:
87 | all_reduce_buffer()
88 |
89 |
90 | def all_gather_list(data, max_size=4096):
91 | """Gathers arbitrary data from all nodes into a list."""
92 | world_size = torch.distributed.get_world_size()
93 | if not hasattr(all_gather_list, '_in_buffer') or \
94 | max_size != all_gather_list._in_buffer.size():
95 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
96 | all_gather_list._out_buffers = [
97 | torch.cuda.ByteTensor(max_size)
98 | for i in range(world_size)
99 | ]
100 | in_buffer = all_gather_list._in_buffer
101 | out_buffers = all_gather_list._out_buffers
102 |
103 | enc = pickle.dumps(data)
104 | enc_size = len(enc)
105 | if enc_size + 2 > max_size:
106 | raise ValueError(
107 | 'encoded data exceeds max_size: {}'.format(enc_size + 2))
108 | assert max_size < 255*256
109 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
110 | in_buffer[1] = enc_size % 255
111 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
112 |
113 | torch.distributed.all_gather(out_buffers, in_buffer.cuda())
114 |
115 | results = []
116 | for i in range(world_size):
117 | out_buffer = out_buffers[i]
118 | size = (255 * out_buffer[0].item()) + out_buffer[1].item()
119 |
120 | bytes_list = bytes(out_buffer[2:size+2].tolist())
121 | result = pickle.loads(bytes_list)
122 | results.append(result)
123 | return results
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Reinforcement Learning-based Counter-Misinformation Response Generation: A Case Study of COVID-19 Vaccine Misinformation
2 | This repository contains code and data for our ACM WWW 2023 publication on counter-misinformation response generation.
3 |
4 | You may access the PDF here: [PDF](https://faculty.cc.gatech.edu/~srijan/pubs/he-www23-misinfocorrect.pdf)
5 |
6 | If our code or data helps you in your research, please cite:
7 |
8 | ```
9 | @inproceedings{he2023reinforcement,
10 | title={Reinforcement Learning-based Counter-Misinformation Response Generation: A Case Study of COVID-19 Vaccine Misinformation},
11 | author={He, Bing and Ahamad, Mustaque and Kumar, Srijan},
12 | booktitle={Proceedings of the ACM Web Conference 2023},
13 | year={2023}
14 | }
15 | ```
16 |
17 | ## Introduction
18 |
19 | The COVID-19 vaccine misinformation on social media reduces vaccine uptake and threatens public health. While fact-checkers debunk these false claims, they do not interact with misinformation spreaders. Ordinary users, who make up 96% of counter-misinformation responses, often lack evidence and are rude in their responses. This study aims to create a counter-misinformation response generation model to empower users to effectively correct misinformation. We first create two novel datasets of misinformation and counter-misinformation responses from social media and crowdsourcing. A reinforcement-learning-based text generation model is then proposed to reward the generator to increase the politeness, refutation attitude, and factuality while retaining text fluency and relevancy. Through extensive experiments, the model outperforms baselines in generating high-quality responses, demonstrating the potential for generative text models for social good.
20 |
21 |
22 | ## Quickstart
23 |
24 | ### 1. Setup and Installation
25 |
26 | Our framework can be compiled on Python 3 environments. The modules used in our code can be installed using:
27 | ```
28 | $ pip install -r requirements.txt
29 | ```
30 |
31 | or
32 |
33 | ```
34 | $ conda env create -f environment.yml
35 | ```
36 |
37 |
38 | ### 2. Prepare dataset
39 |
40 | A sample raw input data file is available in [dataset/sample_data.tsv](dataset/sample_data.tsv). Each line in the file has a tweet and a corresponding counter-response (tab-separated). This input file can be converted into a format that is recognized by the model using with following command:
41 | ```
42 | python src/process_data.py --corpus dataset/sample_data.tsv --if_input_src_only
43 | ```
44 |
45 | Running this command will generate a folder named `sample_data.128len.db`.
46 |
47 | ### 3. Training the model
48 | For training our model on the sample input data, run the following command:
49 |
50 | ```
51 | python src/train_model.py \
52 | --model_name_or_path models/medium/ \
53 | --train_input_file dataset/sample_data.128len.if_input_src_only.db \
54 | --output_dir output/ \
55 | --log_dir output/ \
56 | --train_batch_size 4 \
57 | --num_optim_steps 100
58 | ```
59 | Before running this code,
60 | 1. You will need a DialoGPT-like transformer model for initialization (`model_name_or_path`, ideally finetuned on your dataset, check the warm-up start in the paper);
61 | 2. You will need to separately train multiple reward functions for the reinforcement learning framework. Here, we have three reward functions: politeness, refutation, and evidence reward classifiers. The locations of classifiers are specified in politeness_clf_fp, refutation_clf_fp, evidence_clf_fp in variables_ext.py.
62 | 3. To configure the reward parameters, please refer to variables_ext.py
63 | 4. For sanity check when releasing the codebase, we only use one GPU in the current version. Please revise n_gpu in variables_ext.py when there are multiple GPUs. For the paper, we run our experiments on NVIDIA DGX-1 consisting of 8 V100 GPUs.
64 | 5. The repository is based on [DialoGPT](https://github.com/microsoft/DialoGPT) and [Partner](https://github.com/behavioral-data/PARTNER) and uses a similar code structure and environment.
65 |
66 | ## Dataset
67 |
68 | 1. In-the-wild social media data containing 754 annotated (misinformation tweet, counter-misinformation reply) pairs. Below is the data statistics:
69 |
70 |
71 |
72 |
73 |
74 | 2. Crowdsourced data containing 591 (misinformation tweet, human-written counter-misinformation reply) pairs. Note that for these 591 human-written replies, compared to social media data, they are refuting misinformation, polite, providing evidence per the requirement in the paper.
75 | 3. Our dataset can be found [here](https://www.dropbox.com/sh/5u2mdo53tgh3vrh/AADfYHqhQbt0A2gUciT583E0a?dl=0).
76 | 4. We notice the change of Twitter API. If you have problems regarding the access to the whole dataset or the code, please contact Bing He (bhe46@gatech.edu).
77 |
78 | If you have any questions, please contact the author Bing He (bhe46@gatech.edu)
79 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/automated_metrics.py:
--------------------------------------------------------------------------------
1 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
2 | from nltk import ngrams
3 | import numpy as np
4 | from numpy import dot
5 | from numpy.linalg import norm
6 |
7 | # before Oct 2022
8 | #from bert_serving.client import BertClient
9 | #bc = BertClient()
10 | # after Oct 2022
11 | from transformers import BertTokenizer, BertModel
12 | import torch
13 |
14 | bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
15 | bert_model = BertModel.from_pretrained("bert-base-uncased")
16 |
17 |
18 |
19 | import torch
20 | import math
21 | from transformers import GPT2LMHeadModel, GPT2Tokenizer
22 |
23 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
24 |
25 |
26 | # ==== : device control ====
27 | import sys, os
28 | sys.path.append("./")
29 | sys.path.append("../")
30 | sys.path.append(".../")
31 | sys.path.append("..../")
32 | from MisinfoCorrect.src.variables_ext import device
33 | # ==== : device control ====
34 |
35 | GPT_model = GPT2LMHeadModel.from_pretrained('microsoft/DialoGPT-medium').to(device)
36 | GPT_tokenizer = GPT2Tokenizer.from_pretrained('microsoft/DialoGPT-medium')
37 |
38 | GPT_tokenizer.pad_token = GPT_tokenizer.eos_token
39 |
40 | GPT_model.eval()
41 |
42 | MAX_LEN = 64
43 |
44 | # BLEU
45 | def bleu(predicted, target):
46 |
47 | all_bleus = []
48 |
49 | smoothie = SmoothingFunction().method4
50 |
51 | for idx, elem in enumerate(predicted):
52 | try:
53 | curr_bleu = sentence_bleu([str(target[idx])], str(predicted[idx]), smoothing_function=smoothie)
54 | all_bleus.append(curr_bleu)
55 | except ZeroDivisionError:
56 | continue
57 |
58 | return np.mean(all_bleus)
59 |
60 |
61 | # Perplexity
62 | def perplexity(predicted):
63 |
64 | BATCH_SIZE = 1
65 |
66 | tokenized_input = GPT_tokenizer.batch_encode_plus(predicted, max_length=MAX_LEN, pad_to_max_length=True, truncation=True)
67 |
68 | input_ids = tokenized_input['input_ids']
69 | attention_masks = tokenized_input['attention_mask']
70 |
71 | input_ids = torch.tensor(input_ids)
72 | attention_masks = torch.tensor(attention_masks)
73 |
74 | data = TensorDataset(input_ids, attention_masks)
75 |
76 | sampler = SequentialSampler(data)
77 | dataloader = DataLoader(data, sampler=sampler, batch_size = BATCH_SIZE)
78 |
79 | all_loss = []
80 |
81 | with torch.no_grad():
82 |
83 | for batch in dataloader:
84 | b_input = batch[0].to(device)
85 | b_attn = batch[1].to(device)
86 |
87 | outputs = GPT_model(b_input, attention_mask=b_attn, labels=b_input)
88 |
89 | loss, logits = outputs[:2]
90 | all_loss.append(loss.item())
91 |
92 | return math.exp(np.mean(all_loss))
93 |
94 |
95 | # Diversity
96 |
97 | # Distinct-1, Distinct-2 (# of unigrams and bigrams divided by the total number of words)
98 |
99 | def distinct(predicted):
100 |
101 | UNIGRAMS = set()
102 | BIGRAMS = set()
103 |
104 | NUM_WORDS = 0
105 |
106 | for idx, elem in enumerate(predicted):
107 | curr_unigrams = ngrams(str(elem).split(), 1)
108 | curr_bigrams = ngrams(str(elem).split(), 2)
109 |
110 | NUM_WORDS += len(str(elem).split())
111 |
112 | for unigram in curr_unigrams:
113 | UNIGRAMS.add(' '.join(unigram).strip())
114 |
115 | for bigram in curr_bigrams:
116 | BIGRAMS.add(' '.join(bigram).strip())
117 |
118 |
119 |
120 | DISTINCT_1 = len(UNIGRAMS) / NUM_WORDS
121 | DISTINCT_2 = len(BIGRAMS) / NUM_WORDS
122 |
123 | return DISTINCT_1, DISTINCT_2
124 |
125 |
126 | # Entropy
127 | # def entropy():
128 |
129 |
130 | # Specificity
131 | # similar to relevance
132 | def specificity(seeker_post, predicted):
133 |
134 | # Get embeddings
135 | seeker_post_embeddings = bc.encode(seeker_post)
136 | predicted_response_embeddings = bc.encode(predicted)
137 |
138 | # Compute cosine similarity
139 |
140 | all_cos_sim = []
141 |
142 | for idx, elem in enumerate(seeker_post_embeddings):
143 | a = seeker_post_embeddings[idx]
144 | b = predicted_response_embeddings[idx]
145 |
146 | cos_sim = dot(a, b)/(norm(a)*norm(b))
147 | all_cos_sim.append(cos_sim)
148 |
149 | return np.mean(all_cos_sim)
150 |
151 | # ext: relevance reward - coherence reward
152 | def relevance(seeker_post, predicted):
153 | from numpy import dot
154 | from numpy.linalg import norm
155 | import numpy as np
156 | def encode(post):
157 |
158 | inputs = bert_tokenizer.encode_plus(post, return_tensors="pt", add_special_tokens=True)
159 | outputs = bert_model(**inputs)
160 |
161 | # [0] hidden state of all tokens in the sentence, [1] the final last only 1 hidden state output
162 | last_hidden_states = outputs[1]
163 | return last_hidden_states
164 |
165 | # Get embeddings
166 | seeker_post_embeddings = encode(seeker_post)
167 | predicted_response_embeddings = encode(predicted)
168 |
169 | # Compute cosine similarity
170 |
171 | all_cos_sim = []
172 |
173 | for idx, elem in enumerate(seeker_post_embeddings):
174 | a = seeker_post_embeddings[idx]
175 | b = predicted_response_embeddings[idx]
176 |
177 | a = a.detach().numpy()
178 | b = b.detach().numpy()
179 |
180 | cos_sim = dot(a, b)/(norm(a)*norm(b))
181 | all_cos_sim.append(cos_sim)
182 |
183 | return np.mean(all_cos_sim)
184 |
185 |
186 |
187 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/empathy_classifier_bi_encoder_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import codecs
3 | import numpy as np
4 |
5 |
6 | import pandas as pd
7 | import re
8 | import csv
9 | import numpy as np
10 | import sys
11 |
12 | import time
13 |
14 | from sklearn.metrics import f1_score
15 |
16 | from transformers import RobertaTokenizer
17 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
18 | from torch.utils.data import TensorDataset, random_split
19 |
20 |
21 | from .util.models import BiEncoderAttentionWithRationaleClassification
22 | from transformers import AdamW, RobertaConfig
23 |
24 | import datetime
25 |
26 |
27 | class EmpathyClassifier():
28 |
29 | def __init__(self,
30 | device,
31 | ER_model_path = './',
32 | IP_model_path = './',
33 | EX_model_path = './',
34 | batch_size=2):
35 |
36 | self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
37 | self.batch_size = batch_size
38 | self.device = device
39 |
40 | self.model_ER = BiEncoderAttentionWithRationaleClassification()
41 | self.model_IP = BiEncoderAttentionWithRationaleClassification()
42 | self.model_EX = BiEncoderAttentionWithRationaleClassification()
43 |
44 | ER_weights = torch.load(ER_model_path, map_location=self.device)
45 | self.model_ER.load_state_dict(ER_weights)
46 |
47 | IP_weights = torch.load(IP_model_path, map_location=self.device)
48 | self.model_IP.load_state_dict(IP_weights)
49 |
50 | EX_weights = torch.load(EX_model_path, map_location=self.device)
51 | self.model_EX.load_state_dict(EX_weights)
52 |
53 | self.model_ER.to(self.device)
54 | self.model_IP.to(self.device)
55 | self.model_EX.to(self.device)
56 |
57 | self.model_ER.eval()
58 | self.model_IP.eval()
59 | self.model_EX.eval()
60 |
61 |
62 | def predict_empathy(self, seeker_posts, response_posts):
63 |
64 | input_ids_SP = []
65 | attention_masks_SP = []
66 |
67 | for sent in seeker_posts:
68 |
69 | encoded_dict = self.tokenizer.encode_plus(
70 | sent, # Sentence to encode.
71 | add_special_tokens = True, # Add '[CLS]' and '[SEP]'
72 | max_length = 64, # Pad & truncate all sentences.
73 | truncation=True,
74 | pad_to_max_length = True,
75 | return_attention_mask = True, # Construct attn. masks.
76 | return_tensors = 'pt', # Return pytorch tensors.
77 | )
78 |
79 | input_ids_SP.append(encoded_dict['input_ids'])
80 | attention_masks_SP.append(encoded_dict['attention_mask'])
81 |
82 |
83 | input_ids_RP = []
84 | attention_masks_RP = []
85 |
86 | for sent in response_posts:
87 | encoded_dict = self.tokenizer.encode_plus(
88 | sent, # Sentence to encode.
89 | add_special_tokens = True, # Add '[CLS]' and '[SEP]'
90 | max_length = 64, # Pad & truncate all sentences.
91 | truncation=True,
92 | pad_to_max_length = True,
93 | return_attention_mask = True, # Construct attn. masks.
94 | return_tensors = 'pt', # Return pytorch tensors.
95 | )
96 |
97 | input_ids_RP.append(encoded_dict['input_ids'])
98 | attention_masks_RP.append(encoded_dict['attention_mask'])
99 |
100 | input_ids_SP = torch.cat(input_ids_SP, dim=0)
101 | attention_masks_SP = torch.cat(attention_masks_SP, dim=0)
102 |
103 | input_ids_RP = torch.cat(input_ids_RP, dim=0)
104 | attention_masks_RP = torch.cat(attention_masks_RP, dim=0)
105 |
106 | dataset = TensorDataset(input_ids_SP, attention_masks_SP, input_ids_RP, attention_masks_RP)
107 |
108 | dataloader = DataLoader(
109 | dataset, # The test samples.
110 | sampler = SequentialSampler(dataset), # Pull out batches sequentially.
111 | batch_size = self.batch_size # Evaluate with this batch size.
112 | )
113 |
114 | self.model_ER.eval()
115 | self.model_IP.eval()
116 | self.model_EX.eval()
117 |
118 | for batch in dataloader:
119 | b_input_ids_SP = batch[0].to(self.device)
120 | b_input_mask_SP = batch[1].to(self.device)
121 | b_input_ids_RP = batch[2].to(self.device)
122 | b_input_mask_RP = batch[3].to(self.device)
123 |
124 | with torch.no_grad():
125 | (logits_empathy_ER, logits_rationale_ER,) = self.model_ER(input_ids_SP = b_input_ids_SP,
126 | input_ids_RP = b_input_ids_RP,
127 | token_type_ids_SP=None,
128 | token_type_ids_RP=None,
129 | attention_mask_SP=b_input_mask_SP,
130 | attention_mask_RP=b_input_mask_RP)
131 |
132 | (logits_empathy_IP, logits_rationale_IP,) = self.model_IP(input_ids_SP = b_input_ids_SP,
133 | input_ids_RP = b_input_ids_RP,
134 | token_type_ids_SP=None,
135 | token_type_ids_RP=None,
136 | attention_mask_SP=b_input_mask_SP,
137 | attention_mask_RP=b_input_mask_RP)
138 |
139 | (logits_empathy_EX, logits_rationale_EX,) = self.model_EX(input_ids_SP = b_input_ids_SP,
140 | input_ids_RP = b_input_ids_RP,
141 | token_type_ids_SP=None,
142 | token_type_ids_RP=None,
143 | attention_mask_SP=b_input_mask_SP,
144 | attention_mask_RP=b_input_mask_RP)
145 |
146 |
147 | logits_empathy_ER = logits_empathy_ER.detach().cpu().numpy().tolist()
148 | predictions_ER = np.argmax(logits_empathy_ER, axis=1).flatten()
149 |
150 | logits_empathy_IP = logits_empathy_IP.detach().cpu().numpy().tolist()
151 | predictions_IP = np.argmax(logits_empathy_IP, axis=1).flatten()
152 |
153 | logits_empathy_EX = logits_empathy_EX.detach().cpu().numpy().tolist()
154 | predictions_EX = np.argmax(logits_empathy_EX, axis=1).flatten()
155 |
156 | return (logits_empathy_ER, predictions_ER, logits_empathy_IP, predictions_IP, logits_empathy_EX, predictions_EX)
157 |
158 |
159 | '''
160 | Example:
161 | '''
162 |
163 | import sys, os
164 | sys.path.append("../")
165 | sys.path.append(".../")
166 | from MisinfoCorrect.src.variables_ext import device
167 |
168 |
169 | seeker_posts = ['I need help', 'I want to talk to someone','I do not have any friends.']
170 | response_posts = ['why do you feel this way?', 'I understand how you feel','do you want to talk about it?']
171 |
172 | empathy_classifier = EmpathyClassifier(device)
173 |
174 | (logits_empathy_ER, predictions_ER, logits_empathy_IP, predictions_IP, logits_empathy_EX, predictions_EX) = empathy_classifier.predict_empathy(seeker_posts, response_posts)
175 |
176 | print(predictions_ER, predictions_IP, predictions_EX)
177 |
--------------------------------------------------------------------------------
/src/gpt2_training/train_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import os
5 | import logging
6 | import torch
7 | from collections import defaultdict
8 |
9 | from env import END_OF_TEXT_TOKEN
10 | from lsp_model_rl.optim import warmup_linear, noam_decay, noamwd_decay
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | SEQ_LENGTH_SHRINK_PROP = 0.9
16 |
17 | # ==== understand by ext: it seems like device setup ====
18 | def load_model(model, checkpoint, args, verbose=False):
19 | n_gpu = args.n_gpu
20 | device = args.device
21 | if checkpoint is None or checkpoint == "None":
22 | if verbose:
23 | logger.info('no checkpoint provided for %s!' % model._get_name())
24 | else:
25 | if not os.path.exists(checkpoint):
26 | raise ValueError('checkpoint %s not exist' % checkpoint)
27 | if verbose:
28 | logger.info('loading finetuned model from %s' % checkpoint)
29 | model_state_dict = torch.load(checkpoint)
30 |
31 | model_state_dict = fix_state_dict_namespace(model_state_dict)
32 |
33 | start_model = model
34 | if (hasattr(model, "transformer")
35 | and all(not s.startswith('transformer.')
36 | for s in model_state_dict.keys())):
37 | logger.info('loading transfomer only')
38 | start_model = model.transformer
39 | start_model.load_state_dict(model_state_dict)
40 |
41 | if args.fp16:
42 | logger.info('in fp16, model.half() activated')
43 | model.half()
44 | model.to(device)
45 | if n_gpu > 1:
46 | logging.info('data parallel because more than one gpu')
47 | model = torch.nn.DataParallel(model)
48 | return model
49 |
50 |
51 | def fix_state_dict_namespace(model_state_dict):
52 | old_keys = []
53 | new_keys = []
54 | for t in model_state_dict:
55 | new_key = t
56 | if t.startswith('module.'):
57 | new_key = t.replace('module.', '')
58 | old_keys.append(t)
59 | new_keys.append(new_key)
60 |
61 | for old_key, new_key in zip(old_keys, new_keys):
62 | model_state_dict[new_key] = model_state_dict.pop(old_key)
63 |
64 | return model_state_dict
65 |
66 |
67 |
68 | class InputFeatures(object):
69 | def __init__(self, conv_id, input_ids, position_ids, token_type_ids,
70 | seeker_post, response_post, input_len=None):
71 | self.conv_id = conv_id
72 | self.input_ids = input_ids
73 | self.position_ids = position_ids
74 | self.token_type_ids = token_type_ids
75 |
76 | self.seeker_post = seeker_post
77 | self.response_post = response_post
78 |
79 | if input_len is None:
80 | self.input_len = len(input_ids)
81 | else:
82 | self.input_len = input_len
83 |
84 |
85 |
86 |
87 | class RedditExample(object):
88 | def __init__(self, conv_id, context, response):
89 | self.conv_id = conv_id
90 | self.context = context
91 | self.response = response
92 |
93 | def __repr__(self):
94 | return 'conv_id = {}\ncontext = {}\nresponse = {}'.format(
95 | self.conv_id, self.context, self.response)
96 |
97 | def __str__(self):
98 | return self.__repr__()
99 |
100 |
101 | def boolean_string(s):
102 | if s.lower() not in {'false', 'true'}:
103 | raise ValueError('Not a valid boolean string')
104 | return s.lower() == 'true'
105 |
106 |
107 | def get_eval_list_same_length(input_file, tokenizer, max_batch_size,
108 | norm=True):
109 | examples = []
110 | with open(input_file, 'r', encoding="utf-8") as f:
111 | content = [l.split('\t') for l in f.read().splitlines()]
112 |
113 | context, response = [c[0] for c in content], [c[1:] for c in content]
114 | i = 0
115 | for src, tgt_all in zip(context, response):
116 | for tgt in tgt_all:
117 | if norm:
118 | src_line = ' '.join(src.strip().split())
119 | tgt_line = ' '.join(tgt.strip().split())
120 | else:
121 | src_line = src.strip()
122 | tgt_line = tgt.strip()
123 | examples.append(RedditExample(i, src_line, tgt_line))
124 | i += 1
125 |
126 | def featurize(example):
127 | conv_id = example.conv_id
128 | context_id = tokenizer.encode(example.context)
129 | end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN]
130 |
131 | response_id = tokenizer.encode(example.response)
132 | input_ids = context_id + [end_of_text_id]
133 | lm_labels = response_id
134 |
135 | position_ids = list(range(len(input_ids)))
136 |
137 | token_type_id = [0] * len(input_ids)
138 |
139 | return InputFeatures(conv_id, input_ids, position_ids, token_type_id,
140 | lm_labels, len(context_id), len(response_id))
141 |
142 | def batch_feature_same_len(features):
143 | input_ids = torch.stack([torch.tensor(f.choices_features['input_ids'],
144 | dtype=torch.long)
145 | for f in features])
146 | position_ids = torch.stack(
147 | [torch.tensor(f.choices_features['position_ids'], dtype=torch.long)
148 | for f in features])
149 | token_type_ids = torch.stack(
150 | [torch.tensor(f.choices_features['token_type_ids'],
151 | dtype=torch.long)
152 | for f in features])
153 | labels = torch.nn.utils.rnn.pad_sequence(
154 | [torch.tensor(f.lm_labels, dtype=torch.long) for f in features],
155 | batch_first=True, padding_value=-1)
156 |
157 | context_len = torch.tensor([f.context_len for f in features],
158 | dtype=torch.long)
159 | response_len = torch.tensor([f.response_len for f in features],
160 | dtype=torch.long)
161 | return (input_ids, position_ids, token_type_ids, labels,
162 | context_len, response_len)
163 |
164 | features = [featurize(e) for e in examples]
165 | dataloader_pre = defaultdict(list)
166 | for f in features:
167 | dataloader_pre[f.context_len].append(f)
168 |
169 | dataloader = []
170 | for l in sorted(dataloader_pre):
171 | f = batch_feature_same_len(dataloader_pre[l])
172 | if len(f[0]) <= max_batch_size:
173 | dataloader.append(f)
174 | else:
175 | start_index = 0
176 | while True:
177 | dataloader.append([ff[start_index:start_index + max_batch_size]
178 | for ff in f])
179 | start_index += max_batch_size
180 | if start_index >= len(f[0]):
181 | break
182 | return dataloader
183 |
184 |
185 | def set_lr(optimizer, step, schedule, lr,
186 | warmup_steps, warmup_proportion, n_embd, tot_steps):
187 | if schedule == 'None':
188 | lr_this_step = lr
189 | elif schedule == 'noam': # transformer like
190 | lr_this_step = lr * 1e4 * noam_decay(step+1, warmup_steps, n_embd)
191 | elif schedule == 'noamwd': # transformer like
192 | lr_this_step = lr * 1e4 * noamwd_decay(step+1, warmup_steps, n_embd)
193 | else:
194 | lr_this_step = lr * warmup_linear(step / tot_steps,
195 | warmup_proportion)
196 | for param_group in optimizer.param_groups:
197 | param_group['lr'] = lr_this_step
198 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/rewards.py:
--------------------------------------------------------------------------------
1 | from .automated_metrics import *
2 | import numpy as np
3 | import nltk
4 | # input: edited sentence, initial sentence, weights
5 | # from .empathy_classifier_bi_encoder_attention import empathy_classifier
6 | from .coherence_classifier2 import coherence_classifier
7 |
8 | from .coherence_classifier2 import politeness_classifier
9 | from .coherence_classifier2 import refutation_classifier
10 | from .coherence_classifier2 import evidence_classifier
11 |
12 | orig_sent_list = ["It might help to re-install Python if possible","The version might behind the bug."]
13 | new_sent_list = ["The version might be the reason for the bug.","The version might be the reason behind the bug."]
14 |
15 | # w: weight for the importance of the rewards
16 | w = {'edit':1,'bleu':1,'dist1':10,'dist2':1,'pp':10,'spec':1,'empathy':1,
17 | 'politeness':1, 'refutation':1, 'evidence':1, 'coherence':0.1, 'relevance':0.1}
18 |
19 | def calc_rewards(seeker_posts, original_responses, generated_responses, candidates = None,
20 | _edit=False,
21 | _bleu=False,
22 | _distinct=False,
23 | _perplexity=False,
24 | _specificity=False,
25 | _empathy=False,
26 | _empathy_change=False,
27 | _coherence=False,
28 | _add_noise=True,
29 | _pick_categorical='',
30 | _empathy_adaptive=False,
31 | _politeness=False,
32 | _refutation=False,
33 | _evidence=False,
34 | _relevance=False,
35 | NOISE=0.00001):
36 |
37 | total_score = 0
38 |
39 | if _edit:
40 | edit = edit_level_jaccard(original_responses, generated_responses)
41 | total_score += edit*w['edit']
42 |
43 | if _bleu:
44 | bleu_score = bleu(generated_responses, original_responses)
45 | total_score += bleu_score*w['bleu']
46 |
47 | if _distinct:
48 | distinct_1, distinct_2 = distinct(generated_responses)
49 | total_score += distinct_1*w['dist1']+distinct_2*w['dist2']
50 |
51 | if _perplexity:
52 | perplexity_score = perplexity(generated_responses)
53 | total_score += perplexity_score*w['pp']
54 |
55 | if _specificity:
56 | specificity_score = 0 # specificity(seeker_posts, generated_responses)
57 | total_score += specificity_score*w['spec']
58 |
59 | if _empathy:
60 | empathy_score = calc_empathy_score(seeker_posts, generated_responses)
61 | total_score += empathy_score*w['empathy']
62 |
63 | if _empathy_change:
64 | prev_empathy_score = calc_empathy_score(seeker_posts, original_responses)
65 | curr_empathy_score = calc_empathy_score(seeker_posts, generated_responses)
66 |
67 | total_score += curr_empathy_score - prev_empathy_score
68 |
69 |
70 | if _coherence:
71 | reward_val = relevance(seeker_posts, generated_responses)
72 | total_score += reward_val*w['coherence']
73 | if _politeness:
74 | reward_val = calc_politeness_score(generated_responses, "")
75 | total_score += reward_val*w['politeness']
76 | if _refutation:
77 | reward_val = calc_refutation_score(seeker_posts, generated_responses)
78 | total_score += reward_val*w['refutation']
79 | if _evidence:
80 | reward_val = calc_evidence_score(seeker_posts, generated_responses)
81 | total_score += reward_val*w['evidence']
82 | if _relevance:
83 | reward_val = relevance(seeker_posts, generated_responses)
84 | total_score += reward_val*w['relevance']
85 |
86 |
87 | if _add_noise:
88 | total_score -= NOISE
89 |
90 | if _empathy_adaptive:
91 | _,_,_,_,ER_score, IP_score, EX_score = calc_empathy_score_3dim(seeker_posts, generated_responses)
92 | total_score += ((2-ER_score)*ER_score+(2-IP_score)*IP_score+(2-EX_score)*EX_score)*w['empathy']*0.5
93 |
94 | return total_score
95 |
96 |
97 | def edit_level_jaccard(orig_sent_list, new_sent_list):
98 | total_score = 0
99 | for i, orig_sent in enumerate(orig_sent_list):
100 | total_score += (nltk.jaccard_distance(set(orig_sent), set(new_sent_list[i])))
101 | return total_score/len(orig_sent_list)
102 |
103 | edit_level_jaccard(orig_sent_list, new_sent_list)
104 |
105 | def calc_empathy_score(seeker_posts, generated_responses):
106 | batch_score = 0
107 | for i in range(len(seeker_posts)):
108 | (logits_empathy_ER, predictions_ER, logits_empathy_IP, predictions_IP, logits_empathy_EX, predictions_EX) = empathy_classifier.predict_empathy([seeker_posts[i]], [generated_responses[i]])
109 | batch_score += ((predictions_ER[0]+predictions_IP[0]+predictions_EX[0])*0.5)
110 |
111 | return batch_score/len(seeker_posts)
112 |
113 | def calc_empathy_score_3dim(seeker_posts, generated_responses):
114 | batch_score = 0
115 | ER_score_list =[]
116 | IP_score_list =[]
117 | EX_score_list =[]
118 |
119 | for i in range(len(seeker_posts)):
120 | try:
121 | (logits_empathy_ER, predictions_ER, logits_empathy_IP, predictions_IP, logits_empathy_EX, predictions_EX) = empathy_classifier.predict_empathy([seeker_posts[i]], [generated_responses[i]])
122 | batch_score += ((predictions_ER[0]+predictions_IP[0]+predictions_EX[0]))
123 | ER_score_list.append(predictions_ER[0])
124 | IP_score_list.append(predictions_IP[0])
125 | EX_score_list.append(predictions_EX[0])
126 | except:
127 | print('Error:', seeker_posts[i], generated_responses[i])
128 |
129 | ER_score = np.sum(ER_score_list)/len(seeker_posts)
130 | IP_score = np.sum(IP_score_list)/len(seeker_posts)
131 | EX_score = np.sum(EX_score_list)/len(seeker_posts)
132 |
133 | return batch_score/len(seeker_posts),ER_score_list,IP_score_list,EX_score_list, ER_score, IP_score, EX_score
134 |
135 | def log2prob(logs):
136 | probs = np.divide(np.exp(logs), (1+np.exp(logs)))
137 | return probs
138 |
139 | def calc_coherence_score(original_responses, candidate): # original_response: list of strings, candidate: string
140 | (logits, predictions,) = coherence_classifier.predict_empathy(original_responses, candidate)
141 | logs_1 = [log[1] for log in logits]
142 | score = np.mean(log2prob(logs_1))
143 | return score
144 |
145 | ## for politeness, we only use the original response: 1 input
146 | def calc_politeness_score(original_responses, candidate): # original_response: list of strings, candidate: string
147 | (logits, predictions,) = politeness_classifier.predict_empathy(original_responses, candidate)
148 | # here logits: 0:dim: impolite; 1:polite(normal and polite)
149 | logs_1 = [log[1] for log in logits]
150 | score = np.mean(log2prob(logs_1))
151 | return score
152 |
153 | ## for refutation scores, we should use two inputs:
154 | def calc_refutation_score(original_responses, candidate): # original_response: list of strings, candidate: string
155 | (logits, predictions,) = refutation_classifier.predict_empathy(original_responses, candidate)
156 | # here logits: 0:dim: impolite; 1:polite(normal and polite)
157 | logs_1 = [log[1] for log in logits]
158 | score = np.mean(log2prob(logs_1))
159 | return score
160 |
161 | def calc_evidence_score(original_responses, candidate): # original_response: list of strings, candidate: string
162 | (logits, predictions,) = evidence_classifier.predict_empathy(original_responses, candidate)
163 | # here logits: 0:dim: impolite; 1:polite(normal and polite)
164 | logs_1 = [log[1] for log in logits]
165 | score = np.mean(log2prob(logs_1))
166 | return score
167 |
168 |
169 |
170 |
171 |
172 |
--------------------------------------------------------------------------------
/src/process_data.py:
--------------------------------------------------------------------------------
1 | """
2 | preprocess input data into feature and stores binary as python shelve DB
3 | each chunk is gzipped JSON string
4 | """
5 | import argparse
6 | import gzip
7 | import json
8 | import subprocess as sp
9 | import shelve
10 | import os
11 | from os.path import dirname, exists, join
12 |
13 | import torch
14 | from lsp_model_rl import GPT2Tokenizer
15 | from tqdm import tqdm
16 |
17 | from env import END_OF_TEXT_TOKEN
18 | from gpt2_training.train_utils import InputFeatures as InputFeatures
19 |
20 |
21 | def _get_file_len(corpus):
22 | n_line = int(sp.check_output(f"wc -l {corpus}".split(),
23 | universal_newlines=True).split()[0])
24 | return n_line
25 |
26 |
27 | def _norm_text(text):
28 | w, *toks = text.strip().split()
29 | try:
30 | w = float(w)
31 | except Exception:
32 | toks = [w] + toks
33 | w = 1.0
34 | return w, ' '.join(toks)
35 |
36 |
37 | def _get_inputs_from_text(text, tokenizer, if_input_src_only=False):
38 | src_seeker_post = text.strip().split('\t')[0].strip() # srcs.split('')[0].strip()[4:]
39 | src_response_post = text.strip().split('\t')[1].strip() # srcs.split('')[1].strip()
40 |
41 | if if_input_src_only:
42 | # ==== condition 1 ====: but, see so many split tokens in the text generation
43 | # srcs = src_seeker_post + ' ' + ' '
44 | # ==== condition 2 ====: like Dialog-GPT setup
45 | srcs = src_seeker_post + ' <|endoftext|> '
46 | else:
47 | srcs = src_seeker_post + ' ' + src_response_post + ' '
48 |
49 | inputs = []
50 |
51 | for src in srcs.split(' EOS '):
52 | context_id = tokenizer.encode(src)
53 | inputs.append(context_id)
54 |
55 | # pos = [pos,]
56 | return inputs, src_seeker_post, src_response_post
57 |
58 |
59 | def _make_features(id_, inputs, src_seeker_post, src_response_post, tokenizer, max_len):
60 | end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN]
61 | features = []
62 | sents = []
63 | ws = []
64 | ps = []
65 | len_ = 0
66 | i = 0
67 | # ids: list of list: [[1,2,3]]
68 | for ids in inputs:
69 | if len(ids) > max_len:
70 | ids = ids[:max_len]
71 |
72 | len_ += (len(ids) + 1)
73 | sents.append(ids)
74 | # : it seems sents is just a concated
75 | if len(sents) >= 1:
76 | feat = _make_feature(id_ + i, sents, src_seeker_post, src_response_post, end_of_text_id)
77 | if feat is not None:
78 | features.append(feat)
79 |
80 | return features
81 |
82 |
83 | def _make_feature(id_, sents, src_seeker_post, src_response_post, eos):
84 | # : change from list of list to a single list
85 | input_ids = [i for s in sents for i in s+[eos]][:-1]
86 |
87 | weights = []
88 | token_type_ids = [] # this becomes round ids
89 |
90 | # : in the input index list: we have split in the middle
91 | # like: source reply
92 | split_id = toker.encode("")[0]
93 |
94 | curr_id = 0
95 |
96 | for i, s in enumerate(input_ids):
97 |
98 | if s == split_id:
99 | curr_id = 1
100 |
101 | token_type_ids.append(curr_id)
102 |
103 |
104 | # TODO: handle trailing -1's
105 | # input_ids = input_ids[:i+1]
106 | # token_type_ids = token_type_ids[:i+1]
107 |
108 | # pad to multiples of 8
109 | while len(input_ids) % 8 != 0:
110 | input_ids.append(0)
111 | token_type_ids.append(0)
112 |
113 | position_ids = list(range(len(input_ids)))
114 | assert (len(input_ids) == len(position_ids) == len(token_type_ids))
115 | assert len(input_ids) % 8 == 0
116 |
117 | if len(input_ids) == 0:
118 | import pdb
119 | pdb.set_trace()
120 |
121 | # : here, we have three important ids:
122 | # 1) input_ids: sequential; 2) position_ids: provide the position information
123 | # 3) token_type: for the related split/EOS tagging
124 | feature = InputFeatures(id_, input_ids, position_ids, token_type_ids,
125 | src_seeker_post, src_response_post)
126 |
127 | return feature
128 |
129 |
130 | toker = GPT2Tokenizer.from_pretrained('gpt2-medium')
131 | toker.add_tokens(['', '', ''])
132 |
133 | def main(args):
134 |
135 | attrs = []
136 | if args.if_input_src_only:
137 | print(f"we only keep source text in the input")
138 | attrs.append("if_input_src_only")
139 | if args.reverse:
140 | attrs.append('reverse')
141 | if args.two_turn:
142 | attrs.append('2turn')
143 | if attrs:
144 | db_path = (f'{args.corpus[:-4]}.{args.max_seq_len}len.'
145 | f'{".".join(attrs)}.db/db')
146 | else:
147 | db_path = f'{args.corpus[:-4]}.{args.max_seq_len}len.db/db'
148 | print(f"the db path is: {db_path}")
149 | if exists(dirname(db_path)):
150 | raise ValueError('Found existing DB, please backup')
151 | else:
152 | os.makedirs(dirname(db_path))
153 | with open(args.corpus, "r", encoding="utf-8") as reader, \
154 | shelve.open(db_path, 'n') as db:
155 | chunk = []
156 | n_chunk = 0
157 | n_example = 0
158 | for line in tqdm(reader, total=_get_file_len(args.corpus)):
159 | # print('line:', line)
160 | # print('n_chunk:', len(chunk))
161 |
162 | try:
163 | if len(chunk) >= args.chunk_size:
164 | # save and renew chunk
165 | db[f'chunk_{n_chunk}'] = gzip.compress(
166 | json.dumps(chunk[:args.chunk_size]).encode('utf-8'))
167 | chunk = chunk[args.chunk_size:]
168 | n_chunk += 1
169 |
170 | # : inputs will have the whole seeker post+response post after tokenization and encoding
171 | inputs, src_seeker_post, src_response_post = _get_inputs_from_text(line, toker, if_input_src_only=args.if_input_src_only)
172 | if args.reverse:
173 | inputs = list(reversed(inputs))
174 | if args.two_turn:
175 | inputs = inputs[:2]
176 | # : when running the code, it is one line with only one seeker post and response post
177 | # print(f"inputs is: {inputs}\n{type(inputs)}")
178 | # inputs: list of list: [[22087, 13, 10478]]
179 | features = _make_features(n_example, inputs, src_seeker_post, src_response_post,
180 | toker, args.max_seq_len)
181 |
182 | # print('features:', features) # : no need to see it:
183 | # features: []
184 |
185 |
186 | for feature in features:
187 | chunk.append(vars(feature))
188 | n_example += 1
189 | except Exception as e:
190 | print('!!! prepro exception !!!', e)
191 | continue
192 | # save last chunk
193 | db[f'chunk_{n_chunk}'] = gzip.compress(
194 | json.dumps(chunk).encode('utf-8'))
195 | # save relevant information to reproduce
196 | meta = {'n_example': n_example,
197 | 'chunk_size': args.chunk_size,
198 | 'max_seq_len': args.max_seq_len,
199 | 'reverse': args.reverse,
200 | 'two_turn': args.two_turn}
201 | with open(join(dirname(db_path), 'meta.json'), 'w') as writer:
202 | json.dump(meta, writer, indent=4)
203 | torch.save(toker, join(dirname(db_path), 'tokenizer.pt'))
204 |
205 |
206 | if __name__ == '__main__':
207 | parser = argparse.ArgumentParser()
208 | parser.add_argument('--corpus', required=True,
209 | help='file name of training corpus (should be .tsv)')
210 | parser.add_argument('--chunk_size', type=int, default=65536,
211 | help='num of data examples in a storing chunk')
212 | parser.add_argument('--max_seq_len', type=int, default=128,
213 | help='discard data longer than this')
214 | parser.add_argument('--reverse', action='store_true',
215 | help='reverse the src tgt')
216 | parser.add_argument('--two_turn', action='store_true',
217 | help='take only the first 2 turns: for toy examples? guessed by ')
218 | parser.add_argument('--if_input_src_only', action='store_true', default=False,
219 | help='by : input with src only')
220 |
221 | args = parser.parse_args()
222 |
223 | main(args)
224 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/util/configuration_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ BERT model configuration """
17 |
18 |
19 | import logging
20 |
21 | from .configuration_utils import PretrainedConfig
22 |
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
28 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
29 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
30 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
31 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
32 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
33 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
34 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
35 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
36 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
39 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
40 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
41 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
42 | "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
43 | "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
44 | "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
45 | "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
46 | "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
47 | "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
48 | "bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
49 | }
50 |
51 |
52 | class BertConfig(PretrainedConfig):
53 | r"""
54 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
55 | It is used to instantiate an BERT model according to the specified arguments, defining the model
56 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
57 | the BERT `bert-base-uncased `__ architecture.
58 |
59 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
60 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
61 | for more information.
62 |
63 |
64 | Args:
65 | vocab_size (:obj:`int`, optional, defaults to 30522):
66 | Vocabulary size of the BERT model. Defines the different tokens that
67 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
68 | hidden_size (:obj:`int`, optional, defaults to 768):
69 | Dimensionality of the encoder layers and the pooler layer.
70 | num_hidden_layers (:obj:`int`, optional, defaults to 12):
71 | Number of hidden layers in the Transformer encoder.
72 | num_attention_heads (:obj:`int`, optional, defaults to 12):
73 | Number of attention heads for each attention layer in the Transformer encoder.
74 | intermediate_size (:obj:`int`, optional, defaults to 3072):
75 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
76 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
77 | The non-linear activation function (function or string) in the encoder and pooler.
78 | If string, "gelu", "relu", "swish" and "gelu_new" are supported.
79 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
80 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
81 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
82 | The dropout ratio for the attention probabilities.
83 | max_position_embeddings (:obj:`int`, optional, defaults to 512):
84 | The maximum sequence length that this model might ever be used with.
85 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
86 | type_vocab_size (:obj:`int`, optional, defaults to 2):
87 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
88 | initializer_range (:obj:`float`, optional, defaults to 0.02):
89 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
90 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
91 | The epsilon used by the layer normalization layers.
92 |
93 | Example::
94 |
95 | from transformers import BertModel, BertConfig
96 |
97 | # Initializing a BERT bert-base-uncased style configuration
98 | configuration = BertConfig()
99 |
100 | # Initializing a model from the bert-base-uncased style configuration
101 | model = BertModel(configuration)
102 |
103 | # Accessing the model configuration
104 | configuration = model.config
105 |
106 | Attributes:
107 | pretrained_config_archive_map (Dict[str, str]):
108 | A dictionary containing all the available pre-trained checkpoints.
109 | """
110 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
111 | model_type = "bert"
112 |
113 | def __init__(
114 | self,
115 | vocab_size=30522,
116 | hidden_size=768,
117 | num_hidden_layers=12,
118 | num_attention_heads=12,
119 | intermediate_size=3072,
120 | hidden_act="gelu",
121 | hidden_dropout_prob=0.1,
122 | attention_probs_dropout_prob=0.1,
123 | max_position_embeddings=512,
124 | type_vocab_size=2,
125 | initializer_range=0.02,
126 | layer_norm_eps=1e-12,
127 | pad_token_id=0,
128 | **kwargs
129 | ):
130 | super().__init__(pad_token_id=pad_token_id, **kwargs)
131 |
132 | self.vocab_size = vocab_size
133 | self.hidden_size = hidden_size
134 | self.num_hidden_layers = num_hidden_layers
135 | self.num_attention_heads = num_attention_heads
136 | self.hidden_act = hidden_act
137 | self.intermediate_size = intermediate_size
138 | self.hidden_dropout_prob = hidden_dropout_prob
139 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
140 | self.max_position_embeddings = max_position_embeddings
141 | self.type_vocab_size = type_vocab_size
142 | self.initializer_range = initializer_range
143 | self.layer_norm_eps = layer_norm_eps
--------------------------------------------------------------------------------
/src/data_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 | import gzip
4 | import json
5 | import math
6 | import random
7 | import shelve
8 | import torch
9 |
10 | import subprocess as sp
11 |
12 | from math import ceil
13 | from torch.utils.data import DataLoader, Sampler, Dataset
14 | from torch.nn.utils.rnn import pad_sequence
15 |
16 | from env import END_OF_TEXT_TOKEN
17 | from gpt2_training.train_utils import (InputFeatures,
18 | RedditExample)
19 |
20 |
21 | class BucketSampler(Sampler):
22 | """
23 | this sampler will sort data by sequence length
24 | """
25 | def __init__(self, lens, bucket_size, batch_size,
26 | droplast=False, shuffle=True):
27 | self._lens = lens
28 | self._batch_size = batch_size
29 | self._bucket_size = bucket_size
30 | self._droplast = droplast
31 | self._shuf = shuffle
32 |
33 | def __iter__(self):
34 | ids = list(range(len(self._lens)))
35 | if self._shuf:
36 | random.shuffle(ids)
37 | buckets = [sorted(ids[i:i+self._bucket_size],
38 | key=lambda i: self._lens[i], reverse=True)
39 | for i in range(0, len(ids), self._bucket_size)]
40 | batches = [bucket[i:i+self._batch_size]
41 | for bucket in buckets
42 | for i in range(0, len(bucket), self._batch_size)]
43 | if self._droplast:
44 | batches = [batch for batch in batches
45 | if len(batch) == self._batch_size]
46 | if self._shuf:
47 | random.shuffle(batches)
48 | return iter(batches)
49 |
50 | def __len__(self):
51 | bucket_sizes = ([self._bucket_size]
52 | * (len(self._lens) // self._bucket_size)
53 | + [len(self._lens) % self._bucket_size])
54 | if self._droplast:
55 | return sum(s//self._batch_size for s in bucket_sizes)
56 | else:
57 | return sum(math.ceil(s/self._batch_size) for s in bucket_sizes)
58 |
59 |
60 | class GPT2FeatureDataset(Dataset):
61 | """ pytorch dataset for GPT2 training """
62 | def __init__(self, features, max_len=None):
63 | self.features = features
64 | self.max_len = max_len # this max_len do truncate
65 |
66 | def __getitem__(self, i):
67 | feat_dict = self.features[i]
68 | if self.max_len is not None and feat_dict['input_len'] > self.max_len:
69 | # tuncate on the left side (context)
70 | feat_dict['input_ids'] = feat_dict['input_ids'][-self.max_len:]
71 | feat_dict['position_ids'] = feat_dict['position_ids'][
72 | -self.max_len:]
73 | feat_dict['token_type_ids'] = feat_dict['token_type_ids'][
74 | -self.max_len:]
75 | # feat_dict['lm_labels'] = feat_dict['lm_labels'][-self.max_len:]
76 | # feat_dict['pos_labels'] = feat_dict['pos_labels']
77 | try:
78 | for s in ['context_len', 'response_len']:
79 | if s in feat_dict.keys():
80 | print("db file missing "+s)
81 | del feat_dict[s]
82 | except Exception:
83 | import pdb
84 | pdb.set_trace()
85 |
86 | feat = InputFeatures(**feat_dict)
87 | return feat
88 |
89 | def __len__(self):
90 | return len(self.features)
91 |
92 | @staticmethod
93 | def collate(features):
94 | input_ids = pad_sequence([torch.tensor(f.input_ids, dtype=torch.long)
95 | for f in features],
96 | batch_first=True, padding_value=0)
97 | position_ids = pad_sequence([torch.tensor(f.position_ids,
98 | dtype=torch.long)
99 | for f in features],
100 | batch_first=True, padding_value=0)
101 | token_type_ids = pad_sequence([torch.tensor(f.token_type_ids,
102 | dtype=torch.long)
103 | for f in features],
104 | batch_first=True, padding_value=0)
105 | # labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long)
106 | # for f in features],
107 | # batch_first=True, padding_value=-1)
108 | # pos_labels = torch.tensor([torch.tensor(f.pos_label, dtype=torch.long) for f in features])
109 | seeker_post = [f.seeker_post for f in features]
110 | response_post = [f.response_post for f in features]
111 |
112 | return (input_ids, position_ids, token_type_ids, seeker_post, response_post)
113 |
114 |
115 | class BucketingDataLoader(object):
116 | """ this loads shelve db chunks and then convert to mini-batch loader"""
117 | def __init__(self, db_name, batch_size, max_seq_len,
118 | bucket=100, shuffle=True):
119 | print(f'*** {db_name}/db****')
120 | self.db = shelve.open(f'{db_name}/db', 'r')
121 | self.batch_size = batch_size
122 | self.max_len = max_seq_len
123 | self.bucket_size = bucket * batch_size
124 | self.shuffle = shuffle
125 |
126 | def _get_keys(self):
127 | keys = list(self.db.keys())
128 | return keys
129 |
130 | def __iter__(self):
131 | keys = self._get_keys()
132 | if self.shuffle:
133 | random.shuffle(keys)
134 | for key in keys:
135 | chunk = json.loads(gzip.decompress(self.db[key]).decode('utf-8'))
136 | # discard long examples
137 | trunc_chunk = []
138 | lens = []
139 | for feat in chunk:
140 | if feat['input_len'] > self.max_len:
141 | continue
142 | trunc_chunk.append(feat)
143 | lens.append(feat['input_len'])
144 |
145 | # print('trunc_chunk:', trunc_chunk)
146 |
147 | dataset = GPT2FeatureDataset(trunc_chunk, self.max_len)
148 | sampler = BucketSampler(lens, self.bucket_size, self.batch_size,
149 | droplast=True, shuffle=self.shuffle)
150 | loader = DataLoader(dataset, batch_sampler=sampler,
151 | num_workers=0, # can test multi-worker
152 | collate_fn=GPT2FeatureDataset.collate)
153 | yield from loader
154 |
155 | def __len__(self):
156 | raise NotImplementedError()
157 |
158 | def __del__(self):
159 | self.db.close() # ext: not close it? how it goes.
160 | # pass
161 |
162 |
163 | class DistributedBucketingDataLoader(BucketingDataLoader):
164 | """ distributed version """
165 | def __init__(self, rank, num_replica, *args, **kwargs):
166 | super().__init__(*args, **kwargs)
167 | self.rank = rank
168 | self.num_replica = num_replica
169 |
170 | def _get_keys(self):
171 | keys = list(self.db.keys())[self.rank::self.num_replica]
172 | return keys
173 |
174 |
175 | def convert_examples_to_features_dynamic(examples, tokenizer,
176 | max_seq_length=512):
177 | """
178 | do not pad
179 | """
180 | def featurize(example):
181 | conv_id = example.conv_id
182 | context_id = tokenizer.encode(example.context)
183 | end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN]
184 |
185 | # response is provided in example
186 | response_id = tokenizer.encode(example.response)
187 |
188 | input_ids_len = len(context_id) + len(response_id) + 2
189 | if input_ids_len > max_seq_length:
190 | if len(context_id) > input_ids_len - max_seq_length:
191 | # cut context from beginning if length of context + response is too long
192 | # and len of context is long enough to cut
193 | context_id = context_id[input_ids_len - max_seq_length:]
194 | else:
195 | # cut response from end if length of context + response is too long
196 | # and len of response is long enough to cut
197 | # if no response is available, discard the data
198 | if max_seq_length-len(context_id)-2 < 0:
199 | return None
200 | response_id = response_id[:max_seq_length-len(context_id)-2]
201 |
202 | input_ids = context_id + [end_of_text_id] + response_id + [end_of_text_id]
203 |
204 | # label simplely is next token in sequences. MASK all context_id tokens except for the last one
205 | # lm_labels = [-1] * len(context_id) + response_id + [end_of_text_id] + [-1]
206 |
207 | position_ids = list(range(len(input_ids)))
208 |
209 | token_type_id = [0] * len(input_ids)
210 |
211 | return InputFeatures(conv_id, input_ids, position_ids, token_type_id,
212 | len(context_id), len(response_id))
213 |
214 | # discard None feature
215 | features = [f for f in [featurize(ex) for ex in examples] if f is not None]
216 | return features
217 |
218 |
219 | class DynamicBatchingLoader(object):
220 | """ this loader takes raw text file, used for validate perplexity """
221 | def __init__(self, corpus_file, tokenizer, normalize_data,
222 | batch_size, max_seq_length):
223 | self.corpus = corpus_file
224 | self.toker = tokenizer
225 | self.norm = normalize_data
226 | self.bs = batch_size
227 | self.max_seq_length = max_seq_length
228 | self.num_examples = self.get_len(corpus_file)
229 |
230 | def __iter__(self, epoch=1):
231 | if epoch > 0:
232 | for epoch in range(epoch):
233 | yield from self._iter_epoch()
234 | else:
235 | while True:
236 | yield from self._iter_epoch()
237 |
238 | def __len__(self):
239 | return ceil(self.num_examples/self.bs)
240 |
241 | def _iter_epoch(self):
242 | try:
243 | with open(self.corpus, 'r', encoding="utf-8") as corpus:
244 | i = 0
245 | while True:
246 | examples = []
247 | cur_bs = 0
248 | while True:
249 | line = next(corpus).encode('utf-8').decode('utf-8')
250 | contents = line.split('\t')
251 | src, tgt_all = contents[0], contents[1:]
252 | for tgt in tgt_all:
253 | if self.norm:
254 | src_line = ' '.join(src.strip().split())
255 | tgt_line = ' '.join(tgt.strip().split())
256 | else:
257 | src_line = src.strip()
258 | tgt_line = tgt.strip()
259 | examples.append(
260 | RedditExample(i, src_line, tgt_line),
261 | )
262 | i += 1
263 | cur_bs += 1
264 | if cur_bs >= self.bs:
265 | break
266 | features = convert_examples_to_features_dynamic(
267 | examples, self.toker, self.max_seq_length)
268 | batch = self._batch_feature(features)
269 | yield batch
270 | except StopIteration:
271 | pass
272 |
273 | def _batch_feature(self, features):
274 | input_ids = pad_sequence([torch.tensor(f.choices_features['input_ids'],
275 | dtype=torch.long)
276 | for f in features],
277 | batch_first=True, padding_value=0)
278 | position_ids = pad_sequence(
279 | [torch.tensor(f.choices_features['position_ids'], dtype=torch.long)
280 | for f in features],
281 | batch_first=True, padding_value=0)
282 | token_type_ids = pad_sequence(
283 | [torch.tensor(f.choices_features['token_type_ids'],
284 | dtype=torch.long)
285 | for f in features],
286 | batch_first=True, padding_value=0)
287 | # labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long)
288 | # for f in features],
289 | # batch_first=True, padding_value=-1)
290 | context_len = torch.tensor([f.context_len for f in features],
291 | dtype=torch.long)
292 | response_len = torch.tensor([f.response_len for f in features],
293 | dtype=torch.long)
294 | return (input_ids, position_ids, token_type_ids,
295 | context_len, response_len)
296 |
297 | def get_len(self, corpus):
298 | n_line = int(sp.check_output(f"wc -l {corpus}".split(),
299 | universal_newlines=True).split()[0])
300 | return n_line
301 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/coherence_classifier2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import codecs
3 | import numpy as np
4 |
5 | import tqdm
6 | from tqdm import tqdm
7 |
8 | import pandas as pd
9 | import re
10 | import csv
11 | import numpy as np
12 |
13 | import time
14 |
15 | from sklearn.metrics import f1_score
16 |
17 | from transformers import RobertaTokenizer
18 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
19 | from torch.utils.data import TensorDataset, random_split
20 |
21 | from transformers import AdamW, RobertaConfig, RobertaForSequenceClassification
22 |
23 | import datetime
24 | #
25 | import sys, os
26 | sys.path.append("../")
27 | sys.path.append(".../")
28 | from MisinfoCorrect.src.variables_ext import politeness_clf_fp, refutation_clf_fp, evidence_clf_fp
29 |
30 | class CoherenceClassifier():
31 | def __init__(self,
32 | device,
33 | model_path,
34 | batch_size = 2):
35 |
36 | self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
37 | self.batch_size = batch_size
38 | self.device = device
39 |
40 | self.model = RobertaForSequenceClassification.from_pretrained(
41 | "roberta-base", # Use the 12-layer BERT model, with an uncased vocab.
42 | num_labels = 2, # The number of output labels--2 for binary classification.
43 | # You can increase this for multi-class tasks.
44 | output_attentions = False, # Whether the model returns attentions weights.
45 | output_hidden_states = False, # Whether the model returns all hidden-states.
46 | )
47 |
48 | # comment data parallel by ext, due to the missing key errors: the possible reason is
49 | # that the keys are allocated in different gpus? --- be careful of it? TODO
50 | # OR: another solution, we put nn.DataParallel()after the loading:
51 | # the only way it can work is that the saving is through nn.DataParallel: interesting
52 | # self.model = torch.nn.DataParallel(self.model) # TODO: check the details?? by ext
53 | print(f"the model path is: {model_path}")
54 | weights = torch.load(model_path, map_location=self.device) # commented by ext
55 | self.model.load_state_dict(weights)
56 |
57 | self.model.to(self.device)
58 |
59 |
60 | def predict_empathy(self, original_responses, candidate):
61 |
62 | input_ids = []
63 | attention_masks = []
64 |
65 | for idx, elem in enumerate(original_responses):
66 |
67 | response_sentence = original_responses[idx] + ' ' + candidate
68 |
69 | encoded_dict = self.tokenizer.encode_plus(
70 | response_sentence, # Sentence to encode.
71 | add_special_tokens = True, # Add '[CLS]' and '[SEP]'
72 | max_length = 64, # Pad & truncate all sentences.
73 | pad_to_max_length = True,
74 | return_attention_mask = True, # Construct attn. masks.
75 | return_tensors = 'pt', # Return pytorch tensors.
76 | )
77 |
78 | input_ids.append(encoded_dict['input_ids'])
79 | attention_masks.append(encoded_dict['attention_mask'])
80 |
81 | input_ids = torch.cat(input_ids, dim=0)
82 | attention_masks = torch.cat(attention_masks, dim=0)
83 |
84 | dataset = TensorDataset(input_ids, attention_masks)
85 |
86 | dataloader = DataLoader(
87 | dataset, # The test samples.
88 | sampler = SequentialSampler(dataset), # Pull out batches sequentially.
89 | batch_size = self.batch_size # Evaluate with this batch size.
90 | )
91 |
92 | self.model.eval()
93 |
94 | for batch in dataloader:
95 | b_input_ids = batch[0].to(self.device)
96 | b_input_mask = batch[1].to(self.device)
97 |
98 | with torch.no_grad():
99 | (logits, ) = self.model(input_ids = b_input_ids,
100 | token_type_ids=None,
101 | attention_mask=b_input_mask,)
102 | # res = self.model(input_ids = b_input_ids,
103 | # token_type_ids=None,
104 | # attention_mask=b_input_mask,)
105 | # logits = res.logits
106 |
107 |
108 | logits = logits.detach().cpu().numpy().tolist()
109 | predictions = np.argmax(logits, axis=1).flatten()
110 |
111 | return (logits, predictions)
112 |
113 |
114 | class PolitenessClassifier(CoherenceClassifier):
115 | pass
116 | # def __init__(self):
117 | # super().__init__()
118 |
119 | class RefutationClassifier(CoherenceClassifier):
120 | def __init__(self, device, model_path, batch_size=2):
121 | super().__init__(device, model_path, batch_size)
122 |
123 | def predict_empathy(self, original_responses, candidate):
124 |
125 |
126 | test_encode = self.tokenizer.batch_encode_plus(list(zip(original_responses, candidate)),
127 | padding='max_length', truncation=True, max_length=128, return_tensors='pt', pad_to_max_length=True)
128 |
129 | test_seq = torch.tensor(test_encode['input_ids'])
130 | test_mask = torch.tensor(test_encode['attention_mask'])
131 | # test_token = torch.tensor(test_encode['token_type_ids'])
132 |
133 |
134 | test_data = TensorDataset(test_seq, test_mask) #, test_token
135 |
136 | dataloader = DataLoader(
137 | test_data, # The test samples.
138 | sampler = SequentialSampler(test_data), # Pull out batches sequentially.
139 | batch_size = self.batch_size # Evaluate with this batch size.
140 | )
141 |
142 | def model_eval_and_infer(model, prediction_dataloader, device, if_infer=False, if_have_token_types=True):
143 |
144 | model.eval()
145 | if not if_infer:
146 | predictions , true_labels = [], []
147 |
148 | for batch in tqdm(prediction_dataloader):
149 |
150 | if if_have_token_types:
151 | b_input_ids, b_input_mask, b_token_type, b_labels = batch
152 | else:
153 | b_input_ids, b_input_mask, b_labels = batch
154 | with torch.no_grad():
155 |
156 | if if_have_token_types:
157 |
158 | outputs = model(b_input_ids.to(device), token_type_ids=b_token_type.to(device),
159 | attention_mask=b_input_mask.to(device))
160 | else:
161 | outputs = model(b_input_ids.to(device), attention_mask=b_input_mask.to(device))
162 |
163 | b_proba = outputs[0]
164 |
165 | proba = b_proba.detach().cpu().numpy()
166 | label_ids = b_labels.numpy()
167 |
168 | predictions.append(proba)
169 | true_labels.append(label_ids)
170 |
171 | print(' DONE.')
172 |
173 | flat_predictions = np.concatenate(predictions, axis=0)
174 | y_pred = np.argmax(flat_predictions, axis=1).flatten()
175 | y_true = np.concatenate(true_labels, axis=0)
176 |
177 |
178 | return y_pred, y_true
179 | if if_infer:
180 | predictions = []
181 |
182 | for batch in prediction_dataloader:
183 |
184 | if if_have_token_types:
185 | b_input_ids, b_input_mask, b_token_type = batch
186 | else:
187 | b_input_ids, b_input_mask = batch
188 | with torch.no_grad():
189 | if if_have_token_types:
190 | outputs = model(b_input_ids.to(device), token_type_ids=b_token_type.to(device),
191 | attention_mask=b_input_mask.to(device))
192 | else:
193 | outputs = model(b_input_ids.to(device), attention_mask=b_input_mask.to(device))
194 | b_proba = outputs[0]
195 |
196 | proba = b_proba.detach().cpu().numpy()
197 | # label_ids = b_labels.numpy()
198 |
199 | predictions.append(proba)
200 | # true_labels.append(label_ids)
201 |
202 |
203 |
204 | flat_predictions = np.concatenate(predictions, axis=0)
205 | y_pred = np.argmax(flat_predictions, axis=1).flatten()
206 | # y_true = np.concatenate(true_labels, axis=0)
207 |
208 | return y_pred, flat_predictions
209 |
210 | predictions, logits = model_eval_and_infer(self.model, dataloader, device=self.device, if_infer=True, if_have_token_types=False)
211 | return (logits, predictions)
212 |
213 | # ==== before Oct 2022 ==== in the past
214 | # input_ids = []
215 | # attention_masks = []
216 |
217 | # for idx, elem in enumerate(original_responses):
218 |
219 | # response_sentence = original_responses[idx] + ' ' + candidate
220 |
221 | # encoded_dict = self.tokenizer.encode_plus(
222 | # response_sentence, # Sentence to encode.
223 | # add_special_tokens = True, # Add '[CLS]' and '[SEP]'
224 | # max_length = 64, # Pad & truncate all sentences.
225 | # pad_to_max_length = True,
226 | # return_attention_mask = True, # Construct attn. masks.
227 | # return_tensors = 'pt', # Return pytorch tensors.
228 | # )
229 |
230 | # input_ids.append(encoded_dict['input_ids'])
231 | # attention_masks.append(encoded_dict['attention_mask'])
232 |
233 | # input_ids = torch.cat(input_ids, dim=0)
234 | # attention_masks = torch.cat(attention_masks, dim=0)
235 |
236 | # dataset = TensorDataset(input_ids, attention_masks)
237 |
238 | # dataloader = DataLoader(
239 | # dataset, # The test samples.
240 | # sampler = SequentialSampler(dataset), # Pull out batches sequentially.
241 | # batch_size = self.batch_size # Evaluate with this batch size.
242 | # )
243 |
244 | # self.model.eval()
245 |
246 | # for batch in dataloader:
247 | # b_input_ids = batch[0].to(self.device)
248 | # b_input_mask = batch[1].to(self.device)
249 |
250 | # with torch.no_grad():
251 | # (logits, ) = self.model(input_ids = b_input_ids,
252 | # token_type_ids=None,
253 | # attention_mask=b_input_mask,)
254 | # # res = self.model(input_ids = b_input_ids,
255 | # # token_type_ids=None,
256 | # # attention_mask=b_input_mask,)
257 | # # logits = res.logits
258 |
259 |
260 | # logits = logits.detach().cpu().numpy().tolist()
261 | # predictions = np.argmax(logits, axis=1).flatten()
262 |
263 | # return (logits, predictions)
264 |
265 |
266 | # inherited from the refutation classifier when considering both the tweet and reply
267 | class EvidenceClassifier(RefutationClassifier):
268 | pass
269 | # def __init__(self):
270 | # super().__init__()
271 |
272 |
273 | '''
274 | Example:
275 | '''
276 |
277 | import sys, os
278 | sys.path.append("./")
279 | sys.path.append("../")
280 | sys.path.append(".../")
281 | sys.path.append("..../")
282 | from MisinfoCorrect.src.variables_ext import device
283 |
284 |
285 | original_responses = [ 'I am so sorry that she is not getting it.','so she can get a better idea of what the condition entails?']
286 | #sentences = ['why do you feel this way?', 'Let me know if you want to talk.']
287 | candidate = 'Have you thought of directing her to sites like NAMI and Mental Health First Aid'
288 | candidate2 = ' I have been on and off medication for the majority of my life '
289 |
290 | cadidates = [candidate, candidate2]
291 |
292 | print(f'here we use device: {device}')
293 | coherence_classifier = CoherenceClassifier(device, model_path=politeness_clf_fp)
294 |
295 | # #### ####
296 | # (logits, predictions,) = coherence_classifier.predict_empathy(original_responses, candidate)
297 |
298 | # print(logits, predictions)
299 |
300 | # (logits, predictions,) = coherence_classifier.predict_empathy(original_responses, candidate2)
301 | # print(logits, predictions)
302 |
303 |
304 | politeness_classifier = PolitenessClassifier(device, model_path=politeness_clf_fp) # sanity check
305 | refutation_classifier = RefutationClassifier(device, model_path=refutation_clf_fp) # sanity check
306 | evidence_classifier = EvidenceClassifier(device, model_path=evidence_clf_fp) # sanity check
307 |
308 | (logits, predictions,) = politeness_classifier.predict_empathy(original_responses, candidate2)
309 | print(logits, predictions)
310 |
311 |
312 | # ==== for refutation classifier ====
313 | print('start sanity check for refutation')
314 | (logits, predictions,) = refutation_classifier.predict_empathy(original_responses, cadidates)
315 | print('the sanity check of refutation classifier version 2')
316 | print(logits, predictions)
317 |
318 |
319 | print('start sanity check for refutation')
320 | (logits, predictions,) = evidence_classifier.predict_empathy(original_responses, cadidates)
321 | print('the sanity check of evidence classifier version 2')
322 | print(logits, predictions)
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _anaconda_depends=2020.07=py37_0
5 | _ipyw_jlab_nb_ext_conf=0.1.0=py37_0
6 | _libgcc_mutex=0.1=main
7 | absl-py=1.0.0=pypi_0
8 | aiohttp=3.8.1=pypi_0
9 | aiosignal=1.2.0=pypi_0
10 | alabaster=0.7.12=py37_0
11 | altair=4.1.0=pypi_0
12 | anaconda=custom=py37_1
13 | anaconda-client=1.7.2=py37_0
14 | anaconda-navigator=1.9.12=py37_0
15 | anaconda-project=0.8.4=py_0
16 | appdirs=1.4.4=pypi_0
17 | argh=0.26.2=py37_0
18 | argon2-cffi=20.1.0=py37h7b6447c_1
19 | ase=3.21.1=pypi_0
20 | asn1crypto=1.4.0=py_0
21 | astor=0.8.1=pypi_0
22 | astroid=2.4.2=py37_0
23 | astropy=4.2=py37h27cfd23_0
24 | async-timeout=4.0.1=pypi_0
25 | async_generator=1.10=py37h28b3542_0
26 | asynctest=0.13.0=pypi_0
27 | atomicwrites=1.4.0=py_0
28 | attrs=20.3.0=pyhd3eb1b0_0
29 | autocorrect=1.0.0=pypi_0
30 | autopep8=1.5.4=py_0
31 | babel=2.9.0=pyhd3eb1b0_0
32 | backcall=0.2.0=py_0
33 | backports=1.0=pyhd3eb1b0_2
34 | backports-zoneinfo=0.2.1=pypi_0
35 | backports.functools_lru_cache=1.6.1=pyhd3eb1b0_0
36 | backports.shutil_get_terminal_size=1.0.0=py37_2
37 | backports.tempfile=1.0=py_1
38 | backports.weakref=1.0.post1=py_1
39 | base58=2.1.1=pypi_0
40 | beautifulsoup4=4.9.3=pyhb0f4dca_0
41 | bert-serving=0.0.1=pypi_0
42 | bert-serving-client=1.10.0=pypi_0
43 | bert-serving-server=1.10.0=pypi_0
44 | bitarray=1.6.1=py37h27cfd23_0
45 | bkcharts=0.2=py37_0
46 | blas=1.0=mkl
47 | bleach=3.2.1=py_0
48 | blessings=1.7=pypi_0
49 | blinker=1.4=pypi_0
50 | blis=0.4.1=pypi_0
51 | blosc=1.20.1=hd408876_0
52 | bokeh=2.2.3=py37_0
53 | boto=2.49.0=py37_0
54 | boto3=1.12.34=pypi_0
55 | botocore=1.15.34=pypi_0
56 | bottleneck=1.3.2=py37heb32a55_1
57 | brotli=1.0.9=he6710b0_2
58 | brotlipy=0.7.0=py37h27cfd23_1003
59 | bzip2=1.0.8=h7b6447c_0
60 | ca-certificates=2020.12.5=ha878542_0
61 | cachetools=4.2.4=pypi_0
62 | cairo=1.14.12=h8948797_3
63 | catalogue=1.0.0=pypi_0
64 | certifi=2020.12.5=py37h89c1867_1
65 | cffi=1.14.4=py37h261ae71_0
66 | chardet=3.0.4=py37h06a4308_1003
67 | charls=2.1.0=he6710b0_2
68 | charset-normalizer=2.0.9=pypi_0
69 | clean-text=0.2.1=pypi_0
70 | click=7.1.2=py_0
71 | click-config-file=0.6.0=pypi_0
72 | click-plugins=1.1.1=pypi_0
73 | cloudpickle=1.6.0=py_0
74 | clyent=1.2.2=py37_1
75 | colorama=0.4.4=pyhd3eb1b0_0
76 | conda=4.9.2=py37h89c1867_0
77 | conda-build=3.18.11=py37_0
78 | conda-env=2.6.0=1
79 | conda-package-handling=1.7.2=py37h03888b9_0
80 | conda-verify=3.4.2=py_1
81 | configobj=5.0.6=pypi_0
82 | configparser=5.2.0=pypi_0
83 | contextlib2=0.6.0.post1=py_0
84 | convokit=2.3.2.5=pypi_0
85 | cryptography=3.3.1=py37h3c74f83_0
86 | cudatoolkit=11.0.221=h6bb024c_0
87 | curl=7.71.1=hbc83047_1
88 | cycler=0.10.0=py37_0
89 | cymem=2.0.3=pypi_0
90 | cython=0.29.21=py37h2531618_0
91 | cytoolz=0.11.0=py37h7b6447c_0
92 | dask=2020.12.0=pyhd3eb1b0_0
93 | dask-core=2020.12.0=pyhd3eb1b0_0
94 | datasets=1.16.1=pypi_0
95 | dbus=1.13.18=hb2f20db_0
96 | decorator=4.4.2=py_0
97 | defusedxml=0.6.0=py_0
98 | diff-match-patch=20200713=py_0
99 | dill=0.3.4=pypi_0
100 | distributed=2020.12.0=py37h06a4308_0
101 | docker-pycreds=0.4.0=pypi_0
102 | docutils=0.15.2=pypi_0
103 | emoji=1.2.0=pypi_0
104 | en-core-web-sm=2.3.1=pypi_0
105 | entrypoints=0.3=py37_0
106 | et_xmlfile=1.0.1=py_1001
107 | expat=2.2.10=he6710b0_2
108 | fastai=2.6.3=pypi_0
109 | fastcache=1.1.0=py37h7b6447c_0
110 | fastcore=1.4.2=pypi_0
111 | fastdownload=0.0.5=pypi_0
112 | fastprogress=0.2.4=pypi_0
113 | filelock=3.0.12=py_0
114 | fire=0.4.0=pypi_0
115 | flake8=3.8.4=py_0
116 | flask=1.1.2=py_0
117 | fontconfig=2.13.0=h9420a91_0
118 | freetype=2.10.4=h5ab3b9f_0
119 | fribidi=1.0.10=h7b6447c_0
120 | frozenlist=1.2.0=pypi_0
121 | fsspec=2021.11.1=pypi_0
122 | ftfy=5.8=pypi_0
123 | future=0.18.2=py37_1
124 | get_terminal_size=1.0.0=haa9412d_0
125 | gevent=20.9.0=py37h7b6447c_0
126 | giflib=5.1.4=h14c3975_1
127 | gitdb=4.0.9=pypi_0
128 | gitpython=3.1.24=pypi_0
129 | glib=2.66.1=h92f7085_0
130 | glob2=0.7=py_0
131 | gmp=6.1.2=h6c8ec71_1
132 | gmpy2=2.0.8=py37h10f8cd9_2
133 | google-auth=2.3.3=pypi_0
134 | google-auth-oauthlib=0.4.6=pypi_0
135 | googledrivedownloader=0.4=pypi_0
136 | gpustat=0.6.0=pypi_0
137 | gputil=1.4.0=pypi_0
138 | graphite2=1.3.14=h23475e2_0
139 | greenlet=0.4.17=py37h7b6447c_0
140 | grpcio=1.42.0=pypi_0
141 | gst-plugins-base=1.14.0=hbbd80ab_1
142 | gstreamer=1.14.0=hb31296c_0
143 | h5py=2.10.0=py37h7918eee_0
144 | harfbuzz=2.4.0=hca77d97_1
145 | hdf5=1.10.4=hb1b8bf9_0
146 | heapdict=1.0.1=py_0
147 | html5lib=1.1=py_0
148 | huggingface-hub=0.2.1=pypi_0
149 | humanize=3.12.0=pypi_0
150 | icu=58.2=he6710b0_3
151 | idna=2.8=pypi_0
152 | imagecodecs=2020.5.30=py37hfa7d478_2
153 | imageio=2.9.0=py_0
154 | imagesize=1.2.0=py_0
155 | importlib-metadata=4.8.2=pypi_0
156 | importlib_metadata=2.0.0=1
157 | iniconfig=1.1.1=py_0
158 | intel-openmp=2020.2=254
159 | intervaltree=3.1.0=py_0
160 | ipykernel=5.3.4=py37h5ca1d4c_0
161 | ipython=7.19.0=py37hb070fc8_0
162 | ipython_genutils=0.2.0=pyhd3eb1b0_1
163 | ipywidgets=7.5.1=py_1
164 | isodate=0.6.0=pypi_0
165 | isort=5.6.4=py_0
166 | itsdangerous=1.1.0=py37_0
167 | jbig=2.1=hdba287a_0
168 | jdcal=1.4.1=py_0
169 | jedi=0.14.1=py37_0
170 | jeepney=0.6.0=pyhd3eb1b0_0
171 | jinja2=2.11.2=py_0
172 | jmespath=0.9.5=pypi_0
173 | joblib=0.17.0=py_0
174 | jpeg=9b=h024ee3a_2
175 | json5=0.9.5=py_0
176 | jsonschema=3.2.0=py_2
177 | jupyter=1.0.0=py37_7
178 | jupyter-contrib-core=0.3.3=pypi_0
179 | jupyter_client=6.1.7=py_0
180 | jupyter_console=6.2.0=py_0
181 | jupyter_contrib_core=0.3.3=py_2
182 | jupyter_contrib_nbextensions=0.5.1=pyhd8ed1ab_2
183 | jupyter_core=4.7.0=py37h06a4308_0
184 | jupyter_highlight_selected_word=0.2.0=py37h89c1867_1002
185 | jupyter_latex_envs=1.4.6=pyhd8ed1ab_1002
186 | jupyter_nbextensions_configurator=0.4.1=py37h89c1867_2
187 | jupyterlab=2.2.6=py_0
188 | jupyterlab_pygments=0.1.2=py_0
189 | jupyterlab_server=1.2.0=py_0
190 | jxrlib=1.1=h7b6447c_2
191 | keyring=21.4.0=py37_1
192 | kiwisolver=1.3.0=py37h2531618_0
193 | krb5=1.18.2=h173b8e3_0
194 | lazy-object-proxy=1.4.3=py37h27cfd23_2
195 | lcms2=2.11=h396b838_0
196 | ld_impl_linux-64=2.33.1=h53a641e_7
197 | libaec=1.0.4=he6710b0_1
198 | libarchive=3.4.2=h62408e4_0
199 | libcurl=7.71.1=h20c2e04_1
200 | libedit=3.1.20191231=h14c3975_1
201 | libffi=3.3=he6710b0_2
202 | libgcc-ng=9.1.0=hdf63c60_0
203 | libgfortran-ng=7.3.0=hdf63c60_0
204 | liblief=0.10.1=he6710b0_0
205 | libllvm10=10.0.1=hbcb73fb_5
206 | libllvm9=9.0.1=h4a3c616_1
207 | libpng=1.6.37=hbc83047_0
208 | libsodium=1.0.18=h7b6447c_0
209 | libspatialindex=1.9.3=he6710b0_0
210 | libssh2=1.9.0=h1ba5d50_1
211 | libstdcxx-ng=9.1.0=hdf63c60_0
212 | libtiff=4.1.0=h2733197_1
213 | libtool=2.4.6=h7b6447c_1005
214 | libuuid=1.0.3=h1bed415_2
215 | libuv=1.40.0=h7b6447c_0
216 | libwebp=1.0.1=h8e7db2f_0
217 | libxcb=1.14=h7b6447c_0
218 | libxml2=2.9.10=hb55368b_3
219 | libxslt=1.1.34=hc22bd24_0
220 | libzopfli=1.0.3=he6710b0_0
221 | liwc=0.5.0=pypi_0
222 | liwc-analysis=1.2.4=pypi_0
223 | llvmlite=0.34.0=py37h269e1b5_4
224 | locket=0.2.0=py37_1
225 | lxml=4.6.2=py37h9120a33_0
226 | lz4-c=1.9.2=heb0550a_3
227 | lzo=2.10=h7b6447c_2
228 | markdown=3.3.6=pypi_0
229 | markupsafe=1.1.1=py37h14c3975_1
230 | matplotlib=3.3.2=0
231 | matplotlib-base=3.3.2=py37h817c723_0
232 | mccabe=0.6.1=py37_1
233 | mistune=0.8.4=py37h14c3975_1001
234 | mkl=2020.2=256
235 | mkl-service=2.3.0=py37he8ac12f_0
236 | mkl_fft=1.2.0=py37h23d657b_0
237 | mkl_random=1.1.1=py37h0573a6f_0
238 | mock=4.0.3=pyhd3eb1b0_0
239 | more-itertools=8.6.0=pyhd3eb1b0_0
240 | mpc=1.1.0=h10f8cd9_1
241 | mpfr=4.0.2=hb69a4c5_1
242 | mpmath=1.1.0=py37_0
243 | msgpack-numpy=0.4.6.1=pypi_0
244 | msgpack-python=1.0.0=py37hfd86e86_1
245 | multidict=5.2.0=pypi_0
246 | multipledispatch=0.6.0=py37_0
247 | multiprocess=0.70.12.2=pypi_0
248 | murmurhash=1.0.2=pypi_0
249 | mwparserfromhell=0.5.4=pypi_0
250 | navigator-updater=0.2.1=py37_0
251 | nbclient=0.5.1=py_0
252 | nbconvert=6.0.7=py37_0
253 | nbformat=5.0.7=py_0
254 | ncurses=6.2=he6710b0_1
255 | nest-asyncio=1.4.3=pyhd3eb1b0_0
256 | networkx=2.5=py_0
257 | ninja=1.10.2=py37hff7bd54_0
258 | nltk=3.5=py_0
259 | nose=1.3.7=pyhd3eb1b0_1006
260 | notebook=6.1.4=py37_0
261 | numba=0.51.2=py37h04863e7_1
262 | numexpr=2.7.1=py37h63df603_0
263 | numpy=1.19.2=py37h54aff64_0
264 | numpy-base=1.19.2=py37hfa32c7d_0
265 | numpydoc=1.1.0=pyhd3eb1b0_1
266 | nvidia-ml-py3=7.352.0=pypi_0
267 | oauthlib=3.2.2=pypi_0
268 | olefile=0.46=py37_0
269 | openjpeg=2.3.0=h05c96fa_1
270 | openpyxl=3.0.5=py_0
271 | openssl=1.1.1k=h27cfd23_0
272 | packaging=21.3=pypi_0
273 | pandas=1.1.3=py37he6710b0_0
274 | pandoc=2.11=hb0f4dca_0
275 | pandocfilters=1.4.3=py37h06a4308_1
276 | pango=1.45.3=hd140c19_0
277 | parso=0.5.2=py_0
278 | partd=1.1.0=py_0
279 | patchelf=0.12=h2531618_1
280 | path=15.0.1=py37h06a4308_0
281 | path.py=12.5.0=0
282 | pathlib2=2.3.5=py37h06a4308_2
283 | pathtools=0.1.2=py_1
284 | patsy=0.5.1=py37_0
285 | pcre=8.44=he6710b0_0
286 | pep8=1.7.1=py37_0
287 | pexpect=4.8.0=pyhd3eb1b0_3
288 | pickleshare=0.7.5=pyhd3eb1b0_1003
289 | pillow=8.0.1=py37he98fc37_0
290 | pip=20.3.1=py37h06a4308_0
291 | pip-autoremove=0.9.1=pypi_0
292 | pixman=0.40.0=h7b6447c_0
293 | pkginfo=1.6.1=py37h06a4308_0
294 | plac=1.1.3=pypi_0
295 | pluggy=0.13.1=py37_0
296 | ply=3.11=py37_0
297 | powerlaw=1.4.6=pypi_0
298 | preshed=3.0.2=pypi_0
299 | prometheus_client=0.9.0=pyhd3eb1b0_0
300 | promise=2.3=pypi_0
301 | prompt-toolkit=3.0.8=py_0
302 | prompt_toolkit=3.0.8=0
303 | protobuf=3.19.1=pypi_0
304 | psutil=5.7.2=py37h7b6447c_0
305 | ptyprocess=0.6.0=pyhd3eb1b0_2
306 | py=1.9.0=py_0
307 | py-lief=0.10.1=py37h403a769_0
308 | pyarrow=6.0.1=pypi_0
309 | pyasn1=0.4.8=pypi_0
310 | pyasn1-modules=0.2.8=pypi_0
311 | pycocoevalcap=1.2=pypi_0
312 | pycocotools=2.0.4=pypi_0
313 | pycodestyle=2.6.0=py_0
314 | pycosat=0.6.3=py37h27cfd23_0
315 | pycparser=2.20=py_2
316 | pycrypto=2.6.1=py37h7b6447c_10
317 | pycurl=7.43.0.6=py37h1ba5d50_0
318 | pydeck=0.7.1=pypi_0
319 | pydocstyle=5.1.1=py_0
320 | pyerfa=1.7.1.1=py37h27cfd23_1
321 | pyflakes=2.2.0=py_0
322 | pygments=2.7.3=pyhd3eb1b0_0
323 | pylint=2.6.0=py37_0
324 | pympler=0.9=pypi_0
325 | pyodbc=4.0.30=py37he6710b0_0
326 | pyopenssl=20.0.0=pyhd3eb1b0_1
327 | pyparsing=2.4.7=py_0
328 | pyphen=0.10.0=pypi_0
329 | pyqt=5.9.2=py37h05f1152_2
330 | pyrsistent=0.17.3=py37h7b6447c_0
331 | pyshorteners=1.0.1=pypi_0
332 | pysocks=1.7.1=py37_1
333 | pytables=3.6.1=py37h71ec239_0
334 | pytest=6.1.2=py37h06a4308_0
335 | python=3.7.6=h0371630_2
336 | python-dateutil=2.8.1=py_0
337 | python-jsonrpc-server=0.4.0=py_0
338 | python-language-server=0.31.7=py37_0
339 | python-levenshtein=0.12.0=pypi_0
340 | python-libarchive-c=2.9=py_0
341 | python-louvain=0.15=pypi_0
342 | python_abi=3.7=1_cp37m
343 | pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0
344 | pytorch-pretrained-bert=0.6.2=pypi_0
345 | pytz=2020.4=pyhd3eb1b0_0
346 | pytz-deprecation-shim=0.1.0.post0=pypi_0
347 | pywavelets=1.1.1=py37h7b6447c_2
348 | pyxdg=0.27=pyhd3eb1b0_0
349 | pyyaml=5.3.1=py37h7b6447c_1
350 | pyzmq=20.0.0=py37h2531618_1
351 | qdarkstyle=2.8.1=py_0
352 | qt=5.9.7=h5867ecd_1
353 | qtawesome=1.0.1=py_0
354 | qtconsole=4.7.7=py_0
355 | qtpy=1.9.0=py_0
356 | rdflib=5.0.0=pypi_0
357 | readline=7.0=h7b6447c_5
358 | regex=2017.4.5=pypi_0
359 | requests=2.28.2=pypi_0
360 | requests-oauthlib=1.3.0=pypi_0
361 | ripgrep=12.1.1=0
362 | rope=0.18.0=py_0
363 | rsa=4.8=pypi_0
364 | rtree=0.9.4=py37_1
365 | ruamel_yaml=0.15.87=py37h7b6447c_1
366 | s3transfer=0.3.3=pypi_0
367 | sacremoses=0.0.41=pypi_0
368 | scikit-image=0.17.2=py37hdf5156a_0
369 | scikit-learn=0.23.2=py37h0573a6f_0
370 | scipy=1.5.2=py37h0b6359f_0
371 | seaborn=0.11.0=py_0
372 | secretstorage=3.3.0=py37h06a4308_0
373 | send2trash=1.5.0=pyhd3eb1b0_1
374 | sentencepiece=0.1.85=pypi_0
375 | sentry-sdk=1.5.1=pypi_0
376 | seqeval=1.2.2=pypi_0
377 | setuptools=51.0.0=py37h06a4308_2
378 | shortuuid=1.0.8=pypi_0
379 | simplediff=1.0=pypi_0
380 | simplegeneric=0.8.1=py37_2
381 | simpletransformers=0.63.3=pypi_0
382 | singledispatch=3.4.0.3=py_1001
383 | sip=4.19.8=py37hf484d3e_0
384 | six=1.15.0=py37h06a4308_0
385 | sklearn=0.0=pypi_0
386 | smmap=5.0.0=pypi_0
387 | snappy=1.1.8=he6710b0_0
388 | snowballstemmer=2.0.0=py_0
389 | sortedcollections=1.2.1=py_0
390 | sortedcontainers=2.3.0=pyhd3eb1b0_0
391 | soupsieve=2.0.1=py_0
392 | spacy=2.3.2=pypi_0
393 | sphinx=3.2.1=py_0
394 | sphinxcontrib=1.0=py37_1
395 | sphinxcontrib-applehelp=1.0.2=py_0
396 | sphinxcontrib-devhelp=1.0.2=py_0
397 | sphinxcontrib-htmlhelp=1.0.3=py_0
398 | sphinxcontrib-jsmath=1.0.1=py_0
399 | sphinxcontrib-qthelp=1.0.3=py_0
400 | sphinxcontrib-serializinghtml=1.1.4=py_0
401 | sphinxcontrib-websupport=1.2.4=py_0
402 | spyder=4.0.1=py37_0
403 | spyder-kernels=1.8.1=py37_0
404 | sqlalchemy=1.3.20=py37he8ac12f_0
405 | sqlite=3.33.0=h62c20be_0
406 | srsly=1.0.2=pypi_0
407 | statsmodels=0.12.1=py37h27cfd23_0
408 | streamlit=1.2.0=pypi_0
409 | subprocess32=3.5.4=pypi_0
410 | sympy=1.6.2=py37h06a4308_1
411 | tbb=2020.3=hfd86e86_0
412 | tblib=1.7.0=py_0
413 | tensorboard=2.7.0=pypi_0
414 | tensorboard-data-server=0.6.1=pypi_0
415 | tensorboard-plugin-wit=1.8.0=pypi_0
416 | tensorboardx=2.0=pypi_0
417 | termcolor=1.1.0=pypi_0
418 | terminado=0.9.1=py37_0
419 | testpath=0.4.4=py_0
420 | textstat=0.7.0=pypi_0
421 | thinc=7.4.1=pypi_0
422 | threadpoolctl=2.1.0=pyh5ca1d4c_0
423 | tifffile=2020.12.4=pyhd3eb1b0_0
424 | tk=8.6.10=hbc83047_0
425 | tokenizers=0.5.2=pypi_0
426 | toml=0.10.1=py_0
427 | toolz=0.11.1=py_0
428 | toposort=1.5=pypi_0
429 | torch=1.4.0=pypi_0
430 | torch-geometric=1.7.0=pypi_0
431 | torch-scatter=2.0.6=pypi_0
432 | torch-sparse=0.6.9=pypi_0
433 | torchaudio=0.7.2=py37
434 | torchvision=0.5.0=pypi_0
435 | tornado=6.1=py37h27cfd23_0
436 | tqdm=4.62.3=pypi_0
437 | traitlets=5.0.5=py_0
438 | transformers=2.7.0=pypi_0
439 | twarc=2.8.1=pypi_0
440 | tweepy=4.12.1=pypi_0
441 | twython=3.8.2=pypi_0
442 | typed-ast=1.4.1=py37h7b6447c_0
443 | typing_extensions=3.7.4.3=py_0
444 | tzdata=2021.5=pypi_0
445 | tzlocal=4.1=pypi_0
446 | ujson=4.0.1=py37he6710b0_0
447 | unicodecsv=0.14.1=py37_0
448 | unidecode=1.1.1=pypi_0
449 | unixodbc=2.3.9=h7b6447c_0
450 | unshortenit=0.4.0=pypi_0
451 | uritools=3.0.0=pypi_0
452 | urlextract=1.0.0=pypi_0
453 | urllib3=1.24.3=pypi_0
454 | vadersentiment=3.3.2=pypi_0
455 | validators=0.18.2=pypi_0
456 | wandb=0.12.7=pypi_0
457 | wasabi=0.7.1=pypi_0
458 | watchdog=0.10.4=py37h06a4308_0
459 | wcwidth=0.2.5=py_0
460 | webencodings=0.5.1=py37_1
461 | werkzeug=1.0.1=py_0
462 | wheel=0.36.1=pyhd3eb1b0_0
463 | widgetsnbextension=3.5.1=py37_0
464 | wikipedia-api=0.5.4=pypi_0
465 | wrapt=1.11.2=py37h7b6447c_0
466 | wurlitzer=2.0.1=py37_0
467 | xlrd=1.2.0=py37_0
468 | xlsxwriter=1.3.7=py_0
469 | xlwt=1.3.0=py37_0
470 | xmltodict=0.12.0=py_0
471 | xxhash=2.0.2=pypi_0
472 | xz=5.2.5=h7b6447c_0
473 | yaml=0.2.5=h7b6447c_0
474 | yapf=0.30.0=py_0
475 | yarl=1.7.2=pypi_0
476 | yaspin=2.1.0=pypi_0
477 | zeromq=4.3.3=he6710b0_3
478 | zict=2.0.0=py_0
479 | zipp=3.4.0=pyhd3eb1b0_0
480 | zlib=1.2.11=h7b6447c_3
481 | zope=1.0=py37_1
482 | zope.event=4.5.0=py37_0
483 | zope.interface=5.2.0=py37h27cfd23_0
484 | zstd=1.4.5=h9ceee32_0
485 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | aiohttp==3.8.1
3 | aiosignal==1.2.0
4 | alabaster==0.7.12
5 | altair==4.1.0
6 | anaconda-client==1.7.2
7 | anaconda-navigator==1.9.12
8 | anaconda-project==0.8.3
9 | appdirs==1.4.4
10 | argh==0.26.2
11 | argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1596828452693/work
12 | ase==3.21.1
13 | asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work
14 | astor==0.8.1
15 | astroid @ file:///tmp/build/80754af9/astroid_1592495881661/work
16 | astropy @ file:///tmp/build/80754af9/astropy_1606922938720/work
17 | async-generator==1.10
18 | async-timeout==4.0.1
19 | asynctest==0.13.0
20 | atomicwrites==1.4.0
21 | attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work
22 | autocorrect==1.0.0
23 | autopep8 @ file:///tmp/build/80754af9/autopep8_1596578164842/work
24 | Babel @ file:///tmp/build/80754af9/babel_1607110387436/work
25 | backcall==0.2.0
26 | backports.functools-lru-cache @ file:///tmp/build/80754af9/backports.functools_lru_cache_1605305165209/work
27 | backports.shutil-get-terminal-size==1.0.0
28 | backports.tempfile==1.0
29 | backports.weakref==1.0.post1
30 | backports.zoneinfo==0.2.1
31 | base58==2.1.1
32 | beautifulsoup4 @ file:///tmp/build/80754af9/beautifulsoup4_1601924105527/work
33 | bert-serving==0.0.1
34 | bert-serving-client==1.10.0
35 | bert-serving-server==1.10.0
36 | bitarray @ file:///tmp/build/80754af9/bitarray_1605065136653/work
37 | bkcharts==0.2
38 | bleach @ file:///tmp/build/80754af9/bleach_1600439572647/work
39 | blessings==1.7
40 | blinker==1.4
41 | blis==0.4.1
42 | bokeh @ file:///tmp/build/80754af9/bokeh_1603297847301/work
43 | boto==2.49.0
44 | boto3==1.12.34
45 | botocore==1.15.34
46 | Bottleneck==1.3.2
47 | brotlipy==0.7.0
48 | cachetools==4.2.4
49 | catalogue==1.0.0
50 | certifi==2020.12.5
51 | cffi @ file:///tmp/build/80754af9/cffi_1606255099073/work
52 | chardet @ file:///tmp/build/80754af9/chardet_1605303159953/work
53 | charset-normalizer==2.0.9
54 | clean-text==0.2.1
55 | click==7.1.2
56 | click-config-file==0.6.0
57 | click-plugins==1.1.1
58 | cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work
59 | clyent==1.2.2
60 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work
61 | conda==4.9.2
62 | conda-build==3.18.11
63 | conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1603018138503/work
64 | conda-verify==3.4.2
65 | configobj==5.0.6
66 | configparser==5.2.0
67 | contextlib2==0.6.0.post1
68 | convokit==2.3.2.5
69 | cryptography @ file:///tmp/build/80754af9/cryptography_1607635305226/work
70 | cycler==0.10.0
71 | cymem==2.0.3
72 | Cython @ file:///tmp/build/80754af9/cython_1605457613176/work
73 | cytoolz==0.11.0
74 | dask @ file:///tmp/build/80754af9/dask-core_1607706933335/work
75 | datasets==1.16.1
76 | decorator==4.4.2
77 | defusedxml==0.6.0
78 | diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work
79 | dill==0.3.4
80 | distributed @ file:///tmp/build/80754af9/distributed_1607714018210/work
81 | docker-pycreds==0.4.0
82 | docutils==0.15.2
83 | emoji==1.2.0
84 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz
85 | entrypoints==0.3
86 | et-xmlfile==1.0.1
87 | fastai==2.6.3
88 | fastcache==1.1.0
89 | fastcore==1.4.2
90 | fastdownload==0.0.5
91 | fastprogress==0.2.4
92 | filelock==3.0.12
93 | fire==0.4.0
94 | flake8 @ file:///tmp/build/80754af9/flake8_1601911421857/work
95 | Flask==1.1.2
96 | frozenlist==1.2.0
97 | fsspec==2021.11.1
98 | ftfy==5.8
99 | future==0.18.2
100 | gevent @ file:///tmp/build/80754af9/gevent_1601397565838/work
101 | gitdb==4.0.9
102 | GitPython==3.1.24
103 | glob2==0.7
104 | gmpy2==2.0.8
105 | google-auth==2.3.3
106 | google-auth-oauthlib==0.4.6
107 | googledrivedownloader==0.4
108 | gpustat==0.6.0
109 | GPUtil==1.4.0
110 | greenlet @ file:///tmp/build/80754af9/greenlet_1600873995270/work
111 | grpcio==1.42.0
112 | h5py==2.10.0
113 | HeapDict==1.0.1
114 | html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work
115 | huggingface-hub==0.2.1
116 | humanize==3.12.0
117 | idna==2.8
118 | imagecodecs @ file:///tmp/build/80754af9/imagecodecs_1603270454756/work
119 | imageio @ file:///tmp/build/80754af9/imageio_1594161405741/work
120 | imagesize==1.2.0
121 | importlib-metadata==4.8.2
122 | iniconfig @ file:///tmp/build/80754af9/iniconfig_1602780191262/work
123 | intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work
124 | ipykernel @ file:///tmp/build/80754af9/ipykernel_1596206598566/work/dist/ipykernel-5.3.4-py3-none-any.whl
125 | ipython @ file:///tmp/build/80754af9/ipython_1604101195213/work
126 | ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
127 | ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1601490159889/work
128 | isodate==0.6.0
129 | isort @ file:///tmp/build/80754af9/isort_1602603989581/work
130 | itsdangerous==1.1.0
131 | jdcal==1.4.1
132 | jedi==0.14.1
133 | jeepney @ file:///tmp/build/80754af9/jeepney_1606148855031/work
134 | Jinja2==2.11.2
135 | jmespath==0.9.5
136 | joblib @ file:///tmp/build/80754af9/joblib_1601912903842/work
137 | json5==0.9.5
138 | jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
139 | jupyter==1.0.0
140 | jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work
141 | jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1598884538475/work
142 | jupyter-contrib-core==0.3.3
143 | jupyter-contrib-nbextensions @ file:///home/conda/feedstock_root/build_artifacts/jupyter_contrib_nbextensions_1614931162960/work
144 | jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1606148959479/work
145 | jupyter-highlight-selected-word @ file:///home/conda/feedstock_root/build_artifacts/jupyter_highlight_selected_word_1611341004115/work
146 | jupyter-latex-envs @ file:///home/conda/feedstock_root/build_artifacts/jupyter_latex_envs_1614852190293/work
147 | jupyter-nbextensions-configurator @ file:///home/conda/feedstock_root/build_artifacts/jupyter_nbextensions_configurator_1611341112910/work
148 | jupyterlab==2.2.6
149 | jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
150 | jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1594164409481/work
151 | keyring @ file:///tmp/build/80754af9/keyring_1601490840626/work
152 | kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1604014532738/work
153 | lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1607707315973/work
154 | libarchive-c==2.9
155 | liwc==0.5.0
156 | liwc-analysis==1.2.4
157 | llvmlite==0.34.0
158 | locket==0.2.0
159 | lxml @ file:///tmp/build/80754af9/lxml_1606516847630/work
160 | Markdown==3.3.6
161 | MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1594371495811/work
162 | matplotlib @ file:///tmp/build/80754af9/matplotlib-base_1603376012865/work
163 | mccabe==0.6.1
164 | mistune @ file:///tmp/build/80754af9/mistune_1594373098390/work
165 | mkl-fft==1.2.0
166 | mkl-random==1.1.1
167 | mkl-service==2.3.0
168 | mock @ file:///tmp/build/80754af9/mock_1607622725907/work
169 | more-itertools @ file:///tmp/build/80754af9/more-itertools_1605111547926/work
170 | mpmath==1.1.0
171 | msgpack==1.0.0
172 | msgpack-numpy==0.4.6.1
173 | multidict==5.2.0
174 | multipledispatch==0.6.0
175 | multiprocess==0.70.12.2
176 | murmurhash==1.0.2
177 | mwparserfromhell==0.5.4
178 | navigator-updater==0.2.1
179 | nbclient @ file:///tmp/build/80754af9/nbclient_1602783176460/work
180 | nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914821128/work
181 | nbformat==5.0.7
182 | nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1606153767164/work
183 | networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work
184 | nltk @ file:///tmp/build/80754af9/nltk_1592496090529/work
185 | nose @ file:///tmp/build/80754af9/nose_1606773131901/work
186 | notebook @ file:///tmp/build/80754af9/notebook_1601501580008/work
187 | numba @ file:///tmp/build/80754af9/numba_1600102479638/work
188 | numexpr @ file:///tmp/build/80754af9/numexpr_1607693749420/work
189 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1603479632437/work
190 | numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work
191 | nvidia-ml-py3==7.352.0
192 | oauthlib==3.2.2
193 | olefile==0.46
194 | openpyxl @ file:///tmp/build/80754af9/openpyxl_1598113097404/work
195 | packaging==21.3
196 | pandas @ file:///tmp/build/80754af9/pandas_1602088128026/work
197 | pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120451932/work
198 | parso==0.5.2
199 | partd==1.1.0
200 | path @ file:///tmp/build/80754af9/path_1607537225003/work
201 | pathlib2 @ file:///tmp/build/80754af9/pathlib2_1607024979554/work
202 | pathtools==0.1.2
203 | patsy==0.5.1
204 | pep8==1.7.1
205 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
206 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
207 | Pillow @ file:///tmp/build/80754af9/pillow_1603822253009/work
208 | pip-autoremove==0.9.1
209 | pkginfo==1.6.1
210 | plac==1.1.3
211 | pluggy==0.13.1
212 | ply==3.11
213 | powerlaw==1.4.6
214 | preshed==3.0.2
215 | prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1606344362066/work
216 | promise==2.3
217 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work
218 | protobuf==3.19.1
219 | psutil @ file:///tmp/build/80754af9/psutil_1598370249042/work
220 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1605560620615/work/dist/ptyprocess-0.6.0-py2.py3-none-any.whl
221 | py @ file:///tmp/build/80754af9/py_1593446248552/work
222 | pyarrow==6.0.1
223 | pyasn1==0.4.8
224 | pyasn1-modules==0.2.8
225 | pycocoevalcap==1.2
226 | pycocotools==2.0.4
227 | pycodestyle==2.6.0
228 | pycosat==0.6.3
229 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
230 | pycrypto==2.6.1
231 | pycurl==7.43.0.6
232 | pydeck==0.7.1
233 | pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1598885001695/work
234 | pyerfa @ file:///tmp/build/80754af9/pyerfa_1606860189168/work
235 | pyflakes==2.2.0
236 | Pygments @ file:///tmp/build/80754af9/pygments_1607368905949/work
237 | pylint @ file:///tmp/build/80754af9/pylint_1598624038450/work
238 | Pympler==0.9
239 | pyodbc===4.0.0-unsupported
240 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1606517880428/work
241 | pyparsing==2.4.7
242 | Pyphen==0.10.0
243 | pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141707582/work
244 | pyshorteners==1.0.1
245 | PySocks @ file:///tmp/build/80754af9/pysocks_1594394576006/work
246 | pytest==0.0.0
247 | python-dateutil==2.8.1
248 | python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work
249 | python-language-server==0.31.7
250 | python-Levenshtein==0.12.0
251 | python-louvain==0.15
252 | pytorch-pretrained-bert==0.6.2
253 | pytz @ file:///tmp/build/80754af9/pytz_1606604771399/work
254 | pytz-deprecation-shim==0.1.0.post0
255 | PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658308664/work
256 | pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work
257 | PyYAML==5.3.1
258 | pyzmq==20.0.0
259 | QDarkStyle==2.8.1
260 | QtAwesome @ file:///tmp/build/80754af9/qtawesome_1602272867890/work
261 | qtconsole @ file:///tmp/build/80754af9/qtconsole_1600870028330/work
262 | QtPy==1.9.0
263 | rdflib==5.0.0
264 | regex==2017.4.5
265 | requests==2.28.2
266 | requests-oauthlib==1.3.0
267 | rope @ file:///tmp/build/80754af9/rope_1602264064449/work
268 | rsa==4.8
269 | Rtree==0.9.4
270 | ruamel-yaml==0.15.87
271 | s3transfer==0.3.3
272 | sacremoses==0.0.41
273 | scikit-image==0.17.2
274 | scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1598376882706/work
275 | scipy @ file:///tmp/build/80754af9/scipy_1597686620742/work
276 | seaborn @ file:///tmp/build/80754af9/seaborn_1600553570093/work
277 | SecretStorage @ file:///tmp/build/80754af9/secretstorage_1606864755624/work
278 | Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
279 | sentencepiece==0.1.85
280 | sentry-sdk==1.5.1
281 | seqeval==1.2.2
282 | shortuuid==1.0.8
283 | simplediff==1.0
284 | simplegeneric==0.8.1
285 | simpletransformers==0.63.3
286 | singledispatch @ file:///tmp/build/80754af9/singledispatch_1602523705405/work
287 | six @ file:///tmp/build/80754af9/six_1605205313296/work
288 | sklearn==0.0
289 | smmap==5.0.0
290 | snowballstemmer==2.0.0
291 | sortedcollections==1.2.1
292 | sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1606865132123/work
293 | soupsieve==2.0.1
294 | spacy==2.3.2
295 | Sphinx @ file:///tmp/build/80754af9/sphinx_1597428793432/work
296 | sphinxcontrib-applehelp==1.0.2
297 | sphinxcontrib-devhelp==1.0.2
298 | sphinxcontrib-htmlhelp==1.0.3
299 | sphinxcontrib-jsmath==1.0.1
300 | sphinxcontrib-qthelp==1.0.3
301 | sphinxcontrib-serializinghtml==1.1.4
302 | sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work
303 | spyder==4.0.1
304 | spyder-kernels==1.8.1
305 | SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1607563700310/work
306 | srsly==1.0.2
307 | statsmodels @ file:///tmp/build/80754af9/statsmodels_1606925351355/work
308 | streamlit==1.2.0
309 | subprocess32==3.5.4
310 | sympy @ file:///tmp/build/80754af9/sympy_1605119531870/work
311 | tables==3.6.1
312 | tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work
313 | tensorboard==2.7.0
314 | tensorboard-data-server==0.6.1
315 | tensorboard-plugin-wit==1.8.0
316 | tensorboardX==2.0
317 | termcolor==1.1.0
318 | terminado==0.9.1
319 | testpath==0.4.4
320 | textstat==0.7.0
321 | thinc==7.4.1
322 | threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl
323 | tifffile @ file:///tmp/build/80754af9/tifffile_1607369857969/work
324 | tokenizers==0.5.2
325 | toml @ file:///tmp/build/80754af9/toml_1592853716807/work
326 | toolz @ file:///tmp/build/80754af9/toolz_1601054250827/work
327 | toposort==1.5
328 | torch==1.7.1
329 | torch-geometric==1.7.0
330 | torch-scatter==2.0.6
331 | torch-sparse==0.6.9
332 | torchaudio==0.7.0a0+a853dff
333 | torchvision==0.8.2
334 | tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work
335 | tqdm==4.62.3
336 | traitlets @ file:///tmp/build/80754af9/traitlets_1602787416690/work
337 | transformers==2.7.0
338 | twarc==2.8.1
339 | tweepy==4.12.1
340 | twython==3.8.2
341 | typed-ast==1.4.1
342 | typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1598376058250/work
343 | tzdata==2021.5
344 | tzlocal==4.1
345 | ujson @ file:///tmp/build/80754af9/ujson_1602523313803/work
346 | unicodecsv==0.14.1
347 | Unidecode==1.1.1
348 | unshortenit==0.4.0
349 | uritools==3.0.0
350 | urlextract==1.0.0
351 | urllib3==1.24.3
352 | vaderSentiment==3.3.2
353 | validators==0.18.2
354 | wandb==0.12.7
355 | wasabi==0.7.1
356 | watchdog @ file:///tmp/build/80754af9/watchdog_1606939061947/work
357 | wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
358 | webencodings==0.5.1
359 | Werkzeug==1.0.1
360 | widgetsnbextension==3.5.1
361 | Wikipedia-API==0.5.4
362 | wrapt==1.11.2
363 | wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1594751868473/work
364 | xlrd==1.2.0
365 | XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1602692860603/work
366 | xlwt==1.3.0
367 | xmltodict==0.12.0
368 | xxhash==2.0.2
369 | yapf @ file:///tmp/build/80754af9/yapf_1593528177422/work
370 | yarl==1.7.2
371 | yaspin==2.1.0
372 | zict==2.0.0
373 | zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work
374 | zope.event==4.5.0
375 | zope.interface @ file:///tmp/build/80754af9/zope.interface_1606940237377/work
376 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/optim.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch optimization """
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.nn.utils import clip_grad_norm_
21 |
22 |
23 | def warmup_cosine(x, warmup=0.002):
24 | if x < warmup:
25 | return x/warmup
26 | return 0.5 * (1.0 + torch.cos(math.pi * x))
27 |
28 |
29 | def warmup_constant(x, warmup=0.002):
30 | if x < warmup:
31 | return x/warmup
32 | return 1.0
33 |
34 |
35 | def warmup_linear(x, warmup=0.002):
36 | if x < warmup:
37 | return x/warmup
38 | return (1.0 - x)/(1.0 - warmup)
39 |
40 |
41 | def noam_decay(step, warmup_steps, model_size):
42 | """Learning rate schedule described in
43 | https://arxiv.org/pdf/1706.03762.pdf.
44 | """
45 | return (
46 | model_size ** (-0.5) *
47 | min(step ** (-0.5), step * warmup_steps**(-1.5)))
48 |
49 |
50 | def noamwd_decay(step, warmup_steps,
51 | model_size, rate=0.5, decay_steps=1000, start_step=500):
52 | """Learning rate schedule optimized for huge batches
53 | """
54 | return (
55 | model_size ** (-0.5) *
56 | min(step ** (-0.5), step * warmup_steps**(-1.5)) *
57 | rate ** (max(step - start_step + decay_steps, 0) // decay_steps))
58 |
59 |
60 | def exponential_decay(step, rate, decay_steps, start_step=0):
61 | """A standard exponential decay, scaling the learning rate by :obj:`rate`
62 | every :obj:`decay_steps` steps.
63 | """
64 | return rate ** (max(step - start_step + decay_steps, 0) // decay_steps)
65 |
66 |
67 | def rsqrt_decay(step, warmup_steps):
68 | """Decay based on the reciprocal of the step square root."""
69 | return 1.0 / math.sqrt(max(step, warmup_steps))
70 |
71 |
72 | SCHEDULES = {
73 | 'warmup_cosine': warmup_cosine,
74 | 'warmup_constant': warmup_constant,
75 | 'warmup_linear': warmup_linear,
76 | }
77 |
78 |
79 | class Adam(Optimizer):
80 | """Implements BERT version of Adam algorithm with weight decay fix (and no ).
81 | Params:
82 | lr: learning rate
83 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
84 | t_total: total number of training steps for the learning
85 | rate schedule, -1 means constant learning rate. Default: -1
86 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
87 | b1: Adams b1. Default: 0.9
88 | b2: Adams b2. Default: 0.999
89 | e: Adams epsilon. Default: 1e-6
90 | weight_decay_rate: Weight decay. Default: 0.01
91 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
92 | """
93 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
94 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,
95 | max_grad_norm=1.0):
96 | if not lr >= 0.0:
97 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
98 | if schedule not in SCHEDULES:
99 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
100 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
101 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
102 | if not 0.0 <= b1 < 1.0:
103 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
104 | if not 0.0 <= b2 < 1.0:
105 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
106 | if not e >= 0.0:
107 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
108 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
109 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
110 | max_grad_norm=max_grad_norm)
111 | super(Adam, self).__init__(params, defaults)
112 |
113 | def get_lr(self):
114 | lr = []
115 | for group in self.param_groups:
116 | for p in group['params']:
117 | state = self.state[p]
118 | if len(state) == 0:
119 | return [0]
120 | if group['t_total'] != -1:
121 | schedule_fct = SCHEDULES[group['schedule']]
122 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
123 | else:
124 | lr_scheduled = group['lr']
125 | lr.append(lr_scheduled)
126 | return lr
127 |
128 | def to(self, device):
129 | """ Move the optimizer state to a specified device"""
130 | for state in self.state.values():
131 | state['exp_avg'].to(device)
132 | state['exp_avg_sq'].to(device)
133 |
134 | def initialize_step(self, initial_step):
135 | """Initialize state with a defined step (but we don't have stored averaged).
136 | Arguments:
137 | initial_step (int): Initial step number.
138 | """
139 | for group in self.param_groups:
140 | for p in group['params']:
141 | state = self.state[p]
142 | # State initialization
143 | state['step'] = initial_step
144 | # Exponential moving average of gradient values
145 | state['exp_avg'] = torch.zeros_like(p.data)
146 | # Exponential moving average of squared gradient values
147 | state['exp_avg_sq'] = torch.zeros_like(p.data)
148 |
149 | def step(self, closure=None):
150 | """Performs a single optimization step.
151 |
152 | Arguments:
153 | closure (callable, optional): A closure that reevaluates the model
154 | and returns the loss.
155 | """
156 | loss = None
157 | if closure is not None:
158 | loss = closure()
159 |
160 | for group in self.param_groups:
161 | for p in group['params']:
162 | if p.grad is None:
163 | continue
164 | grad = p.grad.data
165 | if grad.is_sparse:
166 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
167 |
168 | state = self.state[p]
169 |
170 | # State initialization
171 | if len(state) == 0:
172 | state['step'] = 0
173 | # Exponential moving average of gradient values
174 | state['next_m'] = torch.zeros_like(p.data)
175 | # Exponential moving average of squared gradient values
176 | state['next_v'] = torch.zeros_like(p.data)
177 |
178 | next_m, next_v = state['next_m'], state['next_v']
179 | beta1, beta2 = group['b1'], group['b2']
180 |
181 | # Add grad clipping
182 | if group['max_grad_norm'] > 0:
183 | clip_grad_norm_(p, group['max_grad_norm'])
184 |
185 | # Decay the first and second moment running average coefficient
186 | # In-place operations to update the averages at the same time
187 | next_m.mul_(beta1).add_(1 - beta1, grad)
188 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
189 | update = next_m / (next_v.sqrt() + group['e'])
190 |
191 | # Just adding the square of the weights to the loss function is *not*
192 | # the correct way of using L2 regularization/weight decay with Adam,
193 | # since that will interact with the m and v parameters in strange ways.
194 | #
195 | # Instead we want ot decay the weights in a manner that doesn't interact
196 | # with the m/v parameters. This is equivalent to adding the square
197 | # of the weights to the loss with plain (non-momentum) SGD.
198 | if group['weight_decay_rate'] > 0.0:
199 | update += group['weight_decay_rate'] * p.data
200 |
201 | if group['t_total'] != -1:
202 | schedule_fct = SCHEDULES[group['schedule']]
203 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
204 | else:
205 | lr_scheduled = group['lr']
206 |
207 | update_with_lr = lr_scheduled * update
208 | p.data.add_(-update_with_lr)
209 |
210 | state['step'] += 1
211 |
212 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
213 | # bias_correction1 = 1 - beta1 ** state['step']
214 | # bias_correction2 = 1 - beta2 ** state['step']
215 |
216 | return loss
217 |
218 |
219 | class Adamax(Optimizer):
220 | """Implements BERT version of Adam algorithm with weight decay fix (and no ).
221 | Params:
222 | lr: learning rate
223 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
224 | t_total: total number of training steps for the learning
225 | rate schedule, -1 means constant learning rate. Default: -1
226 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
227 | b1: Adams b1. Default: 0.9
228 | b2: Adams b2. Default: 0.999
229 | e: Adams epsilon. Default: 1e-6
230 | weight_decay_rate: Weight decay. Default: 0.01
231 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
232 | """
233 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
234 | betas=(0.9, 0.999), eps=1e-6, weight_decay_rate=0.01,
235 | max_grad_norm=1.0):
236 | if not lr >= 0.0:
237 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
238 | if schedule not in SCHEDULES:
239 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
240 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
241 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
242 | if not 0.0 <= eps:
243 | raise ValueError("Invalid epsilon value: {}".format(eps))
244 | if not 0.0 <= betas[0] < 1.0:
245 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
246 | if not 0.0 <= betas[1] < 1.0:
247 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
248 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
249 | betas=betas, eps=eps, weight_decay_rate=weight_decay_rate,
250 | max_grad_norm=max_grad_norm)
251 | super(Adamax, self).__init__(params, defaults)
252 |
253 | def get_lr(self):
254 | lr = []
255 | for group in self.param_groups:
256 | for p in group['params']:
257 | state = self.state[p]
258 | if len(state) == 0:
259 | return [0]
260 | if group['t_total'] != -1:
261 | schedule_fct = SCHEDULES[group['schedule']]
262 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
263 | else:
264 | lr_scheduled = group['lr']
265 | lr.append(lr_scheduled)
266 | return lr
267 |
268 | def to(self, device):
269 | """ Move the optimizer state to a specified device"""
270 | for state in self.state.values():
271 | state['exp_avg'].to(device)
272 | state['exp_avg_sq'].to(device)
273 |
274 | def initialize_step(self, initial_step):
275 | """Initialize state with a defined step (but we don't have stored averaged).
276 | Arguments:
277 | initial_step (int): Initial step number.
278 | """
279 | for group in self.param_groups:
280 | for p in group['params']:
281 | state = self.state[p]
282 | # State initialization
283 | state['step'] = initial_step
284 | # Exponential moving average of gradient values
285 | state['exp_avg'] = torch.zeros_like(p.data)
286 | # Exponential moving average of squared gradient values
287 | state['exp_avg_sq'] = torch.zeros_like(p.data)
288 |
289 | def step(self, closure=None):
290 | """Performs a single optimization step.
291 |
292 | Arguments:
293 | closure (callable, optional): A closure that reevaluates the model
294 | and returns the loss.
295 | """
296 | loss = None
297 | if closure is not None:
298 | loss = closure()
299 |
300 | for group in self.param_groups:
301 | for p in group['params']:
302 | if p.grad is None:
303 | continue
304 | grad = p.grad.data
305 | if grad.is_sparse:
306 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
307 |
308 | state = self.state[p]
309 |
310 | # State initialization
311 | if len(state) == 0:
312 | state['step'] = 0
313 | # Exponential moving average of gradient values
314 | state['exp_avg'] = torch.zeros_like(p.data)
315 | state['exp_inf'] = torch.zeros_like(p.data)
316 |
317 | exp_avg, exp_inf = state['exp_avg'], state['exp_inf']
318 | beta1, beta2 = group['betas']
319 | eps = group['eps']
320 | # Add grad clipping
321 | if group['max_grad_norm'] > 0:
322 | clip_grad_norm_(p, group['max_grad_norm'])
323 |
324 | # Update biased first moment estimate.
325 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
326 | # Update the exponentially weighted infinity norm.
327 | norm_buf = torch.cat([
328 | exp_inf.mul_(beta2).unsqueeze(0),
329 | grad.abs().add_(eps).unsqueeze_(0)
330 | ], 0)
331 | torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()))
332 | update = exp_avg / (exp_inf + eps)
333 |
334 | if group['weight_decay_rate'] > 0.0:
335 | update += group['weight_decay_rate'] * p.data
336 |
337 | if group['t_total'] != -1:
338 | schedule_fct = SCHEDULES[group['schedule']]
339 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
340 | else:
341 | lr_scheduled = group['lr']
342 |
343 |
344 | # Just adding the square of the weights to the loss function is *not*
345 | # the correct way of using L2 regularization/weight decay with Adam,
346 | # since that will interact with the m and v parameters in strange ways.
347 | #
348 | # Instead we want ot decay the weights in a manner that doesn't interact
349 | # with the m/v parameters. This is equivalent to adding the square
350 | # of the weights to the loss with plain (non-momentum) SGD.
351 | if group['weight_decay_rate'] > 0.0:
352 | update += group['weight_decay_rate'] * p.data
353 |
354 | if group['t_total'] != -1:
355 | schedule_fct = SCHEDULES[group['schedule']]
356 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
357 | else:
358 | lr_scheduled = group['lr']
359 |
360 | update_with_lr = lr_scheduled * update
361 | p.data.add_(-update_with_lr)
362 |
363 | state['step'] += 1
364 |
365 | return loss
366 |
367 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/util/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 |
7 | import fnmatch
8 | import json
9 | import logging
10 | import os
11 | import shutil
12 | import sys
13 | import tarfile
14 | import tempfile
15 | from contextlib import contextmanager
16 | from functools import partial, wraps
17 | from hashlib import sha256
18 | from typing import Optional
19 | from urllib.parse import urlparse
20 | from zipfile import ZipFile, is_zipfile
21 |
22 | import requests
23 | from filelock import FileLock
24 | from tqdm.auto import tqdm
25 |
26 | # from . import __version__
27 | __version__ = "2.8.0"
28 |
29 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
30 |
31 | try:
32 | USE_TF = os.environ.get("USE_TF", "AUTO").upper()
33 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
34 | if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
35 | import torch
36 |
37 | _torch_available = True # pylint: disable=invalid-name
38 | logger.info("PyTorch version {} available.".format(torch.__version__))
39 | else:
40 | logger.info("Disabling PyTorch because USE_TF is set")
41 | _torch_available = False
42 | except ImportError:
43 | _torch_available = False # pylint: disable=invalid-name
44 |
45 | try:
46 | USE_TF = os.environ.get("USE_TF", "AUTO").upper()
47 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
48 |
49 | if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
50 | import tensorflow as tf
51 |
52 | assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
53 | _tf_available = True # pylint: disable=invalid-name
54 | logger.info("TensorFlow version {} available.".format(tf.__version__))
55 | else:
56 | logger.info("Disabling Tensorflow because USE_TORCH is set")
57 | _tf_available = False
58 | except (ImportError, AssertionError):
59 | _tf_available = False # pylint: disable=invalid-name
60 |
61 | try:
62 | from torch.hub import _get_torch_home
63 |
64 | torch_cache_home = _get_torch_home()
65 | except ImportError:
66 | torch_cache_home = os.path.expanduser(
67 | os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
68 | )
69 | default_cache_path = os.path.join(torch_cache_home, "transformers")
70 |
71 | try:
72 | from pathlib import Path
73 |
74 | PYTORCH_PRETRAINED_BERT_CACHE = Path(
75 | os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
76 | )
77 | except (AttributeError, ImportError):
78 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
79 | "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
80 | )
81 |
82 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
83 | TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
84 |
85 | WEIGHTS_NAME = "pytorch_model.bin"
86 | TF2_WEIGHTS_NAME = "tf_model.h5"
87 | TF_WEIGHTS_NAME = "model.ckpt"
88 | CONFIG_NAME = "config.json"
89 | MODEL_CARD_NAME = "modelcard.json"
90 |
91 |
92 | MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
93 | DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
94 | DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
95 |
96 | S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
97 | CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
98 |
99 |
100 | def is_torch_available():
101 | return _torch_available
102 |
103 |
104 | def is_tf_available():
105 | return _tf_available
106 |
107 |
108 | def add_start_docstrings(*docstr):
109 | def docstring_decorator(fn):
110 | fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
111 | return fn
112 |
113 | return docstring_decorator
114 |
115 |
116 | def add_start_docstrings_to_callable(*docstr):
117 | def docstring_decorator(fn):
118 | class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
119 | intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
120 | note = r"""
121 |
122 | .. note::
123 | Although the recipe for forward pass needs to be defined within
124 | this function, one should call the :class:`Module` instance afterwards
125 | instead of this since the former takes care of running the
126 | pre and post processing steps while the latter silently ignores them.
127 | """
128 | fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
129 | return fn
130 |
131 | return docstring_decorator
132 |
133 |
134 | def add_end_docstrings(*docstr):
135 | def docstring_decorator(fn):
136 | fn.__doc__ = fn.__doc__ + "".join(docstr)
137 | return fn
138 |
139 | return docstring_decorator
140 |
141 |
142 | def is_remote_url(url_or_filename):
143 | parsed = urlparse(url_or_filename)
144 | return parsed.scheme in ("http", "https")
145 |
146 |
147 | def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
148 | endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
149 | if postfix is None:
150 | return "/".join((endpoint, identifier))
151 | else:
152 | return "/".join((endpoint, identifier, postfix))
153 |
154 |
155 | def url_to_filename(url, etag=None):
156 | """
157 | Convert `url` into a hashed filename in a repeatable way.
158 | If `etag` is specified, append its hash to the url's, delimited
159 | by a period.
160 | If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
161 | so that TF 2.0 can identify it as a HDF5 file
162 | (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
163 | """
164 | url_bytes = url.encode("utf-8")
165 | url_hash = sha256(url_bytes)
166 | filename = url_hash.hexdigest()
167 |
168 | if etag:
169 | etag_bytes = etag.encode("utf-8")
170 | etag_hash = sha256(etag_bytes)
171 | filename += "." + etag_hash.hexdigest()
172 |
173 | if url.endswith(".h5"):
174 | filename += ".h5"
175 |
176 | return filename
177 |
178 |
179 | def filename_to_url(filename, cache_dir=None):
180 | """
181 | Return the url and etag (which may be ``None``) stored for `filename`.
182 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
183 | """
184 | if cache_dir is None:
185 | cache_dir = TRANSFORMERS_CACHE
186 | if isinstance(cache_dir, Path):
187 | cache_dir = str(cache_dir)
188 |
189 | cache_path = os.path.join(cache_dir, filename)
190 | if not os.path.exists(cache_path):
191 | raise EnvironmentError("file {} not found".format(cache_path))
192 |
193 | meta_path = cache_path + ".json"
194 | if not os.path.exists(meta_path):
195 | raise EnvironmentError("file {} not found".format(meta_path))
196 |
197 | with open(meta_path, encoding="utf-8") as meta_file:
198 | metadata = json.load(meta_file)
199 | url = metadata["url"]
200 | etag = metadata["etag"]
201 |
202 | return url, etag
203 |
204 |
205 | def cached_path(
206 | url_or_filename,
207 | cache_dir=None,
208 | force_download=False,
209 | proxies=None,
210 | resume_download=False,
211 | user_agent=None,
212 | extract_compressed_file=False,
213 | force_extract=False,
214 | local_files_only=False,
215 | ) -> Optional[str]:
216 | """
217 | Given something that might be a URL (or might be a local path),
218 | determine which. If it's a URL, download the file and cache it, and
219 | return the path to the cached file. If it's already a local path,
220 | make sure the file exists and then return the path.
221 | Args:
222 | cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
223 | force_download: if True, re-dowload the file even if it's already cached in the cache dir.
224 | resume_download: if True, resume the download if incompletly recieved file is found.
225 | user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
226 | extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
227 | file in a folder along the archive.
228 | force_extract: if True when extract_compressed_file is True and the archive was already extracted,
229 | re-extract the archive and overide the folder where it was extracted.
230 |
231 | Return:
232 | None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
233 | Local path (string) otherwise
234 | """
235 | if cache_dir is None:
236 | cache_dir = TRANSFORMERS_CACHE
237 | if isinstance(url_or_filename, Path):
238 | url_or_filename = str(url_or_filename)
239 | if isinstance(cache_dir, Path):
240 | cache_dir = str(cache_dir)
241 |
242 | if is_remote_url(url_or_filename):
243 | # URL, so get it from the cache (downloading if necessary)
244 | output_path = get_from_cache(
245 | url_or_filename,
246 | cache_dir=cache_dir,
247 | force_download=force_download,
248 | proxies=proxies,
249 | resume_download=resume_download,
250 | user_agent=user_agent,
251 | local_files_only=local_files_only,
252 | )
253 | elif os.path.exists(url_or_filename):
254 | # File, and it exists.
255 | output_path = url_or_filename
256 | elif urlparse(url_or_filename).scheme == "":
257 | # File, but it doesn't exist.
258 | raise EnvironmentError("file {} not found".format(url_or_filename))
259 | else:
260 | # Something unknown
261 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
262 |
263 | if extract_compressed_file:
264 | if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
265 | return output_path
266 |
267 | # Path where we extract compressed archives
268 | # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
269 | output_dir, output_file = os.path.split(output_path)
270 | output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
271 | output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
272 |
273 | if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
274 | return output_path_extracted
275 |
276 | # Prevent parallel extractions
277 | lock_path = output_path + ".lock"
278 | with FileLock(lock_path):
279 | shutil.rmtree(output_path_extracted, ignore_errors=True)
280 | os.makedirs(output_path_extracted)
281 | if is_zipfile(output_path):
282 | with ZipFile(output_path, "r") as zip_file:
283 | zip_file.extractall(output_path_extracted)
284 | zip_file.close()
285 | elif tarfile.is_tarfile(output_path):
286 | tar_file = tarfile.open(output_path)
287 | tar_file.extractall(output_path_extracted)
288 | tar_file.close()
289 | else:
290 | raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
291 |
292 | return output_path_extracted
293 |
294 | return output_path
295 |
296 |
297 | def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
298 | ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
299 | if is_torch_available():
300 | ua += "; torch/{}".format(torch.__version__)
301 | if is_tf_available():
302 | ua += "; tensorflow/{}".format(tf.__version__)
303 | if isinstance(user_agent, dict):
304 | ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
305 | elif isinstance(user_agent, str):
306 | ua += "; " + user_agent
307 | headers = {"user-agent": ua}
308 | if resume_size > 0:
309 | headers["Range"] = "bytes=%d-" % (resume_size,)
310 | response = requests.get(url, stream=True, proxies=proxies, headers=headers)
311 | if response.status_code == 416: # Range not satisfiable
312 | return
313 | content_length = response.headers.get("Content-Length")
314 | total = resume_size + int(content_length) if content_length is not None else None
315 | progress = tqdm(
316 | unit="B",
317 | unit_scale=True,
318 | total=total,
319 | initial=resume_size,
320 | desc="Downloading",
321 | disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
322 | )
323 | for chunk in response.iter_content(chunk_size=1024):
324 | if chunk: # filter out keep-alive new chunks
325 | progress.update(len(chunk))
326 | temp_file.write(chunk)
327 | progress.close()
328 |
329 |
330 | def get_from_cache(
331 | url,
332 | cache_dir=None,
333 | force_download=False,
334 | proxies=None,
335 | etag_timeout=10,
336 | resume_download=False,
337 | user_agent=None,
338 | local_files_only=False,
339 | ) -> Optional[str]:
340 | """
341 | Given a URL, look for the corresponding file in the local cache.
342 | If it's not there, download it. Then return the path to the cached file.
343 |
344 | Return:
345 | None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
346 | Local path (string) otherwise
347 | """
348 | if cache_dir is None:
349 | cache_dir = TRANSFORMERS_CACHE
350 | if isinstance(cache_dir, Path):
351 | cache_dir = str(cache_dir)
352 |
353 | os.makedirs(cache_dir, exist_ok=True)
354 |
355 | etag = None
356 | if not local_files_only:
357 | try:
358 | response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
359 | if response.status_code == 200:
360 | etag = response.headers.get("ETag")
361 | except (EnvironmentError, requests.exceptions.Timeout):
362 | # etag is already None
363 | pass
364 |
365 | filename = url_to_filename(url, etag)
366 |
367 | # get cache path to put the file
368 | cache_path = os.path.join(cache_dir, filename)
369 |
370 | # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
371 | # try to get the last downloaded one
372 | if etag is None:
373 | if os.path.exists(cache_path):
374 | return cache_path
375 | else:
376 | matching_files = [
377 | file
378 | for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
379 | if not file.endswith(".json") and not file.endswith(".lock")
380 | ]
381 | if len(matching_files) > 0:
382 | return os.path.join(cache_dir, matching_files[-1])
383 | else:
384 | # If files cannot be found and local_files_only=True,
385 | # the models might've been found if local_files_only=False
386 | # Notify the user about that
387 | if local_files_only:
388 | raise ValueError(
389 | "Cannot find the requested files in the cached path and outgoing traffic has been"
390 | " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
391 | " to False."
392 | )
393 | return None
394 |
395 | # From now on, etag is not None.
396 | if os.path.exists(cache_path) and not force_download:
397 | return cache_path
398 |
399 | # Prevent parallel downloads of the same file with a lock.
400 | lock_path = cache_path + ".lock"
401 | with FileLock(lock_path):
402 |
403 | # If the download just completed while the lock was activated.
404 | if os.path.exists(cache_path) and not force_download:
405 | # Even if returning early like here, the lock will be released.
406 | return cache_path
407 |
408 | if resume_download:
409 | incomplete_path = cache_path + ".incomplete"
410 |
411 | @contextmanager
412 | def _resumable_file_manager():
413 | with open(incomplete_path, "a+b") as f:
414 | yield f
415 |
416 | temp_file_manager = _resumable_file_manager
417 | if os.path.exists(incomplete_path):
418 | resume_size = os.stat(incomplete_path).st_size
419 | else:
420 | resume_size = 0
421 | else:
422 | temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
423 | resume_size = 0
424 |
425 | # Download to temporary file, then copy to cache dir once finished.
426 | # Otherwise you get corrupt cache entries if the download gets interrupted.
427 | with temp_file_manager() as temp_file:
428 | logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
429 |
430 | http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
431 |
432 | logger.info("storing %s in cache at %s", url, cache_path)
433 | os.replace(temp_file.name, cache_path)
434 |
435 | logger.info("creating metadata file for %s", cache_path)
436 | meta = {"url": url, "etag": etag}
437 | meta_path = cache_path + ".json"
438 | with open(meta_path, "w") as meta_file:
439 | json.dump(meta, meta_file)
440 |
441 | return cache_path
442 |
443 |
444 | class cached_property(property):
445 | """
446 | Descriptor that mimics @property but caches output in member variable.
447 |
448 | From tensorflow_datasets
449 |
450 | Built-in in functools from Python 3.8.
451 | """
452 |
453 | def __get__(self, obj, objtype=None):
454 | # See docs.python.org/3/howto/descriptor.html#properties
455 | if obj is None:
456 | return self
457 | if self.fget is None:
458 | raise AttributeError("unreadable attribute")
459 | attr = "__cached_" + self.fget.__name__
460 | cached = getattr(obj, attr, None)
461 | if cached is None:
462 | cached = self.fget(obj)
463 | setattr(obj, attr, cached)
464 | return cached
465 |
466 |
467 | def torch_required(func):
468 | # Chose a different decorator name than in tests so it's clear they are not the same.
469 | @wraps(func)
470 | def wrapper(*args, **kwargs):
471 | if is_torch_available():
472 | return func(*args, **kwargs)
473 | else:
474 | raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
475 |
476 | return wrapper
477 |
478 |
479 | def tf_required(func):
480 | # Chose a different decorator name than in tests so it's clear they are not the same.
481 | @wraps(func)
482 | def wrapper(*args, **kwargs):
483 | if is_tf_available():
484 | return func(*args, **kwargs)
485 | else:
486 | raise ImportError(f"Method `{func.__name__}` requires TF.")
487 |
488 | return wrapper
--------------------------------------------------------------------------------
/src/train_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 | '''
4 | * @Desc: train GPT2 from scratch/ fine tuning.
5 | Modified based on Huggingface GPT-2 implementation
6 | '''
7 | import json
8 | import os
9 | import sys
10 | import argparse
11 | import logging
12 | import time
13 | import tqdm
14 | import datetime
15 | import torch
16 |
17 | import numpy as np
18 |
19 | from os.path import join
20 | from torch.distributed import get_rank, get_world_size
21 |
22 | from lsp_model_rl import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, Adam
23 | from lsp_model_rl import GPT2LMHeadModel_v2
24 | from gpt2_training.train_utils import load_model, boolean_string, set_lr, get_eval_list_same_length
25 | from gpt2_training.eval_utils import eval_model_loss
26 |
27 | from data_loader import BucketingDataLoader, DynamicBatchingLoader, DistributedBucketingDataLoader
28 |
29 |
30 | from gpt2_training.distributed import all_reduce_and_rescale_tensors, all_gather_list
31 |
32 | import sys, os
33 | sys.path.append("../")
34 | sys.path.append(".../")
35 | import MisinfoCorrect.src.variables_ext as cfg
36 |
37 | import time
38 |
39 |
40 |
41 | ############################
42 |
43 | logging.basicConfig(
44 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
45 | datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
46 | logger = logging.getLogger(__name__)
47 |
48 | INF = 100000000
49 | CACHE_EMPTY_STEP = 1000 # previously, 1000, let's use
50 | EVAL_STEP = 100000
51 |
52 | #########################################################################
53 | # Prepare Parser
54 | ##########################################################################
55 |
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('--model_name_or_path', type=str,
58 | help='pretrained model name or path to local checkpoint')
59 | parser.add_argument("--seed", type=int, default=42)
60 | parser.add_argument("--max_seq_length", type=int, default=128)
61 |
62 | parser.add_argument("--skip_eval", action='store_true',
63 | help='If true, skip evaluation.')
64 | parser.add_argument("--use_baseline", action='store_true',
65 | help='If true, use baseline for RL.')
66 | parser.add_argument("--init_checkpoint", type=str)
67 | parser.add_argument("--train_input_file", type=str)
68 | parser.add_argument("--eval_input_file", type=str)
69 | parser.add_argument("--continue_from", type=int, default=0)
70 |
71 | parser.add_argument("--train_batch_size", type=int, default=4,
72 | help="batch size now means per GPU per step")
73 | parser.add_argument("--gradient_accumulation_steps", type=int, default=2,
74 | help="to increase effective batch size "
75 | "and reduce synchronization")
76 | parser.add_argument("--eval_batch_size", type=int, default=4)
77 | parser.add_argument("--learning_rate", type=float, default=1e-5)
78 | parser.add_argument("--num_optim_steps", type=int, default=1000000,
79 | help="new API specifies num update steps")
80 | parser.add_argument("--valid_step", type=int, default=1000,
81 | help="how many optim steps between validations")
82 | parser.add_argument("--warmup_proportion", type=float, default=0.1)
83 | parser.add_argument("--warmup_steps", type=int, default=16000)
84 |
85 | parser.add_argument("--normalize_data", type=boolean_string, default=True)
86 | parser.add_argument("--fp16", type=boolean_string, default=False)
87 | parser.add_argument("--lr_schedule", type=str,
88 | choices=['noam', 'noamwd', 'BERT', 'None'], default='noam')
89 | parser.add_argument("--loss_scale", type=float, default=0)
90 | parser.add_argument("--no_token_id", type=boolean_string, default=True)
91 |
92 | parser.add_argument("--output_dir", type=str)
93 | parser.add_argument("--log_dir", type=str)
94 | parser.add_argument('--pbar', type=boolean_string, default=True, help='turn on progress bar')
95 |
96 | # distributed
97 | parser.add_argument('--local_rank', type=int, default=-1,
98 | help='for torch.distributed')
99 | parser.add_argument('--config', help='JSON config file')
100 |
101 |
102 | # do normal parsing
103 | args = parser.parse_args()
104 |
105 | if args.config is not None:
106 | # override argparse defaults by config JSON
107 | opts = json.load(open(args.config))
108 | for k, v in opts.items():
109 | if isinstance(v, str):
110 | # PHILLY ENV special cases
111 | if 'PHILLY_JOB_DIRECTORY' in v:
112 | v = v.replace('PHILLY_JOB_DIRECTORY',
113 | os.environ['PHILLY_JOB_DIRECTORY'])
114 | elif 'PHILLY_LOG_DIRECTORY' in v:
115 | v = v.replace('PHILLY_LOG_DIRECTORY',
116 | os.environ['PHILLY_LOG_DIRECTORY'])
117 | setattr(args, k, v)
118 |
119 | # command line should override config JSON
120 | argv = sys.argv[1:]
121 | overrides, _ = parser.parse_known_args(argv)
122 | for k, v in vars(overrides).items():
123 | if f'--{k}' in argv:
124 | setattr(args, k, v)
125 | setattr(args, 'local_rank', overrides.local_rank)
126 |
127 |
128 | assert args.train_batch_size % args.gradient_accumulation_steps == 0, \
129 | 'batch size % gradient accumulation steps != 0!'
130 | args.train_batch_size = (args.train_batch_size
131 | // args.gradient_accumulation_steps)
132 | logger.info('train batch size = {}, '
133 | 'new train batch size (after gradient accumulation) = {}'.format(
134 | args.train_batch_size*args.gradient_accumulation_steps,
135 | args.train_batch_size))
136 |
137 |
138 | if args.local_rank == -1:
139 | logger.info('CUDA available? {}'.format(str(torch.cuda.is_available())))
140 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
141 | # n_gpu = torch.cuda.device_count()
142 | # args.device, args.n_gpu = device, n_gpu
143 | # n_gpu = 1
144 | # device = torch.device("cuda:7")
145 |
146 | import sys, os
147 | sys.path.append("./")
148 | sys.path.append("../")
149 | sys.path.append(".../")
150 | sys.path.append("..../")
151 | from MisinfoCorrect.src.variables_ext import device, n_gpu
152 |
153 | args.device, args.n_gpu = device, n_gpu
154 | else:
155 | # distributed training
156 | print('args.local_rank:', args.local_rank)
157 | torch.cuda.set_device(args.local_rank)
158 | device = torch.device("cuda", args.local_rank)
159 | # Initializes the distributed backend which will take care of
160 | # sychronizing nodes/GPUs
161 | torch.distributed.init_process_group(backend='nccl')
162 | n_gpu = torch.distributed.get_world_size()
163 | args.device, args.n_gpu = device, 1
164 | logger.info("device: {} n_gpu: {}, distributed training: {}, "
165 | "16-bits training: {}".format(
166 | device, n_gpu, bool(args.local_rank != -1), args.fp16))
167 |
168 | np.random.seed(args.seed)
169 | torch.random.manual_seed(args.seed)
170 | torch.cuda.manual_seed(args.seed)
171 | if n_gpu > 0:
172 | torch.cuda.manual_seed_all(args.seed)
173 |
174 | timestamp = datetime.datetime.now().strftime('%Y-%m-%d%H%M%S')
175 | # ==== TODO ==== by ext: use this output_dir
176 | output_dir = join(args.output_dir,
177 | 'GPT2.{}.{}.{}gpu.{}'.format(args.learning_rate,
178 | args.train_batch_size, n_gpu,
179 | timestamp))
180 | log_dir = args.log_dir if args.log_dir is not None and len(args.log_dir) > 0 else output_dir
181 | if args.local_rank == -1 or get_rank() == 0:
182 | os.makedirs(output_dir, exist_ok=True)
183 |
184 | logger.info('Input Argument Information')
185 | args_dict = vars(args)
186 | for a in args_dict:
187 | logger.info('%-28s %s' % (a, args_dict[a]))
188 |
189 |
190 | #########################################################################
191 | # Prepare Data Set
192 | ##########################################################################
193 | enc = GPT2Tokenizer.from_pretrained('gpt2-medium')
194 | enc.add_tokens(['', '', ''])
195 | eos = enc.encoder["<|endoftext|>"]
196 |
197 | config = GPT2Config.from_json_file(
198 | join(args.model_name_or_path, 'config.json'))
199 |
200 | # ext: single GPU or CPU
201 | if args.local_rank == -1:
202 | train_dataloader = BucketingDataLoader(args.train_input_file,
203 | args.train_batch_size,
204 | args.max_seq_length)
205 | # ext: multiple GPUs
206 | else:
207 | train_dataloader = DistributedBucketingDataLoader(
208 | get_rank(), get_world_size(),
209 | args.train_input_file, args.train_batch_size,
210 | args.max_seq_length)
211 |
212 | # eval_dataloader_loss = DynamicBatchingLoader(
213 | # args.eval_input_file, enc, args.normalize_data,
214 | # args.eval_batch_size, args.max_seq_length)
215 |
216 | # eval_dataloader_gen = get_eval_list_same_length(
217 | # args.eval_input_file, enc, args.eval_batch_size, True)
218 |
219 |
220 | #########################################################################
221 | # Prepare Model and Optimizer
222 | ##########################################################################
223 |
224 | gpt2_model = GPT2LMHeadModel_v2.from_pretrained(args.model_name_or_path) # ext
225 | print("we use version 2 of GPT model from ext")
226 |
227 | gpt2_model.resize_token_embeddings(len(enc))
228 |
229 | # ==== understand by ext: it seems like device setup ====
230 | model = load_model(gpt2_model, args.init_checkpoint,
231 | args, verbose=True)
232 | if args.local_rank != -1:
233 | # when from scratch make sure initial models are the same
234 | params = [p.data for p in model.parameters()]
235 | all_reduce_and_rescale_tensors(
236 | params, float(torch.distributed.get_world_size()))
237 |
238 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
239 | total_params = sum([np.prod(p.size()) for p in model_parameters])
240 | logger.info('Number of parameter = {}'.format(total_params))
241 |
242 | param_optimizer = list(model.named_parameters())
243 | no_decay = ['bias', 'ln'] # no decay for bias and LayerNorm (ln)
244 | optimizer_grouped_parameters = [
245 | {'params': [p for n, p in param_optimizer
246 | if not any(nd in n for nd in no_decay)],
247 | 'weight_decay': 0.01},
248 | {'params': [p for n, p in param_optimizer
249 | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
250 | ]
251 |
252 | if args.fp16:
253 | logger.info('in fp16, using FusedAdam')
254 | try:
255 | from apex.optimizers import FP16_Optimizer
256 | from apex.optimizers import FusedAdam
257 | except ImportError:
258 | raise ImportError(
259 | "Please install apex from https://www.github.com/nvidia/apex "
260 | "to use distributed and fp16 training.")
261 |
262 | optimizer = FusedAdam(optimizer_grouped_parameters,
263 | lr=args.learning_rate,
264 | bias_correction=False,
265 | max_grad_norm=1.0)
266 | if args.loss_scale == 0:
267 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True,
268 | verbose=False)
269 | else:
270 | optimizer = FP16_Optimizer(optimizer,
271 | static_loss_scale=args.loss_scale,
272 | verbose=False)
273 | else:
274 | optimizer = Adam(optimizer_grouped_parameters, args.learning_rate,
275 | max_grad_norm=1.0)
276 |
277 | #########################################################################
278 | # Training !
279 | ##########################################################################
280 |
281 | if args.local_rank == -1 or get_rank() == 0:
282 | train_logger = open(join(log_dir, 'train_log.txt'), 'a+', buffering=1)
283 | eval_logger = open(join(log_dir, 'eval_log.txt'), 'a+', buffering=1)
284 | print('epoch,global_step,step,mean_loss,n_token_real,'
285 | 'n_token_total,epoch_time', file=train_logger)
286 | print('epoch,global_step,step,eval_loss', file=eval_logger)
287 |
288 | global_step = 0 # multiple batches one global_step where we update the gradient
289 | step = 0 # one batch, one step
290 | epoch = 0 # the default definition: all training examples one epoch
291 |
292 | if args.continue_from:
293 | global_step = args.continue_from
294 | step = global_step*2 - 1
295 |
296 |
297 | if args.local_rank != -1:
298 | n_gpu = 1
299 | if args.local_rank == -1 or get_rank() == 0:
300 | if args.pbar:
301 | pbar = tqdm.tqdm(total=args.num_optim_steps, desc=f"training")
302 | else:
303 | pbar = None
304 |
305 | # pbar = None
306 |
307 | while True:
308 |
309 | if args.use_baseline:
310 | moving_avg_cnt = 20
311 | cumsum = [0.0]
312 | moving_avg_idx = 0
313 |
314 | model.train()
315 | (tr_loss, nb_tr_examples, nb_tr_steps) = 0.0, 0.0, 0.0
316 | n_token_real, n_token_total = 0, 0
317 | train_start_time_epoch = time.time()
318 | tr_reward = 0.0
319 |
320 | # print('iteration started')
321 |
322 | for batch in train_dataloader:
323 | # torch.cuda.empty_cache()
324 | # activate new training mode
325 | seq_len = batch[0].shape[1]
326 | batch = tuple(t for t in batch)
327 |
328 | input_ids, position_ids, token_ids, seeker_post, response_post, = batch
329 |
330 | input_ids = input_ids.to(device) # ext: here, input_ids, we have seeker_post+response_post
331 | position_ids = position_ids.to(device)
332 | token_ids = token_ids.to(device)
333 |
334 | if args.no_token_id:
335 | token_ids = None
336 |
337 | forward_pass_start_time = time.time()
338 |
339 | # ext: in the released code: we do not use the baseline
340 | # which means, we use the relative distance for the downstream task.
341 | if args.use_baseline:
342 | if len(cumsum) >= moving_avg_cnt:
343 | baseline_val = (cumsum[moving_avg_idx] - cumsum[moving_avg_idx-moving_avg_cnt])/moving_avg_cnt
344 | else:
345 | baseline_val = cumsum[moving_avg_idx]
346 |
347 | loss, reward = model(input_ids, position_ids=position_ids, token_type_ids=token_ids, seeker_post=seeker_post, response_post=response_post, eos=eos, tokenizer=enc, baseline_val=baseline_val)
348 |
349 | cumsum.append(cumsum[moving_avg_idx-1] + reward)
350 | moving_avg_idx+=1
351 | # ext: in the release code, we use this one.
352 | else:
353 | loss, reward = model(input_ids, position_ids=position_ids, token_type_ids=token_ids,
354 | seeker_post=seeker_post, response_post=response_post, eos=eos, tokenizer=enc)
355 |
356 | forward_pass_end_time = time.time()
357 |
358 | backward_pass_start_time = time.time()
359 |
360 | # print(f"the loss is: {loss}") # the loss is a negative value
361 |
362 | if n_gpu > 1:
363 | loss = loss.mean()
364 | loss = loss / (args.train_batch_size / input_ids.shape[0])
365 | if args.fp16:
366 | optimizer.backward(loss)
367 | else:
368 | loss.backward()
369 | # ==== add by ext to see the reward change:
370 | if n_gpu > 1:
371 | reward = reward.mean()
372 | reward = reward / (args.train_batch_size / input_ids.shape[0])
373 |
374 |
375 | backward_pass_end_time = time.time()
376 |
377 | tr_loss += float(loss.item()) * (args.train_batch_size / input_ids.shape[0])
378 | tr_reward += float(reward.item()) * (args.train_batch_size / input_ids.shape[0])
379 |
380 | nb_tr_examples += input_ids.size(0)
381 | nb_tr_steps += 1
382 |
383 | mean_loss = tr_loss / nb_tr_steps
384 | mean_reward = tr_reward / nb_tr_steps
385 |
386 | n_token_total += input_ids.shape[0] * input_ids.shape[1]
387 | n_token_real += (input_ids != 0).sum().item()
388 |
389 | # gradient update
390 | step += 1
391 | print(f'the step is: {step}, the time is: {time.time()}') # added by ext to monitor the running time
392 | # ext: it seems that only update gradient after multiple batches
393 | # gradient_accumulation_steps by default is 2.
394 | if step % args.gradient_accumulation_steps == 0:
395 | # ext: it seems to optimize the learning rate
396 | set_lr(optimizer, global_step,
397 | args.lr_schedule, args.learning_rate,
398 | args.warmup_steps, args.warmup_proportion,
399 | config.n_embd, args.num_optim_steps)
400 |
401 | if args.local_rank != -1:
402 | grads = [p.grad.data for p in model.parameters()
403 | if p.requires_grad and p.grad is not None]
404 | all_reduce_and_rescale_tensors(grads, float(1))
405 |
406 | optimizer.step()
407 | optimizer.zero_grad()
408 | global_step += 1
409 |
410 | # Print log info to file
411 | if args.local_rank != -1:
412 | mean_loss = sum(all_gather_list(mean_loss)) / get_world_size()
413 | n_token_real_all_proc = sum(all_gather_list(n_token_real))
414 | n_token_total_all_proc = sum(all_gather_list(n_token_total))
415 | else:
416 | n_token_real_all_proc = n_token_real
417 | n_token_total_all_proc = n_token_total
418 |
419 | if args.local_rank == -1 or get_rank() == 0:
420 | epoch_time = time.time() - train_start_time_epoch
421 |
422 | # print('step:', global_step, 'time:', forward_pass_end_time - forward_pass_start_time, backward_pass_end_time - backward_pass_start_time)
423 |
424 | if pbar is not None:
425 | pbar.set_postfix_str(
426 | f"tok/s: {n_token_real_all_proc//epoch_time//1000}k "
427 | f"epoch: {epoch}")
428 | pbar.update(1)
429 | # commented by ext: append the result to train_log.txt
430 | print('{},{},{},{},{},{},{}'.format(
431 | epoch+1, global_step+1, step+1, mean_loss,
432 | n_token_real_all_proc, n_token_total_all_proc, epoch_time),
433 | file=train_logger)
434 |
435 | if global_step % args.valid_step == 0:
436 | if args.local_rank == -1 or get_rank() == 0:
437 | # by ext: TODO: check how to save by these lines
438 | # only rank 0 process evaluate
439 | torch.save(
440 | {k: (v.cpu() if v is not None else None) # save to cpu tensors
441 | for k, v in model.state_dict().items()},
442 | join(output_dir,
443 | f'GP2-pretrain-step-{global_step}.pkl'))
444 |
445 | # eval_loss, eval_ppl = eval_model_loss(
446 | # model, enc, eval_dataloader_loss, epoch, args)
447 | # enable generation step evaluation for now
448 | # gen_response = eval_model_generation(
449 | # model, enc, eval_dataloader_gen, epoch, args)
450 | '''
451 | # probably use beam search only for test set
452 | if False:
453 | gen_response_beam = eval_model_generation(
454 | model, enc, eval_dataloader_gen, epoch, args,
455 | use_beam_search=True, beam_width=3)
456 | '''
457 | # print('{},{},{},{},{}'.format(
458 | # epoch+1, global_step+1, step+1, eval_loss, eval_ppl),
459 | # file=eval_logger)
460 | logger.info('current learning rate: '
461 | + str(optimizer.param_groups[0]['lr']))
462 | model.train()
463 | if global_step >= args.num_optim_steps:
464 | break
465 | # ==== commented by ext ====
466 | # if (step+1) % CACHE_EMPTY_STEP == 0:
467 | # torch.cuda.empty_cache()
468 | # ==== by ext ====: after one batch, we print some results
469 | print(f'epoch: {epoch + 1}, mean loss:{mean_loss}, mean reward:{mean_reward}', file=train_logger)
470 | # # ==== by ext ====: add here
471 | # if (step+1) % CACHE_EMPTY_STEP == 0:
472 | # torch.cuda.empty_cache()
473 | # ==== or after one epoch, we clear something, since the % may not hold
474 | # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
475 | with torch.cuda.device(cfg.device):
476 | torch.cuda.empty_cache()
477 | if global_step >= args.num_optim_steps:
478 | break
479 | epoch += 1
480 | # ==== by ext ==== save the model some checkpoints
481 | if epoch % cfg.every_k_epoch_save_model == 0:
482 | save_dir = f"{output_dir}/epoch_{epoch}"
483 | os.makedirs(save_dir, exist_ok=True)
484 | model.save_pretrained(save_dir)
485 | # ==== end ====
486 |
487 |
488 | if args.local_rank == -1 or get_rank() == 0:
489 | if pbar is not None:
490 | pbar.close()
491 | train_logger.close()
492 | eval_logger.close()
493 |
494 | # ==== save model by ext ====
495 | model.save_pretrained(output_dir)
496 | print(f"yes, we run the program and save the model at: {output_dir}")
497 |
498 |
499 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/modeling_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """PyTorch OpenAI GPT-2 model."""
3 |
4 | from __future__ import absolute_import, division, print_function, unicode_literals
5 |
6 | import logging
7 | import copy
8 | import math
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn import CrossEntropyLoss
12 | from typing import Iterable, Optional, Tuple
13 | from torch import Tensor
14 | import time
15 | import nltk
16 | import numpy as np
17 |
18 | from transformers import GPT2PreTrainedModel, GPT2Model
19 |
20 | from pytorch_pretrained_bert.modeling_gpt2 import GPT2LMHead, Attention, Block, \
21 | LayerNorm, MLP
22 |
23 | # from generate import top_filtering
24 |
25 | from .rewards import calc_rewards
26 |
27 | # from .generation_utils import GenerationMixin
28 |
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 | import sys, os
33 | sys.path.append("../")
34 | sys.path.append(".../")
35 | import MisinfoCorrect.src.variables_ext as cfg
36 |
37 | class AttentionFP16(Attention):
38 | def __init__(self, nx, n_ctx, config, scale=False):
39 | super(AttentionFP16, self).__init__(nx, n_ctx, config, scale)
40 |
41 | def _attn(self, q, k, v):
42 | w = torch.matmul(q, k)
43 | if self.scale:
44 | w = w / math.sqrt(v.size(-1))
45 | nd, ns = w.size(-2), w.size(-1)
46 | b = self.bias[:, :, ns-nd:ns, :ns]
47 | w = w * b - 1e4 * (1 - b) # point out by Yen-Chun, FP16 overflow
48 |
49 | w = nn.Softmax(dim=-1)(w)
50 | return torch.matmul(w, v)
51 |
52 |
53 | class BlockFP16(Block):
54 | def __init__(self, n_ctx, config, scale=False):
55 | super(BlockFP16, self).__init__(n_ctx, config, scale)
56 | nx = config.n_embd
57 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
58 | self.attn = AttentionFP16(nx, n_ctx, config, scale)
59 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
60 | self.mlp = MLP(4 * nx, config)
61 |
62 |
63 | class GPT2ModelFP16(GPT2Model):
64 | def __init__(self, config):
65 | # super(GPT2ModelFP16, self).__init__(config)
66 | super().__init__(config)
67 | self.wte = nn.Embedding(config.vocab_size, config.n_embd)
68 | self.wpe = nn.Embedding(config.n_positions, config.n_embd)
69 | block = BlockFP16(config.n_ctx, config, scale=True)
70 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
71 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
72 |
73 | self.init_weights()
74 |
75 | class GPT2LMHeadModel(GPT2PreTrainedModel):
76 | def __init__(self, config):
77 | super(GPT2LMHeadModel, self).__init__(config)
78 | self.transformer = GPT2Model(config)
79 | # lm_head generated hidden state to lm_logits -> vocabulary distribution
80 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # GPT2LMHead(self.transformer.wte.weight, config)
81 | self.position_num_labels = 2 # insert/replace: only two
82 | self.lambda_position = 0.1
83 | # position classifier as mentioned in the paper
84 | # actually, the head is a linear layer for the classification
85 | self.position_classifier = GPT2ClassificationHead(num_labels = self.position_num_labels) #GPT2LMHead(self.transformer.wte.weight, config)
86 | self.init_weights()
87 |
88 | def set_tied(self):
89 | """ Make sure we are sharing the embeddings
90 | """
91 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
92 |
93 | def get_output_embeddings(self):
94 | return self.lm_head
95 |
96 | def padding_tensor_3D(self, sequences, max_len):
97 | """
98 | :param sequences: list of tensors
99 | :return:
100 | """
101 | num = len(sequences)
102 | out_dims = (num, max_len, *sequences[0].shape[1:])
103 | out_tensor = sequences[0].data.new(*out_dims).fill_(0)
104 |
105 | # print('out_tensor:', out_tensor.shape)
106 |
107 | mask = sequences[0].data.new(*out_dims).fill_(0)
108 | for i, tensor in enumerate(sequences):
109 | length = tensor.size(0)
110 | # print('length:', length)
111 | out_tensor[i, :length] = tensor
112 | mask[i, :length] = 1
113 | return out_tensor, mask
114 |
115 | def padding_tensor_2D(self, sequences, max_len):
116 | """
117 | :param sequences: list of tensors
118 | :return:
119 | """
120 | num = len(sequences)
121 | out_dims = (num, max_len)
122 | out_tensor = sequences[0].data.new(*out_dims).fill_(0)
123 |
124 | # print('out_tensor:', out_tensor.shape)
125 |
126 | mask = sequences[0].data.new(*out_dims).fill_(0)
127 | for i, tensor in enumerate(sequences):
128 | length = min(tensor.size(0), max_len)
129 | # print('length:', length)
130 | out_tensor[i, :length] = tensor[:length]
131 | mask[i, :length] = 1
132 | return out_tensor, mask
133 |
134 | def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, position_labels=None, past=None, seeker_post=None, response_post=None, top_k=60, top_p=0.92, temperature=0.9, eos=None, tokenizer=None, baseline_val=0):
135 |
136 | transformer_start_time = time.time()
137 |
138 | # Forward Transformer Pass
139 | # self.transformer is a GPT model: no generation at this moment
140 | hidden_states, presents = self.transformer(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, past=past)
141 | # res = self.transformer(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, past_key_values=past) # past_key_values from past # by ext
142 | # hidden_states, presents = res.last_hidden_state, None # by ext last_hidden_state
143 |
144 | transformer_end_time = time.time()
145 |
146 | # Get LM and position logits
147 | lm_logits = self.lm_head(hidden_states) #
148 |
149 | if tokenizer is None:
150 | return lm_logits, presents
151 |
152 | # ext: i think X2 to expand it to 2k+1: even if we just k positions, we double it.
153 | # -1: we select the final dimension data: --- dimension-driven understanding
154 | # like a.shape = 2,3,4 -> :,-1,:->2,4
155 | position_logits = self.position_classifier(hidden_states[:, -1, :]) # X2: shape
156 |
157 | # A1 (Selecting a position)
158 | probs_position = torch.softmax(position_logits.view(-1, self.position_num_labels), -1) # (batch_size, num_position)
159 | all_positions = torch.argmax(probs_position, 1)
160 | all_positions = all_positions.squeeze()
161 |
162 | all_positions = all_positions.cpu().numpy().tolist()
163 |
164 | # A2 (Candidate Sentence): Sample from DialoGPT
165 | all_outputs = []
166 | all_output_ids = []
167 | all_output_logits = []
168 | sample_dialogpt_start_time = time.time()
169 |
170 | for ii, _ in enumerate(input_ids):
171 | curr_seeker = tokenizer.encode(seeker_post[ii] + tokenizer.eos_token)
172 | curr_seeker = torch.tensor([curr_seeker,])
173 |
174 | # ==== ext: device control ====
175 | # ==== ext: device control ====
176 | curr_seeker = curr_seeker # .to(device) # by ext from cuda for .to(device), use the version2
177 | # self.generate: is a method of class from transformers.PreTrainedModel
178 | # here parameters, like top_p, top_k, temperature
179 | generated_output = self.generate(input_ids = curr_seeker,
180 | max_length=1000,
181 | pad_token_id=tokenizer.eos_token_id,
182 | top_p=0.92,
183 | top_k=60,
184 | temperature=1,
185 | num_return_sequences=1)
186 |
187 | curr_output = tokenizer.decode(generated_output[:, curr_seeker.shape[-1]:][0], skip_special_tokens=True)
188 |
189 | curr_output_ids = generated_output[:, curr_seeker.shape[-1]:][0]
190 | curr_output_ids = curr_output_ids[:hidden_states.shape[1]]
191 | curr_position_ids = torch.tensor(range(len(curr_output_ids)), dtype=torch.long) # ext: .to(device) for .to(device), use the version2
192 |
193 | curr_output_logits = lm_logits[ii, range(curr_output_ids.shape[0]), curr_output_ids]
194 |
195 | all_outputs.append(curr_output)
196 | all_output_ids.append(curr_output_ids)
197 | all_output_logits.append(curr_output_logits)
198 |
199 | log_softmax = nn.LogSoftmax(1)
200 |
201 | all_output_logits, _ = self.padding_tensor_2D(all_output_logits, hidden_states.shape[1])
202 | all_output_logits = log_softmax(all_output_logits)
203 |
204 | sample_dialogpt_end_time = time.time()
205 |
206 |
207 | # Calculate Reward
208 |
209 | rewritten_response = []
210 |
211 | for idx, _ in enumerate(all_outputs):
212 | curr_seeker_post = seeker_post[idx]
213 | curr_response = response_post[idx]
214 | curr_output = all_outputs[idx]
215 | curr_position = all_positions[idx]
216 |
217 | curr_response_li = nltk.sent_tokenize(curr_response)
218 |
219 | if curr_position == 0:
220 | curr_rewritten_response = curr_response
221 |
222 | else:
223 | curr_rewritten_response_li = curr_response_li[:curr_position] + [curr_output] + curr_response_li[curr_position:]
224 | curr_rewritten_response = '. '.join(curr_rewritten_response_li)
225 |
226 | rewritten_response.append(curr_rewritten_response)
227 |
228 | reward_start_time = time.time()
229 | # TODO: ext: we rewrite this part
230 | # ==== partner's ====
231 | # we change the reward in v2, the follow child class
232 | # reward = calc_rewards(seeker_post, response_post, rewritten_response, _empathy_change=True, _perplexity=True)
233 | reward = calc_rewards(seeker_post, response_post, rewritten_response,
234 | _politeness=True,
235 | _coherence=False, _perplexity=True)
236 | reward_end_time = time.time()
237 |
238 | batches = np.arange(input_ids.shape[0]).tolist()
239 |
240 | rl_loss = - (reward - baseline_val) * (-torch.mean(all_output_logits[batches,]) + torch.mean(torch.log(probs_position[batches, all_positions]) ))
241 |
242 | return rl_loss, reward
243 |
244 | # ext: by search on github repo: this function is not used
245 | def forward_pointwise(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
246 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
247 | # import pdb; pdb.set_trace()
248 | lm_logits = self.lm_head(hidden_states)
249 | if lm_labels is not None:
250 | # loss_fct = CrossEntropyLoss(ignore_index=-1)
251 | # loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
252 | loss_fct1 = CrossEntropyLoss(ignore_index=-1, reduction='none')
253 | loss1 = loss_fct1(lm_logits.view(-1, lm_logits.size(-1)),
254 | lm_labels.view(-1))
255 | loss1 = loss1.view(lm_labels.size(0), lm_labels.size(1))
256 | label_size = torch.sum(lm_labels != -1, dim=1).type(loss1.type())
257 | loss1 = torch.sum(loss1, dim=1)/label_size
258 | ppl1 = torch.exp(loss1)
259 |
260 | return loss1, ppl1
261 | return lm_logits, presents
262 |
263 | def prepare_inputs_for_generation(self, input_ids, **kwargs):
264 | return {"input_ids": input_ids}
265 |
266 | # created by ext for our purpose
267 | class GPT2LMHeadModel_v2(GPT2LMHeadModel):
268 | def __init__(self, config):
269 | super(GPT2LMHeadModel_v2, self).__init__(config)
270 | self.transformer = GPT2Model(config)
271 | # lm_head generated hidden state to lm_logits -> vocabulary distribution
272 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # GPT2LMHead(self.transformer.wte.weight, config)
273 | # self.position_num_labels = 2 # insert/replace: only two
274 | # self.lambda_position = 0.1 # commented for inheritation: ext
275 | # TODO: maybe, when we load config, we do not find the weight and it is missing
276 | # position classifier as mentioned in the paper
277 | # actually, the head is a linear layer for the classification
278 | # self.position_classifier = GPT2ClassificationHead(num_labels = self.position_num_labels) #GPT2LMHead(self.transformer.wte.weight, config)
279 | self.init_weights()
280 |
281 | def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, position_labels=None,
282 | past=None, seeker_post=None, response_post=None, top_k=60, top_p=0.92, temperature=0.9, eos=None,
283 | tokenizer=None, baseline_val=0):
284 | # print("==== use the forward function in version 2 ====")
285 |
286 | transformer_start_time = time.time()
287 |
288 | # Forward Transformer Pass
289 | # self.transformer is a GPT model: no generation at this moment
290 | # from the code setup: when we have batches: input_ids can be n-batch-size*length-of-a sentence/input-token-id
291 | hidden_states, presents = self.transformer(input_ids=input_ids, position_ids=position_ids,
292 | token_type_ids=token_type_ids, past=past)
293 | # res = self.transformer(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, past_key_values=past) # past_key_values from past # by ext
294 | # hidden_states, presents = res.last_hidden_state, None # by ext last_hidden_state
295 |
296 | transformer_end_time = time.time()
297 |
298 | # Get LM and position logits
299 | lm_logits = self.lm_head(hidden_states) #
300 |
301 | if tokenizer is None:
302 | return lm_logits, presents
303 |
304 | # ext: i think X2 to expand it to 2k+1: even if we just k positions, we double it.
305 | # -1: we select the final dimension data: --- dimension-driven understanding
306 | # like a.shape = 2,3,4 -> :,-1,:->2,4
307 | # position_logits = self.position_classifier(hidden_states[:, -1, :]) # X2: shape
308 |
309 | # A1 (Selecting a position)
310 | # probs_position = torch.softmax(position_logits.view(-1, self.position_num_labels),-1) # (batch_size, num_position)
311 | # all_positions = torch.argmax(probs_position, 1)
312 | # all_positions = all_positions.squeeze()
313 |
314 | # all_positions = all_positions.cpu().numpy().tolist()
315 |
316 | # A2 (Candidate Sentence): Sample from DialoGPT
317 | all_outputs = []
318 | all_output_ids = []
319 | all_output_logits = []
320 | sample_dialogpt_start_time = time.time()
321 |
322 | for ii, _ in enumerate(input_ids):
323 | curr_seeker = tokenizer.encode(seeker_post[ii] + tokenizer.eos_token)
324 | curr_seeker = torch.tensor([curr_seeker, ])
325 |
326 | # ==== ext: device control ====
327 | import sys, os
328 | sys.path.append("../")
329 | sys.path.append(".../")
330 | from MisinfoCorrect.src.variables_ext import device
331 | # ==== ext: device control ====
332 | curr_seeker = curr_seeker.to(device) # by ext from cuda
333 | # self.generate: is a method of class from transformers.PreTrainedModel
334 | # here parameters, like top_p, top_k, temperature
335 | generated_output = self.generate(input_ids=curr_seeker,
336 | max_length=140, # by ext from 1000 to 140
337 | pad_token_id=tokenizer.eos_token_id,
338 | top_p=0.92,
339 | top_k=60,
340 | temperature=1,
341 | num_return_sequences=1)
342 |
343 | curr_output = tokenizer.decode(generated_output[:, curr_seeker.shape[-1]:][0], skip_special_tokens=True)
344 |
345 | curr_output_ids = generated_output[:, curr_seeker.shape[-1]:][0]
346 | curr_output_ids = curr_output_ids[:hidden_states.shape[1]] # ext: TODO: expanded usage?
347 | # TODO: or nonsense due to the large number of hidden state
348 | curr_position_ids = torch.tensor(range(len(curr_output_ids)), dtype=torch.long).to(
349 | device) # by ext
350 |
351 | # ext: here, we propability of each vob in the vocabulary, like p1*p2*p3
352 | curr_output_logits = lm_logits[ii, range(curr_output_ids.shape[0]), curr_output_ids]
353 |
354 | all_outputs.append(curr_output)
355 | all_output_ids.append(curr_output_ids)
356 | all_output_logits.append(curr_output_logits)
357 |
358 | log_softmax = nn.LogSoftmax(1)
359 |
360 | all_output_logits, _ = self.padding_tensor_2D(all_output_logits, hidden_states.shape[1])
361 | all_output_logits = log_softmax(all_output_logits)
362 |
363 | sample_dialogpt_end_time = time.time()
364 |
365 | # Calculate Reward
366 |
367 | rewritten_response = []
368 |
369 | for idx, _ in enumerate(all_outputs):
370 | curr_seeker_post = seeker_post[idx]
371 | curr_response = response_post[idx]
372 | curr_output = all_outputs[idx]
373 | # curr_position = all_positions[idx]
374 |
375 | curr_response_li = nltk.sent_tokenize(curr_response)
376 |
377 | # ==== previous partner ====
378 | # if curr_position == 0:
379 | # curr_rewritten_response = curr_response
380 | # else:
381 | # curr_rewritten_response_li = curr_response_li[:curr_position] + [curr_output] + curr_response_li[
382 | # curr_position:]
383 | # curr_rewritten_response = '. '.join(curr_rewritten_response_li)
384 | # ==== version 2 ====
385 |
386 | curr_rewritten_response_li = [curr_output]
387 | curr_rewritten_response = '. '.join(curr_rewritten_response_li)
388 |
389 | rewritten_response.append(curr_rewritten_response)
390 |
391 | reward_start_time = time.time()
392 | # TODO: ext: we rewrite this part
393 | # ==== partner's ====
394 | # reward = calc_rewards(seeker_post, response_post, rewritten_response, _empathy_change=True, _perplexity=True)
395 | # ==== by ext ====
396 | # ======== debug =======: seeker_post: list, ['i like it', 'i forget it']
397 | # print(f'the current seeker post is: {seeker_post}')
398 | # print(f'the current rewritten_response is: {rewritten_response}')
399 | # exit(0)
400 |
401 | if_cut_responses_with_more_characters = True
402 | if if_cut_responses_with_more_characters:
403 | rewritten_response_new = []
404 | for one_response in rewritten_response:
405 | if len(one_response) > 280:
406 | rewritten_response_new.append(one_response[:280])
407 | else:
408 | rewritten_response_new.append(one_response)
409 |
410 | rewritten_response = rewritten_response_new
411 |
412 | reward = calc_rewards(seeker_post, None, rewritten_response,
413 | _politeness=cfg.if_politeness,
414 | _refutation=cfg.if_refutation,
415 | _evidence=cfg.if_evidence,
416 | _perplexity=cfg.if_perplexity,
417 | _relevance=cfg.if_relevance)
418 | reward_end_time = time.time()
419 |
420 | batches = np.arange(input_ids.shape[0]).tolist()
421 |
422 | # rl_loss = - (reward - baseline_val) * (-torch.mean(all_output_logits[batches,]) + torch.mean(
423 | # torch.log(probs_position[batches, all_positions])))
424 | # ==== version 2 ====
425 | # log addition -> 0 for probability 1
426 | # in the release code, baseline_val = 0
427 | # reward is positive: like cross-entropy setup
428 | # here, -1 to change the negative value to the positive value
429 | # some data examples: reward: 0.50, rl_loss is -95, then prob after log: ~200
430 | # please note: the baseline_val is 0 if we do not initiate it
431 | rl_loss = - (reward - baseline_val) * (-torch.mean(all_output_logits[batches,]))
432 | # print(f"the reward, baseline_val, all_output_logits[batches,], rl_loss is:"
433 | # f"\n {reward}\n{baseline_val},\n{all_output_logits[batches,]},\n{rl_loss}")
434 | # exit
435 | return rl_loss, reward
436 |
437 | class GPT2ClassificationHead(nn.Module):
438 | """Head for sentence-level classification tasks."""
439 |
440 | def __init__(self, hidden_dropout_prob=0.1, hidden_size=1024, num_labels=2):
441 | super().__init__()
442 |
443 | # self.dense = nn.Linear(hidden_size, hidden_size) # previous by the
444 | self.dense = nn.Linear(768, hidden_size) # by ext from 768*2 to 1024,
445 | self.dropout = nn.Dropout(hidden_dropout_prob)
446 | self.out_proj = nn.Linear(hidden_size, num_labels)
447 |
448 | def forward(self, features, **kwargs):
449 | print(f"***{features.shape}***")
450 | x = features[:, :]
451 | x = self.dropout(x)
452 | print(f"***{x.shape}***")
453 | x = self.dense(x)
454 | x = torch.relu(x)
455 | x = self.dropout(x)
456 | x = self.out_proj(x)
457 | return x
458 |
--------------------------------------------------------------------------------
/src/lsp_model_rl/util/configuration_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Configuration base class and utilities."""
17 |
18 |
19 | import copy
20 | import json
21 | import logging
22 | import os
23 | from typing import Dict, Optional, Tuple
24 |
25 | from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
26 |
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | class PretrainedConfig(object):
32 | r""" Base class for all configuration classes.
33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
34 |
35 | Note:
36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
37 | It only affects the model's configuration.
38 |
39 | Class attributes (overridden by derived classes):
40 | - ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
41 | - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
42 |
43 | Args:
44 | finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
45 | Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
46 | num_labels (:obj:`int`, `optional`, defaults to `2`):
47 | Number of classes to use when the model is a classification model (sequences/tokens)
48 | output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
49 | Should the model returns attentions weights.
50 | output_hidden_states (:obj:`string`, `optional`, defaults to :obj:`False`):
51 | Should the model returns all hidden-states.
52 | torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
53 | Is the model used with Torchscript (for PyTorch models).
54 | """
55 | pretrained_config_archive_map: Dict[str, str] = {}
56 | model_type: str = ""
57 |
58 | def __init__(self, **kwargs):
59 | # Attributes with defaults
60 | self.output_attentions = kwargs.pop("output_attentions", False)
61 | self.output_hidden_states = kwargs.pop("output_hidden_states", False)
62 | self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
63 | self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
64 | self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
65 | self.pruned_heads = kwargs.pop("pruned_heads", {})
66 |
67 | # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
68 | self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
69 | self.is_decoder = kwargs.pop("is_decoder", False)
70 |
71 | # Parameters for sequence generation
72 | self.max_length = kwargs.pop("max_length", 20)
73 | self.min_length = kwargs.pop("min_length", 0)
74 | self.do_sample = kwargs.pop("do_sample", False)
75 | self.early_stopping = kwargs.pop("early_stopping", False)
76 | self.num_beams = kwargs.pop("num_beams", 1)
77 | self.temperature = kwargs.pop("temperature", 1.0)
78 | self.top_k = kwargs.pop("top_k", 50)
79 | self.top_p = kwargs.pop("top_p", 1.0)
80 | self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
81 | self.length_penalty = kwargs.pop("length_penalty", 1.0)
82 | self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
83 | self.bad_words_ids = kwargs.pop("bad_words_ids", None)
84 | self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
85 |
86 | # Fine-tuning task arguments
87 | self.architectures = kwargs.pop("architectures", None)
88 | self.finetuning_task = kwargs.pop("finetuning_task", None)
89 | self.num_labels = kwargs.pop("num_labels", 2)
90 | self.empathy_num_labels = kwargs.pop("empathy_num_labels", 2)
91 | self.rationale_num_labels = kwargs.pop("rationale_num_labels", 2)
92 | self.id2label = kwargs.pop("id2label", {i: f"LABEL_{i}" for i in range(self.num_labels)})
93 | self.id2label = dict((int(key), value) for key, value in self.id2label.items())
94 | self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
95 | self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
96 |
97 | # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
98 | self.prefix = kwargs.pop("prefix", None)
99 | self.bos_token_id = kwargs.pop("bos_token_id", None)
100 | self.pad_token_id = kwargs.pop("pad_token_id", None)
101 | self.eos_token_id = kwargs.pop("eos_token_id", None)
102 | self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
103 |
104 | # task specific arguments
105 | self.task_specific_params = kwargs.pop("task_specific_params", None)
106 |
107 | # TPU arguments
108 | self.xla_device = kwargs.pop("xla_device", None)
109 |
110 | # Additional attributes without default values
111 | for key, value in kwargs.items():
112 | try:
113 | setattr(self, key, value)
114 | except AttributeError as err:
115 | logger.error("Can't set {} with value {} for {}".format(key, value, self))
116 | raise err
117 |
118 | @property
119 | def num_labels(self):
120 | return self._num_labels
121 |
122 | @num_labels.setter
123 | def num_labels(self, num_labels):
124 | self._num_labels = num_labels
125 | self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)}
126 | self.id2label = dict((int(key), value) for key, value in self.id2label.items())
127 | self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
128 | self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
129 |
130 | def save_pretrained(self, save_directory):
131 | """
132 | Save a configuration object to the directory `save_directory`, so that it
133 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
134 |
135 | Args:
136 | save_directory (:obj:`string`):
137 | Directory where the configuration JSON file will be saved.
138 | """
139 | assert os.path.isdir(
140 | save_directory
141 | ), "Saving path should be a directory where the model and configuration can be saved"
142 |
143 | # If we save using the predefined names, we can load using `from_pretrained`
144 | output_config_file = os.path.join(save_directory, CONFIG_NAME)
145 |
146 | self.to_json_file(output_config_file, use_diff=True)
147 | logger.info("Configuration saved in {}".format(output_config_file))
148 |
149 | @classmethod
150 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
151 | r"""
152 |
153 | Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
154 |
155 | Args:
156 | pretrained_model_name_or_path (:obj:`string`):
157 | either:
158 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or
159 | download, e.g.: ``bert-base-uncased``.
160 | - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
161 | our S3, e.g.: ``dbmdz/bert-base-german-cased``.
162 | - a path to a `directory` containing a configuration file saved using the
163 | :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
164 | - a path or url to a saved configuration JSON `file`, e.g.:
165 | ``./my_model_directory/configuration.json``.
166 | cache_dir (:obj:`string`, `optional`):
167 | Path to a directory in which a downloaded pre-trained model
168 | configuration should be cached if the standard cache should not be used.
169 | kwargs (:obj:`Dict[str, any]`, `optional`):
170 | The values in kwargs of any keys which are configuration attributes will be used to override the loaded
171 | values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
172 | controlled by the `return_unused_kwargs` keyword parameter.
173 | force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
174 | Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
175 | resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
176 | Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
177 | proxies (:obj:`Dict`, `optional`):
178 | A dictionary of proxy servers to use by protocol or endpoint, e.g.:
179 | :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
180 | The proxies are used on each request.
181 | return_unused_kwargs: (`optional`) bool:
182 | If False, then this function returns just the final configuration object.
183 | If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
184 | dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
185 | of kwargs which has not been used to update `config` and is otherwise ignored.
186 |
187 | Returns:
188 | :class:`PretrainedConfig`: An instance of a configuration object
189 |
190 | Examples::
191 |
192 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
193 | # derived class: BertConfig
194 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
195 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
196 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
197 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
198 | assert config.output_attention == True
199 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
200 | foo=False, return_unused_kwargs=True)
201 | assert config.output_attention == True
202 | assert unused_kwargs == {'foo': False}
203 |
204 | """
205 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
206 | return cls.from_dict(config_dict, **kwargs)
207 |
208 | @classmethod
209 | def get_config_dict(
210 | cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs
211 | ) -> Tuple[Dict, Dict]:
212 | """
213 | From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
214 | for instantiating a Config using `from_dict`.
215 |
216 | Parameters:
217 | pretrained_model_name_or_path (:obj:`string`):
218 | The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
219 | pretrained_config_archive_map: (:obj:`Dict[str, str]`, `optional`) Dict:
220 | A map of `shortcut names` to `url`. By default, will use the current class attribute.
221 |
222 | Returns:
223 | :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
224 |
225 | """
226 | cache_dir = kwargs.pop("cache_dir", None)
227 | force_download = kwargs.pop("force_download", False)
228 | resume_download = kwargs.pop("resume_download", False)
229 | proxies = kwargs.pop("proxies", None)
230 | local_files_only = kwargs.pop("local_files_only", False)
231 |
232 | if pretrained_config_archive_map is None:
233 | pretrained_config_archive_map = cls.pretrained_config_archive_map
234 |
235 | if pretrained_model_name_or_path in pretrained_config_archive_map:
236 | config_file = pretrained_config_archive_map[pretrained_model_name_or_path]
237 | elif os.path.isdir(pretrained_model_name_or_path):
238 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
239 | elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
240 | config_file = pretrained_model_name_or_path
241 | else:
242 | config_file = hf_bucket_url(pretrained_model_name_or_path, postfix=CONFIG_NAME)
243 |
244 | try:
245 | # Load from URL or cache if already cached
246 | resolved_config_file = cached_path(
247 | config_file,
248 | cache_dir=cache_dir,
249 | force_download=force_download,
250 | proxies=proxies,
251 | resume_download=resume_download,
252 | local_files_only=local_files_only,
253 | )
254 | # Load config dict
255 | if resolved_config_file is None:
256 | raise EnvironmentError
257 | config_dict = cls._dict_from_json_file(resolved_config_file)
258 |
259 | except EnvironmentError:
260 | if pretrained_model_name_or_path in pretrained_config_archive_map:
261 | msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
262 | config_file
263 | )
264 | else:
265 | msg = (
266 | "Can't load '{}'. Make sure that:\n\n"
267 | "- '{}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
268 | "- or '{}' is the correct path to a directory containing a '{}' file\n\n".format(
269 | pretrained_model_name_or_path,
270 | pretrained_model_name_or_path,
271 | pretrained_model_name_or_path,
272 | CONFIG_NAME,
273 | )
274 | )
275 | raise EnvironmentError(msg)
276 |
277 | except json.JSONDecodeError:
278 | msg = (
279 | "Couldn't reach server at '{}' to download configuration file or "
280 | "configuration file is not a valid JSON file. "
281 | "Please check network or file content here: {}.".format(config_file, resolved_config_file)
282 | )
283 | raise EnvironmentError(msg)
284 |
285 | if resolved_config_file == config_file:
286 | logger.info("loading configuration file {}".format(config_file))
287 | else:
288 | logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
289 |
290 | return config_dict, kwargs
291 |
292 | @classmethod
293 | def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
294 | """
295 | Constructs a `Config` from a Python dictionary of parameters.
296 |
297 | Args:
298 | config_dict (:obj:`Dict[str, any]`):
299 | Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
300 | from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
301 | method.
302 | kwargs (:obj:`Dict[str, any]`):
303 | Additional parameters from which to initialize the configuration object.
304 |
305 | Returns:
306 | :class:`PretrainedConfig`: An instance of a configuration object
307 | """
308 | return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
309 |
310 | config = cls(**config_dict)
311 |
312 | if hasattr(config, "pruned_heads"):
313 | config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
314 |
315 | # Update config with kwargs if needed
316 | to_remove = []
317 | for key, value in kwargs.items():
318 | if hasattr(config, key):
319 | setattr(config, key, value)
320 | to_remove.append(key)
321 | for key in to_remove:
322 | kwargs.pop(key, None)
323 |
324 | logger.info("Model config %s", str(config))
325 | if return_unused_kwargs:
326 | return config, kwargs
327 | else:
328 | return config
329 |
330 | @classmethod
331 | def from_json_file(cls, json_file: str) -> "PretrainedConfig":
332 | """
333 | Constructs a `Config` from the path to a json file of parameters.
334 |
335 | Args:
336 | json_file (:obj:`string`):
337 | Path to the JSON file containing the parameters.
338 |
339 | Returns:
340 | :class:`PretrainedConfig`: An instance of a configuration object
341 |
342 | """
343 | config_dict = cls._dict_from_json_file(json_file)
344 | return cls(**config_dict)
345 |
346 | @classmethod
347 | def _dict_from_json_file(cls, json_file: str):
348 | with open(json_file, "r", encoding="utf-8") as reader:
349 | text = reader.read()
350 | return json.loads(text)
351 |
352 | def __eq__(self, other):
353 | return self.__dict__ == other.__dict__
354 |
355 | def __repr__(self):
356 | return "{} {}".format(self.__class__.__name__, self.to_json_string())
357 |
358 | def to_diff_dict(self):
359 | """
360 | Removes all attributes from config which correspond to the default
361 | config attributes for better readability and serializes to a Python
362 | dictionary.
363 |
364 | Returns:
365 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
366 | """
367 | config_dict = self.to_dict()
368 |
369 | # get the default config dict
370 | default_config_dict = PretrainedConfig().to_dict()
371 |
372 | serializable_config_dict = {}
373 |
374 | # only serialize values that differ from the default config
375 | for key, value in config_dict.items():
376 | if key not in default_config_dict or value != default_config_dict[key]:
377 | serializable_config_dict[key] = value
378 |
379 | return serializable_config_dict
380 |
381 | def to_dict(self):
382 | """
383 | Serializes this instance to a Python dictionary.
384 |
385 | Returns:
386 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
387 | """
388 | output = copy.deepcopy(self.__dict__)
389 | if hasattr(self.__class__, "model_type"):
390 | output["model_type"] = self.__class__.model_type
391 | return output
392 |
393 | def to_json_string(self, use_diff=True):
394 | """
395 | Serializes this instance to a JSON string.
396 |
397 | Args:
398 | use_diff (:obj:`bool`):
399 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
400 |
401 | Returns:
402 | :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
403 | """
404 | if use_diff is True:
405 | config_dict = self.to_diff_dict()
406 | else:
407 | config_dict = self.to_dict()
408 | return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
409 |
410 | def to_json_file(self, json_file_path, use_diff=True):
411 | """
412 | Save this instance to a json file.
413 |
414 | Args:
415 | json_file_path (:obj:`string`):
416 | Path to the JSON file in which this configuration instance's parameters will be saved.
417 | use_diff (:obj:`bool`):
418 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
419 | """
420 | with open(json_file_path, "w", encoding="utf-8") as writer:
421 | writer.write(self.to_json_string(use_diff=use_diff))
422 |
423 | def update(self, config_dict: Dict):
424 | """
425 | Updates attributes of this class
426 | with attributes from `config_dict`.
427 |
428 | Args:
429 | :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
430 | """
431 | for key, value in config_dict.items():
432 | setattr(self, key, value)
--------------------------------------------------------------------------------