├── .gitignore ├── GPT2ForwardBackward ├── __init__.py ├── configuration_opengpt2.py ├── modeling_opengpt2.py └── padded_encoder.py ├── LICENSE ├── README.md ├── abductive.sh ├── bleuloss.py ├── cold_decoding.py ├── commongen.sh ├── counterfactual.sh ├── data ├── abductive │ └── small_data.json ├── commongen │ ├── commongen.dev.jsonl │ └── commongen.test_noref.jsonl └── counterfactual │ └── dev_data.json ├── requirements.txt └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /GPT2ForwardBackward/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qkaren/COLD_decoding/0cda54b366b4fd3c74550436a869c50593f40ee2/GPT2ForwardBackward/__init__.py -------------------------------------------------------------------------------- /GPT2ForwardBackward/configuration_opengpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration """ 17 | 18 | 19 | # coding=utf-8 20 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 21 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 22 | # 23 | # Licensed under the Apache License, Version 2.0 (the "License"); 24 | # you may not use this file except in compliance with the License. 25 | # You may obtain a copy of the License at 26 | # 27 | # http://www.apache.org/licenses/LICENSE-2.0 28 | # 29 | # Unless required by applicable law or agreed to in writing, software 30 | # distributed under the License is distributed on an "AS IS" BASIS, 31 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | # See the License for the specific language governing permissions and 33 | # limitations under the License. 34 | """ OpenAI GPT-2 configuration """ 35 | 36 | from transformers.configuration_utils import PretrainedConfig 37 | from transformers.utils import logging 38 | 39 | 40 | logger = logging.get_logger(__name__) 41 | 42 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = { 43 | "gpt2": "https://huggingface.co/gpt2/resolve/main/config.json", 44 | "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/config.json", 45 | "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/config.json", 46 | "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/config.json", 47 | "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/config.json", 48 | } 49 | 50 | 51 | class OpenGPT2Config(PretrainedConfig): 52 | """ 53 | This is the configuration class to store the configuration of a :class:`~transformers.GPT2Model` or a 54 | :class:`~transformers.TFGPT2Model`. It is used to instantiate a GPT-2 model according to the specified arguments, 55 | defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration 56 | to that of the GPT-2 `small `__ architecture. 57 | 58 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 59 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 60 | 61 | 62 | Args: 63 | vocab_size (:obj:`int`, `optional`, defaults to 50257): 64 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the 65 | :obj:`inputs_ids` passed when calling :class:`~transformers.GPT2Model` or 66 | :class:`~transformers.TFGPT2Model`. 67 | n_positions (:obj:`int`, `optional`, defaults to 1024): 68 | The maximum sequence length that this model might ever be used with. Typically set this to something large 69 | just in case (e.g., 512 or 1024 or 2048). 70 | n_ctx (:obj:`int`, `optional`, defaults to 1024): 71 | Dimensionality of the causal mask (usually same as n_positions). 72 | n_embd (:obj:`int`, `optional`, defaults to 768): 73 | Dimensionality of the embeddings and hidden states. 74 | n_layer (:obj:`int`, `optional`, defaults to 12): 75 | Number of hidden layers in the Transformer encoder. 76 | n_head (:obj:`int`, `optional`, defaults to 12): 77 | Number of attention heads for each attention layer in the Transformer encoder. 78 | n_inner (:obj:`int`, `optional`, defaults to None): 79 | Dimensionality of the inner feed-forward layers. :obj:`None` will set it to 4 times n_embd 80 | activation_function (:obj:`str`, `optional`, defaults to :obj:`"gelu"`): 81 | Activation function, to be selected in the list :obj:`["relu", "silu", "gelu", "tanh", "gelu_new"]`. 82 | resid_pdrop (:obj:`float`, `optional`, defaults to 0.1): 83 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 84 | embd_pdrop (:obj:`int`, `optional`, defaults to 0.1): 85 | The dropout ratio for the embeddings. 86 | attn_pdrop (:obj:`float`, `optional`, defaults to 0.1): 87 | The dropout ratio for the attention. 88 | layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): 89 | The epsilon to use in the layer normalization layers 90 | initializer_range (:obj:`float`, `optional`, defaults to 0.02): 91 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 92 | summary_type (:obj:`string`, `optional`, defaults to :obj:`"cls_index"`): 93 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel` 94 | and :class:`~transformers.TFGPT2DoubleHeadsModel`. 95 | 96 | Has to be one of the following options: 97 | 98 | - :obj:`"last"`: Take the last token hidden state (like XLNet). 99 | - :obj:`"first"`: Take the first token hidden state (like BERT). 100 | - :obj:`"mean"`: Take the mean of all tokens hidden states. 101 | - :obj:`"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). 102 | - :obj:`"attn"`: Not implemented now, use multi-head attention. 103 | summary_use_proj (:obj:`bool`, `optional`, defaults to :obj:`True`): 104 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel` 105 | and :class:`~transformers.TFGPT2DoubleHeadsModel`. 106 | 107 | Whether or not to add a projection after the vector extraction. 108 | summary_activation (:obj:`str`, `optional`): 109 | Argument used when doing sequence summary. Used in for the multiple choice head in 110 | :class:`~transformers.GPT2DoubleHeadsModel`. 111 | 112 | Pass :obj:`"tanh"` for a tanh activation to the output, any other value will result in no activation. 113 | summary_proj_to_labels (:obj:`bool`, `optional`, defaults to :obj:`True`): 114 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel` 115 | and :class:`~transformers.TFGPT2DoubleHeadsModel`. 116 | 117 | Whether the projection outputs should have :obj:`config.num_labels` or :obj:`config.hidden_size` classes. 118 | summary_first_dropout (:obj:`float`, `optional`, defaults to 0.1): 119 | Argument used when doing sequence summary, used in the models :class:`~transformers.GPT2DoubleHeadsModel` 120 | and :class:`~transformers.TFGPT2DoubleHeadsModel`. 121 | 122 | The dropout ratio to be used after the projection and activation. 123 | gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): 124 | Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. 125 | use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): 126 | Whether or not the model should return the last key/values attentions (not used by all models). 127 | 128 | Example:: 129 | 130 | >>> from transformers import GPT2Model, GPT2Config 131 | 132 | >>> # Initializing a GPT2 configuration 133 | >>> configuration = GPT2Config() 134 | 135 | >>> # Initializing a model from the configuration 136 | >>> model = GPT2Model(configuration) 137 | 138 | >>> # Accessing the model configuration 139 | >>> configuration = model.config 140 | """ 141 | 142 | #model_type = "gpt2" 143 | #keys_to_ignore_at_inference = ["past_key_values"] 144 | 145 | def __init__( 146 | self, 147 | vocab_size=50257, 148 | max_position_embeddings=1024, 149 | intermediate_size=3072, 150 | hidden_size=768, 151 | num_hidden_layers=12, 152 | num_attention_heads=12, 153 | n_inner=None, 154 | hidden_act="gelu", 155 | hidden_dropout_prob=0.1, 156 | attention_probs_dropout_prob=0.1, 157 | #attn_pdrop=0.1, 158 | layer_norm_epsilon=1e-5, 159 | initializer_range=0.02, 160 | summary_type="cls_index", 161 | summary_use_proj=True, 162 | summary_activation=None, 163 | summary_proj_to_labels=True, 164 | summary_first_dropout=0.1, 165 | gradient_checkpointing=False, 166 | use_cache=True, 167 | bos_token_id=50256, 168 | eos_token_id=50256, 169 | **kwargs 170 | ): 171 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 172 | 173 | 174 | self.vocab_size = vocab_size 175 | self.hidden_size = hidden_size 176 | self.num_hidden_layers = num_hidden_layers 177 | self.num_attention_heads = num_attention_heads 178 | self.hidden_act = hidden_act 179 | self.intermediate_size = intermediate_size 180 | self.hidden_dropout_prob = hidden_dropout_prob 181 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 182 | self.max_position_embeddings = max_position_embeddings 183 | self.initializer_range = initializer_range 184 | self.pad_token_id = 0 185 | 186 | 187 | self.layer_norm_epsilon = layer_norm_epsilon 188 | self.initializer_range = initializer_range 189 | self.summary_type = summary_type 190 | self.summary_use_proj = summary_use_proj 191 | self.summary_activation = summary_activation 192 | self.summary_first_dropout = summary_first_dropout 193 | self.summary_proj_to_labels = summary_proj_to_labels 194 | self.gradient_checkpointing = gradient_checkpointing 195 | self.use_cache = use_cache 196 | 197 | self.bos_token_id = bos_token_id 198 | self.eos_token_id = eos_token_id 199 | 200 | ''' 201 | @property 202 | def max_position_embeddings(self): 203 | return self.n_positions 204 | 205 | @property 206 | def hidden_size(self): 207 | return self.n_embd 208 | 209 | @property 210 | def num_attention_heads(self): 211 | return self.n_head 212 | 213 | @property 214 | def num_hidden_layers(self): 215 | return self.n_layer 216 | ''' 217 | 218 | 219 | 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /GPT2ForwardBackward/modeling_opengpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch open-GPT-2 model.""" 17 | 18 | import os 19 | from dataclasses import dataclass 20 | from typing import List, Optional, Tuple 21 | 22 | import math 23 | 24 | import torch 25 | import torch.nn as nn 26 | from torch.nn import CrossEntropyLoss, MSELoss 27 | 28 | from transformers.activations import ACT2FN 29 | 30 | from transformers.file_utils import ( 31 | ModelOutput, 32 | add_code_sample_docstrings, 33 | add_start_docstrings, 34 | add_start_docstrings_to_model_forward, 35 | replace_return_docstrings, 36 | ) 37 | 38 | from transformers.modeling_outputs import ( 39 | BaseModelOutputWithPastAndCrossAttentions, 40 | CausalLMOutputWithCrossAttentions, 41 | SequenceClassifierOutputWithPast, 42 | ) 43 | 44 | from transformers.modeling_utils import ( 45 | Conv1D, 46 | PreTrainedModel, 47 | SequenceSummary, 48 | find_pruneable_heads_and_indices, 49 | prune_conv1d_layer, 50 | ) 51 | 52 | from transformers.utils import logging 53 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 54 | 55 | from configuration_opengpt2 import OpenGPT2Config 56 | 57 | ## PETER: replaced these relative imports with imports from the package (above) 58 | ''' 59 | from ...activations import ACT2FN 60 | from ...file_utils import ( 61 | ModelOutput, 62 | add_code_sample_docstrings, 63 | add_start_docstrings, 64 | add_start_docstrings_to_model_forward, 65 | replace_return_docstrings, 66 | ) 67 | from ...modeling_outputs import ( 68 | BaseModelOutputWithPastAndCrossAttentions, 69 | CausalLMOutputWithCrossAttentions, 70 | SequenceClassifierOutputWithPast, 71 | ) 72 | from ...modeling_utils import ( 73 | Conv1D, 74 | PreTrainedModel, 75 | SequenceSummary, 76 | find_pruneable_heads_and_indices, 77 | prune_conv1d_layer, 78 | ) 79 | from ...utils import logging 80 | from ...utils.model_parallel_utils import assert_device_map, get_device_map 81 | ''' 82 | 83 | 84 | 85 | logger = logging.get_logger(__name__) 86 | 87 | _CONFIG_FOR_DOC = "GPT2Config" 88 | _TOKENIZER_FOR_DOC = "GPT2Tokenizer" 89 | 90 | GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ 91 | "gpt2", 92 | "gpt2-medium", 93 | "gpt2-large", 94 | "gpt2-xl", 95 | "distilgpt2", 96 | # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 97 | ] 98 | 99 | 100 | def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): 101 | """Load tf checkpoints in a pytorch model""" 102 | try: 103 | import re 104 | 105 | import tensorflow as tf 106 | except ImportError: 107 | logger.error( 108 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 109 | "https://www.tensorflow.org/install/ for installation instructions." 110 | ) 111 | raise 112 | tf_path = os.path.abspath(gpt2_checkpoint_path) 113 | logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) 114 | # Load weights from TF model 115 | init_vars = tf.train.list_variables(tf_path) 116 | names = [] 117 | arrays = [] 118 | for name, shape in init_vars: 119 | logger.info("Loading TF weight {} with shape {}".format(name, shape)) 120 | array = tf.train.load_variable(tf_path, name) 121 | names.append(name) 122 | arrays.append(array.squeeze()) 123 | 124 | for name, array in zip(names, arrays): 125 | name = name[6:] # skip "model/" 126 | name = name.split("/") 127 | pointer = model 128 | for m_name in name: 129 | if re.fullmatch(r"[A-Za-z]+\d+", m_name): 130 | scope_names = re.split(r"(\d+)", m_name) 131 | else: 132 | scope_names = [m_name] 133 | if scope_names[0] == "w" or scope_names[0] == "g": 134 | pointer = getattr(pointer, "weight") 135 | elif scope_names[0] == "b": 136 | pointer = getattr(pointer, "bias") 137 | elif scope_names[0] == "wpe" or scope_names[0] == "wte": 138 | pointer = getattr(pointer, scope_names[0]) 139 | pointer = getattr(pointer, "weight") 140 | else: 141 | pointer = getattr(pointer, scope_names[0]) 142 | if len(scope_names) >= 2: 143 | num = int(scope_names[1]) 144 | pointer = pointer[num] 145 | try: 146 | assert ( 147 | pointer.shape == array.shape 148 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 149 | except AssertionError as e: 150 | e.args += (pointer.shape, array.shape) 151 | raise 152 | logger.info("Initialize PyTorch weight {}".format(name)) 153 | pointer.data = torch.from_numpy(array) 154 | return model 155 | 156 | 157 | class Attention(nn.Module): 158 | def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False): 159 | super().__init__() 160 | 161 | # don't allow cross attention for now 162 | assert(not is_cross_attention) 163 | 164 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 165 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 166 | assert n_state % config.num_attention_heads == 0 167 | self.register_buffer( 168 | "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx) 169 | ) 170 | self.register_buffer("masked_bias", torch.tensor(-1e4)) 171 | self.n_head = config.num_attention_heads 172 | self.split_size = n_state 173 | self.scale = scale 174 | self.is_cross_attention = is_cross_attention 175 | if self.is_cross_attention: 176 | assert(False) # PETER: not allowed in this version 177 | #self.c_attn = Conv1D(2 * n_state, nx) 178 | #self.q_attn = Conv1D(n_state, nx) 179 | else: 180 | # PETER: opengpt2/grover breaks c_attn into 3 convolutions 181 | self.c_attn_q = Conv1D(n_state, nx) 182 | self.c_attn_k = Conv1D(n_state, nx) 183 | self.c_attn_v = Conv1D(n_state, nx) 184 | self.c_proj = Conv1D(n_state, nx) 185 | self.attn_dropout = nn.Dropout(config.attention_probs_dropout_prob) 186 | self.resid_dropout = nn.Dropout(config.hidden_dropout_prob) 187 | self.pruned_heads = set() 188 | 189 | def prune_heads(self, heads): 190 | assert(False) # PETER: not implemented for opengpt2 (code below still uses gpt2 params) 191 | 192 | if len(heads) == 0: 193 | return 194 | heads, index = find_pruneable_heads_and_indices( 195 | heads, self.n_head, self.split_size // self.n_head, self.pruned_heads 196 | ) 197 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 198 | 199 | # Prune conv1d layers 200 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 201 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 202 | 203 | # Update hyper params 204 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) 205 | self.n_head = self.n_head - len(heads) 206 | self.pruned_heads = self.pruned_heads.union(heads) 207 | 208 | def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): 209 | w = torch.matmul(q, k) 210 | if self.scale: 211 | w = w / (float(v.size(-1)) ** 0.5) 212 | nd, ns = w.size(-2), w.size(-1) 213 | 214 | ## good up to here 215 | 216 | if not self.is_cross_attention: 217 | # if only "normal" attention layer implements causal mask 218 | nd, ns = w.size(-2), w.size(-1) 219 | b = self.bias[:, :, ns-nd:ns, :ns] 220 | w = w * b - 1e4 * (1 - b) 221 | 222 | # PETER: this was the huggingface code, replaced with mine above to match Grover code 223 | #mask = self.bias[:, :, ns - nd : ns, :ns] 224 | #w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) 225 | 226 | if attention_mask is not None: 227 | # Apply the attention mask 228 | w = w + attention_mask 229 | 230 | w = nn.Softmax(dim=-1)(w) 231 | 232 | # PETER: not included in Grover tf code 233 | #w = self.attn_dropout(w) 234 | 235 | # PETER: not included in Grover tf code 236 | # Mask heads if we want to 237 | # if head_mask is not None: 238 | # w = w * head_mask 239 | 240 | outputs = (torch.matmul(w, v),) 241 | if output_attentions: 242 | outputs += (w,) 243 | return outputs 244 | 245 | def merge_heads(self, x): 246 | x = x.permute(0, 2, 1, 3).contiguous() 247 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 248 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 249 | 250 | def split_heads(self, x, k=False): 251 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 252 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 253 | if k: 254 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 255 | else: 256 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 257 | 258 | def forward( 259 | self, 260 | hidden_states, 261 | layer_past=None, 262 | attention_mask=None, 263 | head_mask=None, 264 | encoder_hidden_states=None, 265 | encoder_attention_mask=None, 266 | use_cache=False, 267 | output_attentions=False, 268 | ): 269 | if encoder_hidden_states is not None: 270 | assert(False) # PETER: for now, encoder_hidden_states functionality is not included 271 | 272 | 273 | assert hasattr( 274 | self, "q_attn" 275 | ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`." 276 | query = self.q_attn(hidden_states) 277 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 278 | attention_mask = encoder_attention_mask 279 | else: 280 | query, key, value = self.c_attn_q(hidden_states), self.c_attn_k(hidden_states), self.c_attn_v(hidden_states) 281 | 282 | # PETER: replacing below code with above (to handle Conv1D difference 283 | #query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 284 | 285 | 286 | 287 | query = self.split_heads(query) 288 | key = self.split_heads(key, k=True) 289 | value = self.split_heads(value) 290 | if layer_past is not None: 291 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 292 | key = torch.cat((past_key, key), dim=-1) 293 | value = torch.cat((past_value, value), dim=-2) 294 | 295 | if use_cache is True: 296 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 297 | else: 298 | present = None 299 | 300 | 301 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) 302 | 303 | a = attn_outputs[0] 304 | 305 | a = self.merge_heads(a) 306 | a = self.c_proj(a) 307 | a = self.resid_dropout(a) 308 | 309 | return (a, present) + attn_outputs[1:] # a, present, (attentions) 310 | 311 | 312 | # PETER: added this for the residual_MLP below 313 | def gelu(x): 314 | cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 315 | return x * cdf 316 | 317 | # PETER: this is rewritten, different enough from MLP (in original OpenAI gpt2) to 318 | # just write it from scratch 319 | class residual_MLP(nn.Module): 320 | def __init__(self, intermediate_size, config): # in MLP: n_state=3072 (4 * n_embd) 321 | super(residual_MLP, self).__init__() 322 | nx = config.hidden_size 323 | 324 | self.act = gelu 325 | self.linear_intermediate = nn.Linear(nx, intermediate_size) 326 | self.linear_output = nn.Linear(intermediate_size,nx) 327 | self.ln_0 = nn.LayerNorm(nx, eps=1e-5) 328 | self.ln_1 = nn.LayerNorm(nx, eps=1e-5) 329 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 330 | 331 | def forward(self, x): 332 | x_norm = self.ln_0(x) 333 | intermediate = self.act(self.linear_intermediate(x_norm)) 334 | output_for_resid = self.dropout(self.linear_output(intermediate)) 335 | layer_output = self.ln_1(x + output_for_resid) 336 | return layer_output 337 | 338 | 339 | 340 | class Block(nn.Module): 341 | def __init__(self, n_ctx, config, scale=False): 342 | super().__init__() 343 | hidden_size = config.hidden_size 344 | 345 | # PETER: changed this to match grover config 346 | inner_dim = config.intermediate_size #config.n_inner if config.n_inner is not None else 4 * hidden_size 347 | # self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 348 | self.attn = Attention(hidden_size, n_ctx, config, scale) 349 | # self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 350 | 351 | # PETER: just removed the cross-attention option below 352 | if config.add_cross_attention: 353 | assert(False) 354 | # self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True) 355 | # self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 356 | 357 | # PETER: using our redisual_MLP 358 | self.mlp = residual_MLP(config.intermediate_size, config) # MLP(inner_dim, config) 359 | 360 | def forward( 361 | self, 362 | hidden_states, 363 | layer_past=None, 364 | attention_mask=None, 365 | head_mask=None, 366 | encoder_hidden_states=None, 367 | encoder_attention_mask=None, 368 | use_cache=False, 369 | output_attentions=False, 370 | ): 371 | 372 | attn_outputs = self.attn( 373 | hidden_states, 374 | layer_past=layer_past, 375 | attention_mask=attention_mask, 376 | head_mask=head_mask, 377 | use_cache=use_cache, 378 | output_attentions=output_attentions, 379 | ) 380 | 381 | # PETER: this code is replaced with above, 382 | #attn_outputs = self.attn( 383 | # self.ln_1(hidden_states), 384 | # layer_past=layer_past, 385 | # attention_mask=attention_mask, 386 | # head_mask=head_mask, 387 | # use_cache=use_cache, 388 | # output_attentions=output_attentions, 389 | #) 390 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 391 | 392 | 393 | outputs = attn_outputs[1:] 394 | # residual connection 395 | hidden_states = attn_output + hidden_states 396 | 397 | if encoder_hidden_states is not None: 398 | # add one self-attention block for cross-attention 399 | assert(False) # PETER: for now, encoder_hidden_states functionality is not included 400 | 401 | assert hasattr( 402 | self, "crossattention" 403 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 404 | cross_attn_outputs = self.crossattention( 405 | self.ln_cross_attn(hidden_states), 406 | attention_mask=attention_mask, 407 | head_mask=head_mask, 408 | encoder_hidden_states=encoder_hidden_states, 409 | encoder_attention_mask=encoder_attention_mask, 410 | output_attentions=output_attentions, 411 | ) 412 | attn_output = cross_attn_outputs[0] 413 | # residual connection 414 | hidden_states = hidden_states + attn_output 415 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 416 | 417 | feed_forward_hidden_states = self.mlp(hidden_states) 418 | 419 | # PETER: replaced this code with above, because no layer-norm here in Grover code 420 | #feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) 421 | 422 | # PETER: we don't do this residual connection in Grover 423 | # residual connection 424 | #hidden_states = hidden_states + feed_forward_hidden_states 425 | 426 | if use_cache: 427 | outputs = (feed_forward_hidden_states,) + outputs 428 | else: 429 | outputs = (feed_forward_hidden_states,) + outputs[1:] 430 | 431 | return outputs # hidden_states, present, (attentions, cross_attentions) 432 | 433 | 434 | class OpenGPT2PreTrainedModel(PreTrainedModel): 435 | """ 436 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 437 | models. 438 | """ 439 | 440 | config_class = OpenGPT2Config 441 | load_tf_weights = load_tf_weights_in_gpt2 442 | base_model_prefix = "transformer" 443 | is_parallelizable = True 444 | 445 | def __init__(self, *inputs, **kwargs): 446 | super().__init__(*inputs, **kwargs) 447 | 448 | def _init_weights(self, module): 449 | """Initialize the weights.""" 450 | if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): 451 | # Slightly different from the TF version which uses truncated_normal for initialization 452 | # cf https://github.com/pytorch/pytorch/pull/5617 453 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 454 | if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: 455 | module.bias.data.zero_() 456 | elif isinstance(module, nn.LayerNorm): 457 | module.bias.data.zero_() 458 | module.weight.data.fill_(1.0) 459 | 460 | 461 | @dataclass 462 | class OpenGPT2DoubleHeadsModelOutput(ModelOutput): 463 | """ 464 | Base class for outputs of models predicting if two sentences are consecutive or not. 465 | 466 | Args: 467 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): 468 | Language modeling loss. 469 | mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided): 470 | Multiple choice classification loss. 471 | logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): 472 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 473 | mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): 474 | Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). 475 | past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): 476 | List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, 477 | batch_size, num_heads, sequence_length, embed_size_per_head)`). 478 | 479 | Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see 480 | :obj:`past_key_values` input) to speed up sequential decoding. 481 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): 482 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 483 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 484 | 485 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 486 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): 487 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, 488 | sequence_length, sequence_length)`. 489 | 490 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 491 | heads. 492 | """ 493 | 494 | loss: Optional[torch.FloatTensor] = None 495 | mc_loss: Optional[torch.FloatTensor] = None 496 | logits: torch.FloatTensor = None 497 | mc_logits: torch.FloatTensor = None 498 | past_key_values: Optional[List[torch.FloatTensor]] = None 499 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 500 | attentions: Optional[Tuple[torch.FloatTensor]] = None 501 | 502 | 503 | OpenGPT2_START_DOCSTRING = r""" 504 | 505 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 506 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 507 | pruning heads etc.) 508 | 509 | This model is also a PyTorch `torch.nn.Module `__ 510 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 511 | general usage and behavior. 512 | 513 | Parameters: 514 | config (:class:`~transformers.OpenGPT2Config`): Model configuration class with all the parameters of the model. 515 | Initializing with a config file does not load the weights associated with the model, only the 516 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model 517 | weights. 518 | """ 519 | 520 | OpenGPT2_INPUTS_DOCSTRING = r""" 521 | Args: 522 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): 523 | :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else 524 | ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input 525 | sequence tokens in the vocabulary. 526 | 527 | If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be 528 | passed as ``input_ids``. 529 | 530 | Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See 531 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 532 | details. 533 | 534 | `What are input IDs? <../glossary.html#input-ids>`__ 535 | past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): 536 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see 537 | :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which 538 | have their past given to this model should not be passed as ``input_ids`` as they have already been 539 | computed. 540 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 541 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 542 | 543 | - 1 for tokens that are **not masked**, 544 | - 0 for tokens that are **masked**. 545 | 546 | `What are attention masks? <../glossary.html#attention-mask>`__ 547 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`): 548 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 549 | 1]``: 550 | 551 | - 0 corresponds to a `sentence A` token, 552 | - 1 corresponds to a `sentence B` token. 553 | 554 | `What are token type IDs? <../glossary.html#token-type-ids>`_ 555 | position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 556 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, 557 | config.max_position_embeddings - 1]``. 558 | 559 | `What are position IDs? <../glossary.html#position-ids>`_ 560 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): 561 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: 562 | 563 | - 1 indicates the head is **not masked**, 564 | - 0 indicates the head is **masked**. 565 | 566 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 567 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 568 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 569 | vectors than the model's internal embedding lookup matrix. 570 | 571 | If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see 572 | :obj:`past_key_values`). 573 | use_cache (:obj:`bool`, `optional`): 574 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 575 | decoding (see :obj:`past_key_values`). 576 | output_attentions (:obj:`bool`, `optional`): 577 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 578 | tensors for more detail. 579 | output_hidden_states (:obj:`bool`, `optional`): 580 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 581 | more detail. 582 | return_dict (:obj:`bool`, `optional`): 583 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 584 | """ 585 | PARALLELIZE_DOCSTRING = r""" 586 | This is an experimental feature and is a subject to change at a moment's notice. 587 | 588 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given, 589 | it will evenly distribute blocks across all devices. 590 | 591 | Args: 592 | device_map (:obj:`Dict[int, list]`, optional, defaults to None): 593 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always 594 | automatically mapped to the first device (for esoteric reasons). That means that the first device should 595 | have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the 596 | following number of attention modules: 597 | 598 | - gpt2: 12 599 | - gpt2-medium: 24 600 | - gpt2-large: 36 601 | - gpt2-xl: 48 602 | 603 | Example:: 604 | 605 | # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: 606 | model = OpenGPT2LMHeadModel.from_pretrained('gpt2-xl') 607 | device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8], 608 | 609 | 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], 610 | 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], 611 | 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]} 612 | model.parallelize(device_map) 613 | """ 614 | DEPARALLELIZE_DOCSTRING = r""" 615 | Moves the model to cpu from a model parallel state. 616 | 617 | Example:: 618 | 619 | # On a 4 GPU machine with gpt2-large: 620 | model = OpenGPT2LMHeadModel.from_pretrained('gpt2-large') 621 | device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7], 622 | 623 | 1: [8, 9, 10, 11, 12, 13, 14, 15], 624 | 2: [16, 17, 18, 19, 20, 21, 22, 23], 625 | 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]} 626 | model.parallelize(device_map) # Splits the model across several devices 627 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() 628 | """ 629 | 630 | 631 | @add_start_docstrings( 632 | "The bare OpenGPT2 Model transformer outputting raw hidden-states without any specific head on top.", 633 | OpenGPT2_START_DOCSTRING, 634 | ) 635 | class OpenGPT2Model(OpenGPT2PreTrainedModel): 636 | def __init__(self, config): 637 | super().__init__(config) 638 | 639 | 640 | self.output_hidden_states = True #config.output_hidden_states 641 | self.output_attentions = True #config.output_attentions 642 | 643 | 644 | self.wte = nn.Embedding(config.vocab_size, config.hidden_size)#config.n_embd) 645 | self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)# n_positions, config.n_embd) 646 | self.drop = nn.Dropout(config.hidden_dropout_prob) #embd_pdrop) 647 | self.h = nn.ModuleList([Block(config.max_position_embeddings, config, scale=True) for _ in range(config.num_hidden_layers)]) 648 | self.ln_embed = nn.LayerNorm(config.hidden_size, eps=1e-5)# 649 | 650 | # PETER: replaced the below block with above (to match Grover config) 651 | 652 | #self.wte = nn.Embedding(config.vocab_size, config.n_embd) 653 | #self.wpe = nn.Embedding(config.n_positions, config.n_embd) 654 | #self.drop = nn.Dropout(config.embd_pdrop) 655 | #self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) 656 | #self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 657 | 658 | 659 | self.init_weights() 660 | # Model parallel 661 | self.model_parallel = False 662 | self.device_map = None 663 | 664 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 665 | def parallelize(self, device_map=None): 666 | # Check validity of device_map 667 | self.device_map = ( 668 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map 669 | ) 670 | assert_device_map(self.device_map, len(self.h)) 671 | self.model_parallel = True 672 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 673 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 674 | self.wte = self.wte.to(self.first_device) 675 | self.wpe = self.wpe.to(self.first_device) 676 | # Load onto devices 677 | for k, v in self.device_map.items(): 678 | for block in v: 679 | cuda_device = "cuda:" + str(k) 680 | self.h[block] = self.h[block].to(cuda_device) 681 | # ln_f to last 682 | self.ln_embed = self.ln_embed.to(self.last_device) 683 | 684 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 685 | def deparallelize(self): 686 | self.model_parallel = False 687 | self.device_map = None 688 | self.first_device = "cpu" 689 | self.last_device = "cpu" 690 | self.wte = self.wte.to("cpu") 691 | self.wpe = self.wpe.to("cpu") 692 | for index in range(len(self.h)): 693 | self.h[index] = self.h[index].to("cpu") 694 | self.ln_embed = self.ln_embed.to("cpu") 695 | torch.cuda.empty_cache() 696 | 697 | def get_input_embeddings(self): 698 | return self.wte 699 | 700 | def set_input_embeddings(self, new_embeddings): 701 | self.wte = new_embeddings 702 | 703 | def _prune_heads(self, heads_to_prune): 704 | """ 705 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 706 | """ 707 | for layer, heads in heads_to_prune.items(): 708 | self.h[layer].attn.prune_heads(heads) 709 | 710 | @add_start_docstrings_to_model_forward(OpenGPT2_INPUTS_DOCSTRING) 711 | @add_code_sample_docstrings( 712 | tokenizer_class=_TOKENIZER_FOR_DOC, 713 | checkpoint="gpt2", 714 | output_type=BaseModelOutputWithPastAndCrossAttentions, 715 | config_class=_CONFIG_FOR_DOC, 716 | ) 717 | def forward( 718 | self, 719 | input_ids=None, 720 | past_key_values=None, 721 | attention_mask=None, 722 | token_type_ids=None, 723 | position_ids=None, 724 | head_mask=None, 725 | inputs_embeds=None, 726 | encoder_hidden_states=None, 727 | encoder_attention_mask=None, 728 | use_cache=None, 729 | output_attentions=None, 730 | output_hidden_states=None, 731 | return_dict=None, 732 | ): 733 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 734 | output_hidden_states = ( 735 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 736 | ) 737 | use_cache = use_cache if use_cache is not None else self.config.use_cache 738 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 739 | 740 | if input_ids is not None and inputs_embeds is not None: 741 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 742 | elif input_ids is not None: 743 | input_shape = input_ids.size() 744 | input_ids = input_ids.view(-1, input_shape[-1]) 745 | batch_size = input_ids.shape[0] 746 | elif inputs_embeds is not None: 747 | input_shape = inputs_embeds.size()[:-1] 748 | batch_size = inputs_embeds.shape[0] 749 | else: 750 | raise ValueError("You have to specify either input_ids or inputs_embeds") 751 | 752 | if token_type_ids is not None: 753 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 754 | if position_ids is not None: 755 | position_ids = position_ids.view(-1, input_shape[-1]) 756 | 757 | if past_key_values is None: 758 | past_length = 0 759 | past_key_values = [None] * len(self.h) 760 | else: 761 | past_length = past_key_values[0][0].size(-2) 762 | if position_ids is None: 763 | device = input_ids.device if input_ids is not None else inputs_embeds.device 764 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 765 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 766 | 767 | # Attention mask. 768 | if attention_mask is not None: 769 | assert batch_size > 0, "batch_size has to be defined and > 0" 770 | attention_mask = attention_mask.view(batch_size, -1) 771 | # We create a 3D attention mask from a 2D tensor mask. 772 | # Sizes are [batch_size, 1, 1, to_seq_length] 773 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 774 | # this attention mask is more simple than the triangular masking of causal attention 775 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 776 | attention_mask = attention_mask[:, None, None, :] 777 | 778 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 779 | # masked positions, this operation will create a tensor which is 0.0 for 780 | # positions we want to attend and -10000.0 for masked positions. 781 | # Since we are adding it to the raw scores before the softmax, this is 782 | # effectively the same as removing these entirely. 783 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 784 | attention_mask = (1.0 - attention_mask) * -10000.0 785 | 786 | # If a 2D ou 3D attention mask is provided for the cross-attention 787 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 788 | if self.config.add_cross_attention and encoder_hidden_states is not None: 789 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 790 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 791 | if encoder_attention_mask is None: 792 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 793 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 794 | else: 795 | encoder_attention_mask = None 796 | 797 | # Prepare head mask if needed 798 | # 1.0 in head_mask indicate we keep the head 799 | # attention_probs has shape bsz x n_heads x N x N 800 | # head_mask has shape n_layer x batch x n_heads x N x N 801 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 802 | 803 | if inputs_embeds is None: 804 | inputs_embeds = self.wte(input_ids) 805 | position_embeds = self.wpe(position_ids) 806 | hidden_states = inputs_embeds + position_embeds 807 | 808 | if token_type_ids is not None: 809 | print('TOKEN TYPES') 810 | token_type_embeds = self.wte(token_type_ids) 811 | hidden_states = hidden_states + token_type_embeds 812 | 813 | hidden_states = self.ln_embed(hidden_states) 814 | 815 | hidden_states = self.drop(hidden_states) 816 | 817 | output_shape = input_shape + (hidden_states.size(-1),) 818 | 819 | presents = () if use_cache else None 820 | all_self_attentions = () if output_attentions else None 821 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 822 | all_hidden_states = () if output_hidden_states else None 823 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 824 | # Model parallel 825 | if self.model_parallel: 826 | torch.cuda.set_device(hidden_states.device) 827 | # Ensure layer_past is on same device as hidden_states (might not be correct) 828 | if layer_past is not None: 829 | layer_past = layer_past.to(hidden_states.device) 830 | # Ensure that attention_mask is always on the same device as hidden_states 831 | if attention_mask is not None: 832 | attention_mask = attention_mask.to(hidden_states.device) 833 | if isinstance(head_mask, torch.Tensor): 834 | head_mask = head_mask.to(hidden_states.device) 835 | if output_hidden_states: 836 | all_hidden_states = all_hidden_states + (hidden_states,) 837 | 838 | if getattr(self.config, "gradient_checkpointing", False): 839 | 840 | def create_custom_forward(module): 841 | def custom_forward(*inputs): 842 | # checkpointing only works with tuple returns, not with lists 843 | return tuple(output for output in module(*inputs, use_cache, output_attentions)) 844 | 845 | return custom_forward 846 | 847 | outputs = torch.utils.checkpoint.checkpoint( 848 | create_custom_forward(block), 849 | hidden_states, 850 | layer_past, 851 | attention_mask, 852 | head_mask[i], 853 | encoder_hidden_states, 854 | encoder_attention_mask, 855 | ) 856 | else: 857 | outputs = block( 858 | hidden_states, 859 | layer_past=layer_past, 860 | attention_mask=attention_mask, 861 | head_mask=head_mask[i], 862 | encoder_hidden_states=encoder_hidden_states, 863 | encoder_attention_mask=encoder_attention_mask, 864 | use_cache=use_cache, 865 | output_attentions=output_attentions, 866 | ) 867 | 868 | hidden_states = outputs[0] 869 | if use_cache is True: 870 | presents = presents + (outputs[1],) 871 | 872 | if output_attentions: 873 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 874 | if self.config.add_cross_attention: 875 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) 876 | 877 | # Model Parallel: If it's the last layer for that device, put things on the next device 878 | if self.model_parallel: 879 | for k, v in self.device_map.items(): 880 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 881 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 882 | 883 | # PETER: actually, ln_embed should be applied way earlier... 884 | #hidden_states = self.ln_embed(hidden_states) 885 | # PETER: replaced below line with above for naming purposes 886 | # hidden_states = self.ln_f(hidden_states) 887 | 888 | hidden_states = hidden_states.view(*output_shape) 889 | # Add last hidden state 890 | if output_hidden_states: 891 | all_hidden_states = all_hidden_states + (hidden_states,) 892 | 893 | if not return_dict: 894 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) 895 | 896 | return BaseModelOutputWithPastAndCrossAttentions( 897 | last_hidden_state=hidden_states, 898 | past_key_values=presents, 899 | hidden_states=all_hidden_states, 900 | attentions=all_self_attentions, 901 | cross_attentions=all_cross_attentions, 902 | ) 903 | 904 | 905 | @add_start_docstrings( 906 | """ 907 | The OpenGPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input 908 | embeddings). 909 | """, 910 | OpenGPT2_START_DOCSTRING, 911 | ) 912 | class OpenGPT2LMHeadModel(OpenGPT2PreTrainedModel): 913 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] 914 | 915 | def __init__(self, config): 916 | super().__init__(config) 917 | self.transformer = OpenGPT2Model(config) 918 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 919 | 920 | self.init_weights() 921 | 922 | self.model_parallel = False 923 | 924 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 925 | def parallelize(self, device_map=None): 926 | self.device_map = ( 927 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 928 | if device_map is None 929 | else device_map 930 | ) 931 | assert_device_map(self.device_map, len(self.transformer.h)) 932 | self.transformer.parallelize(self.device_map) 933 | self.lm_head = self.lm_head.to(self.transformer.first_device) 934 | self.model_parallel = True 935 | 936 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 937 | def deparallelize(self): 938 | self.transformer.deparallelize() 939 | self.transformer = self.transformer.to("cpu") 940 | self.lm_head = self.lm_head.to("cpu") 941 | self.model_parallel = False 942 | torch.cuda.empty_cache() 943 | 944 | def get_output_embeddings(self): 945 | return self.lm_head 946 | 947 | def set_output_embeddings(self, new_embeddings): 948 | self.lm_head = new_embeddings 949 | 950 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 951 | token_type_ids = kwargs.get("token_type_ids", None) 952 | # only last token for inputs_ids if past is defined in kwargs 953 | if past: 954 | input_ids = input_ids[:, -1].unsqueeze(-1) 955 | if token_type_ids is not None: 956 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 957 | 958 | attention_mask = kwargs.get("attention_mask", None) 959 | position_ids = kwargs.get("position_ids", None) 960 | 961 | if attention_mask is not None and position_ids is None: 962 | # create position_ids on the fly for batch generation 963 | position_ids = attention_mask.long().cumsum(-1) - 1 964 | position_ids.masked_fill_(attention_mask == 0, 1) 965 | if past: 966 | position_ids = position_ids[:, -1].unsqueeze(-1) 967 | else: 968 | position_ids = None 969 | return { 970 | "input_ids": input_ids, 971 | "past_key_values": past, 972 | "use_cache": kwargs.get("use_cache"), 973 | "position_ids": position_ids, 974 | "attention_mask": attention_mask, 975 | "token_type_ids": token_type_ids, 976 | } 977 | 978 | @add_start_docstrings_to_model_forward(OpenGPT2_INPUTS_DOCSTRING) 979 | @add_code_sample_docstrings( 980 | tokenizer_class=_TOKENIZER_FOR_DOC, 981 | checkpoint="gpt2", 982 | output_type=CausalLMOutputWithCrossAttentions, 983 | config_class=_CONFIG_FOR_DOC, 984 | ) 985 | def forward( 986 | self, 987 | input_ids=None, 988 | past_key_values=None, 989 | attention_mask=None, 990 | token_type_ids=None, 991 | position_ids=None, 992 | head_mask=None, 993 | inputs_embeds=None, 994 | encoder_hidden_states=None, 995 | encoder_attention_mask=None, 996 | labels=None, 997 | use_cache=None, 998 | output_attentions=None, 999 | output_hidden_states=None, 1000 | return_dict=None, 1001 | ): 1002 | r""" 1003 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1004 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 1005 | ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to 1006 | ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` 1007 | """ 1008 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1009 | 1010 | transformer_outputs = self.transformer( 1011 | input_ids, 1012 | past_key_values=past_key_values, 1013 | attention_mask=attention_mask, 1014 | token_type_ids=token_type_ids, 1015 | position_ids=position_ids, 1016 | head_mask=head_mask, 1017 | inputs_embeds=inputs_embeds, 1018 | encoder_hidden_states=encoder_hidden_states, 1019 | encoder_attention_mask=encoder_attention_mask, 1020 | use_cache=use_cache, 1021 | output_attentions=output_attentions, 1022 | output_hidden_states=output_hidden_states, 1023 | return_dict=return_dict, 1024 | ) 1025 | hidden_states = transformer_outputs[0] 1026 | 1027 | # Set device for model parallelism 1028 | if self.model_parallel: 1029 | torch.cuda.set_device(self.transformer.first_device) 1030 | hidden_states = hidden_states.to(self.lm_head.weight.device) 1031 | 1032 | lm_logits = self.lm_head(hidden_states) 1033 | 1034 | loss = None 1035 | if labels is not None: 1036 | # Shift so that tokens < n predict n 1037 | shift_logits = lm_logits[..., :-1, :].contiguous() 1038 | shift_labels = labels[..., 1:].contiguous() 1039 | # Flatten the tokens 1040 | loss_fct = CrossEntropyLoss() 1041 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1042 | 1043 | if not return_dict: 1044 | output = (lm_logits,) + transformer_outputs[1:] 1045 | return ((loss,) + output) if loss is not None else output 1046 | 1047 | return CausalLMOutputWithCrossAttentions( 1048 | loss=loss, 1049 | logits=lm_logits, 1050 | past_key_values=transformer_outputs.past_key_values, 1051 | hidden_states=transformer_outputs.hidden_states, 1052 | attentions=transformer_outputs.attentions, 1053 | cross_attentions=transformer_outputs.cross_attentions, 1054 | ) 1055 | 1056 | 1057 | @add_start_docstrings( 1058 | """ 1059 | The OpenGPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for 1060 | RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the 1061 | input embeddings, the classification head takes as input the input of a specified classification token index in the 1062 | input sequence). 1063 | """, 1064 | OpenGPT2_START_DOCSTRING, 1065 | ) 1066 | class OpenGPT2DoubleHeadsModel(OpenGPT2PreTrainedModel): 1067 | def __init__(self, config): 1068 | super().__init__(config) 1069 | config.num_labels = 1 1070 | self.transformer = OpenGPT2Model(config) 1071 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 1072 | self.multiple_choice_head = SequenceSummary(config) 1073 | 1074 | self.init_weights() 1075 | 1076 | def get_output_embeddings(self): 1077 | return self.lm_head 1078 | 1079 | def set_output_embeddings(self, new_embeddings): 1080 | self.lm_head = new_embeddings 1081 | 1082 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 1083 | token_type_ids = kwargs.get("token_type_ids", None) 1084 | # only last token for inputs_ids if past is defined in kwargs 1085 | if past: 1086 | input_ids = input_ids[:, -1].unsqueeze(-1) 1087 | if token_type_ids is not None: 1088 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 1089 | 1090 | attention_mask = kwargs.get("attention_mask", None) 1091 | position_ids = kwargs.get("position_ids", None) 1092 | 1093 | if attention_mask is not None and position_ids is None: 1094 | # create position_ids on the fly for batch generation 1095 | position_ids = attention_mask.long().cumsum(-1) - 1 1096 | position_ids.masked_fill_(attention_mask == 0, 1) 1097 | if past: 1098 | position_ids = position_ids[:, -1].unsqueeze(-1) 1099 | else: 1100 | position_ids = None 1101 | 1102 | return { 1103 | "input_ids": input_ids, 1104 | "past_key_values": past, 1105 | "use_cache": kwargs.get("use_cache"), 1106 | "position_ids": position_ids, 1107 | "attention_mask": attention_mask, 1108 | "token_type_ids": token_type_ids, 1109 | } 1110 | 1111 | @add_start_docstrings_to_model_forward(OpenGPT2_INPUTS_DOCSTRING) 1112 | @replace_return_docstrings(output_type=OpenGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) 1113 | def forward( 1114 | self, 1115 | input_ids=None, 1116 | past_key_values=None, 1117 | attention_mask=None, 1118 | token_type_ids=None, 1119 | position_ids=None, 1120 | head_mask=None, 1121 | inputs_embeds=None, 1122 | mc_token_ids=None, 1123 | labels=None, 1124 | mc_labels=None, 1125 | use_cache=None, 1126 | output_attentions=None, 1127 | output_hidden_states=None, 1128 | return_dict=None, 1129 | **kwargs, 1130 | ): 1131 | r""" 1132 | mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input): 1133 | Index of the classification token in each input sequence. Selected in the range ``[0, input_ids.size(-1) - 1134 | 1[``. 1135 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1136 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 1137 | ``labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` All labels set to 1138 | ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` 1139 | mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`): 1140 | Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., 1141 | num_choices]`` where `num_choices` is the size of the second dimension of the input tensors. (see 1142 | `input_ids` above) 1143 | 1144 | Return: 1145 | 1146 | Example:: 1147 | 1148 | >>> import torch 1149 | >>> from transformers import OpenGPT2Tokenizer, OpenGPT2DoubleHeadsModel 1150 | 1151 | >>> tokenizer = OpenGPT2Tokenizer.from_pretrained('gpt2') 1152 | >>> model = OpenGPT2DoubleHeadsModel.from_pretrained('gpt2') 1153 | 1154 | >>> # Add a [CLS] to the vocabulary (we should train it also!) 1155 | >>> num_added_tokens = tokenizer.add_special_tokens({'cls_token': '[CLS]'}) 1156 | 1157 | >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size 1158 | 1159 | >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] 1160 | >>> encoded_choices = [tokenizer.encode(s) for s in choices] 1161 | >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] 1162 | 1163 | >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 1164 | >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 1165 | 1166 | >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) 1167 | >>> lm_logits = outputs.logits 1168 | >>> mc_logits = outputs.mc_logits 1169 | 1170 | """ 1171 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1172 | 1173 | transformer_outputs = self.transformer( 1174 | input_ids, 1175 | past_key_values=past_key_values, 1176 | attention_mask=attention_mask, 1177 | token_type_ids=token_type_ids, 1178 | position_ids=position_ids, 1179 | head_mask=head_mask, 1180 | inputs_embeds=inputs_embeds, 1181 | use_cache=use_cache, 1182 | output_attentions=output_attentions, 1183 | output_hidden_states=output_hidden_states, 1184 | return_dict=return_dict, 1185 | ) 1186 | 1187 | hidden_states = transformer_outputs[0] 1188 | 1189 | lm_logits = self.lm_head(hidden_states) 1190 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) 1191 | 1192 | mc_loss = None 1193 | if mc_labels is not None: 1194 | loss_fct = CrossEntropyLoss() 1195 | mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) 1196 | lm_loss = None 1197 | if labels is not None: 1198 | shift_logits = lm_logits[..., :-1, :].contiguous() 1199 | shift_labels = labels[..., 1:].contiguous() 1200 | loss_fct = CrossEntropyLoss() 1201 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1202 | 1203 | if not return_dict: 1204 | output = (lm_logits, mc_logits) + transformer_outputs[1:] 1205 | if mc_loss is not None: 1206 | output = (mc_loss,) + output 1207 | return ((lm_loss,) + output) if lm_loss is not None else output 1208 | 1209 | return OpenGPT2DoubleHeadsModelOutput( 1210 | loss=lm_loss, 1211 | mc_loss=mc_loss, 1212 | logits=lm_logits, 1213 | mc_logits=mc_logits, 1214 | past_key_values=transformer_outputs.past_key_values, 1215 | hidden_states=transformer_outputs.hidden_states, 1216 | attentions=transformer_outputs.attentions, 1217 | ) 1218 | 1219 | 1220 | @add_start_docstrings( 1221 | """ 1222 | The OpenGPT2 Model transformer with a sequence classification head on top (linear layer). 1223 | 1224 | :class:`~transformers.OpenGPT2ForSequenceClassification` uses the last token in order to do the classification, as 1225 | other causal models (e.g. GPT-1) do. 1226 | 1227 | Since it does classification on the last token, it requires to know the position of the last token. If a 1228 | :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each 1229 | row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot 1230 | guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take 1231 | the last value in each row of the batch). 1232 | """, 1233 | OpenGPT2_START_DOCSTRING, 1234 | ) 1235 | class OpenGPT2ForSequenceClassification(OpenGPT2PreTrainedModel): 1236 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] 1237 | 1238 | def __init__(self, config): 1239 | super().__init__(config) 1240 | self.num_labels = config.num_labels 1241 | self.transformer = OpenGPT2Model(config) 1242 | self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) 1243 | 1244 | self.init_weights() 1245 | 1246 | @add_start_docstrings_to_model_forward(OpenGPT2_INPUTS_DOCSTRING) 1247 | @add_code_sample_docstrings( 1248 | tokenizer_class=_TOKENIZER_FOR_DOC, 1249 | checkpoint="microsoft/dialogrpt", 1250 | output_type=SequenceClassifierOutputWithPast, 1251 | config_class=_CONFIG_FOR_DOC, 1252 | ) 1253 | def forward( 1254 | self, 1255 | input_ids=None, 1256 | past_key_values=None, 1257 | attention_mask=None, 1258 | token_type_ids=None, 1259 | position_ids=None, 1260 | head_mask=None, 1261 | inputs_embeds=None, 1262 | labels=None, 1263 | use_cache=None, 1264 | output_attentions=None, 1265 | output_hidden_states=None, 1266 | return_dict=None, 1267 | ): 1268 | r""" 1269 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1270 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 1271 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 1272 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1273 | """ 1274 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1275 | 1276 | transformer_outputs = self.transformer( 1277 | input_ids, 1278 | past_key_values=past_key_values, 1279 | attention_mask=attention_mask, 1280 | token_type_ids=token_type_ids, 1281 | position_ids=position_ids, 1282 | head_mask=head_mask, 1283 | inputs_embeds=inputs_embeds, 1284 | use_cache=use_cache, 1285 | output_attentions=output_attentions, 1286 | output_hidden_states=output_hidden_states, 1287 | return_dict=return_dict, 1288 | ) 1289 | hidden_states = transformer_outputs[0] 1290 | logits = self.score(hidden_states) 1291 | 1292 | if input_ids is not None: 1293 | batch_size, sequence_length = input_ids.shape[:2] 1294 | else: 1295 | batch_size, sequence_length = inputs_embeds.shape[:2] 1296 | 1297 | assert ( 1298 | self.config.pad_token_id is not None or batch_size == 1 1299 | ), "Cannot handle batch sizes > 1 if no padding token is defined." 1300 | if self.config.pad_token_id is None: 1301 | sequence_lengths = -1 1302 | else: 1303 | if input_ids is not None: 1304 | sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 1305 | else: 1306 | sequence_lengths = -1 1307 | logger.warning( 1308 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " 1309 | f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" 1310 | ) 1311 | 1312 | pooled_logits = logits[range(batch_size), sequence_lengths] 1313 | 1314 | loss = None 1315 | if labels is not None: 1316 | if self.num_labels == 1: 1317 | # We are doing regression 1318 | loss_fct = MSELoss() 1319 | loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) 1320 | else: 1321 | loss_fct = CrossEntropyLoss() 1322 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1323 | 1324 | if not return_dict: 1325 | output = (pooled_logits,) + transformer_outputs[1:] 1326 | return ((loss,) + output) if loss is not None else output 1327 | 1328 | return SequenceClassifierOutputWithPast( 1329 | loss=loss, 1330 | logits=pooled_logits, 1331 | past_key_values=transformer_outputs.past_key_values, 1332 | hidden_states=transformer_outputs.hidden_states, 1333 | attentions=transformer_outputs.attentions, 1334 | ) 1335 | -------------------------------------------------------------------------------- /GPT2ForwardBackward/padded_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Oct 25 16:03:52 2019 5 | 6 | @author: peterawest 7 | """ 8 | from transformers import GPT2Tokenizer 9 | 10 | 11 | class Encoder(): 12 | def __init__(self): 13 | self.encoder = GPT2Tokenizer.from_pretrained('gpt2') 14 | 15 | assert(len(self.encoder.encode('<|endoftext|>')) ==1 ) 16 | self.endoftext = self.encoder.encode('<|endoftext|>')[0] 17 | self.padding = 0 18 | 19 | # TODO(qin) 20 | self.vocab_size = 50271 #self.encoder.vocab_size 21 | #print(self.vocab_size) 22 | #exit() 23 | 24 | def encode(self, text): 25 | return [t + 1 for t in self.encoder.encode(text)] 26 | 27 | def decode(self, tokens): 28 | tokens_shifted = [t - 1 for t in tokens if t !=0 ] 29 | if len(tokens_shifted) != len(tokens): 30 | print('WARNING: padding removed from sequence during decoding') 31 | 32 | return self.encoder.decode(tokens_shifted) 33 | 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for "COLD Decoding: Energy-based Constrained Text Generation with Langevin Dynamics" 2 | 3 | This is the code for the following paper: 4 | 5 | [COLD Decoding: Energy-based Constrained Text Generation with Langevin Dynamics] (https://arxiv.org/pdf/2202.11705.pdf) \ 6 | Lianhui Qin, Sean Welleck, Daniel Khashabi, Yejin Choi 7 | 8 | 9 | **1) Setup Environment** 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | **2) Download this Github** 15 | ``` 16 | git clone https://github.com/qkaren/COLD_decoding.git 17 | ``` 18 | 19 | **3) Run Command for COLD Decoding** 20 | 21 | * CommonGen 22 | ``` 23 | sh commongen.sh 24 | ``` 25 | 26 | * Abductive Reasoning 27 | ``` 28 | sh abductive.sh 29 | ``` 30 | 31 | * Counterfactual Reasoning 32 | ``` 33 | sh counterfactual.sh 34 | ``` 35 | 36 | **4) Rank the generation** 37 | 38 | will add soon 39 | -------------------------------------------------------------------------------- /abductive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Abductive 4 | 5 | python3 cold_decoding.py \ 6 | --seed 12 \ 7 | --mode abductive_langevin \ 8 | --pretrained_model gpt2-xl \ 9 | --init-temp 1 \ 10 | --length 10 \ 11 | --max-length 40 \ 12 | --num-iters 2000 \ 13 | --min-iters 1000 \ 14 | --constraint-weight 0.5 \ 15 | --abductive-c2-weight 0.05 \ 16 | --stepsize 0.1 \ 17 | --noise-iters 1 \ 18 | --win-anneal-iters 1000 \ 19 | --start 0 \ 20 | --end 5 \ 21 | --lr-nll-portion 0.6 \ 22 | --topk 2 \ 23 | --output-lgt-temp 1 \ 24 | --verbose \ 25 | --straight-through \ 26 | --large-noise-iters 50,500,1000,1500 \ 27 | --large_gs_std 1,0.5,0.1,0.05 \ 28 | --input-file "./data/abductive/small_data.json" \ 29 | --output-dir "./data/abductive/" \ 30 | --stepsize-ratio 1 \ 31 | --batch-size 16 \ 32 | --print-every 200 33 | -------------------------------------------------------------------------------- /bleuloss.py: -------------------------------------------------------------------------------- 1 | from torch.cuda import LongTensor, FloatTensor 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def batch_log_bleulosscnn_ae(decoder_outputs, target_idx, ngram_list, trans_len=None, pad=0, weight_list=None): 8 | """ 9 | decoder_outputs: [output_len, batch_size, vocab_size] 10 | - matrix with probabilityes -- log probs 11 | target_variable: [batch_size, target_len] 12 | - reference batch 13 | ngram_list: int or List[int] 14 | - n-gram to consider 15 | pad: int 16 | the idx of "pad" token 17 | weight_list : List 18 | corresponding weight of ngram 19 | 20 | NOTE: output_len == target_len 21 | """ 22 | decoder_outputs = decoder_outputs.transpose(0,1) 23 | batch_size, output_len, vocab_size = decoder_outputs.size() 24 | _, tgt_len = target_idx.size() 25 | if type(ngram_list) == int: 26 | ngram_list = [ngram_list] 27 | if ngram_list[0] <= 0: 28 | ngram_list[0] = output_len 29 | if weight_list is None: 30 | weight_list = [1. / len(ngram_list)] * len(ngram_list) 31 | decoder_outputs = torch.log_softmax(decoder_outputs,dim=-1) 32 | decoder_outputs = torch.relu(decoder_outputs + 20) - 20 33 | index = target_idx.unsqueeze(1).expand(-1, output_len, tgt_len) 34 | cost_nll = decoder_outputs.gather(dim=2, index=index) 35 | cost_nll = cost_nll.unsqueeze(1) 36 | out = cost_nll 37 | sum_gram = 0. #FloatTensor([0.]) 38 | ########################### 39 | zero = torch.tensor(0.0).cuda() 40 | target_expand = target_idx.view(batch_size,1,1,-1).expand(-1,-1,output_len,-1) 41 | out = torch.where(target_expand==pad, zero, out) 42 | ############################ 43 | for cnt, ngram in enumerate(ngram_list): 44 | if ngram > output_len: 45 | continue 46 | eye_filter = torch.eye(ngram).view([1, 1, ngram, ngram]).cuda() 47 | term = nn.functional.conv2d(out, eye_filter)/ngram 48 | if ngram < decoder_outputs.size()[1]: 49 | term = term.squeeze(1) 50 | gum_tmp = F.gumbel_softmax(term, tau=1, dim=1) 51 | term = term.mul(gum_tmp).sum(1).mean(1) 52 | else: 53 | while len(term.shape) > 1: 54 | assert term.shape[-1] == 1, str(term.shape) 55 | term = term.sum(-1) 56 | try: 57 | sum_gram += weight_list[cnt] * term 58 | except: 59 | print(sum_gram.shape) 60 | print(term.shape) 61 | print((weight_list[cnt] * term).shape) 62 | print(ngram) 63 | print(decoder_outputs.size()[1]) 64 | assert False 65 | 66 | loss = - sum_gram 67 | return loss 68 | 69 | -------------------------------------------------------------------------------- /cold_decoding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import os 5 | import numpy as np 6 | import time 7 | import wandb 8 | import argparse 9 | 10 | import sys 11 | sys.path.insert(0, './GPT2ForwardBackward') 12 | 13 | import nltk 14 | nltk.download('punkt') 15 | nltk.download('stopwords') 16 | nltk.download('averaged_perceptron_tagger') 17 | 18 | from nltk import tokenize 19 | from nltk.corpus import stopwords 20 | from nltk.tokenize import word_tokenize 21 | 22 | from util import * 23 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 24 | from bleuloss import batch_log_bleulosscnn_ae 25 | from modeling_opengpt2 import OpenGPT2LMHeadModel 26 | 27 | stop_words = set(stopwords.words('english')) 28 | 29 | 30 | def options(): 31 | parser = argparse.ArgumentParser() 32 | ## setting 33 | parser.add_argument("--seed", type=int, default=-1) 34 | parser.add_argument("--no-cuda", action="store_true", help="no cuda") 35 | parser.add_argument("--verbose", action="store_true") 36 | parser.add_argument("--print-every", type=int, default=200) 37 | parser.add_argument("--pretrained_model", type=str, default="gpt2-large") 38 | parser.add_argument("--wandb", action="store_true") 39 | parser.add_argument("--straight-through", action="store_true") 40 | parser.add_argument("--topk", type=int, default=0) 41 | parser.add_argument("--rl-topk", type=int, default=0) 42 | parser.add_argument("--lexical", type=str, default='max', choices=['max', 'ppl_max', 'all', 'bleu']) 43 | parser.add_argument("--lexical-variants", action="store_true", help="") 44 | parser.add_argument("--if-zx", action="store_true") 45 | ## experiment 46 | parser.add_argument("--input-file", type=str, 47 | default="./data/lexical/commongen_data/test.multi.constraint.json") 48 | parser.add_argument("--output-dir", type=str, default="./data/commongen/") 49 | parser.add_argument("--fwd-model", type=str, 50 | default="/var/karen/workspace/GPT2ForwardBackward/opengpt2_pytorch_forward") 51 | parser.add_argument("--back-model", type=str, 52 | default="danyaljj/opengpt2_pytorch_backward") 53 | parser.add_argument("--version", type=str, default="") 54 | parser.add_argument("--start", type=int, default=1, help="loading data from ith examples.") 55 | parser.add_argument("--end", type=int, default=10, help="loading data util ith examples.") 56 | parser.add_argument("--repeat-batch", type=int, default=1, help="loading data util ith examples.") 57 | parser.add_argument("--mode", type=str, default='constrained_langevin', 58 | choices=['lexical_generation', 'counterfactual_langevin', 'abductive_langevin', 59 | 'grammar']) 60 | ## model 61 | parser.add_argument("--batch-size", type=int, default=1) 62 | parser.add_argument("--length", type=int, default=15, help="maximum length of optimized logits.") 63 | parser.add_argument("--max-length", type=int, default=50, help="maximum length of complete sentence.") 64 | parser.add_argument("--frozen-length", type=int, default=0, help="length of optimization window in sequence.") 65 | parser.add_argument("--constraint-weight", type=float, default=0.1) 66 | parser.add_argument("--abductive-c2-weight", type=float, default=0.05) 67 | parser.add_argument("--abductive-filterx", action="store_true", help="filter out keywords included in x") 68 | parser.add_argument("--lr-nll-portion", type=float, default=1) 69 | parser.add_argument("--prefix-length", type=int, default=0, help="length of prefix.") 70 | parser.add_argument("--counterfactual-max-ngram", type=int, default=6) 71 | parser.add_argument("--no-loss-rerank", action="store_true", help="") 72 | # temperature 73 | parser.add_argument("--input-lgt-temp", type=float, default=1, 74 | help="temperature of logits used for model input.") 75 | parser.add_argument("--output-lgt-temp", type=float, default=1, 76 | help="temperature of logits used for model output.") 77 | parser.add_argument("--rl-output-lgt-temp", type=float, default=1, 78 | help="temperature of logits used for model output.") 79 | parser.add_argument("--init-temp", type=float, default=0.1, 80 | help="temperature of logits used in the initialization pass. High => uniform init.") 81 | parser.add_argument("--init-mode", type=str, default='random', choices=['random', 'original']) 82 | # lr 83 | parser.add_argument("--stepsize", type=float, default=0.1, help="learning rate in the backward pass.") 84 | parser.add_argument("--stepsize-ratio", type=float, default=1, help="") 85 | parser.add_argument("--stepsize-iters", type=int, default=1000, help="") 86 | # iterations 87 | parser.add_argument("--num-iters", type=int, default=1000) 88 | parser.add_argument("--min-iters", type=int, default=0, help="record best only after N iterations") 89 | parser.add_argument("--noise-iters", type=int, default=1, help="add noise at every N iterations") 90 | parser.add_argument("--win-anneal-iters", type=int, default=-1, help="froze the optimization window after N iters") 91 | parser.add_argument("--constraint-iters", type=int, default=1000, 92 | help="add one more group of constraints from N iters") 93 | # gaussian noise 94 | parser.add_argument("--gs_mean", type=float, default=0.0) 95 | parser.add_argument("--gs_std", type=float, default=0.01) 96 | parser.add_argument("--large-noise-iters", type=str, default="-1", help="Example: '50,1000'") 97 | parser.add_argument("--large_gs_std", type=str, default="1", help="Example: '1,0.1'") 98 | 99 | args = parser.parse_args() 100 | return args 101 | 102 | 103 | def decode(model, tokenizer, device, x="", z="", constraints=None, args=None, model_back=None, zz=None): 104 | ''' 105 | x: left context (prompt in lexical lexical task) 106 | z: optimization target (original ending in counterfactual task) 107 | constraints: (constraint set in lexical constrained task) 108 | ''' 109 | 110 | x_ = tokenizer.encode(x) 111 | x_t = torch.tensor(x_, device=device, dtype=torch.long) 112 | x_onehot = one_hot(x_t, dimension=tokenizer.vocab_size) 113 | 114 | # repeat batch_size times 115 | x_t = x_t.unsqueeze(0).repeat(args.batch_size, 1) 116 | x_onehot = x_onehot.repeat(args.batch_size, 1, 1) 117 | 118 | z_mask = None 119 | 120 | if 'counterfactual' in args.mode: 121 | z_ = tokenizer.encode(z)[1:] # delete the "." token we appended before 122 | z_t = torch.tensor(z_, device=device, dtype=torch.long) 123 | 124 | z_onehot = one_hot(z_t, dimension=tokenizer.vocab_size) 125 | z_onehot = z_onehot.repeat(args.batch_size, 1, 1) 126 | 127 | z_t = z_t.unsqueeze(0).repeat(args.batch_size, 1) 128 | 129 | length = args.length 130 | if length <= 0: 131 | length = z_t.shape[1] - length 132 | if args.verbose: 133 | print("x:\t|%s|\nz:\t|%s|\nlength:\t%d\nconstraints:\t%s" % ( 134 | tokenizer.decode(x_), tokenizer.decode(z_), length, constraints)) 135 | 136 | # z_mask: [batch_size, vocab_size] 137 | z_words = word_tokenize(z[2:]) # delete the ". " token we appended before 138 | z_nonstop_words = [w.lower() for w in z_words if w.lower() not in stop_words and w.isalnum()] 139 | z_nonstop_words += [z_words[0]] # add the first token 140 | z_nonstop_words = ' ' + ' '.join(z_nonstop_words) 141 | z_nonstop_ = tokenizer.encode(z_nonstop_words) 142 | print('|' + z_nonstop_words + '|') 143 | 144 | z_mask = np.zeros([tokenizer.vocab_size]) 145 | z_mask[z_nonstop_] = 1. 146 | z_mask = torch.tensor(z_mask, device=device) 147 | z_mask = z_mask.unsqueeze(0).unsqueeze(0).repeat(args.batch_size, length, 1) 148 | 149 | if 'abductive' in args.mode or 'lexical' in args.mode: 150 | length = args.length 151 | 152 | z_ = tokenizer.encode(z)[1:] # delete the "." token we appended before 153 | z_t = torch.tensor(z_, device=device, dtype=torch.long) 154 | z_onehot = one_hot(z_t, dimension=tokenizer.vocab_size) 155 | # repeat batch_size times 156 | z_t = z_t.unsqueeze(0).repeat(args.batch_size, 1) 157 | z_onehot = z_onehot.repeat(args.batch_size, 1, 1) 158 | 159 | zz_ = tokenizer.encode(zz)[1:] # delete the "." token we appended before 160 | zz_t = torch.tensor(zz_, device=device, dtype=torch.long) 161 | zz_t = zz_t.unsqueeze(0).repeat(args.batch_size, 1) 162 | 163 | z_mask = np.zeros([tokenizer.vocab_size]) 164 | z_mask[zz_] = 1. 165 | z_mask = torch.tensor(z_mask, device=device) 166 | z_mask = z_mask.unsqueeze(0).unsqueeze(0).repeat(args.batch_size, length, 1) 167 | 168 | if args.verbose: 169 | print("x:\t|%s|\nz:\t|%s|\nzz:\t|%s|\nconstraints:\t%s" % ( 170 | tokenizer.decode(x_), tokenizer.decode(z_), tokenizer.decode(zz_), constraints)) 171 | 172 | cs_ = None 173 | cs_onehot = None 174 | 175 | model.eval() 176 | 177 | if args.init_mode == 'random': 178 | init_logits = initialize(model, x_t, length, args.init_temp, device) 179 | else: 180 | init_logits = z_onehot / 0.1 181 | init_logits = init_logits[:, :length, :] 182 | if length > init_logits.shape[1]: 183 | init_logits = torch.cat( 184 | [init_logits, 185 | torch.zeros([args.batch_size, length - init_logits.shape[1], tokenizer.vocab_size], device=device)], 186 | dim=1) 187 | text, _, _ = get_text_from_logits(init_logits, tokenizer) 188 | for bi in range(args.batch_size): 189 | print("[initial]: %s" % (text[bi])) 190 | 191 | if args.wandb: 192 | wandb.init( 193 | project='args.mode' + str(int(round(time.time() * 1000))), 194 | config=args) 195 | 196 | assert args.prefix_length <= 0 # Otherwise not compatible with batch mode 197 | 198 | if args.prefix_length > 0: 199 | prefix_logits = torch.nn.Parameter( 200 | torch.rand(x_onehot.shape[0], args.prefix_length, x_onehot.shape[2], dtype=init_logits.dtype, 201 | device=device)) 202 | 203 | y_logits = init_logits 204 | epsilon = torch.nn.Parameter(torch.zeros_like(y_logits)) 205 | if args.prefix_length > 0: 206 | optim = torch.optim.Adam([epsilon, prefix_logits], lr=args.stepsize) 207 | else: 208 | optim = torch.optim.Adam([epsilon], lr=args.stepsize) 209 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim, step_size=args.stepsize_iters, 210 | gamma=args.stepsize_ratio) 211 | 212 | frozen_len = args.frozen_length 213 | 214 | y_logits_ = None 215 | noise_std = 0.0 216 | 217 | ## Encode x beforehand 218 | assert args.prefix_length <= 0, "The current code does not support prefix-length > 0" 219 | soft_forward_x = x_onehot[:, -1:, :] # The last token of x is used in soft_forward 220 | if x_t.shape[1] == 1: 221 | x_model_past = None 222 | else: 223 | x_model_outputs = model(x_t[:, :-1]) 224 | x_model_past = x_model_outputs.past_key_values 225 | x_model_past = [_.detach() for _ in x_model_past] 226 | 227 | # For right to left model 228 | rl_reverse_index = torch.arange(y_logits.shape[1] - 1, -1, -1) 229 | 230 | mask_t = None 231 | 232 | for iter in range(args.num_iters): 233 | optim.zero_grad() 234 | y_logits_ = y_logits + epsilon 235 | 236 | soft_forward_y = y_logits_ / 0.001 237 | if args.straight_through: 238 | if mask_t is None: 239 | soft_forward_y = (y_logits_.detach() / 0.001 - y_logits_).detach() + y_logits_ 240 | else: 241 | soft_forward_y = top_k_filter_3d(y_logits_, args.topk, mask=mask_t, extra_mask=z_mask) / 0.001 242 | 243 | y_logits_t = soft_forward(model, soft_forward_x, soft_forward_y, x_past=x_model_past) 244 | 245 | if args.topk == 0: 246 | mask_t = None 247 | else: 248 | _, indices_t = torch.topk(y_logits_t, args.topk) 249 | mask_t = torch.zeros_like(y_logits_t).scatter_(2, indices_t, 1) 250 | 251 | # Compute loss, gradients, and update. 252 | lr_nll_loss = soft_nll( 253 | top_k_filter_3d(y_logits_t / args.output_lgt_temp, args.topk, extra_mask=z_mask), 254 | y_logits_ / args.input_lgt_temp) 255 | 256 | if args.lr_nll_portion == 1.0: 257 | rl_nll_loss = lr_nll_loss 258 | else: 259 | # add right-to-left model (rl) 260 | if "counterfactual" in args.mode: 261 | y_logits_rev = y_logits_[:, rl_reverse_index, :] 262 | y_logits_rev_t = model_back(y_logits_rev.argmax(-1) + 1).logits[:, :-1, :] 263 | y_logits_rev_t = y_logits_rev_t[:, :, 1:y_logits_.shape[-1] + 1] 264 | rl_nll_loss = soft_nll( 265 | top_k_filter_3d(y_logits_rev_t / args.output_lgt_temp, args.rl_topk), 266 | y_logits_rev[:, 1:] / args.input_lgt_temp) 267 | elif "abductive" in args.mode or "lexical" in args.mode: 268 | yz_logits_rev = torch.flip(torch.cat([y_logits_, z_onehot], dim=1), [1]) 269 | yz_logits_rev_t = soft_backward(model_back, yz_logits_rev / 0.00001) 270 | yz_logits_rev_rev_t = torch.flip(yz_logits_rev_t, [1]) 271 | yz_logits_rev_rev_t = yz_logits_rev_rev_t[:, :, 1:y_logits_.shape[-1] + 1] 272 | yz_logits_rev_rev_t_ = yz_logits_rev_rev_t[:, :y_logits_.shape[1], :] 273 | 274 | tmp_logits = yz_logits_rev_rev_t_ 275 | repetition_mask = torch.cat([F.softmax(tmp_logits[:, 1:, :], dim=-1), 276 | torch.zeros_like(tmp_logits[:, -1:, :])], dim=1) 277 | yz_logits_rev_rev_t_ = yz_logits_rev_rev_t_ - repetition_mask * 1e4 278 | yz_logits_rev_rev_t_ = yz_logits_rev_rev_t_.detach() 279 | 280 | rl_nll_loss = soft_nll( 281 | top_k_filter_3d(yz_logits_rev_rev_t_ / args.rl_output_lgt_temp, args.rl_topk), 282 | y_logits_ / args.input_lgt_temp) 283 | 284 | 285 | if "counterfactual" in args.mode: 286 | c_loss = batch_log_bleulosscnn_ae( 287 | decoder_outputs=top_k_filter_3d(y_logits_, args.topk, mask=mask_t, extra_mask=z_mask).transpose(0, 1), 288 | target_idx=z_t, 289 | ngram_list=list(range(2, args.counterfactual_max_ngram + 1)) 290 | ) 291 | 292 | if "abductive" in args.mode or "lexical" in args.mode: 293 | soft_forward_y_ = (y_logits_.detach() / 0.3 - y_logits_).detach() + y_logits_ 294 | xyz_logits, xy_length = soft_forward_xyz(model, soft_forward_x, soft_forward_y_, z_onehot) 295 | 296 | # Reshaping 297 | bz = args.batch_size 298 | lg = xyz_logits.shape[1] 299 | st = xy_length - 1 300 | ed = xyz_logits.shape[1] - 1 301 | xyz_logits = xyz_logits.view(-1, xyz_logits.shape[-1]) 302 | z_logits = torch.cat([xyz_logits[bi * lg + st:bi * lg + ed, :] for bi in range(bz)], dim=0) 303 | 304 | c_loss_1 = torch.nn.CrossEntropyLoss(reduction='none')( 305 | z_logits, 306 | z_t.view(-1)) 307 | c_loss_1 = c_loss_1.view(args.batch_size, -1).mean(-1) 308 | 309 | c_loss_2 = batch_log_bleulosscnn_ae( 310 | decoder_outputs=y_logits_.transpose(0, 1), 311 | target_idx=zz_t, 312 | ngram_list=[1] 313 | ) 314 | c_loss = c_loss_1 + args.abductive_c2_weight * c_loss_2 315 | 316 | loss = (1.0 - args.constraint_weight) * args.lr_nll_portion * lr_nll_loss \ 317 | + (1.0 - args.constraint_weight) * (1 - args.lr_nll_portion) * rl_nll_loss \ 318 | + args.constraint_weight * c_loss 319 | loss = loss.mean() 320 | 321 | if iter < args.num_iters - 1: # so that the mask_t at the last iteration will not change 322 | loss.backward() 323 | optim.step() 324 | scheduler.step() # turn off the scheduler 325 | last_lr = scheduler.get_last_lr()[0] 326 | 327 | if args.verbose and ((iter + 1) % args.print_every == 0 or iter == 0 or iter + 1 == args.num_iters): 328 | text, _, _ = decode_with_model_topk( 329 | model, y_logits_, args.topk, soft_forward_x, x_model_past, tokenizer, extra_mask=z_mask) 330 | for bi in range(args.batch_size): 331 | if "abductive" in args.mode or "lexical" in args.mode: 332 | print( 333 | "%d, loss: %.4f, lr_nll_loss: %.4f, rl_nll_loss: %.4f, c_loss_2: %.4f, lr: %.4f, |%s|" % ( 334 | iter + 1, loss.item(), lr_nll_loss[bi].item(), rl_nll_loss[bi].item(), 335 | c_loss_2[bi].item(), last_lr, text[bi])) 336 | # print("%d, loss: %.4f, lr_nll_loss: %.4f, rl_nll_loss: %.4f, c_loss_1: %.4f, c_loss_2: %.4f, lr: %.4f, |%s|" % (iter + 1, loss.item(), lr_nll_loss[bi].item(), rl_nll_loss[bi].item(), c_loss_1[bi].item(), c_loss_2[bi].item(), last_lr, text[bi])) 337 | else: 338 | print("%d, loss: %.4f, lr_nll_loss: %.4f, c_loss: %.4f, lr: %.4f, |%s|" % ( 339 | iter + 1, loss.item(), lr_nll_loss[bi].item(), c_loss[bi].item(), last_lr, text[bi])) 340 | 341 | if "abductive" in args.mode or "lexical" in args.mode: 342 | pass 343 | 344 | print() 345 | 346 | if args.wandb: 347 | wandb.log( 348 | {"Loss": loss.item(), 349 | "left-to-right nll loss": lr_nll_loss.item(), 350 | "right-to-left nll loss": rl_nll_loss.item(), 351 | "constraint loss": c_loss, 352 | "Gassian_Noise_STD": noise_std, 353 | "LR": last_lr, 354 | "Gradient": torch.norm(epsilon.grad).detach().clone().data.cpu().numpy()} 355 | ) 356 | 357 | ## noise 358 | if iter < args.num_iters - 1: 359 | 360 | if 'grammar' in args.mode: 361 | continue 362 | 363 | large_noise_iters = [int(_) for _ in args.large_noise_iters.split(',')] 364 | large_gs_stds = [float(_) for _ in args.large_gs_std.split(',')] 365 | noise_std = 0. 366 | if iter % args.noise_iters == 0: 367 | noise_last = True 368 | for ni in range(len(large_noise_iters)): 369 | if iter < large_noise_iters[ni]: 370 | noise_last = False 371 | break 372 | if noise_last: 373 | noise_std = args.gs_std 374 | else: 375 | noise_std = large_gs_stds[ni] 376 | 377 | noise = torch.normal(mean=args.gs_mean, std=noise_std, size=epsilon.size(), 378 | device='cuda', requires_grad=False) 379 | if args.win_anneal_iters >= 0 and iter >= args.win_anneal_iters: 380 | zeros = torch.zeros_like(noise) 381 | noise_mix = torch.cat([zeros[:, :frozen_len], noise[:, frozen_len:]], dim=1) 382 | y_logits = y_logits + noise_mix 383 | else: 384 | y_logits = y_logits + noise 385 | 386 | if args.wandb: 387 | wandb.finish() 388 | 389 | text, _, last_text_ids = decode_with_model_topk( 390 | model, y_logits_, args.topk, soft_forward_x, x_model_past, tokenizer, extra_mask=z_mask) 391 | 392 | last_rank_loss = model(input_ids=last_text_ids, labels=last_text_ids).loss 393 | last_rank_loss = last_rank_loss.detach().clone().data.cpu().numpy() 394 | text_post = post_process(last_text_ids, model, args.max_length, args.length, tokenizer, device) 395 | ppl_last = np.exp(last_rank_loss) 396 | 397 | if args.verbose: 398 | for bi in range(args.batch_size): 399 | print("[final]: %s\n%.4f" % (text[bi], ppl_last)) 400 | print("[final complete sentence]: %s\n" % text_post[bi]) 401 | 402 | return ppl_last, text, text_post 403 | 404 | 405 | def counterfactual_reasoning(model, tokenizer, device, args, model_back=None): 406 | fr = open(args.input_file, 'r') 407 | data = [json.loads(x) for x in fr.readlines()] 408 | loss_rerank = 'norerank' if args.no_loss_rerank else 'rerank' 409 | file_name = '%s_%s_seed%d_%d_%d_%s_ngram%d_cw%.3f_lrnllp%.3f_len%d_topk%d_niter%d_frozlen%d' \ 410 | '_winiter%d_noiseiter%d_gsstd%.4f_lr%.3f_%s_%s_output.json' % ( 411 | args.version, 412 | loss_rerank, 413 | args.seed, 414 | args.start, 415 | args.end, 416 | args.mode, 417 | args.counterfactual_max_ngram, 418 | args.constraint_weight, 419 | args.lr_nll_portion, 420 | args.length, 421 | args.topk, 422 | args.num_iters, 423 | args.frozen_length, 424 | args.win_anneal_iters, 425 | args.noise_iters, 426 | args.gs_std, 427 | args.stepsize, 428 | args.large_noise_iters, 429 | args.large_gs_std) 430 | 431 | outfile = os.path.join(args.output_dir, file_name) 432 | fw = open(outfile, 'w') 433 | fw_pretty = open(os.path.join(args.output_dir, 'pretty_' + file_name), 'w') 434 | fw_res = open(os.path.join(args.output_dir, 'res_' + file_name), 'w') 435 | 436 | procssed = set() 437 | for i, d in enumerate(data): 438 | if i < args.start or i > args.end: 439 | continue 440 | 441 | if args.seed != -1: 442 | torch.manual_seed(args.seed) 443 | np.random.seed(args.seed) 444 | 445 | print("%d / %d" % (i, len(data))) 446 | print('Output to: \t', outfile) 447 | print('output-lgt-temp:\t', args.output_lgt_temp) 448 | 449 | premise = d.get('premise', "") 450 | counterfactual = d.get('counterfactual', "") 451 | 452 | x = premise + ' ' + counterfactual 453 | ori_ending = d.get('original_ending', "") 454 | ori_endings = tokenize.sent_tokenize(ori_ending) 455 | 456 | if x in procssed: 457 | continue 458 | else: 459 | procssed.add(x) 460 | 461 | x_text_so_far = [""] 462 | x_addon = [[x]] 463 | 464 | outputs = [] 465 | for oi, z_sent in enumerate(ori_endings): 466 | print("Sentence %d" % oi) 467 | z_text_so_far = z_sent.strip() 468 | z_text_so_far = ". " + z_text_so_far 469 | 470 | assert len(x_text_so_far) == len(x_addon), "%d vs %d" % (len(x_text_so_far), len(x_addon)) 471 | 472 | new_x_text_so_far = [] 473 | new_x_addon = [] 474 | for ii, text_i in enumerate(x_text_so_far): 475 | for text_j in x_addon[ii]: 476 | text_ij = text_i.strip() + " " + text_j.strip() 477 | new_x_text_so_far.append(text_ij) 478 | 479 | text_ij = text_ij.strip() 480 | 481 | ppl_last, text, text_post = decode( 482 | model, tokenizer, device, text_ij, z_text_so_far, None, args, model_back=model_back) 483 | 484 | outputs.append([text_ij, text_post]) 485 | 486 | # Rank and filter text_post from util.py: 487 | text_post = [post_sent(x) for x in text_post] 488 | text_post = rank_and_filter(text_post, text_ij, z_text_so_far, model, tokenizer, device, 489 | args.no_loss_rerank) 490 | 491 | if ii == len(x_text_so_far) - 1 and oi == len(ori_endings) - 1: 492 | last_output = text_post 493 | final_res = ' '.join([text_ij, last_output]) 494 | outputs.append(final_res) 495 | fw_res.write(final_res + '\n') 496 | fw_res.flush() 497 | 498 | new_x_addon.append([text_post]) 499 | 500 | x_text_so_far = new_x_text_so_far 501 | x_addon = new_x_addon 502 | 503 | break 504 | 505 | complete_output = outputs 506 | out = { 507 | 'premise': premise, 508 | 'initial': d.get('initial', ""), 509 | 'counterfactual': counterfactual, 510 | 'original_ending': ori_ending, 511 | 'generation_complete': complete_output, 512 | } 513 | 514 | fw.write(json.dumps(out) + '\n') 515 | fw.flush() 516 | fw_pretty.write(json.dumps(out, indent=4) + '\n') 517 | fw_pretty.flush() 518 | 519 | print("outputs: %s" % outfile) 520 | 521 | 522 | def grammar_correction(model, tokenizer, device, args, model_back=None): 523 | fr = open(args.input_file, 'r') 524 | data = [x.strip() for x in fr.readlines()] 525 | file_name = '%s_seed%d_%d_%d_%s_cw%.3f_lrnllp%.3f_len%d_topk%d_niter%d_frozlen%d' \ 526 | '_winiter%d_noiseiter%d_gsstd%.4f_lr%.3f_%s_%s_output.json' % ( 527 | args.version, 528 | args.seed, 529 | args.start, 530 | args.end, 531 | args.mode, 532 | args.constraint_weight, 533 | args.lr_nll_portion, 534 | args.length, 535 | args.topk, 536 | args.num_iters, 537 | args.frozen_length, 538 | args.win_anneal_iters, 539 | args.noise_iters, 540 | args.gs_std, 541 | args.stepsize, 542 | args.large_noise_iters, 543 | args.large_gs_std) 544 | 545 | outfile = os.path.join(args.output_dir, file_name) 546 | fw = open(outfile, 'w') 547 | 548 | # Grammar 549 | data = [[' '.join(x.split()[:3]), ' '.join(x.split()[3:])] for x in data] 550 | print('#data: ', len(data)) 551 | 552 | for i, d in enumerate(data): 553 | if i < args.start or i > args.end: 554 | continue 555 | print("%d / %d" % (i, len(data))) 556 | print('Output to: \t', outfile) 557 | 558 | if len(d[1].split()) <= 4: 559 | text = [d[1][2:]] 560 | text_post = [d[1][2:]] 561 | continue 562 | 563 | x = d[0] 564 | y = d[1] 565 | 566 | y = ". " + y 567 | 568 | ppl_last, text, text_post = decode( 569 | model, tokenizer, device, x, y, None, args, model_back=model_back) 570 | out = { 571 | 'original': x + " " + y, 572 | 'generation': text, 573 | 'generation_complete': text_post, 574 | } 575 | 576 | fw.write(json.dumps(out) + '\n') 577 | 578 | print("outputs: %s" % outfile) 579 | 580 | 581 | 582 | def _get_adverbs_and_nnps(z_words): 583 | pos = nltk.pos_tag(z_words) 584 | adverbs = [w[0] for w in pos if 'RB' in w[1]] 585 | nnps = [w[0] for w in pos if 'NNP' in w[1]] 586 | return adverbs, nnps 587 | 588 | def _get_keywords(z, x, args): 589 | stop_words = set(stopwords.words('english')) 590 | z_words = word_tokenize(z) 591 | z_adverbs, z_nnps = _get_adverbs_and_nnps(z_words) 592 | ret_words = [] 593 | for w in z_words: 594 | if w in z_nnps: 595 | if w not in ret_words: 596 | ret_words.append(w) 597 | else: 598 | w = w.lower() 599 | if w not in stop_words and w.isalnum() and w not in z_adverbs and w not in ret_words: 600 | ret_words.append(w) 601 | 602 | if args.abductive_filterx: 603 | x_words = word_tokenize(x) 604 | ret_words = [w for w in ret_words if w not in x_words] 605 | 606 | return ' '.join(ret_words) 607 | 608 | def abductive_reasoning(model, tokenizer, device, args, model_back=None): 609 | with open(args.input_file, 'r') as f: 610 | lines = f.readlines() 611 | data = [json.loads(l.strip()) for l in lines] 612 | 613 | outfile = '%s_seed%d_%d_%d_%s_cw%.3f_c2w%.3f_lrnllp%.3f_len%d_topk%d_niter%d_frozlen%d' \ 614 | '_winiter%d_noiseiter%d_gsstd%.4f_lr%.3f_lrratio%.2f_lriter%d_%s_%s_output.json' % ( 615 | args.version, 616 | args.seed, 617 | args.start, 618 | args.end, 619 | args.mode, 620 | args.constraint_weight, 621 | args.abductive_c2_weight, 622 | args.lr_nll_portion, 623 | args.length, 624 | args.topk, 625 | args.num_iters, 626 | args.frozen_length, 627 | args.win_anneal_iters, 628 | args.noise_iters, 629 | args.gs_std, 630 | args.stepsize, 631 | args.stepsize_ratio, 632 | args.stepsize_iters, 633 | args.large_noise_iters, 634 | args.large_gs_std) 635 | print("outputs: %s" % outfile) 636 | 637 | fw = open(os.path.join(args.output_dir, outfile), 'w') 638 | 639 | procssed = set() 640 | for i, d in enumerate(data): 641 | if i < args.start or i > args.end: 642 | continue 643 | 644 | if args.if_zx: 645 | x = d["obs2"].strip() + '<|endoftext|>' + d["obs1"].strip() 646 | else: 647 | x = d["obs1"].strip() 648 | z = d["obs2"].strip() 649 | z_keywords = _get_keywords(z, d["obs1"].strip(), args) 650 | 651 | if ' '.join([x, z]) in procssed: 652 | continue 653 | procssed.add(' '.join([x, z])) 654 | 655 | print("%d / %d" % (i, len(data))) 656 | print('Output to: \t', outfile) 657 | 658 | z = ". " + z 659 | z_keywords = ". " + z_keywords 660 | 661 | text_candidates = [] 662 | text_complete_candidates = [] 663 | for _ in range(args.repeat_batch): 664 | ppl_last, text, text_post = decode(model, tokenizer, device, x, z, None, args, 665 | model_back=model_back, zz=z_keywords) 666 | text_candidates.extend(text) 667 | text_complete_candidates.extend(text_post) 668 | 669 | 670 | out = { 671 | 'x': x, 672 | 'z': z, 673 | 'z_keywords': z_keywords, 674 | 'generation': text_candidates, 675 | 'generation_complete': text_complete_candidates, 676 | } 677 | 678 | fw.write(json.dumps(out) + '\n') 679 | fw.flush() 680 | 681 | print("outputs: %s" % outfile) 682 | 683 | 684 | def lexical_generation(model, tokenizer, device, args, model_back=None): 685 | with open(args.input_file, 'r') as f: 686 | lines = f.readlines() 687 | data = [json.loads(l.strip()) for l in lines] 688 | 689 | outfile = '%if_zx%s_seed%d_%d_%d_%s_cw%.3f_c2w%.3f_lrnllp%.3f_len%d_topk%d_niter%d_frozlen%d' \ 690 | '_winiter%d_noiseiter%d_gsstd%.4f_lr%.3f_lrratio%.2f_lriter%d_%s_%s_output.json' % ( 691 | args.if_zx, 692 | args.version, 693 | args.seed, 694 | args.start, 695 | args.end, 696 | args.mode, 697 | args.constraint_weight, 698 | args.abductive_c2_weight, 699 | args.lr_nll_portion, 700 | args.length, 701 | args.topk, 702 | args.num_iters, 703 | args.frozen_length, 704 | args.win_anneal_iters, 705 | args.noise_iters, 706 | args.gs_std, 707 | args.stepsize, 708 | args.stepsize_ratio, 709 | args.stepsize_iters, 710 | args.large_noise_iters, 711 | args.large_gs_std) 712 | print("outputs: %s" % outfile) 713 | 714 | fw = open(os.path.join(args.output_dir, outfile), 'w') 715 | fw_pretty = open(os.path.join(args.output_dir, 'pretty_' + outfile), 'w') 716 | 717 | for i, d in enumerate(data): 718 | if i < args.start or i > args.end: 719 | continue 720 | print(d["concept_set"]) 721 | constraints = d["concept_set"].split("#") 722 | 723 | constraints = ' '.join(constraints) 724 | x = "<|endoftext|>" 725 | z = constraints 726 | z_keywords = constraints 727 | 728 | print("%d / %d" % (i, len(data))) 729 | print('Output to: \t', outfile) 730 | 731 | z = ". " + z 732 | z_keywords = ". " + z_keywords 733 | 734 | text_candidates = [] 735 | text_complete_candidates = [] 736 | for _ in range(args.repeat_batch): 737 | ppl_last, text, text_post = decode(model, tokenizer, device, x, z, None, args, model_back=model_back, 738 | zz=z_keywords) 739 | text_candidates.extend(text) 740 | text_complete_candidates.extend(text_post) 741 | 742 | out = { 743 | 'x': x, 744 | 'constraints': constraints, 745 | 'generation': text_candidates, 746 | 'generation_complete': text_complete_candidates, 747 | } 748 | print(out) 749 | print('Output to: \t', outfile) 750 | 751 | fw.write(json.dumps(out) + '\n') 752 | fw.flush() 753 | fw_pretty.write(json.dumps(out, indent=4) + '\n') 754 | fw_pretty.flush() 755 | 756 | print("outputs: %s" % outfile) 757 | 758 | 759 | def main(): 760 | args = options() 761 | device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" 762 | 763 | if args.seed != -1: 764 | torch.manual_seed(args.seed) 765 | np.random.seed(args.seed) 766 | # Load pretrained model 767 | model = GPT2LMHeadModel.from_pretrained( 768 | args.pretrained_model, output_hidden_states=True, 769 | resid_pdrop=0, embd_pdrop=0, attn_pdrop=0, summary_first_dropout=0) 770 | model.to(device) 771 | model.eval() 772 | # Freeze GPT-2 weights 773 | for param in model.parameters(): 774 | param.requires_grad = False 775 | 776 | # Load tokenizer 777 | tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_model) 778 | 779 | model_back = OpenGPT2LMHeadModel.from_pretrained( 780 | args.back_model, hidden_dropout_prob=0, attention_probs_dropout_prob=0, summary_first_dropout=0) 781 | model_back.to(device) 782 | model_back.eval() 783 | # Freeze GPT-2 weights 784 | for param in model_back.parameters(): 785 | param.requires_grad = False 786 | 787 | 788 | if "counterfactual" in args.mode: 789 | counterfactual_reasoning(model, tokenizer, device, args, model_back) 790 | if "abductive" in args.mode: 791 | abductive_reasoning(model, tokenizer, device, args, model_back) 792 | if "lexical" in args.mode: 793 | lexical_generation(model, tokenizer, device, args, model_back) 794 | if "grammar" in args.mode: 795 | grammar_correction(model, tokenizer, device, args, model_back) 796 | 797 | 798 | if __name__ == "__main__": 799 | main() 800 | -------------------------------------------------------------------------------- /commongen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## CommonGen 4 | 5 | python3 cold_decoding.py \ 6 | --seed 12 \ 7 | --mode lexical_generation \ 8 | --pretrained_model gpt2-xl \ 9 | --init-temp 1 \ 10 | --length 10 \ 11 | --max-length 40 \ 12 | --num-iters 2000 \ 13 | --min-iters 1000 \ 14 | --constraint-weight 0.5 \ 15 | --abductive-c2-weight 0.1 \ 16 | --stepsize 0.1 \ 17 | --noise-iters 1 \ 18 | --win-anneal-iters 1000 \ 19 | --start 0 \ 20 | --end 5 \ 21 | --lr-nll-portion 0.6 \ 22 | --topk 5 \ 23 | --output-lgt-temp 1 \ 24 | --verbose \ 25 | --straight-through \ 26 | --large-noise-iters 50,500,1000,1500 \ 27 | --large_gs_std 1,0.5,0.1,0.05 \ 28 | --stepsize-ratio 1 \ 29 | --batch-size 32 \ 30 | --repeat-batch 8 \ 31 | --print-every 200 \ 32 | --input-file "./data/commongen/commongen.dev.jsonl" \ 33 | --output-dir "./data/commongen/" \ 34 | 35 | -------------------------------------------------------------------------------- /counterfactual.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Counterfactual 4 | 5 | python3 cold_decoding.py \ 6 | --seed 12 \ 7 | --mode counterfactual_langevin \ 8 | --pretrained_model gpt2-xl \ 9 | --init-temp 1 \ 10 | --length 20 \ 11 | --max-length 50 \ 12 | --num-iters 2000 \ 13 | --min-iters 10 \ 14 | --constraint-weight 0.2 \ 15 | --counterfactual-max-ngram 3 \ 16 | --stepsize 0.1 \ 17 | --noise-iters 1 \ 18 | --win-anneal-iters 1000 \ 19 | --start 0 \ 20 | --end 5 \ 21 | --lr-nll-portion 0.9 \ 22 | --topk 5 \ 23 | --output-lgt-temp 1 \ 24 | --verbose \ 25 | --straight-through \ 26 | --large-noise-iters 50,200,500 \ 27 | --large_gs_std 0.5,0.1,0.05 \ 28 | --input-file "./data/counterfactual/dev_data.json" \ 29 | --output-dir "./data/counterfactual/" \ 30 | --stepsize-ratio 1 \ 31 | --batch-size 32 \ 32 | --print-every 200 33 | 34 | -------------------------------------------------------------------------------- /data/abductive/small_data.json: -------------------------------------------------------------------------------- 1 | {"story_id": "164f9ea8-c438-476f-860e-a027b1538507-1", "obs1": "Ray drive his car on a steep mountain road.", "obs2": "Ray was fine but his car was totaled.", "hyp1": "The car made it down with no problems.", "hyp2": "The car slipped down a hill.", "label": "2", "comet_preds": {"obs1": {"oEffect": {"event": "Ray drive his car on a steep mountain road.", "effect_type": "oEffect", "beams": ["none", "gets into accident", "gets into car accident", "they get hurt", "gets hurt"]}, "oReact": {"event": "Ray drive his car on a steep mountain road.", "effect_type": "oReact", "beams": ["none", "scared", "happy", "impressed", "worried"]}, "oWant": {"event": "Ray drive his car on a steep mountain road.", "effect_type": "oWant", "beams": ["none", "to thank personx", "to be safe", "to get to their destination", "to get out of the car"]}, "xAttr": {"event": "Ray drive his car on a steep mountain road.", "effect_type": "xAttr", "beams": ["adventurous", "brave", "careless", "daring", "reckless"]}, "xEffect": {"event": "Ray drive his car on a steep mountain road.", "effect_type": "xEffect", "beams": ["gets a flat", "gets into accident", "gets hurt", "crashes car", "personx gets a flat"]}, "xIntent": {"event": "Ray drive his car on a steep mountain road.", "effect_type": "xIntent", "beams": ["to have fun", "to get somewhere", "none", "to get to the top", "to get to the bottom"]}, "xNeed": {"event": "Ray drive his car on a steep mountain road.", "effect_type": "xNeed", "beams": ["start the car", "to start the car", "to get in the car", "to have a car", "none"]}, "xReact": {"event": "Ray drive his car on a steep mountain road.", "effect_type": "xReact", "beams": ["excited", "nervous", "tired", "happy", "scared"]}, "xWant": {"event": "Ray drive his car on a steep mountain road.", "effect_type": "xWant", "beams": ["to slow down", "get out of car", "slow down", "to get out of the car", "to go to the hospital"]}}, "obs2": {"oEffect": {"event": "Ray was fine but his car was totaled.", "effect_type": "oEffect", "beams": ["none", "they get a new car", "loses money", "loses car", "they have to pay the mechanic"]}, "oReact": {"event": "Ray was fine but his car was totaled.", "effect_type": "oReact", "beams": ["none", "upset", "sad", "angry", "worried"]}, "oWant": {"event": "Ray was fine but his car was totaled.", "effect_type": "oWant", "beams": ["none", "to fix the car", "to fix it", "to get the car fixed", "to get their car fixed"]}, "xAttr": {"event": "Ray was fine but his car was totaled.", "effect_type": "xAttr", "beams": ["careless", "upset", "irresponsible", "unlucky", "broke"]}, "xEffect": {"event": "Ray was fine but his car was totaled.", "effect_type": "xEffect", "beams": ["has no insurance", "has no car", "has to pay the mechanic", "none", "cries"]}, "xIntent": {"event": "Ray was fine but his car was totaled.", "effect_type": "xIntent", "beams": ["none", "to get a new car", "to have a car", "to not have a car", "to not have to drive"]}, "xNeed": {"event": "Ray was fine but his car was totaled.", "effect_type": "xNeed", "beams": ["to have a car", "to be driving", "none", "to drive", "to get into a wreck"]}, "xReact": {"event": "Ray was fine but his car was totaled.", "effect_type": "xReact", "beams": ["upset", "sad", "worried", "angry", "frustrated"]}, "xWant": {"event": "Ray was fine but his car was totaled.", "effect_type": "xWant", "beams": ["to fix the car", "to repair the car", "to pay the mechanic", "to buy a new car", "to get their car fixed"]}}}} 2 | {"story_id": "11871d4b-6b94-4184-b36d-c4e76d311f3d1", "obs1": "Peter was excited to go to the Sanders rally in New Hampshire.", "obs2": "He couldn't wait to vote for him.", "hyp1": "He was 18 and was allowed to vote for the first time.", "hyp2": "He was 17 and was not allowed to vote.", "label": "1", "comet_preds": {"obs1": {"oEffect": {"event": "Peter was excited to go to the Sanders rally in New Hampshire.", "effect_type": "oEffect", "beams": ["they have a good time", "they have fun", "they have to drive to the rally", "none", "the people of the new york area are invited to the rally"]}, "oReact": {"event": "Peter was excited to go to the Sanders rally in New Hampshire.", "effect_type": "oReact", "beams": ["none", "happy", "excited", "also happy", "interested"]}, "oWant": {"event": "Peter was excited to go to the Sanders rally in New Hampshire.", "effect_type": "oWant", "beams": ["none", "to have fun", "to have a good time", "to see them", "to go to the rally"]}, "xAttr": {"event": "Peter was excited to go to the Sanders rally in New Hampshire.", "effect_type": "xAttr", "beams": ["excited", "eager", "brave", "adventurous", "enthusiastic"]}, "xEffect": {"event": "Peter was excited to go to the Sanders rally in New Hampshire.", "effect_type": "xEffect", "beams": ["personx is excited", "gets sweaty", "gets tired", "none", "personx is excited to go to the rally"]}, "xIntent": {"event": "Peter was excited to go to the Sanders rally in New Hampshire.", "effect_type": "xIntent", "beams": ["to have fun", "to see the sights", "to have a good time", "to go to the rally", "none"]}, "xNeed": {"event": "Peter was excited to go to the Sanders rally in New Hampshire.", "effect_type": "xNeed", "beams": ["buy a ticket", "to buy tickets", "to buy a ticket", "none", "to have a ticket"]}, "xReact": {"event": "Peter was excited to go to the Sanders rally in New Hampshire.", "effect_type": "xReact", "beams": ["excited", "happy", "satisfied", "nervous", "tired"]}, "xWant": {"event": "Peter was excited to go to the Sanders rally in New Hampshire.", "effect_type": "xWant", "beams": ["to have fun", "to go home", "to see the sights", "to go to the rally", "to go to a rally"]}}, "obs2": {"oEffect": {"event": "He couldn't wait to vote for him.", "effect_type": "oEffect", "beams": ["none", "gets called out", "receives praise for his work", "gets called a racist", "receives praise"]}, "oReact": {"event": "He couldn't wait to vote for him.", "effect_type": "oReact", "beams": ["none", "happy", "excited", "grateful", "flattered"]}, "oWant": {"event": "He couldn't wait to vote for him.", "effect_type": "oWant", "beams": ["none", "to win the election", "to vote for him", "to protest", "to vote"]}, "xAttr": {"event": "He couldn't wait to vote for him.", "effect_type": "xAttr", "beams": ["eager", "motivated", "ambitious", "enthusiastic", "excited"]}, "xEffect": {"event": "He couldn't wait to vote for him.", "effect_type": "xEffect", "beams": ["gets stressed", "personx sweats from nervousness", "personx sweats", "none", "sweats"]}, "xIntent": {"event": "He couldn't wait to vote for him.", "effect_type": "xIntent", "beams": ["to vote for him", "none", "to win", "to vote", "to be a part of something"]}, "xNeed": {"event": "He couldn't wait to vote for him.", "effect_type": "xNeed", "beams": ["to be a politician", "to have a job", "to be a political leader", "none", "to be in a political party"]}, "xReact": {"event": "He couldn't wait to vote for him.", "effect_type": "xReact", "beams": ["excited", "happy", "eager", "satisfied", "proud"]}, "xWant": {"event": "He couldn't wait to vote for him.", "effect_type": "xWant", "beams": ["to win the election", "to vote for him", "to vote", "to win", "to go to the voting place"]}}}} 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk 2 | torch 3 | tqdm 4 | wandb 5 | transformers==4.2.1 6 | 7 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import json 4 | import os 5 | import nltk 6 | from nltk import tokenize 7 | import torch 8 | import numpy as np 9 | 10 | nltk.download('punkt') 11 | 12 | import sys 13 | import os 14 | if os.path.isdir('/var/karen'): 15 | os.environ['TRANSFORMERS_CACHE'] = '/var/karen/workspace/Refinement-Generation/cache' 16 | sys.path.insert(0, '/var/karen/workspace/Refinement-Generation/') 17 | 18 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 19 | from tqdm import tqdm 20 | from difflib import SequenceMatcher 21 | 22 | from bleuloss import batch_log_bleulosscnn_ae 23 | from util import * 24 | 25 | 26 | def embed_inputs(embedding, logits, x_onehot=None, z_onehot=None, device='cuda'): 27 | ''' 28 | embeds inputs in a dense representation, before passing them to the model 29 | ''' 30 | # typically we embed a one-hot vector. But here since we work we work with dense representations, 31 | # we have softmax here to make sure that all the values of the input logits sum to one (similar to a 1-hot vector). 32 | probs = F.softmax(logits, dim=-1) 33 | 34 | if x_onehot is not None: 35 | probs = torch.cat((x_onehot.type(torch.FloatTensor), probs.type(torch.FloatTensor)), dim=1) 36 | if z_onehot is not None: 37 | probs = torch.cat((probs.type(torch.FloatTensor), z_onehot.type(torch.FloatTensor)), dim=1) 38 | probs = probs.to(device) 39 | 40 | return torch.matmul(probs, embedding) 41 | 42 | 43 | def _greedy(logits): 44 | _, last = torch.topk(logits, k=1, dim=-1) 45 | return last 46 | 47 | 48 | def top_k_filter_3d(logits, k, probs=False, mask=None, extra_mask=None): 49 | """ 50 | logits.shape = [batch_size, length, vocab_size] 51 | extra_mask: [batch_size, length, vocab_size], 1 if reserve 52 | """ 53 | BIG_CONST = 1e10 54 | if k == 0: 55 | return logits 56 | else: 57 | if mask is None: 58 | _, indices = torch.topk(logits, k) 59 | mask = torch.zeros_like(logits).scatter_(2, indices, 1) 60 | if extra_mask is not None: 61 | mask = ((mask + extra_mask) > 0).float() 62 | if probs: 63 | return logits * mask 64 | return logits * mask + -BIG_CONST * (1-mask) 65 | 66 | 67 | def top_k_filter(logits, k, probs=False): 68 | BIG_CONST = 1e10 69 | if k == 0: 70 | return logits 71 | else: 72 | values = torch.topk(logits, k)[0] 73 | batch_mins = values[:, -1].view(-1, 1).expand_as(logits) 74 | if probs: 75 | return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits) 76 | return torch.where(logits < batch_mins, torch.ones_like(logits) * -BIG_CONST, logits) 77 | 78 | 79 | def _topk(logits, k=10): 80 | logits = top_k_filter(logits, k) 81 | probs = F.softmax(logits, dim=-1) 82 | last = torch.multinomial(probs, num_samples=1) 83 | return last 84 | 85 | 86 | def get_text_from_logits(logits, tokenizer): 87 | output_so_far = None 88 | last = None 89 | logp = 0 90 | for i in range(logits.shape[1]): 91 | last = _greedy(logits[:, i, :]) 92 | output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1) 93 | logp += logits[:, i, :].log_softmax(-1).data.cpu().numpy()[:, last.data.cpu().numpy()] 94 | 95 | nll = -logp 96 | batch_size = output_so_far.shape[0] 97 | text = [] 98 | for i in range(batch_size): 99 | text_i = tokenizer.decode(output_so_far[i].tolist()) 100 | text_i = text_i.replace('\n', ' ') 101 | text.append(text_i) 102 | 103 | return text, nll, output_so_far 104 | 105 | 106 | def get_text_from_logits_topk(logits, tokenizer, top_k=1): 107 | output_so_far = None 108 | last = None 109 | logp = 0 110 | 111 | for i in range(logits.shape[1]): 112 | last = _topk(logits[:, i, :], top_k) 113 | output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1) 114 | logp += logits[:, i, :].log_softmax(-1)[:, last.item()].item() 115 | 116 | nll = -logp 117 | text = tokenizer.decode(output_so_far.tolist()[0]) 118 | text = text.replace('\n', ' ') 119 | return text, nll, output_so_far 120 | 121 | 122 | def one_hot(tensor, dimension): 123 | while len(tensor.shape) < 2: 124 | tensor = tensor.unsqueeze(0) 125 | onehot = torch.LongTensor(tensor.shape[0], tensor.shape[1], dimension).to(tensor.device) 126 | onehot.zero_().scatter_(2, tensor.unsqueeze(-1), 1) 127 | onehot.to(tensor.device) 128 | return onehot 129 | 130 | 131 | def initialize(model, x, length, temperature, device): 132 | if x.dim() == 1: 133 | x = x.unsqueeze(0) 134 | past = None 135 | last_token_embedding = None 136 | logits_so_far = None 137 | for i in range(length): 138 | # for the first iteration, `past` is None 139 | if past is None: 140 | x_last_token = x[:, -1:] 141 | last_token_embedding = model.get_input_embeddings()(x_last_token) 142 | 143 | # if the input length is longer than a single token 144 | if x.shape[1] > 1: 145 | x_except_last_token = x[:, :-1] 146 | model_outputs = model(x_except_last_token) 147 | past = model_outputs.past_key_values 148 | 149 | model_outputs = model(past_key_values=past, inputs_embeds=last_token_embedding) 150 | logits = model_outputs.logits 151 | past = model_outputs.past_key_values 152 | 153 | logits = logits[:, -1, :] / temperature 154 | logits = logits.unsqueeze(1) 155 | logits_so_far = logits if logits_so_far is None else torch.cat((logits_so_far, logits), dim=1) 156 | last_token_embedding = embed_inputs(embedding=model.get_input_embeddings().weight, logits=logits, device=device) 157 | 158 | return logits_so_far 159 | 160 | 161 | def decode_with_model_topk(model, y_logits, topk, x_onehot, x_past, tokenizer, extra_mask=None): 162 | assert x_onehot.shape[1] == 1, x_onehot.shape 163 | length = y_logits.shape[1] 164 | past = x_past 165 | input_embeds = torch.matmul(x_onehot.float(), model.get_input_embeddings().weight) 166 | mask_t_all = None 167 | logits_so_far = None 168 | for i in range(length): 169 | model_outputs = model(past_key_values=past, inputs_embeds=input_embeds) 170 | past = model_outputs.past_key_values 171 | logits_t = model_outputs.logits[:, -1:, :] 172 | assert logits_t.shape[1] == 1, logits_t.shape 173 | _, indices_t = torch.topk(logits_t, topk) 174 | mask_t = torch.zeros_like(logits_t).scatter_(2, indices_t, 1) 175 | mask_t_all = mask_t if mask_t_all is None else torch.cat((mask_t_all, mask_t), dim=1) 176 | logits_so_far = logits_t if logits_so_far is None else torch.cat((logits_so_far, logits_t), dim=1) 177 | if i < length - 1: 178 | if extra_mask is None: 179 | y_logits_i_topk = top_k_filter_3d(y_logits[:,i:i+1,:], topk, mask=mask_t) / 0.001 180 | else: 181 | y_logits_i_topk = top_k_filter_3d(y_logits[:,i:i+1,:], topk, mask=mask_t, extra_mask=extra_mask[:,i:i+1,:]) / 0.001 182 | input_embeds = torch.matmul(F.softmax(y_logits_i_topk, dim=-1), model.get_input_embeddings().weight) 183 | return get_text_from_logits( 184 | top_k_filter_3d(y_logits, topk, mask=mask_t_all, extra_mask=extra_mask), 185 | tokenizer) 186 | 187 | 188 | def post_process(text_ids, model, max_length, length, tokenizer, device): 189 | # sentence completion 190 | text_ids_complete = sentence_completion(text_ids, model, max_length, device) 191 | batch_size = text_ids.shape[0] 192 | text_so_far_all = [] 193 | for bi in range(batch_size): 194 | text_complete = tokenizer.decode(text_ids_complete[bi].tolist()) 195 | text_complete = text_complete.replace('\n', ' ') 196 | 197 | # truncate to minimal complete text 198 | sents = nltk.sent_tokenize(text_complete) 199 | text_so_far = None 200 | length_so_far = 0 201 | for i, sent in enumerate(sents): 202 | text_so_far = sent if text_so_far is None else text_so_far + ' ' + sent 203 | sent_length = len(sent.split()) 204 | length_so_far += sent_length 205 | if length_so_far >= length: 206 | break 207 | text_so_far_all.append(text_so_far) 208 | return text_so_far_all 209 | 210 | 211 | def sentence_completion(text_ids, model, max_length, device): 212 | output_so_far = text_ids 213 | past = None 214 | last_embeds = None 215 | # logits_so_far = None 216 | for i in range(max_length - text_ids.shape[1]): 217 | if past is None and output_so_far is not None: 218 | last = output_so_far[:, -1:] 219 | last_embeds = model.get_input_embeddings()(last) 220 | 221 | if output_so_far.shape[1] > 1: 222 | model_outputs = model(output_so_far[:, :-1]) 223 | past = model_outputs.past_key_values 224 | 225 | model_outputs = model(past_key_values=past, inputs_embeds=last_embeds) 226 | logits = model_outputs.logits 227 | past = model_outputs.past_key_values 228 | 229 | last = _greedy(logits[:, -1, :]) 230 | output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1) 231 | # last_embeds = get_input_embeds(model.get_input_embeddings(), logits[:, -1:, :], device=device) 232 | last_embeds = model.get_input_embeddings()(last) 233 | 234 | return output_so_far 235 | 236 | 237 | def soft_distance(logits_perturbed, logits): 238 | return torch.nn.MSELoss()(logits_perturbed, logits) 239 | 240 | 241 | def soft_nll(logits_perturbed, logits): 242 | p = F.softmax(logits_perturbed, dim=-1) 243 | logp = F.log_softmax(logits, dim=-1) 244 | return -(p * logp).sum(dim=-1).mean(dim=-1) 245 | 246 | 247 | def soft_nll_detach(logits_perturbed, logits): 248 | p = F.softmax(logits_perturbed, dim=-1).detach() 249 | logp = F.log_softmax(logits, dim=-1) 250 | return -(p * logp).sum(dim=-1).mean() 251 | 252 | 253 | def additional_nll(logits, cur_text_ids): 254 | return torch.nn.CrossEntropyLoss()( 255 | logits.view(-1, logits.shape[-1]), 256 | cur_text_ids.view(-1) 257 | ) 258 | 259 | 260 | def soft_forward(model, x_onehot, y_logits, x_past=None, detach=True): 261 | ''' 262 | computes logits for $y$, based on a fixed context $y$ and the current logit distribution of $y$ 263 | :param model: 264 | :param x_onehot: 265 | :param y_logits: 266 | :return: 267 | ''' 268 | xy_embeds = embed_inputs( 269 | model.get_input_embeddings().weight, 270 | y_logits, 271 | x_onehot=x_onehot, 272 | device=x_onehot.device 273 | ) 274 | xy_logits = model(past_key_values=x_past, inputs_embeds=xy_embeds).logits 275 | x_length = x_onehot.shape[1] 276 | y_logits = xy_logits[:, x_length - 1:-1, :] 277 | if detach: 278 | return y_logits.detach() 279 | else: 280 | return y_logits 281 | 282 | 283 | def soft_forward_xyz(model, x_onehot, y_logits, z_onehot): 284 | ''' 285 | computes logits for $y$, based on a fixed context $y$ and the current logit distribution of $y$ 286 | :param model: 287 | :param x_onehot: 288 | :param y_logits: 289 | :return: 290 | ''' 291 | xyz_embeds = embed_inputs( 292 | model.get_input_embeddings().weight, 293 | y_logits, 294 | x_onehot=x_onehot, 295 | z_onehot=z_onehot, 296 | device=y_logits.device 297 | ) 298 | xyz_logits = model(inputs_embeds=xyz_embeds).logits 299 | if x_onehot is not None: 300 | xy_length = x_onehot.shape[1] + y_logits.shape[1] 301 | else: 302 | xy_length = y_logits.shape[1] 303 | return xyz_logits, xy_length 304 | 305 | 306 | 307 | def soft_backward(model, y_logits_rev): 308 | embeddings_weight = model.get_input_embeddings().weight[1:y_logits_rev.shape[-1]+1] 309 | y_embeds = embed_inputs( 310 | embeddings_weight, 311 | y_logits_rev, 312 | device=y_logits_rev.device 313 | ) 314 | y_logits_ = model(inputs_embeds=y_embeds).logits 315 | return y_logits_[:, :-1, :] 316 | 317 | 318 | def soft_backward_steps(model, y_logits): 319 | device = y_logits.device 320 | past = None 321 | last_embeds = None 322 | logits_so_far = None 323 | for i in range(y_logits.shape[1]-2, -1, -1): 324 | last = y_logits[:, i:i+1] 325 | last_embeds = embed_inputs(model.get_input_embeddings(), last, device=device) 326 | 327 | model_outputs = model(past_key_values=past, inputs_embeds=last_embeds) 328 | past = model_outputs.past_key_values 329 | 330 | logits = model_outputs.logits 331 | logits = logits[:, -1, :] 332 | logits = logits.unsqueeze(1) 333 | logits_so_far = logits if logits_so_far is None else torch.cat((logits_so_far, logits), dim=1) 334 | 335 | return logits_so_far 336 | 337 | 338 | 339 | def constraint_loss(logits, cs_onehot, cs_ids): 340 | """ 341 | constraint loss with mask 342 | cs_ids: [batch_size, num_cs] 343 | """ 344 | log_ps = logits.log_softmax(-1).unsqueeze(2) # shape: [batch_size, length, 1, vocab_size] 345 | constraint_max_log_ps_ = (log_ps * cs_onehot.unsqueeze(1)).max(1)[0].sum(-1) # shape: [batch_size, num_cs] 346 | 347 | log_ps_max_ids = log_ps[:, :, 0, :].argmax(-1) # shape: [batch_size, length] 348 | cs_ids_repeat = cs_ids.unsqueeze(2).repeat([1, 1, log_ps_max_ids.shape[1]]) # shape: [batch_size, num_cs, length] 349 | mask = (log_ps_max_ids.unsqueeze(1) == cs_ids_repeat).type(torch.FloatTensor).sum(-1) # shape: [batch_size, num_cs] 350 | mask = (mask < 1).type(torch.FloatTensor) 351 | mask = mask.to(constraint_max_log_ps_.device) 352 | 353 | loss = - (constraint_max_log_ps_ * mask).sum() 354 | 355 | if mask.sum() != 0: 356 | loss = loss / mask.sum() 357 | else: 358 | loss = 0 359 | 360 | return loss 361 | 362 | 363 | def constraint_loss_with_variants(logits, cs_onehot_all, cs_ids_all): 364 | """ 365 | constraint loss with mask 366 | cs_ids_all: list of tensor [batch_size, num_variants], of length num_cs 367 | """ 368 | device = logits.device 369 | log_ps = logits.log_softmax(-1).unsqueeze(2) # shape: [batch_size, length, 1, vocab_size] 370 | 371 | num_cs = len(cs_onehot_all) 372 | loss_all = 0 373 | mask_sum = 0 374 | for i in range(num_cs): 375 | cs_onehot = cs_onehot_all[i] 376 | cs_ids = cs_ids_all[i] 377 | constraint_max_log_ps_ = (log_ps * cs_onehot.unsqueeze(1)).max(1)[0].sum(-1) # shape: [batch_size, num_variants] 378 | 379 | log_ps_max_ids = log_ps[:, :, 0, :].argmax(-1) # shape: [batch_size, length] 380 | cs_ids_repeat = cs_ids.unsqueeze(2).repeat([1, 1, log_ps_max_ids.shape[1]]) # shape: [batch_size, num_variants, length] 381 | mask = (log_ps_max_ids.unsqueeze(1) == cs_ids_repeat).type(torch.FloatTensor).sum(-1) # shape: [batch_size, num_variants] 382 | #mask = (mask >= 1).type(torch.FloatTensor) 383 | mask = (mask.sum(1) < 1).type(torch.FloatTensor) # shape: [batch_size]. mask = 0 if any of the variants already occurs 384 | mask = mask.to(device) 385 | 386 | loss_i = - (constraint_max_log_ps_.max(1)[0] * mask).mean() # average over batch_size 387 | 388 | loss_all += loss_i 389 | mask_sum += mask 390 | 391 | if mask_sum != 0: 392 | loss_all = loss_all / mask_sum 393 | 394 | return loss_all #, mask_sum 395 | 396 | 397 | def constraint_loss_with_variants_by_ppl(logits, cs_onehot_all, cs_ids_all, probs_t): 398 | device = logits.device 399 | batch_size = logits.shape[0] 400 | log_ps = logits.log_softmax(-1).unsqueeze(2) 401 | ps_t = probs_t.unsqueeze(2) 402 | 403 | num_cs = len(cs_onehot_all) 404 | loss_all = 0 405 | mask_sum = 0 406 | for i in range(num_cs): 407 | cs_onehot = cs_onehot_all[i] 408 | cs_ids = cs_ids_all[i] 409 | 410 | cs_onehot_ = cs_onehot.unsqueeze(1).type(torch.FloatTensor).to(device) 411 | cs_onehot_ = cs_onehot_.repeat(batch_size, 1, 1, 1).type(torch.FloatTensor).to(device) 412 | ppl_max_idx = (ps_t * cs_onehot_).argmax(1) # [batch_size, num_variants, vocab_size] 413 | ppl_max_idx_onehot = torch.zeros_like(log_ps * cs_onehot_).scatter_(1, ppl_max_idx.unsqueeze(1), cs_onehot_) 414 | 415 | constraint_max_log_ps_ = (log_ps * ppl_max_idx_onehot).sum(1).sum(-1) # shape: [batch_size, num_variants] 416 | 417 | ## Mask 418 | log_ps_max_ids = log_ps[:, :, 0, :].argmax(-1) # shape: [batch_size, length] 419 | cs_ids_repeat = cs_ids.unsqueeze(2).repeat([1, 1, log_ps_max_ids.shape[1]]) # shape: [batch_size, num_variants, length] 420 | mask = (log_ps_max_ids.unsqueeze(1) == cs_ids_repeat).type(torch.FloatTensor).sum(-1) # shape: [batch_size, num_variants] 421 | mask = (mask.sum(1) < 1).type(torch.FloatTensor) # shape: [batch_size]. mask = 0 if any of the variants already occurs 422 | mask = mask.to(device) 423 | 424 | loss_i = - constraint_max_log_ps_.max(1)[0] * mask 425 | 426 | loss_all += loss_i # shape: [batch_size] 427 | mask_sum += mask # shape: [batch_size] 428 | 429 | loss_all = loss_all / (mask_sum + 1e-8) 430 | 431 | return loss_all 432 | 433 | 434 | def constraint_loss_by_ppl(logits, cs_onehot, cs_ids, logits_t): 435 | device = logits.device 436 | log_ps = logits.log_softmax(-1).unsqueeze(2) 437 | 438 | cs_onehot_ = cs_onehot.unsqueeze(1).type(torch.FloatTensor).to(device) 439 | ps_t = logits_t.softmax(-1).unsqueeze(2) 440 | ppl_max_idx = (ps_t * cs_onehot_).argmax(1) # [batch_size, num_cs, vocab_size] 441 | ppl_max_idx_onehot = torch.zeros_like(log_ps * cs_onehot_).scatter_(1, ppl_max_idx.unsqueeze(1), cs_onehot_) 442 | 443 | constraint_max_log_ps_ = (log_ps * ppl_max_idx_onehot).sum(1).sum(-1) # shape: [batch_size, num_cs] 444 | 445 | ## Mask 446 | log_ps_max_ids = log_ps[:, :, 0, :].argmax(-1) # shape: [batch_size, length] 447 | cs_ids_repeat = cs_ids.unsqueeze(2).repeat([1, 1, log_ps_max_ids.shape[1]]) # shape: [batch_size, num_cs, length] 448 | mask = (log_ps_max_ids.unsqueeze(1) == cs_ids_repeat).type(torch.FloatTensor).sum(-1) # shape: [batch_size, num_cs] 449 | mask = (mask < 1).type(torch.FloatTensor) 450 | mask = mask.to(device) 451 | 452 | loss = - (constraint_max_log_ps_ * mask).sum() 453 | 454 | if mask.sum() != 0: 455 | loss = loss / mask.sum() 456 | else: 457 | loss = 0 458 | 459 | return loss 460 | 461 | 462 | def constraint_loss_all(logits, cs_onehot, cs_ids): 463 | device = logits.device 464 | 465 | log_ps = logits.log_softmax(-1).unsqueeze(2) 466 | constraint_max_log_ps_ = (log_ps * cs_onehot.unsqueeze(1)).mean(1).sum(-1) # shape: [batch_size, num_cs] 467 | 468 | ## Mask 469 | log_ps_max_ids = log_ps[:, :, 0, :].argmax(-1) # shape: [batch_size, length] 470 | cs_ids_repeat = cs_ids.unsqueeze(2).repeat([1, 1, log_ps_max_ids.shape[1]]) # shape: [batch_size, num_cs, length] 471 | mask = (log_ps_max_ids.unsqueeze(1) == cs_ids_repeat).type(torch.FloatTensor).sum(-1) # shape: [batch_size, num_cs] 472 | mask = (mask < 1).type(torch.FloatTensor) 473 | mask = mask.to(device) 474 | 475 | loss = - (constraint_max_log_ps_ * mask).sum() 476 | 477 | if mask.sum() != 0: 478 | loss = loss / mask.sum() 479 | else: 480 | loss = 0 481 | 482 | return loss 483 | 484 | def _constraint_loss2(logits, cs_onehot): 485 | ''' 486 | a re-implementation of `_constraint_loss` with a slightly different logic. 487 | TODO: keep only one of these functions 488 | ''' 489 | logits = logits.squeeze(0) # drop the empty dimension 490 | cs_onehot = cs_onehot.float().squeeze(0) # drop the empty dimension and change into float (since torch matrix multiplication does not support integers) 491 | cs_onehot = torch.transpose(cs_onehot, 0, 1) 492 | selected_logits = torch.matmul(logits, cs_onehot) # dim: length x # of constraints 493 | max_logits_per_constraint, _ = selected_logits.max(0) # select the highest logits for each constraint 494 | loss = - max_logits_per_constraint.sum() / selected_logits.size(1) 495 | return loss 496 | 497 | def print_topk_stats(logits, tokenizer): 498 | logits_lg, topk_index_y = torch.topk(F.softmax(logits[0, :3, :], dim=-1), 3) 499 | print(logits_lg.data.cpu().numpy()) 500 | print(topk_index_y.data.cpu().numpy()) 501 | lgs = [int(x[0]) for x in topk_index_y.data.cpu().numpy()] 502 | for a in lgs: 503 | print('|', tokenizer.decode(a), '| ', end='', flush=True) 504 | print() 505 | print("===============================") 506 | return topk_index_y 507 | 508 | 509 | 510 | def collect_json_lines(model_output_json_file): 511 | with open(model_output_json_file, 'r') as fr: 512 | lines = fr.readlines() 513 | json_lines = [json.loads(x.strip()) for x in lines] 514 | return json_lines 515 | 516 | def post_sent(text_complete): 517 | sents = nltk.sent_tokenize(text_complete) 518 | sent = ' '.join(sents[0].strip().split()) 519 | return sent 520 | # return sents[0] 521 | 522 | def _has_repeat_sent(hyp): 523 | """ 524 | Detect if the sentences in `hyp` are repeat. 525 | Args: 526 | hyp: A list of three sentences. 527 | """ 528 | if len(hyp) <= 1: 529 | return False 530 | 531 | for i in range(1, len(hyp)): 532 | a = hyp[i-1] 533 | b = hyp[i] 534 | 535 | if a == b: 536 | return True 537 | 538 | s = SequenceMatcher(None, a, b) 539 | if len(a) > 5 and len(b) > 5 and s.ratio() >= 0.85: 540 | return True 541 | 542 | return False 543 | 544 | 545 | def _has_repeat_substring(s, MINLEN=4, MINCNT=4): 546 | d = {} 547 | has_repeat = False 548 | for sublen in range(int(len(s)/MINCNT)-1, MINLEN-1, -1): 549 | for i in range(0, len(s)-sublen): 550 | sub = s[i:i+sublen] 551 | if len(sub.strip()) < sublen: 552 | continue 553 | cnt = s.count(sub) 554 | if cnt >= MINCNT and sub not in d: 555 | d[sub] = cnt 556 | print('repeat_substring: |' + sub + '| in |' + s + '|') 557 | has_repeat = True 558 | break 559 | if has_repeat: 560 | break 561 | return has_repeat 562 | 563 | 564 | def has_repeat(sents_for_substr, sents_for_sent): 565 | """ 566 | Detect if the hypothesis text has repeat patterns. 567 | """ 568 | has_repeat_substring = False 569 | for h in sents_for_substr: 570 | has_repeat_substring = has_repeat_substring or _has_repeat_substring(h) or _has_repeat_substring(h, MINLEN=20, MINCNT=2) 571 | # print(has_repeat_substring) 572 | # print(_has_repeat_sent(hyp)) 573 | return has_repeat_substring or _has_repeat_sent(sents_for_sent) 574 | 575 | 576 | def write_json_lines(json_lines, fout, model, tokenizer, device): 577 | with open(fout, 'w') as fw: 578 | for line in json_lines: 579 | input_text = line['generation_complete'][0][0] 580 | # input_text = line['counterfactual'] 581 | 582 | ori_ending = line['original_ending'] 583 | ori_endings = tokenize.sent_tokenize(ori_ending) 584 | z = ori_endings[0].strip() 585 | 586 | gens = line['generation_complete'][0][1] 587 | proc_gens = [post_sent(x) for x in gens] 588 | pg_dict, gens_ranked, pg_dict_top, gens_ranked_top = process_batching_counterfactual_outputs( 589 | proc_gens, input_text, z, model, tokenizer, device) 590 | line['proced'] = proc_gens 591 | line['ppl_gens'] = pg_dict 592 | line['gens_ranked'] = gens_ranked 593 | line['ppl_gens_top'] = pg_dict_top 594 | line['gens_ranked_top'] = gens_ranked_top 595 | # print(line) 596 | # exit() 597 | fw.write(json.dumps(line) + '\n') 598 | 599 | 600 | def compute_ppl_line(model, tokenizer, device, line): 601 | line = line.strip() 602 | #print(line) 603 | line_ = tokenizer.encode(line) 604 | line_t = torch.tensor(line_, device=device, dtype=torch.long) 605 | loss = model(input_ids=line_t, labels=line_t).loss 606 | loss = loss.detach().clone().data.cpu().numpy() 607 | ppl = np.exp(loss) 608 | return ppl 609 | 610 | 611 | def compute_loss(model, tokenizer, device, x="", z="", y="", constraints=None, args=None, model_back=None, zz=None): 612 | ''' 613 | x: left context (prompt in lexical constrained task) 614 | z: optimization target (original ending in counterfactual task) 615 | constraints: (constraint set in lexical constrained task) 616 | ''' 617 | batch_size = 2 618 | 619 | x_ = tokenizer.encode(x) 620 | x_t = torch.tensor(x_, device=device, dtype=torch.long) 621 | x_onehot = one_hot(x_t, dimension=tokenizer.vocab_size) 622 | 623 | # repeat batch_size times 624 | x_t = x_t.unsqueeze(0).repeat(batch_size, 1) 625 | x_onehot = x_onehot.repeat(batch_size, 1, 1) 626 | 627 | z_ = tokenizer.encode(z)[1:] # delete the "." token we appended before 628 | z_t = torch.tensor(z_, device=device, dtype=torch.long) 629 | z_t = z_t.unsqueeze(0).repeat(batch_size, 1) 630 | 631 | y_ = tokenizer.encode(y)[1:] # delete the "." token we appended before 632 | y_t = torch.tensor(y_, device=device, dtype=torch.long) 633 | y_onehot = one_hot(y_t, dimension=tokenizer.vocab_size) 634 | y_onehot = y_onehot.repeat(batch_size, 1, 1) 635 | y_t = y_t.unsqueeze(0).repeat(batch_size, 1) 636 | 637 | y_logits_ = y_onehot / 0.0001 638 | 639 | c_loss = batch_log_bleulosscnn_ae( 640 | decoder_outputs=y_logits_.transpose(0, 1), 641 | target_idx=z_t, 642 | ngram_list=[2, 3] 643 | ) 644 | 645 | return c_loss.mean().item() 646 | 647 | 648 | def rank_and_filter(candidates, input_text, z, model, tokenizer, device, no_loss_rerank): 649 | 650 | # de-duplicate 651 | candidates = list(dict.fromkeys(candidates)) 652 | 653 | ppl_list = [] 654 | ppl_y_list = [] 655 | loss_list = [] 656 | for line in candidates: 657 | line = line.strip() 658 | y = ' '.join(line.split()) 659 | # y = line 660 | xy = input_text + ' ' + line 661 | # print(xy) 662 | # exit() 663 | x_sents = nltk.sent_tokenize(input_text) 664 | if has_repeat(sents_for_substr=[y], sents_for_sent=x_sents+[y]) or len(tokenizer.encode(y)) <= 4: 665 | ppl_list.append(10000.0) 666 | ppl_y_list.append(10000.0) 667 | loss_list.append(10000.0) 668 | else: 669 | ppl = compute_ppl_line(model, tokenizer, device, xy) 670 | ppl_list.append(round(ppl, 2)) 671 | 672 | ppl_y = compute_ppl_line(model, tokenizer, device, y) 673 | ppl_y_list.append(round(ppl_y, 2)) 674 | 675 | loss = compute_loss(model, tokenizer, device, 676 | x=input_text, z=". " + z, y=". " + y) 677 | loss_list.append(loss) 678 | 679 | sort_index = sorted(range(len(ppl_list)), key=lambda k: ppl_list[k]) 680 | ppls_reorder = [ppl_list[i] for i in sort_index] 681 | ppls_y_reorder = [ppl_y_list[i] for i in sort_index] 682 | loss_reorder = [loss_list[i] for i in sort_index] 683 | gens_complete_reorder = [candidates[i] for i in sort_index] 684 | 685 | pg_dict = [] 686 | for p, py, l, g in zip(ppls_reorder, ppls_y_reorder, loss_reorder, gens_complete_reorder): 687 | pg_dict.append({"ppl": str(p), "ppl_y": str(py), "loss": str(l), "gen": g}) 688 | 689 | if len(ppls_reorder) <= 1: 690 | sort_len = 1 691 | elif ppls_reorder[1]-ppls_reorder[0] > 10: 692 | sort_len = 1 693 | elif len(ppls_reorder) <= 2: 694 | sort_len = 1 695 | elif ppls_reorder[2]-ppls_reorder[0] > 10: 696 | sort_len = 2 697 | else: 698 | sort_len = 3 699 | 700 | if no_loss_rerank: 701 | return gens_complete_reorder[0] 702 | 703 | sort_index = sorted(range(sort_len), key=lambda k: loss_reorder[k]) 704 | sort_index = sort_index 705 | ppls_reorder_top = [ppls_reorder[i] for i in sort_index] 706 | ppls_y_reorder_top = [ppls_y_reorder[i] for i in sort_index] 707 | loss_reorder_top = [loss_reorder[i] for i in sort_index] 708 | gens_complete_reorder_top = [gens_complete_reorder[i] for i in sort_index] 709 | 710 | pg_dict_top = [] 711 | for p, py, l, g in zip(ppls_reorder_top, ppls_y_reorder_top, loss_reorder_top, gens_complete_reorder_top): 712 | pg_dict_top.append({"ppl": str(p), "ppl_y": str(py), "loss": str(l), "gen": g}) 713 | 714 | return gens_complete_reorder_top[0] 715 | 716 | 717 | --------------------------------------------------------------------------------